359 lines
12 KiB
C#

using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using UnityEditor;
using Unity.Mathematics;
using static Unity.Mathematics.math;
[Serializable]
public class Neuron : INucleus {
public Neuron(Cluster parent, string name) {
this.parent = parent;
this.name = name;
this.parent?.nuclei.Add(this);
}
public Neuron(ClusterPrefab parent, string name) {
this.cluster = parent;
this.name = name;
if (this.cluster != null) {
this.cluster.nuclei.Add(this);
}
// else
// Debug.LogError("No neuroid network");
}
[SerializeField]
protected string _name;
public virtual string name {
get => _name;
set => _name = value;
}
[SerializeField]
private List<Synapse> _synapses = new();
public List<Synapse> synapses => _synapses;
[SerializeReference]
private List<INucleus> _receivers = new();
public List<INucleus> receivers {
get { return _receivers; }
set { _receivers = value; }
}
[SerializeReference]
private NucleusArray _array;
public NucleusArray array {
get { return _array; }
set { _array = value; }
}
#region Serialization
public enum CurvePresets {
Linear,
Power,
Sqrt,
Reciprocal,
Custom
}
[SerializeField]
private CurvePresets _curvePreset;
public CurvePresets curvePreset {
get { return _curvePreset; }
set {
_curvePreset = value;
this.curve = GenerateCurve();
}
}
public AnimationCurve curve;
public float curveMax = 1.0f;
#region Parameters
public bool average = false;
#endregion Parameters
public AnimationCurve GenerateCurve() {
switch (this.curvePreset) {
case CurvePresets.Linear:
this.curveMax = 1;
return Presets.Linear(1);
case CurvePresets.Power:
this.curveMax = 1;
return Presets.Power(2.0f, 1);
case CurvePresets.Sqrt:
this.curveMax = 1;
return Presets.Power(0.5f, 1);
case CurvePresets.Reciprocal:
this.curveMax = 1 / 0.01f * 1;
return Presets.Reciprocal(1);
default:
this.curveMax = 1;
return this.curve;
}
}
public virtual void Deserialize(Neuron nucleus) { }
#endregion Serialization
#region Runtime state (not serialized)
public ClusterPrefab cluster { get; set; }
public Cluster parent { get; set; }
#region Activation
public static class Presets {
private const int samples = 32;
public static AnimationCurve Linear(float weight) {
return AnimationCurve.Linear(0f, 0f, 1000f, weight * 1000);
}
public static AnimationCurve Power(float exponent, float weight) {
// build keyframes
Keyframe[] keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float v = Mathf.Pow(t, exponent) * weight;
keys[i] = new Keyframe(t, v);
}
AnimationCurve curve = new(keys);
// set tangent modes for each key to Auto (smooth). Use Linear if you prefer straight segments.
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
}
return curve;
}
public static AnimationCurve Reciprocal(float weight) {
int samples = 128;
float xMin = 0.001f;
float xMax = 1;
var keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float x = Mathf.Lerp(xMin, xMax, t);
float y = 1f / x * weight;
keys[i] = new Keyframe(x, y);
}
var curve = new AnimationCurve(keys);
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
}
return curve;
}
}
#endregion Activation
protected float3 _outputValue;
public virtual float3 outputValue {
get { return _outputValue; }
set {
this.stale = 0;
// this._isSleeping = false;
_outputValue = value;
}
}
[NonSerialized]
private int stale = 1000;
// private bool _isSleeping = false;
// public bool isSleeping => _isSleeping;
public bool isSleeping => lengthsq(this.outputValue) == 0;
public void UpdateNuclei() {
this.stale++;
// this._isSleeping = this.stale > 2;
// if (isSleeping)
if (this.stale > 2)
_outputValue = Vector3.zero;
}
#endregion Runtime state
// this clone the nucleus without the synapses and receivers
public virtual IReceptor ShallowCloneTo(Cluster newParent) {
Neuron clone = new(newParent, this.name) {
array = this.array,
curve = this.curve,
curvePreset = this.curvePreset,
curveMax = this.curveMax,
average = this.average
};
return clone;
}
public virtual IReceptor Clone() {
Neuron clone = new(this.cluster, this.name) {
array = this.array,
curve = this.curve,
curvePreset = this.curvePreset,
curveMax = this.curveMax,
average = this.average
};
// if (clone.cluster != null)
// clone.cluster.nuclei.Add(clone);
foreach (Synapse synapse in this.synapses) {
Synapse clonedSynapse = clone.AddSynapse(synapse.nucleus);
clonedSynapse.weight = synapse.weight;
}
foreach (INucleus receiver in this.receivers) {
clone.AddReceiver(receiver);
}
return clone;
}
public virtual void AddReceiver(INucleus receivingNucleus, float weight = 1) {
this._receivers.Add(receivingNucleus);
receivingNucleus.AddSynapse(this, weight);
}
public void RemoveReceiver(INucleus receiverNucleus) {
this._receivers.RemoveAll(receiver => receiver == receiverNucleus);
receiverNucleus.synapses.RemoveAll(synapse => synapse.nucleus == this);
}
public static void Delete(INucleus nucleus) {
foreach (Synapse synapse in nucleus.synapses) {
if (synapse.nucleus is Neuron synapse_nucleus) {
if (synapse_nucleus._receivers.Count > 1) {
// there is another nucleus feeding into this input nucleus
synapse_nucleus._receivers.RemoveAll(r => r == nucleus);
}
else {
// No other links, delete it.
Neuron.Delete(synapse_nucleus);
}
}
}
foreach (INucleus receiver in nucleus.receivers) {
if (receiver != null && receiver.synapses != null)
receiver.synapses.RemoveAll(s => s.nucleus == nucleus);
}
if (nucleus.cluster != null) {
nucleus.cluster.nuclei.RemoveAll(n => n == nucleus);
nucleus.cluster.GarbageCollection();
}
}
public Synapse AddSynapse(IReceptor sendingNucleus, float weight = 1.0f) {
Synapse synapse = new(sendingNucleus, weight);
this.synapses.Add(synapse);
return synapse;
}
// public virtual void UpdateState() {
// //UpdateState(new float3(0, 0, 0));
// this.parent?.UpdateState();
// }
// public virtual void UpdateState(float3 inputValue) {
// float3 sum = inputValue;
// int n = 0;
// //Applying the weight factgors
// foreach (Synapse synapse in this.synapses) {
// sum += synapse.weight * synapse.nucleus.outputValue;
// // Perhaps synapses should be removed when the output value goes to 0....
// if (lengthsq(synapse.nucleus.outputValue) != 0)
// n++;
// }
// if (this.average && n > 0)
// sum /= n;
// // Activation function
// Vector3 result;
// switch (this.curvePreset) {
// case CurvePresets.Linear:
// result = sum;
// break;
// case CurvePresets.Sqrt:
// result = normalize(sum) * System.MathF.Sqrt(length(sum));
// break;
// case CurvePresets.Power:
// result = normalize(sum) * System.MathF.Pow(length(sum), 2);
// break;
// case CurvePresets.Reciprocal:
// result = normalize(sum) * (1 / length(sum));
// break;
// default:
// float activatedValue = this.curve.Evaluate(length(sum));
// result = normalize(sum) * activatedValue;
// break;
// }
// UpdateResult(result);
// }
public virtual void UpdateStateIsolated() {
UpdateStateIsolated(new float3(0, 0, 0));
}
public virtual void UpdateStateIsolated(float3 bias) {
float3 sum = bias;
int n = 0;
//Applying the weight factgors
foreach (Synapse synapse in this.synapses) {
sum += synapse.weight * synapse.nucleus.outputValue;
// Perhaps synapses should be removed when the output value goes to 0....
if (lengthsq(synapse.nucleus.outputValue) != 0)
n++;
}
if (this.average && n > 0)
sum /= n;
// Activation function
float3 result = Vector3.zero;
switch (this.curvePreset) {
case CurvePresets.Linear:
result = sum;
break;
case CurvePresets.Sqrt:
result = normalize(sum) * System.MathF.Sqrt(length(sum));
break;
case CurvePresets.Power:
result = normalize(sum) * System.MathF.Pow(length(sum), 2);
break;
case CurvePresets.Reciprocal: {
float magnitude = length(sum);
if (magnitude > 0)
result = normalize(sum) * (1 / magnitude);
break;
}
default:
float activatedValue = this.curve.Evaluate(length(sum));
result = normalize(sum) * activatedValue;
break;
}
this.outputValue = result;
}
// public virtual void UpdateResult(Vector3 result) {
// // float d = Vector3.Distance(result, this.outputValue);
// // if (d < 0.5f) {
// // //Debug.Log($"insignificant update: {d}");
// // return;
// // }
// this.outputValue = result;
// if (lengthsq(outputValue) != 0) {
// Debug.Log($"{this.parent.name}.{this.name}: {this.outputValue}");
// }
// foreach (INucleus receiver in this.receivers)
// receiver.UpdateState();
// }
}