using System; using System.Collections.Generic; using System.Linq; using UnityEngine; using UnityEditor; using Unity.Mathematics; using static Unity.Mathematics.math; [Serializable] public class Neuron : INucleus { [SerializeField] protected string _name; public virtual string name { get => _name; set => _name = value; } [SerializeField] private List _synapses = new(); public List synapses => _synapses; [SerializeReference] private List _receivers = new(); public List receivers { get { return _receivers; } set { _receivers = value; } } [SerializeReference] private NucleusArray _array; public NucleusArray array { get { return _array; } set { _array = value; } } #region Serialization public enum CurvePresets { Linear, Power, Sqrt, Reciprocal, Custom } [SerializeField] private CurvePresets _curvePreset; public CurvePresets curvePreset { get { return _curvePreset; } set { _curvePreset = value; this.curve = GenerateCurve(); } } public AnimationCurve curve; public float curveMax = 1.0f; #region Parameters public bool average = false; #endregion Parameters public AnimationCurve GenerateCurve() { switch (this.curvePreset) { case CurvePresets.Linear: this.curveMax = 1; return Presets.Linear(1); case CurvePresets.Power: this.curveMax = 1; return Presets.Power(2.0f, 1); case CurvePresets.Sqrt: this.curveMax = 1; return Presets.Power(0.5f, 1); case CurvePresets.Reciprocal: this.curveMax = 1 / 0.01f * 1; return Presets.Reciprocal(1); default: this.curveMax = 1; return this.curve; } } public virtual void Deserialize(Neuron nucleus) { } #endregion Serialization #region Runtime state (not serialized) public ClusterPrefab cluster { get; set; } public Cluster parent { get; set; } #region Activation public static class Presets { private const int samples = 32; public static AnimationCurve Linear(float weight) { return AnimationCurve.Linear(0f, 0f, 1000f, weight * 1000); } public static AnimationCurve Power(float exponent, float weight) { // build keyframes Keyframe[] keys = new Keyframe[samples]; for (int i = 0; i < samples; i++) { float t = i / (float)(samples - 1); float v = Mathf.Pow(t, exponent) * weight; keys[i] = new Keyframe(t, v); } AnimationCurve curve = new(keys); // set tangent modes for each key to Auto (smooth). Use Linear if you prefer straight segments. for (int i = 0; i < curve.length; i++) { AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Auto); AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Auto); } return curve; } public static AnimationCurve Reciprocal(float weight) { int samples = 128; float xMin = 0.001f; float xMax = 1; var keys = new Keyframe[samples]; for (int i = 0; i < samples; i++) { float t = i / (float)(samples - 1); float x = Mathf.Lerp(xMin, xMax, t); float y = 1f / x * weight; keys[i] = new Keyframe(x, y); } var curve = new AnimationCurve(keys); for (int i = 0; i < curve.length; i++) { AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear); AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear); } return curve; } } #endregion Activation protected float3 _outputValue; public virtual float3 outputValue { get { return _outputValue; } set { this.stale = 0; // this._isSleeping = false; _outputValue = value; } } [NonSerialized] private int stale = 1000; // private bool _isSleeping = false; // public bool isSleeping => _isSleeping; public bool isSleeping => lengthsq(this.outputValue) == 0; public void UpdateNuclei() { this.stale++; // this._isSleeping = this.stale > 2; // if (isSleeping) if (this.stale > 2) _outputValue = Vector3.zero; } #endregion Runtime state public Neuron(Cluster parent, string name) { this.parent = parent; this.name = name; this.parent?.nuclei.Add(this); } public Neuron(ClusterPrefab parent, string name) { this.cluster = parent; this.name = name; if (this.cluster != null) { this.cluster.nuclei.Add(this); } // else // Debug.LogError("No neuroid network"); } // this clone the nucleus without the synapses and receivers public virtual IReceptor ShallowCloneTo(Cluster newParent) { Neuron clone = new(newParent, this.name) { array = this.array, curve = this.curve, curvePreset = this.curvePreset, curveMax = this.curveMax, average = this.average }; return clone; } public virtual IReceptor ShallowCloneTo(ClusterPrefab newParent) { Neuron clone = new(newParent, this.name) { array = this.array, curve = this.curve, curvePreset = this.curvePreset, curveMax = this.curveMax, average = this.average }; return clone; } public virtual IReceptor CloneTo(ClusterPrefab parent) { Neuron clone = new(parent, this.name) { array = this.array, curve = this.curve, curvePreset = this.curvePreset, curveMax = this.curveMax, average = this.average }; foreach (Synapse synapse in this.synapses) { Synapse clonedSynapse = clone.AddSynapse(synapse.nucleus); clonedSynapse.weight = synapse.weight; } foreach (INucleus receiver in this.receivers) { clone.AddReceiver(receiver); } return clone; } public virtual IReceptor Clone() { Neuron clone = new(this.cluster, this.name) { array = this.array, curve = this.curve, curvePreset = this.curvePreset, curveMax = this.curveMax, average = this.average }; // if (clone.cluster != null) // clone.cluster.nuclei.Add(clone); foreach (Synapse synapse in this.synapses) { Synapse clonedSynapse = clone.AddSynapse(synapse.nucleus); clonedSynapse.weight = synapse.weight; } foreach (INucleus receiver in this.receivers) { clone.AddReceiver(receiver); } return clone; } public virtual void AddReceiver(INucleus receivingNucleus, float weight = 1) { this._receivers.Add(receivingNucleus); receivingNucleus.AddSynapse(this, weight); } public void RemoveReceiver(INucleus receiverNucleus) { this._receivers.RemoveAll(receiver => receiver == receiverNucleus); receiverNucleus.synapses.RemoveAll(synapse => synapse.nucleus == this); } public static void Delete(INucleus nucleus) { foreach (Synapse synapse in nucleus.synapses) { if (synapse.nucleus is Neuron synapse_nucleus) { if (synapse_nucleus._receivers.Count > 1) { // there is another nucleus feeding into this input nucleus synapse_nucleus._receivers.RemoveAll(r => r == nucleus); } else { // No other links, delete it. Neuron.Delete(synapse_nucleus); } } } foreach (INucleus receiver in nucleus.receivers) { if (receiver != null && receiver.synapses != null) receiver.synapses.RemoveAll(s => s.nucleus == nucleus); } if (nucleus.cluster != null) { nucleus.cluster.nuclei.RemoveAll(n => n == nucleus); nucleus.cluster.GarbageCollection(); } } public Synapse AddSynapse(IReceptor sendingNucleus, float weight = 1.0f) { Synapse synapse = new(sendingNucleus, weight); this.synapses.Add(synapse); return synapse; } public virtual void UpdateState() { //UpdateState(new float3(0, 0, 0)); this.parent?.UpdateState(); } public virtual void UpdateState(float3 inputValue) { float3 sum = inputValue; int n = 0; //Applying the weight factgors foreach (Synapse synapse in this.synapses) { sum += synapse.weight * synapse.nucleus.outputValue; // Perhaps synapses should be removed when the output value goes to 0.... if (lengthsq(synapse.nucleus.outputValue) != 0) n++; } if (this.average && n > 0) sum /= n; // Activation function Vector3 result; switch (this.curvePreset) { case CurvePresets.Linear: result = sum; break; case CurvePresets.Sqrt: result = normalize(sum) * System.MathF.Sqrt(length(sum)); break; case CurvePresets.Power: result = normalize(sum) * System.MathF.Pow(length(sum), 2); break; case CurvePresets.Reciprocal: result = normalize(sum) * (1 / length(sum)); break; default: float activatedValue = this.curve.Evaluate(length(sum)); result = normalize(sum) * activatedValue; break; } UpdateResult(result); } public virtual void UpdateStateIsolated() { UpdateStateIsolated(new float3(0, 0, 0)); } public virtual void UpdateStateIsolated(float3 bias) { float3 sum = bias; int n = 0; //Applying the weight factgors foreach (Synapse synapse in this.synapses) { sum += synapse.weight * synapse.nucleus.outputValue; // Perhaps synapses should be removed when the output value goes to 0.... if (lengthsq(synapse.nucleus.outputValue) != 0) n++; } if (this.average && n > 0) sum /= n; // Activation function float3 result = Vector3.zero; switch (this.curvePreset) { case CurvePresets.Linear: result = sum; break; case CurvePresets.Sqrt: result = normalize(sum) * System.MathF.Sqrt(length(sum)); break; case CurvePresets.Power: result = normalize(sum) * System.MathF.Pow(length(sum), 2); break; case CurvePresets.Reciprocal: { float magnitude = length(sum); if (magnitude > 0) result = normalize(sum) * (1 / magnitude); break; } default: float activatedValue = this.curve.Evaluate(length(sum)); result = normalize(sum) * activatedValue; break; } this.outputValue = result; } public virtual void UpdateResult(Vector3 result) { // float d = Vector3.Distance(result, this.outputValue); // if (d < 0.5f) { // //Debug.Log($"insignificant update: {d}"); // return; // } this.outputValue = result; if (lengthsq(outputValue) != 0) { Debug.Log($"{this.parent.name}.{this.name}: {this.outputValue}"); } foreach (INucleus receiver in this.receivers) receiver.UpdateState(); } }