Fixed the evaluation order

This commit is contained in:
Pascal Serrarens 2026-01-30 12:46:29 +01:00
parent b4f8e5a4d8
commit 91c4500b0a
4 changed files with 121 additions and 20 deletions

View File

@ -262,20 +262,45 @@ public class Cluster : INucleus {
UpdateState(new float3(0, 0, 0)); UpdateState(new float3(0, 0, 0));
} }
public void UpdateState(float3 inputValue) { public void UpdateState(float3 bias) {
float3 sum = inputValue; // new(0, 0, 0); float3 sum = bias; // new(0, 0, 0);
//Applying the weight factgors //Applying the weight factors
foreach (Synapse synapse in this.synapses) { foreach (Synapse synapse in this.synapses) {
sum += synapse.weight * synapse.nucleus.outputValue; sum += synapse.weight * synapse.nucleus.outputValue;
} }
//this.prefab.inputs[0].UpdateState(sum); //this.inputs[0].UpdateState(sum);
this.inputs[0].UpdateState(sum); this.inputs[0].UpdateStateIsolated(sum);
foreach (IReceptor receptor in this.sortedNuclei) {
if (receptor is INucleus nucleus && nucleus != this.inputs[0])
nucleus.UpdateStateIsolated();
}
UpdateResult(this.output.outputValue); UpdateResult(this.output.outputValue);
} }
public void UpdateStateIsolated() {
float3 bias = new(0,0,0);
UpdateStateIsolated(bias);
}
public void UpdateStateIsolated(float3 bias) {
float3 sum = bias; // new(0, 0, 0);
//Applying the weight factors
foreach (Synapse synapse in this.synapses) {
sum += synapse.weight * synapse.nucleus.outputValue;
}
//this.inputs[0].UpdateState(sum);
this.inputs[0].UpdateStateIsolated(sum);
foreach (IReceptor receptor in this.sortedNuclei) {
if (receptor is INucleus nucleus && nucleus != this.inputs[0])
nucleus.UpdateStateIsolated();
}
this.outputValue = this.output.outputValue;
}
public virtual void UpdateResult(Vector3 result) { public virtual void UpdateResult(Vector3 result) {
// float d = Vector3.Distance(result, this.outputValue); // float d = Vector3.Distance(result, this.outputValue);
// if (d < 0.5f) { // if (d < 0.5f) {

View File

@ -20,7 +20,8 @@ public interface INucleus : IReceptor {
public void UpdateState(); public void UpdateState();
public void UpdateState(float3 inputValue); public void UpdateState(float3 inputValue);
public void UpdateStateIsolated();
public void UpdateStateIsolated(float3 inputValue);
#endregion dynamic state #endregion dynamic state
} }

View File

@ -4,27 +4,34 @@ using Unity.Mathematics;
using static Unity.Mathematics.math; using static Unity.Mathematics.math;
[Serializable] [Serializable]
public class MemoryCell : Neuron { public class MemoryCell : Neuron, INucleus {
public MemoryCell(ClusterPrefab cluster, string name) : base(cluster, name) {} public MemoryCell(ClusterPrefab cluster, string name) : base(cluster, name) { }
public MemoryCell(Cluster parent, string name) : base(parent, name) { }
// this.parent = parent;
// this.name = name;
// this.parent?.nuclei.Add(this);
// }
#region Parameters public override IReceptor ShallowCloneTo(Cluster newParent) {
MemoryCell clone = new(newParent, this.name) {
// Returns the memorized value weighted by time array = this.array,
// return lastValue * (current time - last time) curve = this.curve,
// [SerializeField] curvePreset = this.curvePreset,
// public bool deltaValue = false; curveMax = this.curveMax,
average = this.average
#endregion Parameters };
return clone;
}
#region State #region State
private float3 _memorizedValue; private float3 _memorizedValue;
private float _memorizedTime; private float _memorizedTime;
public override void UpdateState() { public override void UpdateState(float3 bias) {
// A memorycell does not have an activation function // A memorycell does not have an activation function
float3 result = new(0, 0, 0); float3 result = bias;
int n = 0; int n = 0;
//Applying the weight factgors //Applying the weight factgors
@ -40,6 +47,32 @@ public class MemoryCell : Neuron {
UpdateResult(result); UpdateResult(result);
} }
public override void UpdateStateIsolated() {
float3 bias = new(0, 0, 0);
UpdateStateIsolated(bias);
}
public override void UpdateStateIsolated(float3 bias) {
// A memorycell does not have an activation function
float3 result = bias;
int n = 0;
//Applying the weight factgors
foreach (Synapse synapse in this.synapses) {
result += synapse.weight * synapse.nucleus.outputValue;
if (lengthsq(synapse.nucleus.outputValue) != 0)
n++;
}
if (this.average)
result /= n;
this.outputValue = this._memorizedValue;
// Store the result for the next time
this._memorizedValue = result;
this._memorizedTime = Time.time;
}
public override void UpdateResult(Vector3 result) { public override void UpdateResult(Vector3 result) {
// output value is the previous value // output value is the previous value
// if (this.deltaValue) { // if (this.deltaValue) {
@ -47,7 +80,7 @@ public class MemoryCell : Neuron {
// this._outputValue = this._memorizedValue * deltaTime; // this._outputValue = this._memorizedValue * deltaTime;
// } // }
//else //else
this.outputValue = this._memorizedValue; this.outputValue = this._memorizedValue;
// Store the result for the next time // Store the result for the next time
this._memorizedValue = result; this._memorizedValue = result;

View File

@ -282,7 +282,8 @@ public class Neuron : INucleus {
} }
public virtual void UpdateState() { public virtual void UpdateState() {
UpdateState(new float3(0, 0, 0)); //UpdateState(new float3(0, 0, 0));
this.parent?.UpdateState();
} }
public virtual void UpdateState(float3 inputValue) { public virtual void UpdateState(float3 inputValue) {
@ -323,6 +324,47 @@ public class Neuron : INucleus {
UpdateResult(result); UpdateResult(result);
} }
public virtual void UpdateStateIsolated() {
UpdateStateIsolated(new float3(0, 0, 0));
}
public virtual void UpdateStateIsolated(float3 bias) {
float3 sum = bias;
int n = 0;
//Applying the weight factgors
foreach (Synapse synapse in this.synapses) {
sum += synapse.weight * synapse.nucleus.outputValue;
// Perhaps synapses should be removed when the output value goes to 0....
if (lengthsq(synapse.nucleus.outputValue) != 0)
n++;
}
if (this.average && n > 0)
sum /= n;
// Activation function
Vector3 result;
switch (this.curvePreset) {
case CurvePresets.Linear:
result = sum;
break;
case CurvePresets.Sqrt:
result = normalize(sum) * System.MathF.Sqrt(length(sum));
break;
case CurvePresets.Power:
result = normalize(sum) * System.MathF.Pow(length(sum), 2);
break;
case CurvePresets.Reciprocal:
result = normalize(sum) * (1 / length(sum));
break;
default:
float activatedValue = this.curve.Evaluate(length(sum));
result = normalize(sum) * activatedValue;
break;
}
this.outputValue = result;
}
public virtual void UpdateResult(Vector3 result) { public virtual void UpdateResult(Vector3 result) {
// float d = Vector3.Distance(result, this.outputValue); // float d = Vector3.Distance(result, this.outputValue);
// if (d < 0.5f) { // if (d < 0.5f) {