using System; using System.Collections.Generic; using UnityEngine; #if UNITY_MATHEMATICS using Unity.Mathematics; using static Unity.Mathematics.math; #endif namespace NanoBrain { /// /// A neuron is a basic Nucleus /// /// A neuron combines the weighted input from other neurons and applies an activation function to it /// to compute the output value: /// \code /// Vector3 combination = NanoBrain::Neuron::Combinator(bias, synapses); /// Vector3 output = NanoBrain::Neuron::Activator(combination); /// \endcode /// The synapses are connections to other neurons. /// Each connection has a weight which is used to multiply the output of that other neuron /// before it is used by the combinator. [Serializable] public class Neuron : Nucleus { /// /// Create a new Neuron in a Cluster instance /// /// The parent cluster in which the new Neuron should be created /// The name of the new Neuron public Neuron(Cluster parent, string name) { this.parent = parent; this.name = name; if (this.parent != null) { this.parent.nuclei ??= new(); this.parent.nuclei.Add(this); } } #region Serialization /// /// The bias /// /// The bias which a value which is always added to the combined value of the neuron /// It does not have a synapse and therefore no weight of source nucleus //[HideInInspector] public Vector3 bias = Vector3.zero; #region Synapses [SerializeField] private List _synapses = new(); /// /// The synapses of the nucleus /// public List synapses => _synapses; /// /// Add a new synapse to this nuclues /// /// The nucleus from which the signals may originate /// The weight applied to the input. Default value = 1 /// The created Synapse /// This will add a new input to this nucleus with the given weight. public Synapse AddSynapse(Neuron sendingNucleus, float weight = 1) { Synapse synapse = new(sendingNucleus, weight); this.synapses.Add(synapse); return synapse; } /// /// Find a synapse /// /// The sender of the input to the Synapse /// The found Synapse or null when the sender has no synapse to this nucleus. public Synapse GetSynapse(Nucleus sender) { foreach (Synapse synapse in this.synapses) if (synapse.neuron == sender) return synapse; return null; } /// /// Remove a synapse from a Nucleus /// /// Remote the synapse connecting to this Nucleus public void RemoveSynapse(Nucleus sendingNucleus) { this.synapses.RemoveAll(synapse => synapse.neuron == sendingNucleus); } #endregion Synapses /// /// Set the bias, recalculate the output and update all Nuclei receiving from this Nucleus /// /// public virtual void SetBias(Vector3 inputValue) { this.bias = inputValue; this.lastUpdate = Time.time; this.parent?.UpdateFromNucleus(this); } /// /// The type of combinators /// /// A combinator combines the weighted values of the synapses to a single value public enum CombinatorType { /// Add the weighted values together Sum, /// Multiply the weighted values Product, } /// /// The type of combinator used for this Neuron /// [HideInInspector] public CombinatorType combinator = CombinatorType.Sum; /// /// The type of /// public enum ActivationType { Linear, Power, Sqrt, Reciprocal, Tanh, Binary, Normalized, Custom } /// /// The activation function /// [SerializeField] [HideInInspector] public ActivationType _activator; /// /// The activation funtion /// public ActivationType activator { get { return _activator; } set { _activator = value; //this.curve = GenerateCurve(); } } #endregion Serialization #if UNITY_MATHEMATICS /// /// The output value of the neuron /// [HideInInspector] protected float3 _outputValue; /// /// The output value of the neuron /// public virtual float3 outputValue { get { return _outputValue; } set { _outputValue = value; if (this.isFiring) WhenFiring?.Invoke(); } } /// /// The magnitude of the neuron output /// public float outputMagnitude => length(_outputValue); /// /// The squared magnitude of the neuron output /// public float outputSqrMagnitude => lengthsq(_outputValue); #else /// /// The output value of the neuron /// protected Vector3 _outputValue; /// /// The output value of the neuron /// public virtual Vector3 outputValue { get { return _outputValue; } set { _outputValue = value; if (this.isFiring) WhenFiring?.Invoke(); } } /// /// The magnitude of the neuron output /// public float outputMagnitude => _outputValue.magnitude; /// /// The squared magnitude of the neuron output /// public float outputSqrMagnitude => _outputValue.sqrMagnitude; #endif /// /// True if the neuron have a positive value with magnitude > 0.5 /// public bool isFiring => this.outputMagnitude > 0.5f; /// /// An action which is called every time the neuron is updated and is firing /// public Action WhenFiring; /// /// When true, the value will not be reset after timeToSleep. /// public bool persistOutput = false; /// /// True when the neuron is not persisting and has not be updated for timeToSleep seconds /// public virtual bool isSleeping => !persistOutput && (Time.time - this.lastUpdate > timeToSleep); /// /// Check if the neuron is sleeping. /// /// This will reset the output value if it is sleeping public void SleepCheck() { if (this.isSleeping && this.outputSqrMagnitude > 0) { #if UNITY_MATHEMATICS this._outputValue = new float3(0, 0, 0); #else this._outputValue = new Vector3(0,0,0); #endif } } /// /// The time at which the last update has been done /// [HideInInspector] public float lastUpdate = 0; /// /// Time in seconds after the last update the neuron can go to sleep /// public static readonly float timeToSleep = 0.5f; /// /// When true, Unity will pause exection when this neuron is updated /// /// Pausing is implemented using [Debug.Break()](https://docs.unity3d.com/ScriptReference/Debug.Break.html) public bool breakOnUpdate = false; /// \copydoc NanoBrain::Nucleus::ShallowCloneTo public override Nucleus ShallowCloneTo(Cluster parent) { Neuron clone = new(parent, this.name) { // prefabNucleus = this }; CloneFields(clone); return clone; } /// /// Copy relevant fields of this neuron to the given neuron /// /// protected virtual void CloneFields(Neuron clone) { clone.bias = this.bias; clone.persistOutput = this.persistOutput; clone.combinator = this.combinator; clone.activator = this.activator; clone.breakOnUpdate = this.breakOnUpdate; } /// /// Delete the give neuron /// /// The neuron to delete public static void Delete(Nucleus nucleus) { if (nucleus == null) return; if (nucleus is Neuron neuron) { foreach (Synapse synapse in neuron.synapses) { if (synapse.neuron is Neuron synapse_nucleus) { if (synapse_nucleus.receivers.Count > 1) { // there is another nucleus feeding into this input nucleus synapse_nucleus.receivers.RemoveAll(r => r == nucleus); } else { // No other links, delete it. Neuron.Delete(synapse_nucleus); } } } foreach (Nucleus receiver in neuron.receivers) { if (receiver is not Neuron receiverNeuron) continue; if (receiver != null && receiverNeuron.synapses != null) receiverNeuron.synapses.RemoveAll(s => s.neuron == nucleus); } } else if (nucleus is Cluster cluster) { // remove all receivers for this cluster foreach (Nucleus clusterNucleus in cluster.nuclei) { if (clusterNucleus is Neuron output) { foreach (Nucleus receiver in output.receivers) { if (receiver is not Neuron receiverNeuron) continue; receiverNeuron.synapses.RemoveAll(s => s.neuron == output); } } } } if (nucleus.parent.prefab != null) { nucleus.parent.nuclei.RemoveAll(n => n == nucleus); nucleus.parent.RefreshOutputs(); } } /// \copydoc NanoBrain::Nucleus::UpdateStateIsolated public override void UpdateStateIsolated() { if (breakOnUpdate) { Debug.Break(); } var combination = Combinator(this.bias, this.synapses); this.outputValue = Activator(combination); this.lastUpdate = Time.time; } #region Combinator #if UNITY_MATHEMATICS /// /// The combinator which combines the bias with the values from all synapses /// /// The bias of the neuron /// The synapses of the neuron /// protected float3 Combinator(float3 bias, List synapses) { switch (combinator) { case CombinatorType.Sum: return CombinatorSum(bias, synapses); case CombinatorType.Product: return CombinatorProduct(bias, synapses); default: return CombinatorSum(bias, synapses); } } /// /// Sum the bias and synpase outputs together /// /// The bias of the neuron /// The synapses of the neuron /// public static float3 CombinatorSum(float3 bias, List synapses) { float3 sum = bias; foreach (Synapse synapse in synapses) { synapse.neuron.SleepCheck(); sum += synapse.weight * synapse.neuron.outputValue; } return sum; } /// /// Multiply the synapse outputs together /// /// The bias of the neuron /// The synapses of the neuron /// The result of the multiplication public static float3 CombinatorProduct(float3 bias, List synapses) { float3 product = bias; foreach (Synapse synapse in synapses) { synapse.neuron.SleepCheck(); product *= synapse.weight * synapse.neuron.outputValue; } return product; } #else /// /// The combinator which combines the bias with the values from all synapses /// /// The bias of the neuron /// The synapses of the neuron /// protected Vector3 Combinator(Vector3 bias, List synapses) { switch (combinator) { case CombinatorType.Sum: return CombinatorSum(bias, synapses); case CombinatorType.Product: return CombinatorProduct(bias, synapses); default: return CombinatorSum(bias, synapses); } } /// /// Sum the bias and synpase outputs together /// /// The bias of the neuron /// The synapses of the neuron /// public static Vector3 CombinatorSum(Vector3 bias, List synapses) { float3 sum = bias; foreach (Synapse synapse in synapses) { synapse.neuron.SleepCheck(); sum += synapse.weight * synapse.neuron.outputValue; } return sum; } /// /// Multiply the synapse outputs together /// /// The bias of the neuron /// The synapses of the neuron /// The result of the multiplication public static Vector3 CombinatorProduct(Vector3 bias, List synapses) { float3 product = bias; foreach (Synapse synapse in synapses) { synapse.neuron.SleepCheck(); product *= synapse.weight * synapse.neuron.outputValue; } return product; } #endif #endregion Combinator #region Activator #if UNITY_MATHEMATICS /// /// Apply the activation function to the input /// /// /// The result of applying the activation function // This does not allocate memory and seems faster than a switch expression protected float3 Activator(float3 inputValue) { switch (activator) { case ActivationType.Linear: return ActivatorLinear(inputValue); case ActivationType.Sqrt: return ActivatorSqrt(inputValue); case ActivationType.Power: return ActivatorPower(inputValue); case ActivationType.Reciprocal: return ActivatorReciprocal(inputValue); case ActivationType.Tanh: return ActivatorTanh(inputValue); case ActivationType.Binary: return ActivatorBinary(inputValue); case ActivationType.Normalized: return ActivatorNormalized(inputValue); default: return ActivatorLinear(inputValue); } } /// /// Linear activation function /// /// Input value /// The unchanged value protected float3 ActivatorLinear(float3 input) { return input; } /// /// Square root activation function /// /// Input value /// The square root of the input protected float3 ActivatorSqrt(float3 input) { float3 result = normalize(input) * MathF.Sqrt(length(input)); return result; } /// /// Power activation function /// /// Input value /// The input to the power of 2 protected float3 ActivatorPower(float3 input) { float3 result = normalize(input) * MathF.Pow(length(input), 2); return result; } /// /// Reciprocal activation function /// /// Input value /// 1/input value protected float3 ActivatorReciprocal(float3 input) { float magnitude = length(input); if (magnitude == 0) return new float3(0, 0, 0); float3 result = normalize(input) * (1 / magnitude); return result; } /// /// Tanh activation function /// /// Input value /// Tanh(input value) protected float3 ActivatorTanh(float3 input) { float magnitude = length(input); float3 result = normalize(input) * MathF.Tanh(magnitude); return result; } /// /// Binary activation function /// /// Input value /// An uniform vector with magnitude between 0 and 1 protected float3 ActivatorBinary(float3 input) { float magnitude = length(input); float value = Mathf.Clamp01(magnitude); return float3(value, value, value); } /// /// Normalize activation function /// /// Input value /// The normalized vector protected float3 ActivatorNormalized(float3 input) { if (lengthsq(input) == 0) return input; float3 result = normalize(input); return result; } #else /// /// Apply the activation function to the input /// /// /// The result of applying the activation function // This does not allocate memory and seems faster than a switch expression protected Vector3 Activator(Vector3 inputValue) { switch (activator) { case ActivationType.Linear: return ActivatorLinear(inputValue); case ActivationType.Sqrt: return ActivatorSqrt(inputValue); case ActivationType.Power: return ActivatorPower(inputValue); case ActivationType.Reciprocal: return ActivatorReciprocal(inputValue); // case ActivationType.Tanh: return ActivatorTanh(inputValue); // case ActivationType.Binary: return ActivatorBinary(inputValue); // case ActivationType.Normalized: return ActivatorNormalized(inputValue); default: return ActivatorLinear(inputValue); } } /// /// Linear activation function /// /// Input value /// The unchanged value protected Vector3 ActivatorLinear(Vector3 input) { return input; } /// /// Square root activation function /// /// Input value /// The square root of the input protected Vector3 ActivatorSqrt(Vector3 input) { Vector3 result = input.normalized * System.MathF.Sqrt(input.magnitude); return result; } /// /// Power activation function /// /// Input value /// The input to the power of 2 protected Vector3 ActivatorPower(Vector3 input) { Vector3 result = input.normalized * System.MathF.Pow(input.magnitude, 2); return result; } /// /// Reciprocal activation function /// /// Input value /// 1/input value protected Vector3 ActivatorReciprocal(Vector3 input) { float magnitude = input.magnitude; if (magnitude == 0) return new Vector3(0, 0, 0); Vector3 result = input.normalized * (1 / magnitude); return result; } #endif #endregion Activator #region Receivers /// /// The nuclei which have a synapse to this neuron /// [SerializeReference] [HideInInspector] private List _receivers = new(); /// /// The nuclei which have a synapse to this neuron /// public virtual List receivers { get { return _receivers; } set { _receivers = value; } } /// /// Add a new receiver to this neuron /// /// The receiver to add /// The weight to use for the synapse to his neuron public virtual void AddReceiver(Nucleus receiverToAdd, float weight = 1) { if (receiverToAdd is not Neuron receiverNeuron) return; this._receivers.Add(receiverNeuron); receiverNeuron.AddSynapse(this, weight); //Debug.Log($"Add synapse {this.clusterPrefab.name}.{this.name} -> {receiverToAdd.name} --- [{this.receivers.Count}]"); } /// /// Remove a receiver to this neuron /// /// The receiver to remove public virtual void RemoveReceiver(Nucleus receiverToRemove) { if (receiverToRemove is not Neuron receiverNeuron) return; this._receivers.RemoveAll(receiver => receiver == receiverNeuron); receiverNeuron.synapses.RemoveAll(synapse => synapse.neuron == this); // Nucleus prefabReceiver = receiverToRemove.prefabNucleus; // if (this.prefabNucleus is Neuron prefabNeuron && prefabReceiver != null) { // prefabNeuron.receivers.RemoveAll(receiver => receiver == prefabReceiver); // prefabReceiver.synapses.RemoveAll(synapse => synapse.neuron == prefabNeuron); // } } #endregion Receivers /// /// Process an external stimulus /// /// The value of the stimulus public virtual void ProcessStimulus(Vector3 inputValue) { this.lastUpdate = Time.time; this.bias = inputValue; this.parent?.UpdateFromNucleus(this); } } }