Fixed the evaluation order
This commit is contained in:
parent
b4f8e5a4d8
commit
91c4500b0a
35
Cluster.cs
35
Cluster.cs
@ -262,20 +262,45 @@ public class Cluster : INucleus {
|
||||
UpdateState(new float3(0, 0, 0));
|
||||
}
|
||||
|
||||
public void UpdateState(float3 inputValue) {
|
||||
float3 sum = inputValue; // new(0, 0, 0);
|
||||
public void UpdateState(float3 bias) {
|
||||
float3 sum = bias; // new(0, 0, 0);
|
||||
|
||||
//Applying the weight factgors
|
||||
//Applying the weight factors
|
||||
foreach (Synapse synapse in this.synapses) {
|
||||
sum += synapse.weight * synapse.nucleus.outputValue;
|
||||
}
|
||||
|
||||
//this.prefab.inputs[0].UpdateState(sum);
|
||||
this.inputs[0].UpdateState(sum);
|
||||
//this.inputs[0].UpdateState(sum);
|
||||
this.inputs[0].UpdateStateIsolated(sum);
|
||||
foreach (IReceptor receptor in this.sortedNuclei) {
|
||||
if (receptor is INucleus nucleus && nucleus != this.inputs[0])
|
||||
nucleus.UpdateStateIsolated();
|
||||
}
|
||||
|
||||
UpdateResult(this.output.outputValue);
|
||||
}
|
||||
|
||||
public void UpdateStateIsolated() {
|
||||
float3 bias = new(0,0,0);
|
||||
UpdateStateIsolated(bias);
|
||||
}
|
||||
public void UpdateStateIsolated(float3 bias) {
|
||||
float3 sum = bias; // new(0, 0, 0);
|
||||
|
||||
//Applying the weight factors
|
||||
foreach (Synapse synapse in this.synapses) {
|
||||
sum += synapse.weight * synapse.nucleus.outputValue;
|
||||
}
|
||||
|
||||
//this.inputs[0].UpdateState(sum);
|
||||
this.inputs[0].UpdateStateIsolated(sum);
|
||||
foreach (IReceptor receptor in this.sortedNuclei) {
|
||||
if (receptor is INucleus nucleus && nucleus != this.inputs[0])
|
||||
nucleus.UpdateStateIsolated();
|
||||
}
|
||||
this.outputValue = this.output.outputValue;
|
||||
}
|
||||
|
||||
public virtual void UpdateResult(Vector3 result) {
|
||||
// float d = Vector3.Distance(result, this.outputValue);
|
||||
// if (d < 0.5f) {
|
||||
|
||||
@ -20,7 +20,8 @@ public interface INucleus : IReceptor {
|
||||
|
||||
public void UpdateState();
|
||||
public void UpdateState(float3 inputValue);
|
||||
|
||||
public void UpdateStateIsolated();
|
||||
public void UpdateStateIsolated(float3 inputValue);
|
||||
|
||||
#endregion dynamic state
|
||||
}
|
||||
|
||||
@ -4,27 +4,34 @@ using Unity.Mathematics;
|
||||
using static Unity.Mathematics.math;
|
||||
|
||||
[Serializable]
|
||||
public class MemoryCell : Neuron {
|
||||
public class MemoryCell : Neuron, INucleus {
|
||||
|
||||
public MemoryCell(ClusterPrefab cluster, string name) : base(cluster, name) {}
|
||||
public MemoryCell(ClusterPrefab cluster, string name) : base(cluster, name) { }
|
||||
public MemoryCell(Cluster parent, string name) : base(parent, name) { }
|
||||
// this.parent = parent;
|
||||
// this.name = name;
|
||||
// this.parent?.nuclei.Add(this);
|
||||
// }
|
||||
|
||||
#region Parameters
|
||||
|
||||
// Returns the memorized value weighted by time
|
||||
// return lastValue * (current time - last time)
|
||||
// [SerializeField]
|
||||
// public bool deltaValue = false;
|
||||
|
||||
#endregion Parameters
|
||||
public override IReceptor ShallowCloneTo(Cluster newParent) {
|
||||
MemoryCell clone = new(newParent, this.name) {
|
||||
array = this.array,
|
||||
curve = this.curve,
|
||||
curvePreset = this.curvePreset,
|
||||
curveMax = this.curveMax,
|
||||
average = this.average
|
||||
};
|
||||
return clone;
|
||||
}
|
||||
|
||||
#region State
|
||||
|
||||
private float3 _memorizedValue;
|
||||
private float _memorizedTime;
|
||||
|
||||
public override void UpdateState() {
|
||||
public override void UpdateState(float3 bias) {
|
||||
// A memorycell does not have an activation function
|
||||
float3 result = new(0, 0, 0);
|
||||
float3 result = bias;
|
||||
int n = 0;
|
||||
|
||||
//Applying the weight factgors
|
||||
@ -40,6 +47,32 @@ public class MemoryCell : Neuron {
|
||||
UpdateResult(result);
|
||||
}
|
||||
|
||||
public override void UpdateStateIsolated() {
|
||||
float3 bias = new(0, 0, 0);
|
||||
UpdateStateIsolated(bias);
|
||||
}
|
||||
public override void UpdateStateIsolated(float3 bias) {
|
||||
// A memorycell does not have an activation function
|
||||
float3 result = bias;
|
||||
int n = 0;
|
||||
|
||||
//Applying the weight factgors
|
||||
foreach (Synapse synapse in this.synapses) {
|
||||
result += synapse.weight * synapse.nucleus.outputValue;
|
||||
if (lengthsq(synapse.nucleus.outputValue) != 0)
|
||||
n++;
|
||||
}
|
||||
|
||||
if (this.average)
|
||||
result /= n;
|
||||
|
||||
this.outputValue = this._memorizedValue;
|
||||
|
||||
// Store the result for the next time
|
||||
this._memorizedValue = result;
|
||||
this._memorizedTime = Time.time;
|
||||
}
|
||||
|
||||
public override void UpdateResult(Vector3 result) {
|
||||
// output value is the previous value
|
||||
// if (this.deltaValue) {
|
||||
@ -47,7 +80,7 @@ public class MemoryCell : Neuron {
|
||||
// this._outputValue = this._memorizedValue * deltaTime;
|
||||
// }
|
||||
//else
|
||||
this.outputValue = this._memorizedValue;
|
||||
this.outputValue = this._memorizedValue;
|
||||
|
||||
// Store the result for the next time
|
||||
this._memorizedValue = result;
|
||||
|
||||
44
Neuron.cs
44
Neuron.cs
@ -282,7 +282,8 @@ public class Neuron : INucleus {
|
||||
}
|
||||
|
||||
public virtual void UpdateState() {
|
||||
UpdateState(new float3(0, 0, 0));
|
||||
//UpdateState(new float3(0, 0, 0));
|
||||
this.parent?.UpdateState();
|
||||
}
|
||||
|
||||
public virtual void UpdateState(float3 inputValue) {
|
||||
@ -323,6 +324,47 @@ public class Neuron : INucleus {
|
||||
UpdateResult(result);
|
||||
}
|
||||
|
||||
public virtual void UpdateStateIsolated() {
|
||||
UpdateStateIsolated(new float3(0, 0, 0));
|
||||
}
|
||||
public virtual void UpdateStateIsolated(float3 bias) {
|
||||
float3 sum = bias;
|
||||
int n = 0;
|
||||
|
||||
//Applying the weight factgors
|
||||
foreach (Synapse synapse in this.synapses) {
|
||||
sum += synapse.weight * synapse.nucleus.outputValue;
|
||||
|
||||
// Perhaps synapses should be removed when the output value goes to 0....
|
||||
if (lengthsq(synapse.nucleus.outputValue) != 0)
|
||||
n++;
|
||||
}
|
||||
if (this.average && n > 0)
|
||||
sum /= n;
|
||||
|
||||
// Activation function
|
||||
Vector3 result;
|
||||
switch (this.curvePreset) {
|
||||
case CurvePresets.Linear:
|
||||
result = sum;
|
||||
break;
|
||||
case CurvePresets.Sqrt:
|
||||
result = normalize(sum) * System.MathF.Sqrt(length(sum));
|
||||
break;
|
||||
case CurvePresets.Power:
|
||||
result = normalize(sum) * System.MathF.Pow(length(sum), 2);
|
||||
break;
|
||||
case CurvePresets.Reciprocal:
|
||||
result = normalize(sum) * (1 / length(sum));
|
||||
break;
|
||||
default:
|
||||
float activatedValue = this.curve.Evaluate(length(sum));
|
||||
result = normalize(sum) * activatedValue;
|
||||
break;
|
||||
}
|
||||
this.outputValue = result;
|
||||
}
|
||||
|
||||
public virtual void UpdateResult(Vector3 result) {
|
||||
// float d = Vector3.Distance(result, this.outputValue);
|
||||
// if (d < 0.5f) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user