Pascal Serrarens 4ae9a15fc6 Squashed 'NanoBrain/' changes from 832d849..cc9a845
cc9a845 Fix sleeping for product combinator
e4ba7f8 Better cross-cluster monitoring
4f8a6ab Improved (but not fixed) cross-cluster monitoring
b12616b Fix neuron output visualisation
96439cc Visualize all outputs
d583e67 WIP cluster references/instance
04bab92 Fix links to multiple cluster neurons & cleanup
e17a249 Cross-cluster editor links
0ab2d21 Migrating and cleaning up
b6630ad First steps to using instanceCount for clusters
8801fa2 Cluster reimport fixes
befb69d full graph with collapsed clusters
1a1919f Fix expansion of clsuter arrays
c708f4d Improved clusterarray support
c2e4e1b Fix Cluster array extension
02047a4 Adde full graph scrollbar
471ed36 Completed full graph integration
830e3e7 Added full graph view mode
249e888 Improve full graph view
308a6a1 The Entities are battling
75d9d1c Cleanup
c8f0f0c Fix aging of neurons
e2e169c small fixes
619ced6 Removed the use of Receptors
19f9296 Simplifications
bc0a796 Integrated clusterarray in cluster
e40dd23 Fixed clusterViewer for clusterarrays
b0f4b41 Status quo adding clusterArrays
1fc75a8 Added ClusterArray
0023920 Cover seeking(-ish) behaviour
1c7b8e7 Added Tanh Activation
a99d40c BrainViewer added
db43655 Pew pew!
18ef4cd Merge commit '89017475984bbbf1899fb38846c5bb0e7775dedd' into NanoBrain

git-subtree-dir: NanoBrain
git-subtree-split: cc9a845b643ffb4a9abe4f7da787ac5c5b14dae8
2026-04-23 15:22:02 +02:00

541 lines
19 KiB
C#

