using System; using System.Collections.Generic; using UnityEngine; using UnityEditor; #if UNITY_MATHEMATICS using Unity.Mathematics; using static Unity.Mathematics.math; #endif namespace NanoBrain { /// /// A neuron is a basic Nucleus /// [Serializable] public class Neuron : Nucleus { /// /// Create a new Neuron in a Cluster instance /// /// The parent cluster in which the new Neuron should be created /// The name of the new Neuron public Neuron(Cluster parent, string name) { this.parent = parent; this.name = name; this.parent?.nuclei.Add(this); } /// /// Create a new Neuron in a Cluster Prefab /// /// The Cluster Preafb in which the new Neuron should be created /// The name of the new Neuron // public Neuron(ClusterPrefab prefab, string name) { // this.clusterPrefab = prefab; // this.name = name; // if (this.clusterPrefab != null) { // this.clusterPrefab.cluster.nuclei.Add(this); // this.clusterPrefab.cluster.RefreshOutputs(); // } // else // Debug.LogError("No prefab when adding neuron to prefab"); // } #region Serialization /// /// The bias /// /// The bias which a value which is always added to the combined value of the neuron /// It does not have a synapse and therefore no weight of source nucleus public Vector3 bias = Vector3.zero; #region Synapses [SerializeField] private List _synapses = new(); /// /// The synapses of the nucleus /// public List synapses => _synapses; /// /// Add a new synapse to this nuclues /// /// The nucleus from which the signals may originate /// The weight applied to the input. Default value = 1 /// The created Synapse /// This will add a new input to this nucleus with the given weight. public Synapse AddSynapse(Neuron sendingNucleus, float weight = 1) { Synapse synapse = new(sendingNucleus, weight); this.synapses.Add(synapse); return synapse; } // public Synapse AddSynapse(ClusterPrefab clusterPrefab, string neuronName, float weight = 1) { // } /// /// Find a synapse /// /// The sender of the input to the Synapse /// The found Synapse or null when the sender has no synapse to this nucleus. public Synapse GetSynapse(Nucleus sender) { foreach (Synapse synapse in this.synapses) if (synapse.neuron == sender) return synapse; return null; } /// /// Remove a synapse from a Nucleus /// /// Remote the synapse connecting to this Nucleus public void RemoveSynapse(Nucleus sendingNucleus) { this.synapses.RemoveAll(synapse => synapse.neuron == sendingNucleus); } #endregion Synapses /// /// Set the bias, recalculate the output and update all Nuclei receiving from this Nucleus /// /// public virtual void SetBias(Vector3 inputValue) { this.bias = inputValue; this.parent?.UpdateFromNucleus(this); } /// /// The type of combinators /// /// A combinator combines the weighted values of the synapses to a single value public enum CombinatorType { /// Add the weighted values together Sum, /// Multiply the weighted values Product, /// Take the maximum of all the weighted values Max, } /// /// The type of combinator used for this Neuron /// public CombinatorType combinator = CombinatorType.Sum; /// /// The type of /// public enum ActivationType { Linear, Power, Sqrt, Reciprocal, Tanh, Binary, Normalized, Custom } [SerializeField] public ActivationType _curvePreset; public ActivationType curvePreset { get { return _curvePreset; } set { _curvePreset = value; this.curve = GenerateCurve(); } } public AnimationCurve curve; public float curveMax = 1.0f; public AnimationCurve GenerateCurve() { switch (this.curvePreset) { case ActivationType.Linear: this.curveMax = 1; return Presets.Linear(1); case ActivationType.Power: this.curveMax = 1; return Presets.Power(2.0f, 1); case ActivationType.Sqrt: this.curveMax = 1; return Presets.Power(0.5f, 1); case ActivationType.Reciprocal: this.curveMax = 1 / 0.01f * 1; return Presets.Reciprocal(1); case ActivationType.Tanh: this.curveMax = 1; return Presets.Tanh(1); case ActivationType.Binary: this.curveMax = 1; return Presets.Binary(); case ActivationType.Normalized: this.curveMax = 1; return Presets.Binary(); default: this.curveMax = 1; return this.curve; } } public static class Presets { private const int samples = 32; public static AnimationCurve Linear(float weight) { return AnimationCurve.Linear(0f, 0f, 1000f, weight * 1000); } public static AnimationCurve Power(float exponent, float weight) { // build keyframes Keyframe[] keys = new Keyframe[samples]; for (int i = 0; i < samples; i++) { float t = i / (float)(samples - 1); float v = Mathf.Pow(t, exponent) * weight; keys[i] = new Keyframe(t, v); } AnimationCurve curve = new(keys); // set tangent modes for each key to Auto (smooth). Use Linear if you prefer straight segments. for (int i = 0; i < curve.length; i++) { AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Auto); AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Auto); } return curve; } public static AnimationCurve Reciprocal(float weight) { int samples = 128; float xMin = 0.001f; float xMax = 1; var keys = new Keyframe[samples]; for (int i = 0; i < samples; i++) { float t = i / (float)(samples - 1); float x = Mathf.Lerp(xMin, xMax, t); float y = 1f / x * weight; keys[i] = new Keyframe(x, y); } var curve = new AnimationCurve(keys); for (int i = 0; i < curve.length; i++) { AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear); AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear); } return curve; } public static AnimationCurve Tanh(float weight) { //int samples = 128; float xMin = 0.001f; float xMax = 1; var keys = new Keyframe[samples]; for (int i = 0; i < samples; i++) { float t = i / (float)(samples - 1); float x = Mathf.Lerp(xMin, xMax, t); float y = MathF.Tanh(x * weight); keys[i] = new Keyframe(x, y); } var curve = new AnimationCurve(keys); for (int i = 0; i < curve.length; i++) { AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear); AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear); } return curve; } public static AnimationCurve Binary() { return AnimationCurve.Linear(0, 0, 1, 1); } } #endregion Serialization #if UNITY_MATHEMATICS protected float3 _outputValue; public virtual float3 outputValue { get { return _outputValue; } set { _outputValue = value; if (this.isFiring) WhenFiring?.Invoke(); } } public float outputMagnitude => length(_outputValue); public float outputSqrMagnitude => lengthsq(_outputValue); #else protected Vector3 _outputValue; public virtual Vector3 outputValue { get { return _outputValue; } set { _outputValue = value; if (this.isFiring) WhenFiring?.Invoke(); } } public float outputMagnitude => _outputValue.magnitude; public float outputSqrMagnitude => _outputValue.sqrMagnitude; #endif public bool isFiring { get { SleepCheck(); return this.outputMagnitude > 0.5f; } } public Action WhenFiring; public virtual bool isSleeping => Time.time - this.lastUpdate > this.timeToSleep; //this.outputMagnitude == 0; public void SleepCheck() { if (this.isSleeping) { #if UNITY_MATHEMATICS this._outputValue = new float3(0, 0, 0); #else this._outputValue = new Vector3(0,0,0); #endif } } /// /// Toggle for printing debugging trace data /// public bool trace = false; [NonSerialized] public float lastUpdate = 0; public readonly float timeToSleep = 1f; /// \copydoc NanoBrain::Nucleus::ShallowCloneTo public override Nucleus ShallowCloneTo(Cluster newParent) { Neuron clone = new(newParent, this.name) { // prefabNucleus = this }; CloneFields(clone); return clone; } /// \copydoc NanoBrain::Nucleus::Clone public override Nucleus Clone(ClusterPrefab prefab) { Neuron clone = new(prefab.cluster, this.name); CloneFields(clone); foreach (Synapse synapse in this.synapses) { Synapse clonedSynapse = clone.AddSynapse(synapse.neuron); clonedSynapse.weight = synapse.weight; } foreach (Nucleus receiver in this.receivers) { clone.AddReceiver(receiver); } return clone; } protected virtual void CloneFields(Neuron clone) { clone.bias = this.bias; clone.combinator = this.combinator; clone.curve = this.curve; clone.curvePreset = this.curvePreset; clone.curveMax = this.curveMax; } public static void Delete(Nucleus nucleus) { if (nucleus == null) return; if (nucleus is Neuron neuron) { foreach (Synapse synapse in neuron.synapses) { if (synapse.neuron is Neuron synapse_nucleus) { if (synapse_nucleus.receivers.Count > 1) { // there is another nucleus feeding into this input nucleus synapse_nucleus.receivers.RemoveAll(r => r == nucleus); } else { // No other links, delete it. Neuron.Delete(synapse_nucleus); } } } foreach (Nucleus receiver in neuron.receivers) { if (receiver is not Neuron receiverNeuron) continue; if (receiver != null && receiverNeuron.synapses != null) receiverNeuron.synapses.RemoveAll(s => s.neuron == nucleus); } } else if (nucleus is Cluster cluster) { // remove all receivers for this cluster foreach (Nucleus clusterNucleus in cluster.nuclei) { if (clusterNucleus is Neuron output) { foreach (Nucleus receiver in output.receivers) { if (receiver is not Neuron receiverNeuron) continue; receiverNeuron.synapses.RemoveAll(s => s.neuron == output); } } } } if (nucleus.parent.prefab != null) { nucleus.parent.prefab.cluster.nuclei.RemoveAll(n => n == nucleus); nucleus.parent.prefab.cluster.RefreshOutputs(); nucleus.parent.prefab.GarbageCollection(); } } public override void UpdateStateIsolated() { CheckSleepingSynapses(); var result = Combinator(); this.outputValue = Activator(result); this.lastUpdate = Time.time; } protected void CheckSleepingSynapses() { foreach (Synapse synapse in this.synapses) { if (synapse.isSleeping) { synapse.neuron.outputValue = Vector3.zero; } } } #region Combinator #if UNITY_MATHEMATICS protected Func Combinator => combinator switch { CombinatorType.Sum => CombinatorSum, CombinatorType.Product => CombinatorProduct, CombinatorType.Max => CombinatorMax, _ => CombinatorSum }; public float3 CombinatorSum() { float3 sum = this.bias; foreach (Synapse synapse in this.synapses) sum += synapse.weight * synapse.neuron.outputValue; return sum; } public float3 CombinatorProduct() { float3 product = this.bias; foreach (Synapse synapse in this.synapses) { product *= synapse.weight * synapse.neuron.outputValue; } return product; } public float3 CombinatorMax() { float3 max = this.bias; float maxLength = length(max); //Applying the weight factors foreach (Synapse synapse in this.synapses) { float3 input = synapse.weight * synapse.neuron.outputValue; float inputLength = length(input); if (inputLength > maxLength) { max = input; maxLength = inputLength; } } return max; } #else protected Func Combinator => combinator switch { CombinatorType.Sum => CombinatorSum, CombinatorType.Product => CombinatorProduct, CombinatorType.Max => CombinatorMax, _ => CombinatorSum }; public Vector3 CombinatorSum() { Vector3 sum = this.bias; foreach (Synapse synapse in this.synapses) sum += synapse.weight * synapse.neuron.outputValue; return sum; } public Vector3 CombinatorProduct() { Vector3 product = this.bias; foreach (Synapse synapse in this.synapses) { //product *= synapse.weight * synapse.neuron.outputValue; product = Vector3.Scale(product, synapse.weight * synapse.neuron.outputValue); } return product; } public Vector3 CombinatorMax() { Vector3 max = this.bias; float maxLength = max.magnitude; //Applying the weight factors foreach (Synapse synapse in this.synapses) { Vector3 input = synapse.weight * synapse.neuron.outputValue; float inputLength = input.magnitude; if (inputLength > maxLength) { max = input; maxLength = inputLength; } } return max; } #endif #endregion Combinator #region Activator #if UNITY_MATHEMATICS public Func Activator => this.curvePreset switch { ActivationType.Linear => ActivatorLinear, ActivationType.Sqrt => ActivatorSqrt, ActivationType.Power => ActivatorPower, ActivationType.Reciprocal => ActivatorReciprocal, ActivationType.Tanh => ActivatorTanh, ActivationType.Binary => ActivatorBinary, ActivationType.Normalized => ActivatorNormalized, _ => ActivatorCustom }; protected float3 ActivatorLinear(float3 input) { return input; } protected float3 ActivatorSqrt(float3 input) { float3 result = normalize(input) * System.MathF.Sqrt(length(input)); return result; } protected float3 ActivatorPower(float3 input) { float3 result = normalize(input) * System.MathF.Pow(length(input), 2); return result; } protected float3 ActivatorReciprocal(float3 input) { float magnitude = length(input); if (magnitude == 0) return new float3(0, 0, 0); float3 result = normalize(input) * (1 / magnitude); return result; } protected float3 ActivatorTanh(float3 input) { float magnitude = length(input); float3 result = normalize(input) * MathF.Tanh(magnitude); return result; } protected float3 ActivatorBinary(float3 input) { float magnitude = length(input); float value = Mathf.Clamp01(magnitude); return float3(value, value, value); } protected float3 ActivatorNormalized(float3 input) { if (lengthsq(input) == 0) return input; float3 result = normalize(input); return result; } protected float3 ActivatorCustom(float3 input) { float activatedValue = this.curve.Evaluate(length(input)); float3 result = normalize(input) * activatedValue; return result; } #else public Func Activator => this.curvePreset switch { CurvePresets.Linear => ActivatorLinear, CurvePresets.Sqrt => ActivatorSqrt, CurvePresets.Power => ActivatorPower, CurvePresets.Reciprocal => ActivatorReciprocal, _ => ActivatorCustom }; protected Vector3 ActivatorLinear(Vector3 input) { return input; } protected Vector3 ActivatorSqrt(Vector3 input) { Vector3 result = input.normalized * System.MathF.Sqrt(input.magnitude); return result; } protected Vector3 ActivatorPower(Vector3 input) { Vector3 result = input.normalized * System.MathF.Pow(input.magnitude, 2); return result; } protected Vector3 ActivatorReciprocal(Vector3 input) { float magnitude = input.magnitude; if (magnitude == 0) return new Vector3(0, 0, 0); Vector3 result = input.normalized * (1 / magnitude); return result; } protected Vector3 ActivatorCustom(Vector3 input) { float activatedValue = this.curve.Evaluate(input.magnitude); Vector3 result = input.normalized * activatedValue; return result; } #endif #endregion Activator #region Receivers [SerializeReference] private List _receivers = new(); public virtual List receivers { get { return _receivers; } set { _receivers = value; } } public virtual void AddReceiver(Nucleus receiverToAdd, float weight = 1) { if (receiverToAdd is not Neuron receiverNeuron) return; this._receivers.Add(receiverNeuron); receiverNeuron.AddSynapse(this, weight); //Debug.Log($"Add synapse {this.clusterPrefab.name}.{this.name} -> {receiverToAdd.name} --- [{this.receivers.Count}]"); } public virtual void RemoveReceiver(Nucleus receiverToRemove) { if (receiverToRemove is not Neuron receiverNeuron) return; this._receivers.RemoveAll(receiver => receiver == receiverNeuron); receiverNeuron.synapses.RemoveAll(synapse => synapse.neuron == this); // Nucleus prefabReceiver = receiverToRemove.prefabNucleus; // if (this.prefabNucleus is Neuron prefabNeuron && prefabReceiver != null) { // prefabNeuron.receivers.RemoveAll(receiver => receiver == prefabReceiver); // prefabReceiver.synapses.RemoveAll(synapse => synapse.neuron == prefabNeuron); // } } #endregion Receivers public override void ProcessStimulus(Vector3 inputValue) { this.lastUpdate = Time.time; this.bias = inputValue; this.parent?.UpdateFromNucleus(this); } } }