Cleanup
This commit is contained in:
parent
42bc32c734
commit
a62d1cc5d9
@ -34,18 +34,7 @@ public class MemoryCell : Neuron {
|
|||||||
|
|
||||||
public override void UpdateStateIsolated() {
|
public override void UpdateStateIsolated() {
|
||||||
// A memorycell does not have an activation function
|
// A memorycell does not have an activation function
|
||||||
Vector3 result = this.bias;
|
float3 result = Combinator();
|
||||||
int n = 0;
|
|
||||||
|
|
||||||
//Applying the weight factgors
|
|
||||||
foreach (Synapse synapse in this.synapses) {
|
|
||||||
result += synapse.weight * synapse.nucleus.outputValue;
|
|
||||||
if (lengthsq(synapse.nucleus.outputValue) != 0)
|
|
||||||
n++;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (this.average)
|
|
||||||
result /= n;
|
|
||||||
|
|
||||||
if (initialized)
|
if (initialized)
|
||||||
// Output the previous, memorized value
|
// Output the previous, memorized value
|
||||||
|
|||||||
137
Neuron.cs
137
Neuron.cs
@ -58,12 +58,6 @@ public class Neuron : Nucleus {
|
|||||||
public AnimationCurve curve;
|
public AnimationCurve curve;
|
||||||
public float curveMax = 1.0f;
|
public float curveMax = 1.0f;
|
||||||
|
|
||||||
#region Parameters
|
|
||||||
|
|
||||||
public bool average = false;
|
|
||||||
|
|
||||||
#endregion Parameters
|
|
||||||
|
|
||||||
public AnimationCurve GenerateCurve() {
|
public AnimationCurve GenerateCurve() {
|
||||||
switch (this.curvePreset) {
|
switch (this.curvePreset) {
|
||||||
case CurvePresets.Linear:
|
case CurvePresets.Linear:
|
||||||
@ -84,14 +78,6 @@ public class Neuron : Nucleus {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public virtual void Deserialize(Neuron nucleus) { }
|
|
||||||
|
|
||||||
#endregion Serialization
|
|
||||||
|
|
||||||
#region Runtime state (not serialized)
|
|
||||||
|
|
||||||
#region Activation
|
|
||||||
|
|
||||||
public static class Presets {
|
public static class Presets {
|
||||||
private const int samples = 32;
|
private const int samples = 32;
|
||||||
public static AnimationCurve Linear(float weight) {
|
public static AnimationCurve Linear(float weight) {
|
||||||
@ -136,9 +122,7 @@ public class Neuron : Nucleus {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endregion Activation
|
#endregion Serialization
|
||||||
|
|
||||||
#endregion Runtime state
|
|
||||||
|
|
||||||
// this clone the nucleus without the synapses and receivers
|
// this clone the nucleus without the synapses and receivers
|
||||||
public override Nucleus ShallowCloneTo(Cluster newParent) {
|
public override Nucleus ShallowCloneTo(Cluster newParent) {
|
||||||
@ -166,9 +150,7 @@ public class Neuron : Nucleus {
|
|||||||
clone.combinator = this.combinator;
|
clone.combinator = this.combinator;
|
||||||
clone.curve = this.curve;
|
clone.curve = this.curve;
|
||||||
clone.curvePreset = this.curvePreset;
|
clone.curvePreset = this.curvePreset;
|
||||||
Debug.Log($"clone preset {clone.name} = {clone.curvePreset}");
|
|
||||||
clone.curveMax = this.curveMax;
|
clone.curveMax = this.curveMax;
|
||||||
clone.average = this.average;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public static void Delete(Nucleus nucleus) {
|
public static void Delete(Nucleus nucleus) {
|
||||||
@ -196,52 +178,20 @@ public class Neuron : Nucleus {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public override void UpdateStateIsolated() {
|
public override void UpdateStateIsolated() {
|
||||||
float3 result = CombinatorAction();
|
float3 result = Combinator();
|
||||||
this.outputValue = Activation(result);
|
this.outputValue = Activator(result);
|
||||||
// switch (this.type) {
|
|
||||||
// case Type.Neuron:
|
|
||||||
// UpdateSum();
|
|
||||||
// break;
|
|
||||||
// case Type.Pulsar:
|
|
||||||
// UpdateProduct();
|
|
||||||
// break;
|
|
||||||
// default:
|
|
||||||
// UpdateSum();
|
|
||||||
// break;
|
|
||||||
// }
|
|
||||||
// Vector3 sum = this.bias;
|
|
||||||
// int n = 0;
|
|
||||||
|
|
||||||
// //Applying the weight factgors
|
|
||||||
// foreach (Synapse synapse in this.synapses) {
|
|
||||||
// sum += synapse.weight * synapse.nucleus.outputValue;
|
|
||||||
|
|
||||||
// // Perhaps synapses should be removed when the output value goes to 0....
|
|
||||||
// if (lengthsq(synapse.nucleus.outputValue) != 0) {
|
|
||||||
// n++;
|
|
||||||
// this.stale = 0;
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// if (this.average && n > 0)
|
|
||||||
// sum /= n;
|
|
||||||
|
|
||||||
// // Activation function
|
|
||||||
// float3 result = Activation(sum);
|
|
||||||
// if (this.stale > staleValueForSleep)
|
|
||||||
// this.outputValue = new float3(0, 0, 0);
|
|
||||||
// else
|
|
||||||
// this.outputValue = result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Func<float3> CombinatorAction => combinator switch {
|
#region Combinator
|
||||||
CombinatorType.Sum => UpdateSum,
|
|
||||||
CombinatorType.Product => UpdateProduct,
|
protected Func<float3> Combinator => combinator switch {
|
||||||
CombinatorType.Max => UpdateMax,
|
CombinatorType.Sum => CombinatorSum,
|
||||||
_ => UpdateSum
|
CombinatorType.Product => CombinatorProduct,
|
||||||
|
CombinatorType.Max => CombinatorMax,
|
||||||
|
_ => CombinatorSum
|
||||||
};
|
};
|
||||||
|
|
||||||
|
public float3 CombinatorSum() {
|
||||||
public float3 UpdateSum() {
|
|
||||||
Vector3 sum = this.bias;
|
Vector3 sum = this.bias;
|
||||||
foreach (Synapse synapse in this.synapses)
|
foreach (Synapse synapse in this.synapses)
|
||||||
sum += synapse.weight * synapse.nucleus.outputValue;
|
sum += synapse.weight * synapse.nucleus.outputValue;
|
||||||
@ -249,7 +199,7 @@ public class Neuron : Nucleus {
|
|||||||
//this.outputValue = Activation(sum);
|
//this.outputValue = Activation(sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
public float3 UpdateProduct() {
|
public float3 CombinatorProduct() {
|
||||||
float3 product = this.bias;
|
float3 product = this.bias;
|
||||||
foreach (Synapse synapse in this.synapses)
|
foreach (Synapse synapse in this.synapses)
|
||||||
product *= synapse.weight * synapse.nucleus.outputValue;
|
product *= synapse.weight * synapse.nucleus.outputValue;
|
||||||
@ -257,7 +207,7 @@ public class Neuron : Nucleus {
|
|||||||
//this.outputValue = Activation(product);
|
//this.outputValue = Activation(product);
|
||||||
}
|
}
|
||||||
|
|
||||||
public float3 UpdateMax() {
|
public float3 CombinatorMax() {
|
||||||
float3 max = this.bias;
|
float3 max = this.bias;
|
||||||
float maxSqrLength = lengthsq(max);
|
float maxSqrLength = lengthsq(max);
|
||||||
|
|
||||||
@ -274,32 +224,49 @@ public class Neuron : Nucleus {
|
|||||||
return max;
|
return max;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected float3 Activation(float3 input) {
|
#endregion Combinator
|
||||||
float3 result = Vector3.zero;
|
|
||||||
switch (this.curvePreset) {
|
#region Activator
|
||||||
case CurvePresets.Linear:
|
|
||||||
result = input;
|
protected Func<float3, float3> Activator => this.curvePreset switch {
|
||||||
break;
|
CurvePresets.Linear => ActivatorLinear,
|
||||||
case CurvePresets.Sqrt:
|
CurvePresets.Sqrt => ActivatorSqrt,
|
||||||
result = normalize(input) * System.MathF.Sqrt(length(input));
|
CurvePresets.Power => ActivatorPower,
|
||||||
break;
|
CurvePresets.Reciprocal => ActivatorReciprocal,
|
||||||
case CurvePresets.Power:
|
_ => ActivatorCustom
|
||||||
result = normalize(input) * System.MathF.Pow(length(input), 2);
|
};
|
||||||
break;
|
|
||||||
case CurvePresets.Reciprocal: {
|
protected float3 ActivatorLinear(float3 input) {
|
||||||
float magnitude = length(input);
|
return input;
|
||||||
if (magnitude > 0)
|
|
||||||
result = normalize(input) * (1 / magnitude);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
float activatedValue = this.curve.Evaluate(length(input));
|
|
||||||
result = normalize(input) * activatedValue;
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected float3 ActivatorSqrt(float3 input) {
|
||||||
|
float3 result = normalize(input) * System.MathF.Sqrt(length(input));
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected float3 ActivatorPower(float3 input) {
|
||||||
|
float3 result = normalize(input) * System.MathF.Pow(length(input), 2);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected float3 ActivatorReciprocal(float3 input) {
|
||||||
|
float magnitude = length(input);
|
||||||
|
if (magnitude == 0)
|
||||||
|
return new float3(0, 0, 0);
|
||||||
|
|
||||||
|
float3 result = normalize(input) * (1 / magnitude);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected float3 ActivatorCustom(float3 input) {
|
||||||
|
float activatedValue = this.curve.Evaluate(length(input));
|
||||||
|
float3 result = normalize(input) * activatedValue;
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
#endregion Activator
|
||||||
|
|
||||||
public virtual void ProcessStimulus(Vector3 inputValue, string thingName = null) {
|
public virtual void ProcessStimulus(Vector3 inputValue, string thingName = null) {
|
||||||
this.stale = 0;
|
this.stale = 0;
|
||||||
this.bias = inputValue;
|
this.bias = inputValue;
|
||||||
|
|||||||
@ -49,6 +49,6 @@ public class Pulsar : Neuron {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Activation function
|
// Activation function
|
||||||
this.outputValue = Activation(product);
|
this.outputValue = Activator(product);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -13,7 +13,6 @@ public class Selector : Neuron {
|
|||||||
curve = this.curve,
|
curve = this.curve,
|
||||||
curvePreset = this.curvePreset,
|
curvePreset = this.curvePreset,
|
||||||
curveMax = this.curveMax,
|
curveMax = this.curveMax,
|
||||||
average = this.average
|
|
||||||
};
|
};
|
||||||
return clone;
|
return clone;
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user