using System; using System.Collections.Generic; using UnityEngine; using UnityEditor; using Unity.Mathematics; using static Unity.Mathematics.math; [Serializable] public class Neuron : Nucleus { public Neuron(Cluster parent, string name) { this.parent = parent; this.name = name; this.parent?.clusterNuclei.Add(this); } public Neuron(ClusterPrefab prefab, string name) { this.clusterPrefab = prefab; this.name = name; if (this.clusterPrefab != null) this.clusterPrefab.nuclei.Add(this); else Debug.LogError("No prefab when adding neuron to prefab"); } #region Serialization public enum CombinatorType { Sum, Product, Max } public CombinatorType combinator = CombinatorType.Sum; public enum CurvePresets { Linear, Power, Sqrt, Reciprocal, Custom } [SerializeField] public CurvePresets _curvePreset; public CurvePresets curvePreset { get { return _curvePreset; } set { _curvePreset = value; this.curve = GenerateCurve(); } } public AnimationCurve curve; public float curveMax = 1.0f; public AnimationCurve GenerateCurve() { switch (this.curvePreset) { case CurvePresets.Linear: this.curveMax = 1; return Presets.Linear(1); case CurvePresets.Power: this.curveMax = 1; return Presets.Power(2.0f, 1); case CurvePresets.Sqrt: this.curveMax = 1; return Presets.Power(0.5f, 1); case CurvePresets.Reciprocal: this.curveMax = 1 / 0.01f * 1; return Presets.Reciprocal(1); default: this.curveMax = 1; return this.curve; } } public static class Presets { private const int samples = 32; public static AnimationCurve Linear(float weight) { return AnimationCurve.Linear(0f, 0f, 1000f, weight * 1000); } public static AnimationCurve Power(float exponent, float weight) { // build keyframes Keyframe[] keys = new Keyframe[samples]; for (int i = 0; i < samples; i++) { float t = i / (float)(samples - 1); float v = Mathf.Pow(t, exponent) * weight; keys[i] = new Keyframe(t, v); } AnimationCurve curve = new(keys); // set tangent modes for each key to Auto (smooth). Use Linear if you prefer straight segments. for (int i = 0; i < curve.length; i++) { AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Auto); AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Auto); } return curve; } public static AnimationCurve Reciprocal(float weight) { int samples = 128; float xMin = 0.001f; float xMax = 1; var keys = new Keyframe[samples]; for (int i = 0; i < samples; i++) { float t = i / (float)(samples - 1); float x = Mathf.Lerp(xMin, xMax, t); float y = 1f / x * weight; keys[i] = new Keyframe(x, y); } var curve = new AnimationCurve(keys); for (int i = 0; i < curve.length; i++) { AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear); AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear); } return curve; } } #endregion Serialization protected float3 _outputValue; public virtual float3 outputValue { get { return _outputValue; } set { _outputValue = value; if (this.isFiring) WhenFiring?.Invoke(); } } public bool isFiring => length(_outputValue) > 0.5f; public Action WhenFiring; public virtual bool isSleeping => lengthsq(this.outputValue) == 0; [NonSerialized] public int stale = 1000; public readonly int staleValueForSleep = 20; // this clone the nucleus without the synapses and receivers public override Nucleus ShallowCloneTo(Cluster newParent) { Neuron clone = new(newParent, this.name); CloneFields(clone); return clone; } public override Nucleus Clone(ClusterPrefab prefab) { Neuron clone = new(prefab, this.name); CloneFields(clone); foreach (Synapse synapse in this.synapses) { Synapse clonedSynapse = clone.AddSynapse(synapse.nucleus); clonedSynapse.weight = synapse.weight; } foreach (Nucleus receiver in this.receivers) { clone.AddReceiver(receiver); } return clone; } protected virtual void CloneFields(Neuron clone) { clone.clusterPrefab = this.clusterPrefab; clone.bias = this.bias; clone.combinator = this.combinator; clone.curve = this.curve; clone.curvePreset = this.curvePreset; clone.curveMax = this.curveMax; } public static void Delete(Nucleus nucleus) { foreach (Synapse synapse in nucleus.synapses) { if (synapse.nucleus is Neuron synapse_nucleus) { if (synapse_nucleus.receivers.Count > 1) { // there is another nucleus feeding into this input nucleus synapse_nucleus.receivers.RemoveAll(r => r == nucleus); } else { // No other links, delete it. Neuron.Delete(synapse_nucleus); } } } if (nucleus is Neuron neuron) { foreach (Nucleus receiver in neuron.receivers) { if (receiver != null && receiver.synapses != null) receiver.synapses.RemoveAll(s => s.nucleus == nucleus); } } else if (nucleus is Cluster cluster) { // remove all receivers for this cluster foreach (Neuron output in cluster.outputs) { foreach (Nucleus receiver in output.receivers) { receiver.synapses.RemoveAll(s => s.nucleus == output); } } } if (nucleus.clusterPrefab != null) { nucleus.clusterPrefab.nuclei.RemoveAll(n => n == nucleus); nucleus.clusterPrefab.RefreshOutputs(); nucleus.clusterPrefab.GarbageCollection(); } } public override void UpdateStateIsolated() { float3 result = Combinator(); this.outputValue = Activator(result); } #region Combinator protected Func Combinator => combinator switch { CombinatorType.Sum => CombinatorSum, CombinatorType.Product => CombinatorProduct, CombinatorType.Max => CombinatorMax, _ => CombinatorSum }; public float3 CombinatorSum() { float3 sum = this.bias; foreach (Synapse synapse in this.synapses) { if (synapse.nucleus is Neuron neuron) sum += synapse.weight * neuron.outputValue; } return sum; } public float3 CombinatorProduct() { float3 product = this.bias; foreach (Synapse synapse in this.synapses) { if (synapse.nucleus is Neuron neuron) product *= synapse.weight * neuron.outputValue; } return product; } public float3 CombinatorMax() { float3 max = this.bias; float maxLength = length(max); //Applying the weight factors foreach (Synapse synapse in this.synapses) { if (synapse.nucleus is Neuron neuron) { float3 input = synapse.weight * neuron.outputValue; float inputLength = length(input); if (inputLength > maxLength) { max = input; maxLength = inputLength; } } } return max; } #endregion Combinator #region Activator public Func Activator => this.curvePreset switch { CurvePresets.Linear => ActivatorLinear, CurvePresets.Sqrt => ActivatorSqrt, CurvePresets.Power => ActivatorPower, CurvePresets.Reciprocal => ActivatorReciprocal, _ => ActivatorCustom }; protected float3 ActivatorLinear(float3 input) { return input; } protected float3 ActivatorSqrt(float3 input) { float3 result = normalize(input) * System.MathF.Sqrt(length(input)); return result; } protected float3 ActivatorPower(float3 input) { float3 result = normalize(input) * System.MathF.Pow(length(input), 2); return result; } protected float3 ActivatorReciprocal(float3 input) { float magnitude = length(input); if (magnitude == 0) return new float3(0, 0, 0); float3 result = normalize(input) * (1 / magnitude); return result; } protected float3 ActivatorCustom(float3 input) { float activatedValue = this.curve.Evaluate(length(input)); float3 result = normalize(input) * activatedValue; return result; } #endregion Activator #region Receivers [SerializeReference] private List _receivers = new(); public virtual List receivers { get { return _receivers; } set { _receivers = value; } } public virtual void AddReceiver(Nucleus receiverToAdd, float weight = 1) { this._receivers.Add(receiverToAdd); receiverToAdd.AddSynapse(this, weight); } public virtual void RemoveReceiver(Nucleus receiverToRemove) { if (this is IReceptor receptor) { foreach (Nucleus element in receptor.nucleiArray) { if (element is Neuron neuron) { neuron._receivers.RemoveAll(receiver => receiver == receiverToRemove); receiverToRemove.synapses.RemoveAll(synapse => synapse.nucleus == neuron); } } } else { this._receivers.RemoveAll(receiver => receiver == receiverToRemove); receiverToRemove.synapses.RemoveAll(synapse => synapse.nucleus == this); } } #endregion Receivers public override void ProcessStimulus(Vector3 inputValue, int thingId = 0, string thingName = null) { if (this.parent is ClusterReceptor clusterReceptor) { clusterReceptor.ProcessStimulus(this, inputValue, thingId, thingName); } else ProcessStimulusDirect(inputValue, thingId, thingName); } public void ProcessStimulusDirect(Vector3 inputValue, int thingId = 0, string thingName = null) { this.stale = 0; this.bias = inputValue; this.parent.UpdateFromNucleus(this); } }