2026-03-05 16:04:28 +01:00

340 lines
11 KiB
C#

using System;
using System.Collections.Generic;
using UnityEngine;
using UnityEditor;
using Unity.Mathematics;
using static Unity.Mathematics.math;
[Serializable]
public class Neuron : Nucleus {
public Neuron(Cluster parent, string name) {
this.parent = parent;
this.name = name;
this.parent?.clusterNuclei.Add(this);
}
public Neuron(ClusterPrefab prefab, string name) {
this.clusterPrefab = prefab;
this.name = name;
if (this.clusterPrefab != null)
this.clusterPrefab.nuclei.Add(this);
else
Debug.LogError("No prefab when adding neuron to prefab");
}
#region Serialization
public enum CombinatorType {
Sum,
Product,
Max
}
public CombinatorType combinator = CombinatorType.Sum;
public enum CurvePresets {
Linear,
Power,
Sqrt,
Reciprocal,
Custom
}
[SerializeField]
public CurvePresets _curvePreset;
public CurvePresets curvePreset {
get { return _curvePreset; }
set {
_curvePreset = value;
this.curve = GenerateCurve();
}
}
public AnimationCurve curve;
public float curveMax = 1.0f;
public AnimationCurve GenerateCurve() {
switch (this.curvePreset) {
case CurvePresets.Linear:
this.curveMax = 1;
return Presets.Linear(1);
case CurvePresets.Power:
this.curveMax = 1;
return Presets.Power(2.0f, 1);
case CurvePresets.Sqrt:
this.curveMax = 1;
return Presets.Power(0.5f, 1);
case CurvePresets.Reciprocal:
this.curveMax = 1 / 0.01f * 1;
return Presets.Reciprocal(1);
default:
this.curveMax = 1;
return this.curve;
}
}
public static class Presets {
private const int samples = 32;
public static AnimationCurve Linear(float weight) {
return AnimationCurve.Linear(0f, 0f, 1000f, weight * 1000);
}
public static AnimationCurve Power(float exponent, float weight) {
// build keyframes
Keyframe[] keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float v = Mathf.Pow(t, exponent) * weight;
keys[i] = new Keyframe(t, v);
}
AnimationCurve curve = new(keys);
// set tangent modes for each key to Auto (smooth). Use Linear if you prefer straight segments.
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Auto);
}
return curve;
}
public static AnimationCurve Reciprocal(float weight) {
int samples = 128;
float xMin = 0.001f;
float xMax = 1;
var keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float x = Mathf.Lerp(xMin, xMax, t);
float y = 1f / x * weight;
keys[i] = new Keyframe(x, y);
}
var curve = new AnimationCurve(keys);
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
}
return curve;
}
}
#endregion Serialization
protected float3 _outputValue;
public virtual float3 outputValue {
get { return _outputValue; }
set {
_outputValue = value;
if (this.isFiring)
WhenFiring?.Invoke();
}
}
public bool isFiring => length(_outputValue) > 0.5f;
public Action WhenFiring;
public virtual bool isSleeping => lengthsq(this.outputValue) == 0;
[NonSerialized]
public int stale = 1000;
public readonly int staleValueForSleep = 20;
// this clone the nucleus without the synapses and receivers
public override Nucleus ShallowCloneTo(Cluster newParent) {
Neuron clone = new(newParent, this.name);
CloneFields(clone);
return clone;
}
public override Nucleus Clone(ClusterPrefab prefab) {
Neuron clone = new(prefab, this.name);
CloneFields(clone);
foreach (Synapse synapse in this.synapses) {
Synapse clonedSynapse = clone.AddSynapse(synapse.nucleus);
clonedSynapse.weight = synapse.weight;
}
foreach (Nucleus receiver in this.receivers) {
clone.AddReceiver(receiver);
}
return clone;
}
protected virtual void CloneFields(Neuron clone) {
clone.clusterPrefab = this.clusterPrefab;
clone.bias = this.bias;
clone.combinator = this.combinator;
clone.curve = this.curve;
clone.curvePreset = this.curvePreset;
clone.curveMax = this.curveMax;
}
public static void Delete(Nucleus nucleus) {
foreach (Synapse synapse in nucleus.synapses) {
if (synapse.nucleus is Neuron synapse_nucleus) {
if (synapse_nucleus.receivers.Count > 1) {
// there is another nucleus feeding into this input nucleus
synapse_nucleus.receivers.RemoveAll(r => r == nucleus);
}
else {
// No other links, delete it.
Neuron.Delete(synapse_nucleus);
}
}
}
if (nucleus is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers) {
if (receiver != null && receiver.synapses != null)
receiver.synapses.RemoveAll(s => s.nucleus == nucleus);
}
}
else if (nucleus is Cluster cluster) {
// remove all receivers for this cluster
foreach (Neuron output in cluster.outputs) {
foreach (Nucleus receiver in output.receivers) {
receiver.synapses.RemoveAll(s => s.nucleus == output);
}
}
}
if (nucleus.clusterPrefab != null) {
nucleus.clusterPrefab.nuclei.RemoveAll(n => n == nucleus);
nucleus.clusterPrefab.RefreshOutputs();
nucleus.clusterPrefab.GarbageCollection();
}
}
public override void UpdateStateIsolated() {
float3 result = Combinator();
this.outputValue = Activator(result);
}
#region Combinator
protected Func<float3> Combinator => combinator switch {
CombinatorType.Sum => CombinatorSum,
CombinatorType.Product => CombinatorProduct,
CombinatorType.Max => CombinatorMax,
_ => CombinatorSum
};
public float3 CombinatorSum() {
float3 sum = this.bias;
foreach (Synapse synapse in this.synapses) {
if (synapse.nucleus is Neuron neuron)
sum += synapse.weight * neuron.outputValue;
}
return sum;
}
public float3 CombinatorProduct() {
float3 product = this.bias;
foreach (Synapse synapse in this.synapses) {
if (synapse.nucleus is Neuron neuron)
product *= synapse.weight * neuron.outputValue;
}
return product;
}
public float3 CombinatorMax() {
float3 max = this.bias;
float maxLength = length(max);
//Applying the weight factors
foreach (Synapse synapse in this.synapses) {
if (synapse.nucleus is Neuron neuron) {
float3 input = synapse.weight * neuron.outputValue;
float inputLength = length(input);
if (inputLength > maxLength) {
max = input;
maxLength = inputLength;
}
}
}
return max;
}
#endregion Combinator
#region Activator
public Func<float3, float3> Activator => this.curvePreset switch {
CurvePresets.Linear => ActivatorLinear,
CurvePresets.Sqrt => ActivatorSqrt,
CurvePresets.Power => ActivatorPower,
CurvePresets.Reciprocal => ActivatorReciprocal,
_ => ActivatorCustom
};
protected float3 ActivatorLinear(float3 input) {
return input;
}
protected float3 ActivatorSqrt(float3 input) {
float3 result = normalize(input) * System.MathF.Sqrt(length(input));
return result;
}
protected float3 ActivatorPower(float3 input) {
float3 result = normalize(input) * System.MathF.Pow(length(input), 2);
return result;
}
protected float3 ActivatorReciprocal(float3 input) {
float magnitude = length(input);
if (magnitude == 0)
return new float3(0, 0, 0);
float3 result = normalize(input) * (1 / magnitude);
return result;
}
protected float3 ActivatorCustom(float3 input) {
float activatedValue = this.curve.Evaluate(length(input));
float3 result = normalize(input) * activatedValue;
return result;
}
#endregion Activator
#region Receivers
[SerializeReference]
private List<Nucleus> _receivers = new();
public virtual List<Nucleus> receivers {
get { return _receivers; }
set { _receivers = value; }
}
public virtual void AddReceiver(Nucleus receiverToAdd, float weight = 1) {
this._receivers.Add(receiverToAdd);
receiverToAdd.AddSynapse(this, weight);
}
public virtual void RemoveReceiver(Nucleus receiverToRemove) {
if (this is IReceptor receptor) {
foreach (Nucleus element in receptor.nucleiArray) {
if (element is Neuron neuron) {
neuron._receivers.RemoveAll(receiver => receiver == receiverToRemove);
receiverToRemove.synapses.RemoveAll(synapse => synapse.nucleus == neuron);
}
}
}
else {
this._receivers.RemoveAll(receiver => receiver == receiverToRemove);
receiverToRemove.synapses.RemoveAll(synapse => synapse.nucleus == this);
}
}
#endregion Receivers
public override void ProcessStimulus(Vector3 inputValue, int thingId = 0, string thingName = null) {
if (this.parent is ClusterReceptor clusterReceptor) {
clusterReceptor.ProcessStimulus(this, inputValue, thingId, thingName);
}
else
ProcessStimulusDirect(inputValue, thingId, thingName);
}
public void ProcessStimulusDirect(Vector3 inputValue, int thingId = 0, string thingName = null) {
this.stale = 0;
this.bias = inputValue;
this.parent.UpdateFromNucleus(this);
}
}