153 lines
4.1 KiB
C#
153 lines
4.1 KiB
C#
using System.Collections.Generic;
|
|
using UnityEngine;
|
|
using System.Linq;
|
|
|
|
public class Synapse {
|
|
public Synapse(Neuroid neuroid, Vector3 value, float weight) {
|
|
this.neuroid = neuroid;
|
|
this.value = value;
|
|
this.weight = weight;
|
|
}
|
|
public Neuroid neuroid;
|
|
public Vector3 value;
|
|
public float weight;
|
|
}
|
|
|
|
public class NeuroidNetwork {
|
|
public List<Neuroid> neuroids = new();
|
|
|
|
public Neuroid AddNeuron() {
|
|
Neuroid neuroid = new(this);
|
|
return neuroid;
|
|
}
|
|
|
|
public void Update() {
|
|
foreach (Neuroid neuroid in neuroids) {
|
|
neuroid.stale++;
|
|
}
|
|
}
|
|
}
|
|
|
|
public class Neuroid {
|
|
//public int id;
|
|
public string name;
|
|
|
|
public int layerIx;
|
|
public int stale = 0;
|
|
|
|
public readonly Dictionary<Neuroid, Synapse> newSynapses = new();
|
|
|
|
public Vector3 outputValue;
|
|
public HashSet<Neuroid> outputNeuroids = new();
|
|
|
|
public enum Mode {
|
|
Sum,
|
|
Average,
|
|
}
|
|
public Mode mode = Mode.Sum;
|
|
|
|
|
|
public NeuroidNetwork net;
|
|
|
|
public Neuroid(NeuroidNetwork net) {
|
|
this.net = net;
|
|
if (this.net != null)
|
|
this.net.neuroids.Add(this);
|
|
}
|
|
|
|
public void AddSynapse(Neuroid input) {
|
|
input.AddReceiver(this);
|
|
this.newSynapses[input] = new(input, Vector3.zero, 1.0f);
|
|
}
|
|
|
|
public void AddReceiver(Neuroid receiver) {
|
|
this.outputNeuroids.Add(receiver);
|
|
}
|
|
|
|
public void ResetWeights() {
|
|
foreach (Synapse synapse in this.newSynapses.Values)
|
|
synapse.weight = 1.0f;
|
|
}
|
|
|
|
public void SetWeight(Neuroid input, float weight) {
|
|
if (this.newSynapses.ContainsKey(input)) {
|
|
this.newSynapses[input].weight = weight;
|
|
}
|
|
else {
|
|
this.newSynapses[input] = new(input, Vector3.zero, weight);
|
|
}
|
|
}
|
|
|
|
public void GetInputFrom(Neuroid input, float weight = 1.0f) {
|
|
input.AddReceiver(this);
|
|
this.newSynapses[input] = new(input, Vector3.zero, weight);
|
|
}
|
|
|
|
public void SetInput(Neuroid input, Vector3 value) {
|
|
if (this.newSynapses.ContainsKey(input)) {
|
|
Synapse synapse = this.newSynapses[input];
|
|
synapse.value = value;
|
|
}
|
|
else
|
|
this.newSynapses[input] = new(null, value, 1.0f);
|
|
UpdateState();
|
|
}
|
|
|
|
public void SetInput(Neuroid input, Vector3 value, float weight) {
|
|
if (this.newSynapses.ContainsKey(input)) {
|
|
Synapse synapse = this.newSynapses[input];
|
|
synapse.value = value;
|
|
synapse.weight = weight;
|
|
}
|
|
else
|
|
this.newSynapses[input] = new(null, value, weight);
|
|
UpdateState();
|
|
}
|
|
|
|
public readonly Dictionary<int, Neuroid> fakeNeuroids = new();
|
|
public void SetInput(int thingId, Vector3 value, float weight, NeuroidNetwork net) {
|
|
if (fakeNeuroids.ContainsKey(thingId)) {
|
|
Neuroid fakeInput = fakeNeuroids[thingId];
|
|
Synapse synapse = this.newSynapses[fakeInput];
|
|
synapse.value = value;
|
|
synapse.weight = weight;
|
|
}
|
|
else {
|
|
fakeNeuroids[thingId] = new(net);
|
|
this.newSynapses[fakeNeuroids[thingId]] = new (null, value, weight);
|
|
}
|
|
UpdateState();
|
|
}
|
|
|
|
|
|
protected virtual void UpdateState() {
|
|
Vector3 sum = Vector3.zero;
|
|
foreach (Synapse synapse in this.newSynapses.Values)
|
|
sum += synapse.value * synapse.weight;
|
|
|
|
this.outputValue = Activation(sum);
|
|
foreach (Neuroid neuroid in outputNeuroids) {
|
|
neuroid?.SetInput(this, this.outputValue);
|
|
}
|
|
this.stale = 0;
|
|
}
|
|
|
|
Vector3 Activation(Vector3 sum) {
|
|
if (this.newSynapses.Count == 0 && mode == Mode.Average)
|
|
Debug.LogWarning($"{this.name} has zero synapses for average");
|
|
if (float.IsNaN(sum.magnitude))
|
|
Debug.LogWarning($"{this.name} sum is nan");
|
|
return mode switch {
|
|
Mode.Sum => sum,
|
|
Mode.Average => sum / this.newSynapses.Count,
|
|
_ => sum,
|
|
};
|
|
//return sum; //(sum.magnitude > 0.5f) ? sum : Vector3.zero;
|
|
}
|
|
|
|
public bool IsStale() {
|
|
return this.stale > 2;
|
|
}
|
|
}
|
|
|