2025-11-27 17:35:51 +01:00

153 lines
4.1 KiB
C#

using System.Collections.Generic;
using UnityEngine;
using System.Linq;
public class Synapse {
public Synapse(Neuroid neuroid, Vector3 value, float weight) {
this.neuroid = neuroid;
this.value = value;
this.weight = weight;
}
public Neuroid neuroid;
public Vector3 value;
public float weight;
}
public class NeuroidNetwork {
public List<Neuroid> neuroids = new();
public Neuroid AddNeuron() {
Neuroid neuroid = new(this);
return neuroid;
}
public void Update() {
foreach (Neuroid neuroid in neuroids) {
neuroid.stale++;
}
}
}
public class Neuroid {
//public int id;
public string name;
public int layerIx;
public int stale = 0;
public readonly Dictionary<Neuroid, Synapse> newSynapses = new();
public Vector3 outputValue;
public HashSet<Neuroid> outputNeuroids = new();
public enum Mode {
Sum,
Average,
}
public Mode mode = Mode.Sum;
public NeuroidNetwork net;
public Neuroid(NeuroidNetwork net) {
this.net = net;
if (this.net != null)
this.net.neuroids.Add(this);
}
public void AddSynapse(Neuroid input) {
input.AddReceiver(this);
this.newSynapses[input] = new(input, Vector3.zero, 1.0f);
}
public void AddReceiver(Neuroid receiver) {
this.outputNeuroids.Add(receiver);
}
public void ResetWeights() {
foreach (Synapse synapse in this.newSynapses.Values)
synapse.weight = 1.0f;
}
public void SetWeight(Neuroid input, float weight) {
if (this.newSynapses.ContainsKey(input)) {
this.newSynapses[input].weight = weight;
}
else {
this.newSynapses[input] = new(input, Vector3.zero, weight);
}
}
public void GetInputFrom(Neuroid input, float weight = 1.0f) {
input.AddReceiver(this);
this.newSynapses[input] = new(input, Vector3.zero, weight);
}
public void SetInput(Neuroid input, Vector3 value) {
if (this.newSynapses.ContainsKey(input)) {
Synapse synapse = this.newSynapses[input];
synapse.value = value;
}
else
this.newSynapses[input] = new(null, value, 1.0f);
UpdateState();
}
public void SetInput(Neuroid input, Vector3 value, float weight) {
if (this.newSynapses.ContainsKey(input)) {
Synapse synapse = this.newSynapses[input];
synapse.value = value;
synapse.weight = weight;
}
else
this.newSynapses[input] = new(null, value, weight);
UpdateState();
}
public readonly Dictionary<int, Neuroid> fakeNeuroids = new();
public void SetInput(int thingId, Vector3 value, float weight, NeuroidNetwork net) {
if (fakeNeuroids.ContainsKey(thingId)) {
Neuroid fakeInput = fakeNeuroids[thingId];
Synapse synapse = this.newSynapses[fakeInput];
synapse.value = value;
synapse.weight = weight;
}
else {
fakeNeuroids[thingId] = new(net);
this.newSynapses[fakeNeuroids[thingId]] = new (null, value, weight);
}
UpdateState();
}
protected virtual void UpdateState() {
Vector3 sum = Vector3.zero;
foreach (Synapse synapse in this.newSynapses.Values)
sum += synapse.value * synapse.weight;
this.outputValue = Activation(sum);
foreach (Neuroid neuroid in outputNeuroids) {
neuroid?.SetInput(this, this.outputValue);
}
this.stale = 0;
}
Vector3 Activation(Vector3 sum) {
if (this.newSynapses.Count == 0 && mode == Mode.Average)
Debug.LogWarning($"{this.name} has zero synapses for average");
if (float.IsNaN(sum.magnitude))
Debug.LogWarning($"{this.name} sum is nan");
return mode switch {
Mode.Sum => sum,
Mode.Average => sum / this.newSynapses.Count,
_ => sum,
};
//return sum; //(sum.magnitude > 0.5f) ? sum : Vector3.zero;
}
public bool IsStale() {
return this.stale > 2;
}
}