cc9a845 Fix sleeping for product combinator e4ba7f8 Better cross-cluster monitoring 4f8a6ab Improved (but not fixed) cross-cluster monitoring b12616b Fix neuron output visualisation 96439cc Visualize all outputs d583e67 WIP cluster references/instance 04bab92 Fix links to multiple cluster neurons & cleanup e17a249 Cross-cluster editor links 0ab2d21 Migrating and cleaning up b6630ad First steps to using instanceCount for clusters 8801fa2 Cluster reimport fixes befb69d full graph with collapsed clusters 1a1919f Fix expansion of clsuter arrays c708f4d Improved clusterarray support c2e4e1b Fix Cluster array extension 02047a4 Adde full graph scrollbar 471ed36 Completed full graph integration 830e3e7 Added full graph view mode 249e888 Improve full graph view 308a6a1 The Entities are battling 75d9d1c Cleanup c8f0f0c Fix aging of neurons e2e169c small fixes 619ced6 Removed the use of Receptors 19f9296 Simplifications bc0a796 Integrated clusterarray in cluster e40dd23 Fixed clusterViewer for clusterarrays b0f4b41 Status quo adding clusterArrays 1fc75a8 Added ClusterArray 0023920 Cover seeking(-ish) behaviour 1c7b8e7 Added Tanh Activation a99d40c BrainViewer added db43655 Pew pew! 18ef4cd Merge commit '89017475984bbbf1899fb38846c5bb0e7775dedd' into NanoBrain git-subtree-dir: NanoBrain git-subtree-split: cc9a845b643ffb4a9abe4f7da787ac5c5b14dae8
541 lines
19 KiB
C#
541 lines
19 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using UnityEngine;
|
|
using UnityEditor;
|
|
#if UNITY_MATHEMATICS
|
|
using Unity.Mathematics;
|
|
using static Unity.Mathematics.math;
|
|
#endif
|
|
|
|
namespace NanoBrain {
|
|
|
|
/// <summary>
|
|
/// A neuron is a basic Nucleus
|
|
/// </summary>
|
|
[Serializable]
|
|
public class Neuron : Nucleus {
|
|
|
|
/// <summary>
|
|
/// Create a new Neuron in a Cluster instance
|
|
/// </summary>
|
|
/// <param name="parent">The parent cluster in which the new Neuron should be created</param>
|
|
/// <param name="name">The name of the new Neuron</param>
|
|
public Neuron(Cluster parent, string name) {
|
|
this.parent = parent;
|
|
this.name = name;
|
|
this.parent?.clusterNuclei.Add(this);
|
|
}
|
|
/// <summary>
|
|
/// Create a new Neuron in a Cluster Prefab
|
|
/// </summary>
|
|
/// <param name="prefab">The Cluster Preafb in which the new Neuron should be created</param>
|
|
/// <param name="name">The name of the new Neuron</param>
|
|
public Neuron(ClusterPrefab prefab, string name) {
|
|
this.clusterPrefab = prefab;
|
|
this.name = name;
|
|
if (this.clusterPrefab != null)
|
|
this.clusterPrefab.nuclei.Add(this);
|
|
else
|
|
Debug.LogError("No prefab when adding neuron to prefab");
|
|
}
|
|
|
|
#region Serialization
|
|
|
|
/// <summary>
|
|
/// The type of combinators
|
|
/// </summary>
|
|
/// A combinator combines the weighted values of the synapses to a single value
|
|
public enum CombinatorType {
|
|
/// Add the weighted values together
|
|
Sum,
|
|
/// Multiply the weighted values
|
|
Product,
|
|
/// Take the maximum of all the weighted values
|
|
Max,
|
|
}
|
|
/// <summary>
|
|
/// The type of combinator used for this Neuron
|
|
/// </summary>
|
|
public CombinatorType combinator = CombinatorType.Sum;
|
|
|
|
/// <summary>
|
|
/// The type of
|
|
/// </summary>
|
|
public enum ActivationType {
|
|
Linear,
|
|
Power,
|
|
Sqrt,
|
|
Reciprocal,
|
|
Tanh,
|
|
Binary,
|
|
Normalized,
|
|
Custom
|
|
}
|
|
[SerializeField]
|
|
public ActivationType _curvePreset;
|
|
public ActivationType curvePreset {
|
|
get { return _curvePreset; }
|
|
set {
|
|
_curvePreset = value;
|
|
this.curve = GenerateCurve();
|
|
}
|
|
}
|
|
public AnimationCurve curve;
|
|
public float curveMax = 1.0f;
|
|
|
|
public AnimationCurve GenerateCurve() {
|
|
switch (this.curvePreset) {
|
|
case ActivationType.Linear:
|
|
this.curveMax = 1;
|
|
return Presets.Linear(1);
|
|
case ActivationType.Power:
|
|
this.curveMax = 1;
|
|
return Presets.Power(2.0f, 1);
|
|
case ActivationType.Sqrt:
|
|
this.curveMax = 1;
|
|
return Presets.Power(0.5f, 1);
|
|
case ActivationType.Reciprocal:
|
|
this.curveMax = 1 / 0.01f * 1;
|
|
return Presets.Reciprocal(1);
|
|
case ActivationType.Tanh:
|
|
this.curveMax = 1;
|
|
return Presets.Tanh(1);
|
|
case ActivationType.Binary:
|
|
this.curveMax = 1;
|
|
return Presets.Binary();
|
|
case ActivationType.Normalized:
|
|
this.curveMax = 1;
|
|
return Presets.Binary();
|
|
default:
|
|
this.curveMax = 1;
|
|
return this.curve;
|
|
}
|
|
}
|
|
|
|
public static class Presets {
|
|
private const int samples = 32;
|
|
public static AnimationCurve Linear(float weight) {
|
|
return AnimationCurve.Linear(0f, 0f, 1000f, weight * 1000);
|
|
}
|
|
public static AnimationCurve Power(float exponent, float weight) {
|
|
// build keyframes
|
|
Keyframe[] keys = new Keyframe[samples];
|
|
for (int i = 0; i < samples; i++) {
|
|
float t = i / (float)(samples - 1);
|
|
float v = Mathf.Pow(t, exponent) * weight;
|
|
keys[i] = new Keyframe(t, v);
|
|
}
|
|
|
|
AnimationCurve curve = new(keys);
|
|
|
|
// set tangent modes for each key to Auto (smooth). Use Linear if you prefer straight segments.
|
|
for (int i = 0; i < curve.length; i++) {
|
|
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
|
|
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
|
|
}
|
|
|
|
return curve;
|
|
}
|
|
public static AnimationCurve Reciprocal(float weight) {
|
|
int samples = 128;
|
|
float xMin = 0.001f;
|
|
float xMax = 1;
|
|
var keys = new Keyframe[samples];
|
|
for (int i = 0; i < samples; i++) {
|
|
float t = i / (float)(samples - 1);
|
|
float x = Mathf.Lerp(xMin, xMax, t);
|
|
float y = 1f / x * weight;
|
|
keys[i] = new Keyframe(x, y);
|
|
}
|
|
var curve = new AnimationCurve(keys);
|
|
for (int i = 0; i < curve.length; i++) {
|
|
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
|
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
|
}
|
|
return curve;
|
|
}
|
|
public static AnimationCurve Tanh(float weight) {
|
|
//int samples = 128;
|
|
float xMin = 0.001f;
|
|
float xMax = 1;
|
|
var keys = new Keyframe[samples];
|
|
for (int i = 0; i < samples; i++) {
|
|
float t = i / (float)(samples - 1);
|
|
float x = Mathf.Lerp(xMin, xMax, t);
|
|
float y = MathF.Tanh(x * weight);
|
|
keys[i] = new Keyframe(x, y);
|
|
}
|
|
var curve = new AnimationCurve(keys);
|
|
for (int i = 0; i < curve.length; i++) {
|
|
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
|
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
|
}
|
|
return curve;
|
|
|
|
}
|
|
public static AnimationCurve Binary() {
|
|
return AnimationCurve.Linear(0, 0, 1, 1);
|
|
}
|
|
}
|
|
|
|
#endregion Serialization
|
|
|
|
#if UNITY_MATHEMATICS
|
|
|
|
protected float3 _outputValue;
|
|
public virtual float3 outputValue {
|
|
get { return _outputValue; }
|
|
set {
|
|
_outputValue = value;
|
|
if (this.isFiring)
|
|
WhenFiring?.Invoke();
|
|
}
|
|
}
|
|
public float outputMagnitude => length(_outputValue);
|
|
public float outputSqrMagnitude => lengthsq(_outputValue);
|
|
|
|
#else
|
|
|
|
protected Vector3 _outputValue;
|
|
public virtual Vector3 outputValue {
|
|
get { return _outputValue; }
|
|
set {
|
|
_outputValue = value;
|
|
if (this.isFiring)
|
|
WhenFiring?.Invoke();
|
|
}
|
|
}
|
|
public float outputMagnitude => _outputValue.magnitude;
|
|
public float outputSqrMagnitude => _outputValue.sqrMagnitude;
|
|
|
|
#endif
|
|
public bool isFiring {
|
|
get {
|
|
SleepCheck();
|
|
return this.outputMagnitude > 0.5f;
|
|
}
|
|
}
|
|
public Action WhenFiring;
|
|
|
|
|
|
public virtual bool isSleeping => Time.time - this.lastUpdate > this.timeToSleep; //this.outputMagnitude == 0;
|
|
public void SleepCheck() {
|
|
if (this.isSleeping) {
|
|
#if UNITY_MATHEMATICS
|
|
this._outputValue = new float3(0, 0, 0);
|
|
#else
|
|
this._outputValue = new Vector3(0,0,0);
|
|
#endif
|
|
}
|
|
}
|
|
|
|
[NonSerialized]
|
|
public float lastUpdate = 0;
|
|
public readonly float timeToSleep = 1f;
|
|
|
|
/// \copydoc NanoBrain::Nucleus::ShallowCloneTo
|
|
public override Nucleus ShallowCloneTo(Cluster newParent) {
|
|
Neuron clone = new(newParent, this.name);
|
|
CloneFields(clone);
|
|
return clone;
|
|
}
|
|
|
|
/// \copydoc NanoBrain::Nucleus::Clone
|
|
public override Nucleus Clone(ClusterPrefab prefab) {
|
|
Neuron clone = new(prefab, this.name);
|
|
CloneFields(clone);
|
|
foreach (Synapse synapse in this.synapses) {
|
|
Synapse clonedSynapse = clone.AddSynapse(synapse.neuron);
|
|
clonedSynapse.weight = synapse.weight;
|
|
}
|
|
foreach (Nucleus receiver in this.receivers) {
|
|
clone.AddReceiver(receiver);
|
|
}
|
|
return clone;
|
|
}
|
|
|
|
protected virtual void CloneFields(Neuron clone) {
|
|
clone.clusterPrefab = this.clusterPrefab;
|
|
clone.bias = this.bias;
|
|
clone.combinator = this.combinator;
|
|
clone.curve = this.curve;
|
|
clone.curvePreset = this.curvePreset;
|
|
clone.curveMax = this.curveMax;
|
|
}
|
|
|
|
public static void Delete(Nucleus nucleus) {
|
|
foreach (Synapse synapse in nucleus.synapses) {
|
|
if (synapse.neuron is Neuron synapse_nucleus) {
|
|
if (synapse_nucleus.receivers.Count > 1) {
|
|
// there is another nucleus feeding into this input nucleus
|
|
synapse_nucleus.receivers.RemoveAll(r => r == nucleus);
|
|
}
|
|
else {
|
|
// No other links, delete it.
|
|
Neuron.Delete(synapse_nucleus);
|
|
}
|
|
}
|
|
}
|
|
if (nucleus is Neuron neuron) {
|
|
foreach (Nucleus receiver in neuron.receivers) {
|
|
if (receiver != null && receiver.synapses != null)
|
|
receiver.synapses.RemoveAll(s => s.neuron == nucleus);
|
|
}
|
|
}
|
|
else if (nucleus is Cluster cluster) {
|
|
// remove all receivers for this cluster
|
|
foreach (Nucleus clusterNucleus in cluster.clusterNuclei) {
|
|
if (clusterNucleus is Neuron output) {
|
|
foreach (Nucleus receiver in output.receivers) {
|
|
receiver.synapses.RemoveAll(s => s.neuron == output);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
if (nucleus.clusterPrefab != null) {
|
|
nucleus.clusterPrefab.nuclei.RemoveAll(n => n == nucleus);
|
|
nucleus.clusterPrefab.RefreshOutputs();
|
|
nucleus.clusterPrefab.GarbageCollection();
|
|
}
|
|
}
|
|
|
|
public override void UpdateStateIsolated() {
|
|
CheckSleepingSynapses();
|
|
var result = Combinator();
|
|
this.outputValue = Activator(result);
|
|
this.lastUpdate = Time.time;
|
|
}
|
|
|
|
protected void CheckSleepingSynapses() {
|
|
foreach (Synapse synapse in this.synapses) {
|
|
if (synapse.isSleeping) {
|
|
synapse.neuron.outputValue = Vector3.zero;
|
|
}
|
|
}
|
|
}
|
|
|
|
#region Combinator
|
|
|
|
#if UNITY_MATHEMATICS
|
|
|
|
protected Func<float3> Combinator => combinator switch {
|
|
CombinatorType.Sum => CombinatorSum,
|
|
CombinatorType.Product => CombinatorProduct,
|
|
CombinatorType.Max => CombinatorMax,
|
|
_ => CombinatorSum
|
|
};
|
|
|
|
public float3 CombinatorSum() {
|
|
float3 sum = this.bias;
|
|
foreach (Synapse synapse in this.synapses)
|
|
sum += synapse.weight * synapse.neuron.outputValue;
|
|
return sum;
|
|
}
|
|
|
|
public float3 CombinatorProduct() {
|
|
float3 product = this.bias;
|
|
foreach (Synapse synapse in this.synapses) {
|
|
product *= synapse.weight * synapse.neuron.outputValue;
|
|
}
|
|
return product;
|
|
}
|
|
|
|
public float3 CombinatorMax() {
|
|
float3 max = this.bias;
|
|
float maxLength = length(max);
|
|
|
|
//Applying the weight factors
|
|
foreach (Synapse synapse in this.synapses) {
|
|
float3 input = synapse.weight * synapse.neuron.outputValue;
|
|
|
|
float inputLength = length(input);
|
|
if (inputLength > maxLength) {
|
|
max = input;
|
|
maxLength = inputLength;
|
|
}
|
|
}
|
|
return max;
|
|
}
|
|
|
|
#else
|
|
|
|
protected Func<Vector3> Combinator => combinator switch {
|
|
CombinatorType.Sum => CombinatorSum,
|
|
CombinatorType.Product => CombinatorProduct,
|
|
CombinatorType.Max => CombinatorMax,
|
|
_ => CombinatorSum
|
|
};
|
|
|
|
public Vector3 CombinatorSum() {
|
|
Vector3 sum = this.bias;
|
|
foreach (Synapse synapse in this.synapses)
|
|
sum += synapse.weight * synapse.neuron.outputValue;
|
|
return sum;
|
|
}
|
|
|
|
public Vector3 CombinatorProduct() {
|
|
Vector3 product = this.bias;
|
|
foreach (Synapse synapse in this.synapses) {
|
|
//product *= synapse.weight * synapse.neuron.outputValue;
|
|
product = Vector3.Scale(product, synapse.weight * synapse.neuron.outputValue);
|
|
}
|
|
return product;
|
|
}
|
|
|
|
public Vector3 CombinatorMax() {
|
|
Vector3 max = this.bias;
|
|
float maxLength = max.magnitude;
|
|
|
|
//Applying the weight factors
|
|
foreach (Synapse synapse in this.synapses) {
|
|
Vector3 input = synapse.weight * synapse.neuron.outputValue;
|
|
|
|
float inputLength = input.magnitude;
|
|
if (inputLength > maxLength) {
|
|
max = input;
|
|
maxLength = inputLength;
|
|
}
|
|
}
|
|
return max;
|
|
}
|
|
#endif
|
|
#endregion Combinator
|
|
|
|
#region Activator
|
|
|
|
#if UNITY_MATHEMATICS
|
|
|
|
public Func<float3, float3> Activator => this.curvePreset switch {
|
|
ActivationType.Linear => ActivatorLinear,
|
|
ActivationType.Sqrt => ActivatorSqrt,
|
|
ActivationType.Power => ActivatorPower,
|
|
ActivationType.Reciprocal => ActivatorReciprocal,
|
|
ActivationType.Tanh => ActivatorTanh,
|
|
ActivationType.Binary => ActivatorBinary,
|
|
ActivationType.Normalized => ActivatorNormalized,
|
|
_ => ActivatorCustom
|
|
};
|
|
|
|
protected float3 ActivatorLinear(float3 input) {
|
|
return input;
|
|
}
|
|
|
|
protected float3 ActivatorSqrt(float3 input) {
|
|
float3 result = normalize(input) * System.MathF.Sqrt(length(input));
|
|
return result;
|
|
}
|
|
|
|
protected float3 ActivatorPower(float3 input) {
|
|
float3 result = normalize(input) * System.MathF.Pow(length(input), 2);
|
|
return result;
|
|
}
|
|
|
|
protected float3 ActivatorReciprocal(float3 input) {
|
|
float magnitude = length(input);
|
|
if (magnitude == 0)
|
|
return new float3(0, 0, 0);
|
|
|
|
float3 result = normalize(input) * (1 / magnitude);
|
|
return result;
|
|
}
|
|
|
|
protected float3 ActivatorTanh(float3 input) {
|
|
float magnitude = length(input);
|
|
float3 result = normalize(input) * MathF.Tanh(magnitude);
|
|
return result;
|
|
}
|
|
protected float3 ActivatorBinary(float3 input) {
|
|
float magnitude = length(input);
|
|
float value = Mathf.Clamp01(magnitude);
|
|
return float3(value, value, value);
|
|
}
|
|
|
|
protected float3 ActivatorNormalized(float3 input) {
|
|
if (lengthsq(input) == 0)
|
|
return input;
|
|
float3 result = normalize(input);
|
|
return result;
|
|
}
|
|
|
|
protected float3 ActivatorCustom(float3 input) {
|
|
float activatedValue = this.curve.Evaluate(length(input));
|
|
float3 result = normalize(input) * activatedValue;
|
|
return result;
|
|
}
|
|
|
|
#else
|
|
|
|
public Func<Vector3, Vector3> Activator => this.curvePreset switch {
|
|
CurvePresets.Linear => ActivatorLinear,
|
|
CurvePresets.Sqrt => ActivatorSqrt,
|
|
CurvePresets.Power => ActivatorPower,
|
|
CurvePresets.Reciprocal => ActivatorReciprocal,
|
|
_ => ActivatorCustom
|
|
};
|
|
|
|
protected Vector3 ActivatorLinear(Vector3 input) {
|
|
return input;
|
|
}
|
|
|
|
protected Vector3 ActivatorSqrt(Vector3 input) {
|
|
Vector3 result = input.normalized * System.MathF.Sqrt(input.magnitude);
|
|
return result;
|
|
}
|
|
|
|
protected Vector3 ActivatorPower(Vector3 input) {
|
|
Vector3 result = input.normalized * System.MathF.Pow(input.magnitude, 2);
|
|
return result;
|
|
}
|
|
|
|
protected Vector3 ActivatorReciprocal(Vector3 input) {
|
|
float magnitude = input.magnitude;
|
|
if (magnitude == 0)
|
|
return new Vector3(0, 0, 0);
|
|
|
|
Vector3 result = input.normalized * (1 / magnitude);
|
|
return result;
|
|
}
|
|
|
|
protected Vector3 ActivatorCustom(Vector3 input) {
|
|
float activatedValue = this.curve.Evaluate(input.magnitude);
|
|
Vector3 result = input.normalized * activatedValue;
|
|
return result;
|
|
}
|
|
|
|
#endif
|
|
|
|
#endregion Activator
|
|
|
|
#region Receivers
|
|
|
|
[SerializeReference]
|
|
private List<Nucleus> _receivers = new();
|
|
public virtual List<Nucleus> receivers {
|
|
get { return _receivers; }
|
|
set { _receivers = value; }
|
|
}
|
|
|
|
public virtual void AddReceiver(Nucleus receiverToAdd, float weight = 1) {
|
|
this._receivers.Add(receiverToAdd);
|
|
receiverToAdd.AddSynapse(this, weight);
|
|
}
|
|
|
|
public virtual void RemoveReceiver(Nucleus receiverToRemove) {
|
|
this._receivers.RemoveAll(receiver => receiver == receiverToRemove);
|
|
receiverToRemove.synapses.RemoveAll(synapse => synapse.neuron == this);
|
|
}
|
|
|
|
|
|
#endregion Receivers
|
|
|
|
public override void ProcessStimulus(Vector3 inputValue) {
|
|
;
|
|
this.lastUpdate = Time.time;
|
|
this.bias = inputValue;
|
|
this.parent.UpdateFromNucleus(this);
|
|
}
|
|
}
|
|
|
|
} |