556 lines
20 KiB
C#
556 lines
20 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using UnityEngine;
|
|
using UnityEditor;
|
|
#if UNITY_MATHEMATICS
|
|
using Unity.Mathematics;
|
|
using static Unity.Mathematics.math;
|
|
#endif
|
|
|
|
namespace NanoBrain {
|
|
|
|
/// <summary>
|
|
/// A neuron is a basic Nucleus
|
|
/// </summary>
|
|
[Serializable]
|
|
public class Neuron : Nucleus {
|
|
|
|
/// <summary>
|
|
/// Create a new Neuron in a Cluster instance
|
|
/// </summary>
|
|
/// <param name="parent">The parent cluster in which the new Neuron should be created</param>
|
|
/// <param name="name">The name of the new Neuron</param>
|
|
public Neuron(Cluster parent, string name) {
|
|
this.parent = parent;
|
|
this.name = name;
|
|
this.parent?.clusterNuclei.Add(this);
|
|
}
|
|
/// <summary>
|
|
/// Create a new Neuron in a Cluster Prefab
|
|
/// </summary>
|
|
/// <param name="prefab">The Cluster Preafb in which the new Neuron should be created</param>
|
|
/// <param name="name">The name of the new Neuron</param>
|
|
public Neuron(ClusterPrefab prefab, string name) {
|
|
this.clusterPrefab = prefab;
|
|
this.name = name;
|
|
if (this.clusterPrefab != null)
|
|
this.clusterPrefab.nuclei.Add(this);
|
|
else
|
|
Debug.LogError("No prefab when adding neuron to prefab");
|
|
}
|
|
|
|
#region Serialization
|
|
|
|
/// <summary>
|
|
/// The type of combinators
|
|
/// </summary>
|
|
/// A combinator combines the weighted values of the synapses to a single value
|
|
public enum CombinatorType {
|
|
/// Add the weighted values together
|
|
Sum,
|
|
/// Multiply the weighted values
|
|
Product,
|
|
/// Take the maximum of all the weighted values
|
|
Max,
|
|
}
|
|
/// <summary>
|
|
/// The type of combinator used for this Neuron
|
|
/// </summary>
|
|
public CombinatorType combinator = CombinatorType.Sum;
|
|
|
|
/// <summary>
|
|
/// The type of
|
|
/// </summary>
|
|
public enum ActivationType {
|
|
Linear,
|
|
Power,
|
|
Sqrt,
|
|
Reciprocal,
|
|
Tanh,
|
|
Binary,
|
|
Normalized,
|
|
Custom
|
|
}
|
|
[SerializeField]
|
|
public ActivationType _curvePreset;
|
|
public ActivationType curvePreset {
|
|
get { return _curvePreset; }
|
|
set {
|
|
_curvePreset = value;
|
|
this.curve = GenerateCurve();
|
|
}
|
|
}
|
|
public AnimationCurve curve;
|
|
public float curveMax = 1.0f;
|
|
|
|
public AnimationCurve GenerateCurve() {
|
|
switch (this.curvePreset) {
|
|
case ActivationType.Linear:
|
|
this.curveMax = 1;
|
|
return Presets.Linear(1);
|
|
case ActivationType.Power:
|
|
this.curveMax = 1;
|
|
return Presets.Power(2.0f, 1);
|
|
case ActivationType.Sqrt:
|
|
this.curveMax = 1;
|
|
return Presets.Power(0.5f, 1);
|
|
case ActivationType.Reciprocal:
|
|
this.curveMax = 1 / 0.01f * 1;
|
|
return Presets.Reciprocal(1);
|
|
case ActivationType.Tanh:
|
|
this.curveMax = 1;
|
|
return Presets.Tanh(1);
|
|
case ActivationType.Binary:
|
|
this.curveMax = 1;
|
|
return Presets.Binary();
|
|
case ActivationType.Normalized:
|
|
this.curveMax = 1;
|
|
return Presets.Binary();
|
|
default:
|
|
this.curveMax = 1;
|
|
return this.curve;
|
|
}
|
|
}
|
|
|
|
public static class Presets {
|
|
private const int samples = 32;
|
|
public static AnimationCurve Linear(float weight) {
|
|
return AnimationCurve.Linear(0f, 0f, 1000f, weight * 1000);
|
|
}
|
|
public static AnimationCurve Power(float exponent, float weight) {
|
|
// build keyframes
|
|
Keyframe[] keys = new Keyframe[samples];
|
|
for (int i = 0; i < samples; i++) {
|
|
float t = i / (float)(samples - 1);
|
|
float v = Mathf.Pow(t, exponent) * weight;
|
|
keys[i] = new Keyframe(t, v);
|
|
}
|
|
|
|
AnimationCurve curve = new(keys);
|
|
|
|
// set tangent modes for each key to Auto (smooth). Use Linear if you prefer straight segments.
|
|
for (int i = 0; i < curve.length; i++) {
|
|
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
|
|
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
|
|
}
|
|
|
|
return curve;
|
|
}
|
|
public static AnimationCurve Reciprocal(float weight) {
|
|
int samples = 128;
|
|
float xMin = 0.001f;
|
|
float xMax = 1;
|
|
var keys = new Keyframe[samples];
|
|
for (int i = 0; i < samples; i++) {
|
|
float t = i / (float)(samples - 1);
|
|
float x = Mathf.Lerp(xMin, xMax, t);
|
|
float y = 1f / x * weight;
|
|
keys[i] = new Keyframe(x, y);
|
|
}
|
|
var curve = new AnimationCurve(keys);
|
|
for (int i = 0; i < curve.length; i++) {
|
|
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
|
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
|
}
|
|
return curve;
|
|
}
|
|
public static AnimationCurve Tanh(float weight) {
|
|
//int samples = 128;
|
|
float xMin = 0.001f;
|
|
float xMax = 1;
|
|
var keys = new Keyframe[samples];
|
|
for (int i = 0; i < samples; i++) {
|
|
float t = i / (float)(samples - 1);
|
|
float x = Mathf.Lerp(xMin, xMax, t);
|
|
float y = MathF.Tanh(x * weight);
|
|
keys[i] = new Keyframe(x, y);
|
|
}
|
|
var curve = new AnimationCurve(keys);
|
|
for (int i = 0; i < curve.length; i++) {
|
|
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
|
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
|
}
|
|
return curve;
|
|
|
|
}
|
|
public static AnimationCurve Binary() {
|
|
return AnimationCurve.Linear(0, 0, 1, 1);
|
|
}
|
|
}
|
|
|
|
#endregion Serialization
|
|
|
|
#if UNITY_MATHEMATICS
|
|
|
|
protected float3 _outputValue;
|
|
public virtual float3 outputValue {
|
|
get { return _outputValue; }
|
|
set {
|
|
_outputValue = value;
|
|
if (this.isFiring)
|
|
WhenFiring?.Invoke();
|
|
}
|
|
}
|
|
public float outputMagnitude => length(_outputValue);
|
|
public float outputSqrMagnitude => lengthsq(_outputValue);
|
|
|
|
#else
|
|
|
|
protected Vector3 _outputValue;
|
|
public virtual Vector3 outputValue {
|
|
get { return _outputValue; }
|
|
set {
|
|
_outputValue = value;
|
|
if (this.isFiring)
|
|
WhenFiring?.Invoke();
|
|
}
|
|
}
|
|
public float outputMagnitude => _outputValue.magnitude;
|
|
public float outputSqrMagnitude => _outputValue.sqrMagnitude;
|
|
|
|
#endif
|
|
public bool isFiring => this.outputMagnitude > 0.5f;
|
|
public Action WhenFiring;
|
|
|
|
|
|
public virtual bool isSleeping => Time.time - this.lastUpdate > this.timeToSleep; //this.outputMagnitude == 0;
|
|
public void SleepCheck() {
|
|
if (this.isSleeping) {
|
|
#if UNITY_MATHEMATICS
|
|
this.bias = new float3(0, 0, 0);
|
|
#else
|
|
this.bias = new Vector3(0,0,0);
|
|
#endif
|
|
}
|
|
}
|
|
|
|
// [NonSerialized]
|
|
// public int stale = 1000;
|
|
[NonSerialized]
|
|
public float lastUpdate = 0;
|
|
// public readonly int staleValueForSleep = 20;
|
|
public readonly float timeToSleep = 1f;
|
|
|
|
/// \copydoc NanoBrain::Nucleus::ShallowCloneTo
|
|
public override Nucleus ShallowCloneTo(Cluster newParent) {
|
|
Neuron clone = new(newParent, this.name);
|
|
CloneFields(clone);
|
|
return clone;
|
|
}
|
|
|
|
/// \copydoc NanoBrain::Nucleus::Clone
|
|
public override Nucleus Clone(ClusterPrefab prefab) {
|
|
Neuron clone = new(prefab, this.name);
|
|
CloneFields(clone);
|
|
foreach (Synapse synapse in this.synapses) {
|
|
Synapse clonedSynapse = clone.AddSynapse(synapse.neuron);
|
|
clonedSynapse.weight = synapse.weight;
|
|
}
|
|
foreach (Nucleus receiver in this.receivers) {
|
|
clone.AddReceiver(receiver);
|
|
}
|
|
return clone;
|
|
}
|
|
|
|
protected virtual void CloneFields(Neuron clone) {
|
|
clone.clusterPrefab = this.clusterPrefab;
|
|
clone.bias = this.bias;
|
|
clone.combinator = this.combinator;
|
|
clone.curve = this.curve;
|
|
clone.curvePreset = this.curvePreset;
|
|
clone.curveMax = this.curveMax;
|
|
}
|
|
|
|
public static void Delete(Nucleus nucleus) {
|
|
foreach (Synapse synapse in nucleus.synapses) {
|
|
if (synapse.neuron is Neuron synapse_nucleus) {
|
|
if (synapse_nucleus.receivers.Count > 1) {
|
|
// there is another nucleus feeding into this input nucleus
|
|
synapse_nucleus.receivers.RemoveAll(r => r == nucleus);
|
|
}
|
|
else {
|
|
// No other links, delete it.
|
|
Neuron.Delete(synapse_nucleus);
|
|
}
|
|
}
|
|
}
|
|
if (nucleus is Neuron neuron) {
|
|
foreach (Nucleus receiver in neuron.receivers) {
|
|
if (receiver != null && receiver.synapses != null)
|
|
receiver.synapses.RemoveAll(s => s.neuron == nucleus);
|
|
}
|
|
}
|
|
else if (nucleus is Cluster cluster) {
|
|
// remove all receivers for this cluster
|
|
foreach (Nucleus clusterNucleus in cluster.clusterNuclei) {
|
|
if (clusterNucleus is Neuron output) {
|
|
foreach (Nucleus receiver in output.receivers) {
|
|
receiver.synapses.RemoveAll(s => s.neuron == output);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
if (nucleus.clusterPrefab != null) {
|
|
nucleus.clusterPrefab.nuclei.RemoveAll(n => n == nucleus);
|
|
nucleus.clusterPrefab.RefreshOutputs();
|
|
nucleus.clusterPrefab.GarbageCollection();
|
|
}
|
|
}
|
|
|
|
public override void UpdateStateIsolated() {
|
|
CheckSleepingSynapses();
|
|
var result = Combinator();
|
|
this.outputValue = Activator(result);
|
|
this.lastUpdate = Time.time;
|
|
}
|
|
|
|
protected void CheckSleepingSynapses() {
|
|
foreach (Synapse synapse in this.synapses) {
|
|
if (synapse.isSleeping) {
|
|
synapse.neuron.outputValue = Vector3.zero;
|
|
}
|
|
}
|
|
}
|
|
|
|
#region Combinator
|
|
|
|
#if UNITY_MATHEMATICS
|
|
|
|
protected Func<float3> Combinator => combinator switch {
|
|
CombinatorType.Sum => CombinatorSum,
|
|
CombinatorType.Product => CombinatorProduct,
|
|
CombinatorType.Max => CombinatorMax,
|
|
_ => CombinatorSum
|
|
};
|
|
|
|
public float3 CombinatorSum() {
|
|
float3 sum = this.bias;
|
|
foreach (Synapse synapse in this.synapses)
|
|
sum += synapse.weight * synapse.neuron.outputValue;
|
|
return sum;
|
|
}
|
|
|
|
public float3 CombinatorProduct() {
|
|
float3 product = this.bias;
|
|
foreach (Synapse synapse in this.synapses) {
|
|
product *= synapse.weight * synapse.neuron.outputValue;
|
|
}
|
|
return product;
|
|
}
|
|
|
|
public float3 CombinatorMax() {
|
|
float3 max = this.bias;
|
|
float maxLength = length(max);
|
|
|
|
//Applying the weight factors
|
|
foreach (Synapse synapse in this.synapses) {
|
|
float3 input = synapse.weight * synapse.neuron.outputValue;
|
|
|
|
float inputLength = length(input);
|
|
if (inputLength > maxLength) {
|
|
max = input;
|
|
maxLength = inputLength;
|
|
}
|
|
}
|
|
return max;
|
|
}
|
|
|
|
#else
|
|
|
|
protected Func<Vector3> Combinator => combinator switch {
|
|
CombinatorType.Sum => CombinatorSum,
|
|
CombinatorType.Product => CombinatorProduct,
|
|
CombinatorType.Max => CombinatorMax,
|
|
_ => CombinatorSum
|
|
};
|
|
|
|
public Vector3 CombinatorSum() {
|
|
Vector3 sum = this.bias;
|
|
foreach (Synapse synapse in this.synapses)
|
|
sum += synapse.weight * synapse.neuron.outputValue;
|
|
return sum;
|
|
}
|
|
|
|
public Vector3 CombinatorProduct() {
|
|
Vector3 product = this.bias;
|
|
foreach (Synapse synapse in this.synapses) {
|
|
//product *= synapse.weight * synapse.neuron.outputValue;
|
|
product = Vector3.Scale(product, synapse.weight * synapse.neuron.outputValue);
|
|
}
|
|
return product;
|
|
}
|
|
|
|
public Vector3 CombinatorMax() {
|
|
Vector3 max = this.bias;
|
|
float maxLength = max.magnitude;
|
|
|
|
//Applying the weight factors
|
|
foreach (Synapse synapse in this.synapses) {
|
|
Vector3 input = synapse.weight * synapse.neuron.outputValue;
|
|
|
|
float inputLength = input.magnitude;
|
|
if (inputLength > maxLength) {
|
|
max = input;
|
|
maxLength = inputLength;
|
|
}
|
|
}
|
|
return max;
|
|
}
|
|
#endif
|
|
#endregion Combinator
|
|
|
|
#region Activator
|
|
|
|
#if UNITY_MATHEMATICS
|
|
|
|
public Func<float3, float3> Activator => this.curvePreset switch {
|
|
ActivationType.Linear => ActivatorLinear,
|
|
ActivationType.Sqrt => ActivatorSqrt,
|
|
ActivationType.Power => ActivatorPower,
|
|
ActivationType.Reciprocal => ActivatorReciprocal,
|
|
ActivationType.Tanh => ActivatorTanh,
|
|
ActivationType.Binary => ActivatorBinary,
|
|
ActivationType.Normalized => ActivatorNormalized,
|
|
_ => ActivatorCustom
|
|
};
|
|
|
|
protected float3 ActivatorLinear(float3 input) {
|
|
return input;
|
|
}
|
|
|
|
protected float3 ActivatorSqrt(float3 input) {
|
|
float3 result = normalize(input) * System.MathF.Sqrt(length(input));
|
|
return result;
|
|
}
|
|
|
|
protected float3 ActivatorPower(float3 input) {
|
|
float3 result = normalize(input) * System.MathF.Pow(length(input), 2);
|
|
return result;
|
|
}
|
|
|
|
protected float3 ActivatorReciprocal(float3 input) {
|
|
float magnitude = length(input);
|
|
if (magnitude == 0)
|
|
return new float3(0, 0, 0);
|
|
|
|
float3 result = normalize(input) * (1 / magnitude);
|
|
return result;
|
|
}
|
|
|
|
protected float3 ActivatorTanh(float3 input) {
|
|
float magnitude = length(input);
|
|
float3 result = normalize(input) * MathF.Tanh(magnitude);
|
|
return result;
|
|
}
|
|
protected float3 ActivatorBinary(float3 input) {
|
|
float magnitude = length(input);
|
|
float value = Mathf.Clamp01(magnitude);
|
|
return float3(value, value, value);
|
|
}
|
|
|
|
protected float3 ActivatorNormalized(float3 input) {
|
|
if (lengthsq(input) == 0)
|
|
return input;
|
|
float3 result = normalize(input);
|
|
return result;
|
|
}
|
|
|
|
protected float3 ActivatorCustom(float3 input) {
|
|
float activatedValue = this.curve.Evaluate(length(input));
|
|
float3 result = normalize(input) * activatedValue;
|
|
return result;
|
|
}
|
|
|
|
#else
|
|
|
|
public Func<Vector3, Vector3> Activator => this.curvePreset switch {
|
|
CurvePresets.Linear => ActivatorLinear,
|
|
CurvePresets.Sqrt => ActivatorSqrt,
|
|
CurvePresets.Power => ActivatorPower,
|
|
CurvePresets.Reciprocal => ActivatorReciprocal,
|
|
_ => ActivatorCustom
|
|
};
|
|
|
|
protected Vector3 ActivatorLinear(Vector3 input) {
|
|
return input;
|
|
}
|
|
|
|
protected Vector3 ActivatorSqrt(Vector3 input) {
|
|
Vector3 result = input.normalized * System.MathF.Sqrt(input.magnitude);
|
|
return result;
|
|
}
|
|
|
|
protected Vector3 ActivatorPower(Vector3 input) {
|
|
Vector3 result = input.normalized * System.MathF.Pow(input.magnitude, 2);
|
|
return result;
|
|
}
|
|
|
|
protected Vector3 ActivatorReciprocal(Vector3 input) {
|
|
float magnitude = input.magnitude;
|
|
if (magnitude == 0)
|
|
return new Vector3(0, 0, 0);
|
|
|
|
Vector3 result = input.normalized * (1 / magnitude);
|
|
return result;
|
|
}
|
|
|
|
protected Vector3 ActivatorCustom(Vector3 input) {
|
|
float activatedValue = this.curve.Evaluate(input.magnitude);
|
|
Vector3 result = input.normalized * activatedValue;
|
|
return result;
|
|
}
|
|
|
|
#endif
|
|
|
|
#endregion Activator
|
|
|
|
#region Receivers
|
|
|
|
[SerializeReference]
|
|
private List<Nucleus> _receivers = new();
|
|
public virtual List<Nucleus> receivers {
|
|
get { return _receivers; }
|
|
set { _receivers = value; }
|
|
}
|
|
|
|
public virtual void AddReceiver(Nucleus receiverToAdd, float weight = 1) {
|
|
this._receivers.Add(receiverToAdd);
|
|
receiverToAdd.AddSynapse(this, weight);
|
|
}
|
|
|
|
public virtual void RemoveReceiver(Nucleus receiverToRemove) {
|
|
// if (this is IReceptor receptor) {
|
|
// foreach (Nucleus element in receptor.nucleiArray) {
|
|
// if (element is Neuron neuron) {
|
|
// neuron._receivers.RemoveAll(receiver => receiver == receiverToRemove);
|
|
// receiverToRemove.synapses.RemoveAll(synapse => synapse.neuron == neuron);
|
|
// }
|
|
// }
|
|
// }
|
|
// else {
|
|
this._receivers.RemoveAll(receiver => receiver == receiverToRemove);
|
|
receiverToRemove.synapses.RemoveAll(synapse => synapse.neuron == this);
|
|
// }
|
|
}
|
|
|
|
|
|
#endregion Receivers
|
|
|
|
public override void ProcessStimulus(Vector3 inputValue) { //, int thingId = 0, string thingName = null) {
|
|
// if (this.parent is ClusterReceptor clusterReceptor)
|
|
// clusterReceptor.ProcessStimulus(this, inputValue, thingId, thingName);
|
|
// else
|
|
ProcessStimulusDirect(inputValue); //, thingId, thingName);
|
|
}
|
|
|
|
public void ProcessStimulusDirect(Vector3 inputValue) { //}, int thingId = 0, string thingName = null) {
|
|
// this.stale = 0;
|
|
this.lastUpdate = Time.time;
|
|
this.bias = inputValue;
|
|
this.parent.UpdateFromNucleus(this);
|
|
}
|
|
}
|
|
|
|
} |