using System;
using System.Collections.Generic;
using UnityEngine;
using UnityEditor;
#if UNITY_MATHEMATICS
using Unity.Mathematics;
using static Unity.Mathematics.math;
#endif
namespace NanoBrain {
/// <summary>
/// A neuron is a basic Nucleus
/// </summary>
[Serializable]
public class Neuron : Nucleus {
/// <summary>
/// Create a new Neuron in a Cluster instance
/// </summary>
/// <param name="parent">The parent cluster in which the new Neuron should be created</param>
/// <param name="name">The name of the new Neuron</param>
public Neuron(Cluster parent, string name) {
this.parent = parent;
this.name = name;
this.parent?.clusterNuclei.Add(this);
}
/// <summary>
/// Create a new Neuron in a Cluster Prefab
/// </summary>
/// <param name="prefab">The Cluster Preafb in which the new Neuron should be created</param>
/// <param name="name">The name of the new Neuron</param>
public Neuron(ClusterPrefab prefab, string name) {
this.clusterPrefab = prefab;
this.name = name;
if (this.clusterPrefab != null)
this.clusterPrefab.nuclei.Add(this);
else
Debug.LogError("No prefab when adding neuron to prefab");
}
#region Serialization
/// <summary>
/// The type of combinators
/// </summary>
/// A combinator combines the weighted values of the synapses to a single value
public enum CombinatorType {
/// Add the weighted values together
Sum,
/// Multiply the weighted values
Product,
/// Take the maximum of all the weighted values
Max,
}
/// <summary>
/// The type of combinator used for this Neuron
/// </summary>
public CombinatorType combinator = CombinatorType.Sum;
/// <summary>
/// The type of
/// </summary>
public enum ActivationType {
Linear,
Power,
Sqrt,
Reciprocal,
Tanh,
Binary,
Normalized,
Custom
}
[SerializeField]
public ActivationType _curvePreset;
public ActivationType curvePreset {
get { return _curvePreset; }
set {
_curvePreset = value;
this.curve = GenerateCurve();
}
}
public AnimationCurve curve;
public float curveMax = 1.0f;
public AnimationCurve GenerateCurve() {
switch (this.curvePreset) {
case ActivationType.Linear:
this.curveMax = 1;
return Presets.Linear(1);
case ActivationType.Power:
this.curveMax = 1;
return Presets.Power(2.0f, 1);
case ActivationType.Sqrt:
this.curveMax = 1;
return Presets.Power(0.5f, 1);
case ActivationType.Reciprocal:
this.curveMax = 1 / 0.01f * 1;
return Presets.Reciprocal(1);
case ActivationType.Tanh:
this.curveMax = 1;
return Presets.Tanh(1);
case ActivationType.Binary:
this.curveMax = 1;
return Presets.Binary();
case ActivationType.Normalized:
this.curveMax = 1;
return Presets.Binary();
default:
this.curveMax = 1;
return this.curve;
}
}
public static class Presets {
private const int samples = 32;
public static AnimationCurve Linear(float weight) {
return AnimationCurve.Linear(0f, 0f, 1000f, weight * 1000);
}
public static AnimationCurve Power(float exponent, float weight) {
// build keyframes
Keyframe[] keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float v = Mathf.Pow(t, exponent) * weight;
keys[i] = new Keyframe(t, v);
}
AnimationCurve curve = new(keys);
// set tangent modes for each key to Auto (smooth). Use Linear if you prefer straight segments.
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
}
return curve;
}
public static AnimationCurve Reciprocal(float weight) {
int samples = 128;
float xMin = 0.001f;
float xMax = 1;
var keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float x = Mathf.Lerp(xMin, xMax, t);
float y = 1f / x * weight;
keys[i] = new Keyframe(x, y);
}
var curve = new AnimationCurve(keys);
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
}
return curve;
}
public static AnimationCurve Tanh(float weight) {
//int samples = 128;
float xMin = 0.001f;
float xMax = 1;
var keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float x = Mathf.Lerp(xMin, xMax, t);
float y = MathF.Tanh(x * weight);
keys[i] = new Keyframe(x, y);
}
var curve = new AnimationCurve(keys);
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
}
return curve;
}
public static AnimationCurve Binary() {
return AnimationCurve.Linear(0, 0, 1, 1);
}
}
#endregion Serialization
#if UNITY_MATHEMATICS
protected float3 _outputValue;
public virtual float3 outputValue {
get { return _outputValue; }
set {
_outputValue = value;
if (this.isFiring)
WhenFiring?.Invoke();
}
}
public float outputMagnitude => length(_outputValue);
public float outputSqrMagnitude => lengthsq(_outputValue);
#else
protected Vector3 _outputValue;
public virtual Vector3 outputValue {
get { return _outputValue; }
set {
_outputValue = value;
if (this.isFiring)
WhenFiring?.Invoke();
}
}
public float outputMagnitude => _outputValue.magnitude;
public float outputSqrMagnitude => _outputValue.sqrMagnitude;
#endif
public bool isFiring {
get {
SleepCheck();
return this.outputMagnitude > 0.5f;
}
}
public Action WhenFiring;
public virtual bool isSleeping => Time.time - this.lastUpdate > this.timeToSleep; //this.outputMagnitude == 0;
public void SleepCheck() {
if (this.isSleeping) {
#if UNITY_MATHEMATICS
this._outputValue = new float3(0, 0, 0);
#else
this._outputValue = new Vector3(0,0,0);
#endif
}
}
[NonSerialized]
public float lastUpdate = 0;
public readonly float timeToSleep = 1f;
/// \copydoc NanoBrain::Nucleus::ShallowCloneTo
public override Nucleus ShallowCloneTo(Cluster newParent) {
Neuron clone = new(newParent, this.name);
CloneFields(clone);
return clone;
}
/// \copydoc NanoBrain::Nucleus::Clone
public override Nucleus Clone(ClusterPrefab prefab) {
Neuron clone = new(prefab, this.name);
CloneFields(clone);
foreach (Synapse synapse in this.synapses) {
Synapse clonedSynapse = clone.AddSynapse(synapse.neuron);
clonedSynapse.weight = synapse.weight;
}
foreach (Nucleus receiver in this.receivers) {
clone.AddReceiver(receiver);
}
return clone;
}
protected virtual void CloneFields(Neuron clone) {
clone.clusterPrefab = this.clusterPrefab;
clone.bias = this.bias;
clone.combinator = this.combinator;
clone.curve = this.curve;
clone.curvePreset = this.curvePreset;
clone.curveMax = this.curveMax;
}
public static void Delete(Nucleus nucleus) {
foreach (Synapse synapse in nucleus.synapses) {
if (synapse.neuron is Neuron synapse_nucleus) {
if (synapse_nucleus.receivers.Count > 1) {
// there is another nucleus feeding into this input nucleus
synapse_nucleus.receivers.RemoveAll(r => r == nucleus);
}
else {
// No other links, delete it.
Neuron.Delete(synapse_nucleus);
}
}
}
if (nucleus is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers) {
if (receiver != null && receiver.synapses != null)
receiver.synapses.RemoveAll(s => s.neuron == nucleus);
}
}
else if (nucleus is Cluster cluster) {
// remove all receivers for this cluster
foreach (Nucleus clusterNucleus in cluster.clusterNuclei) {
if (clusterNucleus is Neuron output) {
foreach (Nucleus receiver in output.receivers) {
receiver.synapses.RemoveAll(s => s.neuron == output);
}
}
}
}
if (nucleus.clusterPrefab != null) {
nucleus.clusterPrefab.nuclei.RemoveAll(n => n == nucleus);
nucleus.clusterPrefab.RefreshOutputs();
nucleus.clusterPrefab.GarbageCollection();
}
}
public override void UpdateStateIsolated() {
CheckSleepingSynapses();
var result = Combinator();
this.outputValue = Activator(result);
this.lastUpdate = Time.time;
}
protected void CheckSleepingSynapses() {
foreach (Synapse synapse in this.synapses) {
if (synapse.isSleeping) {
synapse.neuron.outputValue = Vector3.zero;
}
}
}
#region Combinator
#if UNITY_MATHEMATICS
protected Func<float3> Combinator => combinator switch {
CombinatorType.Sum => CombinatorSum,
CombinatorType.Product => CombinatorProduct,
CombinatorType.Max => CombinatorMax,
_ => CombinatorSum
};
public float3 CombinatorSum() {
float3 sum = this.bias;
foreach (Synapse synapse in this.synapses)
sum += synapse.weight * synapse.neuron.outputValue;
return sum;
}
public float3 CombinatorProduct() {
float3 product = this.bias;
foreach (Synapse synapse in this.synapses) {
product *= synapse.weight * synapse.neuron.outputValue;
}
return product;
}
public float3 CombinatorMax() {
float3 max = this.bias;
float maxLength = length(max);
//Applying the weight factors
foreach (Synapse synapse in this.synapses) {
float3 input = synapse.weight * synapse.neuron.outputValue;
float inputLength = length(input);
if (inputLength > maxLength) {
max = input;
maxLength = inputLength;
}
}
return max;
}
#else
protected Func<Vector3> Combinator => combinator switch {
CombinatorType.Sum => CombinatorSum,
CombinatorType.Product => CombinatorProduct,
CombinatorType.Max => CombinatorMax,
_ => CombinatorSum
};
public Vector3 CombinatorSum() {
Vector3 sum = this.bias;
foreach (Synapse synapse in this.synapses)
sum += synapse.weight * synapse.neuron.outputValue;
return sum;
}
public Vector3 CombinatorProduct() {
Vector3 product = this.bias;
foreach (Synapse synapse in this.synapses) {
//product *= synapse.weight * synapse.neuron.outputValue;
product = Vector3.Scale(product, synapse.weight * synapse.neuron.outputValue);
}
return product;
}
public Vector3 CombinatorMax() {
Vector3 max = this.bias;
float maxLength = max.magnitude;
//Applying the weight factors
foreach (Synapse synapse in this.synapses) {
Vector3 input = synapse.weight * synapse.neuron.outputValue;
float inputLength = input.magnitude;
if (inputLength > maxLength) {
max = input;
maxLength = inputLength;
}
}
return max;
}
#endif
#endregion Combinator
#region Activator
#if UNITY_MATHEMATICS
public Func<float3, float3> Activator => this.curvePreset switch {
ActivationType.Linear => ActivatorLinear,
ActivationType.Sqrt => ActivatorSqrt,
ActivationType.Power => ActivatorPower,
ActivationType.Reciprocal => ActivatorReciprocal,
ActivationType.Tanh => ActivatorTanh,
ActivationType.Binary => ActivatorBinary,
ActivationType.Normalized => ActivatorNormalized,
_ => ActivatorCustom
};
protected float3 ActivatorLinear(float3 input) {
return input;
}
protected float3 ActivatorSqrt(float3 input) {
float3 result = normalize(input) * System.MathF.Sqrt(length(input));
return result;
}
protected float3 ActivatorPower(float3 input) {
float3 result = normalize(input) * System.MathF.Pow(length(input), 2);
return result;
}
protected float3 ActivatorReciprocal(float3 input) {
float magnitude = length(input);
if (magnitude == 0)
return new float3(0, 0, 0);
float3 result = normalize(input) * (1 / magnitude);
return result;
}
protected float3 ActivatorTanh(float3 input) {
float magnitude = length(input);
float3 result = normalize(input) * MathF.Tanh(magnitude);
return result;
}
protected float3 ActivatorBinary(float3 input) {
float magnitude = length(input);
float value = Mathf.Clamp01(magnitude);
return float3(value, value, value);
}
protected float3 ActivatorNormalized(float3 input) {
if (lengthsq(input) == 0)
return input;
float3 result = normalize(input);
return result;
}
protected float3 ActivatorCustom(float3 input) {
float activatedValue = this.curve.Evaluate(length(input));
float3 result = normalize(input) * activatedValue;
return result;
}
#else
public Func<Vector3, Vector3> Activator => this.curvePreset switch {
CurvePresets.Linear => ActivatorLinear,
CurvePresets.Sqrt => ActivatorSqrt,
CurvePresets.Power => ActivatorPower,
CurvePresets.Reciprocal => ActivatorReciprocal,
_ => ActivatorCustom
};
protected Vector3 ActivatorLinear(Vector3 input) {
return input;
}
protected Vector3 ActivatorSqrt(Vector3 input) {
Vector3 result = input.normalized * System.MathF.Sqrt(input.magnitude);
return result;
}
protected Vector3 ActivatorPower(Vector3 input) {
Vector3 result = input.normalized * System.MathF.Pow(input.magnitude, 2);
return result;
}
protected Vector3 ActivatorReciprocal(Vector3 input) {
float magnitude = input.magnitude;
if (magnitude == 0)
return new Vector3(0, 0, 0);
Vector3 result = input.normalized * (1 / magnitude);
return result;
}
protected Vector3 ActivatorCustom(Vector3 input) {
float activatedValue = this.curve.Evaluate(input.magnitude);
Vector3 result = input.normalized * activatedValue;
return result;
}
#endif
#endregion Activator
#region Receivers
[SerializeReference]
private List<Nucleus> _receivers = new();
public virtual List<Nucleus> receivers {
get { return _receivers; }
set { _receivers = value; }
}
public virtual void AddReceiver(Nucleus receiverToAdd, float weight = 1) {
this._receivers.Add(receiverToAdd);
receiverToAdd.AddSynapse(this, weight);
}
public virtual void RemoveReceiver(Nucleus receiverToRemove) {
this._receivers.RemoveAll(receiver => receiver == receiverToRemove);
receiverToRemove.synapses.RemoveAll(synapse => synapse.neuron == this);
}
#endregion Receivers
public override void ProcessStimulus(Vector3 inputValue) {
;
this.lastUpdate = Time.time;
this.bias = inputValue;
this.parent.UpdateFromNucleus(this);
}
}
}