using System;
using System.Collections.Generic;
using UnityEngine;
using UnityEditor;
#if UNITY_MATHEMATICS
using Unity.Mathematics;
using static Unity.Mathematics.math;
#endif
namespace NanoBrain {
///
/// A neuron is a basic Nucleus
///
[Serializable]
public class Neuron : Nucleus {
///
/// Create a new Neuron in a Cluster instance
///
/// The parent cluster in which the new Neuron should be created
/// The name of the new Neuron
public Neuron(Cluster parent, string name) {
this.parent = parent;
this.name = name;
this.parent?.nuclei.Add(this);
}
///
/// Create a new Neuron in a Cluster Prefab
///
/// The Cluster Preafb in which the new Neuron should be created
/// The name of the new Neuron
// public Neuron(ClusterPrefab prefab, string name) {
// this.clusterPrefab = prefab;
// this.name = name;
// if (this.clusterPrefab != null) {
// this.clusterPrefab.cluster.nuclei.Add(this);
// this.clusterPrefab.cluster.RefreshOutputs();
// }
// else
// Debug.LogError("No prefab when adding neuron to prefab");
// }
#region Serialization
///
/// The bias
///
/// The bias which a value which is always added to the combined value of the neuron
/// It does not have a synapse and therefore no weight of source nucleus
public Vector3 bias = Vector3.zero;
#region Synapses
[SerializeField]
private List _synapses = new();
///
/// The synapses of the nucleus
///
public List synapses => _synapses;
///
/// Add a new synapse to this nuclues
///
/// The nucleus from which the signals may originate
/// The weight applied to the input. Default value = 1
/// The created Synapse
/// This will add a new input to this nucleus with the given weight.
public Synapse AddSynapse(Neuron sendingNucleus, float weight = 1) {
Synapse synapse = new(sendingNucleus, weight);
this.synapses.Add(synapse);
return synapse;
}
// public Synapse AddSynapse(ClusterPrefab clusterPrefab, string neuronName, float weight = 1) {
// }
///
/// Find a synapse
///
/// The sender of the input to the Synapse
/// The found Synapse or null when the sender has no synapse to this nucleus.
public Synapse GetSynapse(Nucleus sender) {
foreach (Synapse synapse in this.synapses)
if (synapse.neuron == sender)
return synapse;
return null;
}
///
/// Remove a synapse from a Nucleus
///
/// Remote the synapse connecting to this Nucleus
public void RemoveSynapse(Nucleus sendingNucleus) {
this.synapses.RemoveAll(synapse => synapse.neuron == sendingNucleus);
}
#endregion Synapses
///
/// Set the bias, recalculate the output and update all Nuclei receiving from this Nucleus
///
///
public virtual void SetBias(Vector3 inputValue) {
this.bias = inputValue;
this.lastUpdate = Time.time;
this.parent?.UpdateFromNucleus(this);
}
///
/// The type of combinators
///
/// A combinator combines the weighted values of the synapses to a single value
public enum CombinatorType {
/// Add the weighted values together
Sum,
/// Multiply the weighted values
Product,
}
///
/// The type of combinator used for this Neuron
///
public CombinatorType combinator = CombinatorType.Sum;
///
/// The type of
///
public enum ActivationType {
Linear,
Power,
Sqrt,
Reciprocal,
Tanh,
Binary,
Normalized,
Custom
}
[SerializeField]
public ActivationType _curvePreset;
public ActivationType curvePreset {
get { return _curvePreset; }
set {
_curvePreset = value;
this.curve = GenerateCurve();
}
}
public AnimationCurve curve;
public float curveMax = 1.0f;
public AnimationCurve GenerateCurve() {
switch (this.curvePreset) {
case ActivationType.Linear:
this.curveMax = 1;
return Presets.Linear(1);
case ActivationType.Power:
this.curveMax = 1;
return Presets.Power(2.0f, 1);
case ActivationType.Sqrt:
this.curveMax = 1;
return Presets.Power(0.5f, 1);
case ActivationType.Reciprocal:
this.curveMax = 1 / 0.01f * 1;
return Presets.Reciprocal(1);
case ActivationType.Tanh:
this.curveMax = 1;
return Presets.Tanh(1);
case ActivationType.Binary:
this.curveMax = 1;
return Presets.Binary();
case ActivationType.Normalized:
this.curveMax = 1;
return Presets.Binary();
default:
this.curveMax = 1;
return this.curve;
}
}
public static class Presets {
private const int samples = 32;
public static AnimationCurve Linear(float weight) {
return AnimationCurve.Linear(0f, 0f, 1000f, weight * 1000);
}
public static AnimationCurve Power(float exponent, float weight) {
// build keyframes
Keyframe[] keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float v = Mathf.Pow(t, exponent) * weight;
keys[i] = new Keyframe(t, v);
}
AnimationCurve curve = new(keys);
// set tangent modes for each key to Auto (smooth). Use Linear if you prefer straight segments.
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
}
return curve;
}
public static AnimationCurve Reciprocal(float weight) {
int samples = 128;
float xMin = 0.001f;
float xMax = 1;
var keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float x = Mathf.Lerp(xMin, xMax, t);
float y = 1f / x * weight;
keys[i] = new Keyframe(x, y);
}
var curve = new AnimationCurve(keys);
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
}
return curve;
}
public static AnimationCurve Tanh(float weight) {
//int samples = 128;
float xMin = 0.001f;
float xMax = 1;
var keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float x = Mathf.Lerp(xMin, xMax, t);
float y = MathF.Tanh(x * weight);
keys[i] = new Keyframe(x, y);
}
var curve = new AnimationCurve(keys);
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
}
return curve;
}
public static AnimationCurve Binary() {
return AnimationCurve.Linear(0, 0, 1, 1);
}
}
#endregion Serialization
#if UNITY_MATHEMATICS
protected float3 _outputValue;
public virtual float3 outputValue {
get { return _outputValue; }
set {
_outputValue = value;
if (this.isFiring)
WhenFiring?.Invoke();
}
}
public float outputMagnitude => length(_outputValue);
public float outputSqrMagnitude => lengthsq(_outputValue);
#else
protected Vector3 _outputValue;
public virtual Vector3 outputValue {
get { return _outputValue; }
set {
_outputValue = value;
if (this.isFiring)
WhenFiring?.Invoke();
}
}
public float outputMagnitude => _outputValue.magnitude;
public float outputSqrMagnitude => _outputValue.sqrMagnitude;
#endif
public bool isFiring => this.outputMagnitude > 0.5f;
public Action WhenFiring;
public bool persistOutput = false;
public virtual bool isSleeping => !persistOutput && (Time.time - this.lastUpdate > this.timeToSleep);
public void SleepCheck() {
if (this.isSleeping) {
#if UNITY_MATHEMATICS
this._outputValue = new float3(0, 0, 0);
#else
this._outputValue = new Vector3(0,0,0);
#endif
}
}
///
/// Toggle for printing debugging trace data
///
//public bool trace = false;
//[NonSerialized]
public float lastUpdate = 0;
public readonly float timeToSleep = 1f;
/// \copydoc NanoBrain::Nucleus::ShallowCloneTo
public override Nucleus ShallowCloneTo(Cluster newParent) {
Neuron clone = new(newParent, this.name) {
// prefabNucleus = this
};
CloneFields(clone);
return clone;
}
// \copydoc NanoBrain::Nucleus::Clone
// public override Nucleus Clone(ClusterPrefab prefab) {
// Neuron clone = new(prefab.cluster, this.name);
// CloneFields(clone);
// foreach (Synapse synapse in this.synapses) {
// Synapse clonedSynapse = clone.AddSynapse(synapse.neuron);
// clonedSynapse.weight = synapse.weight;
// }
// foreach (Nucleus receiver in this.receivers) {
// clone.AddReceiver(receiver);
// }
// return clone;
// }
protected virtual void CloneFields(Neuron clone) {
clone.bias = this.bias;
clone.persistOutput = this.persistOutput;
clone.combinator = this.combinator;
clone.curve = this.curve;
clone.curvePreset = this.curvePreset;
clone.curveMax = this.curveMax;
}
public static void Delete(Nucleus nucleus) {
if (nucleus == null)
return;
if (nucleus is Neuron neuron) {
foreach (Synapse synapse in neuron.synapses) {
if (synapse.neuron is Neuron synapse_nucleus) {
if (synapse_nucleus.receivers.Count > 1) {
// there is another nucleus feeding into this input nucleus
synapse_nucleus.receivers.RemoveAll(r => r == nucleus);
}
else {
// No other links, delete it.
Neuron.Delete(synapse_nucleus);
}
}
}
foreach (Nucleus receiver in neuron.receivers) {
if (receiver is not Neuron receiverNeuron)
continue;
if (receiver != null && receiverNeuron.synapses != null)
receiverNeuron.synapses.RemoveAll(s => s.neuron == nucleus);
}
}
else if (nucleus is Cluster cluster) {
// remove all receivers for this cluster
foreach (Nucleus clusterNucleus in cluster.nuclei) {
if (clusterNucleus is Neuron output) {
foreach (Nucleus receiver in output.receivers) {
if (receiver is not Neuron receiverNeuron)
continue;
receiverNeuron.synapses.RemoveAll(s => s.neuron == output);
}
}
}
}
if (nucleus.parent.prefab != null) {
nucleus.parent.prefab.cluster.nuclei.RemoveAll(n => n == nucleus);
nucleus.parent.prefab.cluster.RefreshOutputs();
nucleus.parent.prefab.GarbageCollection();
}
}
public override void UpdateStateIsolated() {
var result = Combinator();
this.outputValue = ApplyActivator(result);
this.lastUpdate = Time.time;
}
protected void CheckSleepingSynapses() {
foreach (Synapse synapse in this.synapses)
synapse.neuron.SleepCheck();
}
#region Combinator
#if UNITY_MATHEMATICS
protected Func Combinator => combinator switch {
CombinatorType.Sum => CombinatorSum,
CombinatorType.Product => CombinatorProduct,
_ => CombinatorSum
};
public float3 CombinatorSum() {
float3 sum = this.bias;
foreach (Synapse synapse in this.synapses) {
synapse.neuron.SleepCheck();
sum += synapse.weight * synapse.neuron.outputValue;
}
return sum;
}
public float3 CombinatorProduct() {
float3 product = this.bias;
foreach (Synapse synapse in this.synapses) {
synapse.neuron.SleepCheck();
product *= synapse.weight * synapse.neuron.outputValue;
}
return product;
}
#else
protected Func Combinator => combinator switch {
CombinatorType.Sum => CombinatorSum,
CombinatorType.Product => CombinatorProduct,
CombinatorType.Max => CombinatorMax,
_ => CombinatorSum
};
public Vector3 CombinatorSum() {
Vector3 sum = this.bias;
foreach (Synapse synapse in this.synapses)
sum += synapse.weight * synapse.neuron.outputValue;
return sum;
}
public Vector3 CombinatorProduct() {
Vector3 product = this.bias;
foreach (Synapse synapse in this.synapses) {
//product *= synapse.weight * synapse.neuron.outputValue;
product = Vector3.Scale(product, synapse.weight * synapse.neuron.outputValue);
}
return product;
}
public Vector3 CombinatorMax() {
Vector3 max = this.bias;
float maxLength = max.magnitude;
//Applying the weight factors
foreach (Synapse synapse in this.synapses) {
Vector3 input = synapse.weight * synapse.neuron.outputValue;
float inputLength = input.magnitude;
if (inputLength > maxLength) {
max = input;
maxLength = inputLength;
}
}
return max;
}
#endif
#endregion Combinator
#region Activator
#if UNITY_MATHEMATICS
// This does not allocate memory and seems faster than the solution below
float3 ApplyActivator(float3 x) {
switch (curvePreset) {
case ActivationType.Linear: return ActivatorLinear(x);
case ActivationType.Sqrt: return ActivatorSqrt(x);
case ActivationType.Power: return ActivatorPower(x);
case ActivationType.Reciprocal: return ActivatorReciprocal(x);
case ActivationType.Tanh: return ActivatorTanh(x);
case ActivationType.Binary: return ActivatorBinary(x);
case ActivationType.Normalized: return ActivatorNormalized(x);
default: return ActivatorCustom(x);
}
}
public Func Activator => this.curvePreset switch {
ActivationType.Linear => ActivatorLinear,
ActivationType.Sqrt => ActivatorSqrt,
ActivationType.Power => ActivatorPower,
ActivationType.Reciprocal => ActivatorReciprocal,
ActivationType.Tanh => ActivatorTanh,
ActivationType.Binary => ActivatorBinary,
ActivationType.Normalized => ActivatorNormalized,
_ => ActivatorCustom
};
protected float3 ActivatorLinear(float3 input) {
return input;
}
protected float3 ActivatorSqrt(float3 input) {
float3 result = normalize(input) * System.MathF.Sqrt(length(input));
return result;
}
protected float3 ActivatorPower(float3 input) {
float3 result = normalize(input) * System.MathF.Pow(length(input), 2);
return result;
}
protected float3 ActivatorReciprocal(float3 input) {
float magnitude = length(input);
if (magnitude == 0)
return new float3(0, 0, 0);
float3 result = normalize(input) * (1 / magnitude);
return result;
}
protected float3 ActivatorTanh(float3 input) {
float magnitude = length(input);
float3 result = normalize(input) * MathF.Tanh(magnitude);
return result;
}
protected float3 ActivatorBinary(float3 input) {
float magnitude = length(input);
float value = Mathf.Clamp01(magnitude);
return float3(value, value, value);
}
protected float3 ActivatorNormalized(float3 input) {
if (lengthsq(input) == 0)
return input;
float3 result = normalize(input);
return result;
}
protected float3 ActivatorCustom(float3 input) {
float activatedValue = this.curve.Evaluate(length(input));
float3 result = normalize(input) * activatedValue;
return result;
}
#else
public Func Activator => this.curvePreset switch {
CurvePresets.Linear => ActivatorLinear,
CurvePresets.Sqrt => ActivatorSqrt,
CurvePresets.Power => ActivatorPower,
CurvePresets.Reciprocal => ActivatorReciprocal,
_ => ActivatorCustom
};
protected Vector3 ActivatorLinear(Vector3 input) {
return input;
}
protected Vector3 ActivatorSqrt(Vector3 input) {
Vector3 result = input.normalized * System.MathF.Sqrt(input.magnitude);
return result;
}
protected Vector3 ActivatorPower(Vector3 input) {
Vector3 result = input.normalized * System.MathF.Pow(input.magnitude, 2);
return result;
}
protected Vector3 ActivatorReciprocal(Vector3 input) {
float magnitude = input.magnitude;
if (magnitude == 0)
return new Vector3(0, 0, 0);
Vector3 result = input.normalized * (1 / magnitude);
return result;
}
protected Vector3 ActivatorCustom(Vector3 input) {
float activatedValue = this.curve.Evaluate(input.magnitude);
Vector3 result = input.normalized * activatedValue;
return result;
}
#endif
#endregion Activator
#region Receivers
[SerializeReference]
private List _receivers = new();
public virtual List receivers {
get { return _receivers; }
set { _receivers = value; }
}
public virtual void AddReceiver(Nucleus receiverToAdd, float weight = 1) {
if (receiverToAdd is not Neuron receiverNeuron)
return;
this._receivers.Add(receiverNeuron);
receiverNeuron.AddSynapse(this, weight);
//Debug.Log($"Add synapse {this.clusterPrefab.name}.{this.name} -> {receiverToAdd.name} --- [{this.receivers.Count}]");
}
public virtual void RemoveReceiver(Nucleus receiverToRemove) {
if (receiverToRemove is not Neuron receiverNeuron)
return;
this._receivers.RemoveAll(receiver => receiver == receiverNeuron);
receiverNeuron.synapses.RemoveAll(synapse => synapse.neuron == this);
// Nucleus prefabReceiver = receiverToRemove.prefabNucleus;
// if (this.prefabNucleus is Neuron prefabNeuron && prefabReceiver != null) {
// prefabNeuron.receivers.RemoveAll(receiver => receiver == prefabReceiver);
// prefabReceiver.synapses.RemoveAll(synapse => synapse.neuron == prefabNeuron);
// }
}
#endregion Receivers
///
/// Process an external stimulus
///
/// The value of the stimulus
/// The id of the thing causing the stimulus
/// The name of the thing causing the stimulus
public virtual void ProcessStimulus(Vector3 inputValue) {
this.lastUpdate = Time.time;
this.bias = inputValue;
this.parent?.UpdateFromNucleus(this);
}
}
}