241 lines
8.0 KiB
C#
241 lines
8.0 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
using UnityEngine;
|
|
using UnityEditor;
|
|
using Unity.Mathematics;
|
|
using static Unity.Mathematics.math;
|
|
|
|
[Serializable]
|
|
public class Neuron : Nucleus {
|
|
|
|
public Neuron(Cluster parent, string name) {
|
|
this.parent = parent;
|
|
this.name = name;
|
|
this.parent?.nuclei.Add(this);
|
|
}
|
|
public Neuron(ClusterPrefab parent, string name) {
|
|
this.cluster = parent;
|
|
this.name = name;
|
|
if (this.cluster != null) {
|
|
this.cluster.nuclei.Add(this);
|
|
}
|
|
// else
|
|
// Debug.LogError("No neuroid network");
|
|
}
|
|
|
|
#region Serialization
|
|
|
|
public enum CurvePresets {
|
|
Linear,
|
|
Power,
|
|
Sqrt,
|
|
Reciprocal,
|
|
Custom
|
|
}
|
|
[SerializeField]
|
|
private CurvePresets _curvePreset;
|
|
public CurvePresets curvePreset {
|
|
get { return _curvePreset; }
|
|
set {
|
|
_curvePreset = value;
|
|
this.curve = GenerateCurve();
|
|
}
|
|
}
|
|
public AnimationCurve curve;
|
|
public float curveMax = 1.0f;
|
|
|
|
#region Parameters
|
|
|
|
public bool average = false;
|
|
|
|
#endregion Parameters
|
|
|
|
public AnimationCurve GenerateCurve() {
|
|
switch (this.curvePreset) {
|
|
case CurvePresets.Linear:
|
|
this.curveMax = 1;
|
|
return Presets.Linear(1);
|
|
case CurvePresets.Power:
|
|
this.curveMax = 1;
|
|
return Presets.Power(2.0f, 1);
|
|
case CurvePresets.Sqrt:
|
|
this.curveMax = 1;
|
|
return Presets.Power(0.5f, 1);
|
|
case CurvePresets.Reciprocal:
|
|
this.curveMax = 1 / 0.01f * 1;
|
|
return Presets.Reciprocal(1);
|
|
default:
|
|
this.curveMax = 1;
|
|
return this.curve;
|
|
}
|
|
}
|
|
|
|
public virtual void Deserialize(Neuron nucleus) { }
|
|
|
|
#endregion Serialization
|
|
|
|
#region Runtime state (not serialized)
|
|
|
|
#region Activation
|
|
|
|
public static class Presets {
|
|
private const int samples = 32;
|
|
public static AnimationCurve Linear(float weight) {
|
|
return AnimationCurve.Linear(0f, 0f, 1000f, weight * 1000);
|
|
}
|
|
public static AnimationCurve Power(float exponent, float weight) {
|
|
// build keyframes
|
|
Keyframe[] keys = new Keyframe[samples];
|
|
for (int i = 0; i < samples; i++) {
|
|
float t = i / (float)(samples - 1);
|
|
float v = Mathf.Pow(t, exponent) * weight;
|
|
keys[i] = new Keyframe(t, v);
|
|
}
|
|
|
|
AnimationCurve curve = new(keys);
|
|
|
|
// set tangent modes for each key to Auto (smooth). Use Linear if you prefer straight segments.
|
|
for (int i = 0; i < curve.length; i++) {
|
|
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
|
|
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
|
|
}
|
|
|
|
return curve;
|
|
}
|
|
public static AnimationCurve Reciprocal(float weight) {
|
|
int samples = 128;
|
|
float xMin = 0.001f;
|
|
float xMax = 1;
|
|
var keys = new Keyframe[samples];
|
|
for (int i = 0; i < samples; i++) {
|
|
float t = i / (float)(samples - 1);
|
|
float x = Mathf.Lerp(xMin, xMax, t);
|
|
float y = 1f / x * weight;
|
|
keys[i] = new Keyframe(x, y);
|
|
}
|
|
var curve = new AnimationCurve(keys);
|
|
for (int i = 0; i < curve.length; i++) {
|
|
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
|
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
|
}
|
|
return curve;
|
|
}
|
|
}
|
|
|
|
#endregion Activation
|
|
|
|
#endregion Runtime state
|
|
|
|
// this clone the nucleus without the synapses and receivers
|
|
public override Nucleus ShallowCloneTo(Cluster newParent) {
|
|
Neuron clone = new(newParent, this.name) {
|
|
array = null,
|
|
curve = this.curve,
|
|
curvePreset = this.curvePreset,
|
|
curveMax = this.curveMax,
|
|
average = this.average
|
|
};
|
|
return clone;
|
|
}
|
|
|
|
public override Nucleus Clone(ClusterPrefab prefab) {
|
|
Neuron clone = new(prefab, this.name) {
|
|
//Neuron clone = new(this.parent, this.name) {
|
|
array = this.array,
|
|
curve = this.curve,
|
|
curvePreset = this.curvePreset,
|
|
curveMax = this.curveMax,
|
|
average = this.average
|
|
};
|
|
// if (clone.cluster != null)
|
|
// clone.cluster.nuclei.Add(clone);
|
|
|
|
foreach (Synapse synapse in this.synapses) {
|
|
Synapse clonedSynapse = clone.AddSynapse(synapse.nucleus);
|
|
clonedSynapse.weight = synapse.weight;
|
|
}
|
|
foreach (Nucleus receiver in this.receivers) {
|
|
clone.AddReceiver(receiver);
|
|
}
|
|
return clone;
|
|
}
|
|
|
|
public static void Delete(Nucleus nucleus) {
|
|
foreach (Synapse synapse in nucleus.synapses) {
|
|
if (synapse.nucleus is Neuron synapse_nucleus) {
|
|
if (synapse_nucleus.receivers.Count > 1) {
|
|
// there is another nucleus feeding into this input nucleus
|
|
synapse_nucleus.receivers.RemoveAll(r => r == nucleus);
|
|
}
|
|
else {
|
|
// No other links, delete it.
|
|
Neuron.Delete(synapse_nucleus);
|
|
}
|
|
}
|
|
}
|
|
foreach (Nucleus receiver in nucleus.receivers) {
|
|
if (receiver != null && receiver.synapses != null)
|
|
receiver.synapses.RemoveAll(s => s.nucleus == nucleus);
|
|
}
|
|
|
|
if (nucleus.cluster != null) {
|
|
nucleus.cluster.nuclei.RemoveAll(n => n == nucleus);
|
|
nucleus.cluster.GarbageCollection();
|
|
}
|
|
}
|
|
|
|
public float3 bias = float3(0, 0, 0);
|
|
public override void UpdateStateIsolated(float3 bias_unused) {
|
|
float3 sum = this.bias;
|
|
int n = 0;
|
|
|
|
//Applying the weight factgors
|
|
foreach (Synapse synapse in this.synapses) {
|
|
sum += synapse.weight * synapse.nucleus.outputValue;
|
|
|
|
// Perhaps synapses should be removed when the output value goes to 0....
|
|
if (lengthsq(synapse.nucleus.outputValue) != 0) {
|
|
n++;
|
|
this.stale = 0;
|
|
}
|
|
}
|
|
if (this.average && n > 0)
|
|
sum /= n;
|
|
|
|
// Activation function
|
|
float3 result = Vector3.zero;
|
|
switch (this.curvePreset) {
|
|
case CurvePresets.Linear:
|
|
result = sum;
|
|
break;
|
|
case CurvePresets.Sqrt:
|
|
result = normalize(sum) * System.MathF.Sqrt(length(sum));
|
|
break;
|
|
case CurvePresets.Power:
|
|
result = normalize(sum) * System.MathF.Pow(length(sum), 2);
|
|
break;
|
|
case CurvePresets.Reciprocal: {
|
|
float magnitude = length(sum);
|
|
if (magnitude > 0)
|
|
result = normalize(sum) * (1 / magnitude);
|
|
break;
|
|
}
|
|
default:
|
|
float activatedValue = this.curve.Evaluate(length(sum));
|
|
result = normalize(sum) * activatedValue;
|
|
break;
|
|
}
|
|
if (this.stale > staleValueForSleep)
|
|
this.outputValue = new float3(0,0,0);
|
|
else
|
|
this.outputValue = result;
|
|
}
|
|
|
|
public virtual void ProcessStimulus(Vector3 inputValue, string thingName = null) {
|
|
//this.outputValue = inputValue;
|
|
this.stale = 0;
|
|
//Debug.Log($"{this.name} processed stimulus");
|
|
this.bias = inputValue;
|
|
}
|
|
} |