diff --git a/MemoryCell.cs b/MemoryCell.cs index 5eb13d6..e531922 100644 --- a/MemoryCell.cs +++ b/MemoryCell.cs @@ -34,18 +34,7 @@ public class MemoryCell : Neuron { public override void UpdateStateIsolated() { // A memorycell does not have an activation function - Vector3 result = this.bias; - int n = 0; - - //Applying the weight factgors - foreach (Synapse synapse in this.synapses) { - result += synapse.weight * synapse.nucleus.outputValue; - if (lengthsq(synapse.nucleus.outputValue) != 0) - n++; - } - - if (this.average) - result /= n; + float3 result = Combinator(); if (initialized) // Output the previous, memorized value diff --git a/Neuron.cs b/Neuron.cs index 400359e..0a58257 100644 --- a/Neuron.cs +++ b/Neuron.cs @@ -58,12 +58,6 @@ public class Neuron : Nucleus { public AnimationCurve curve; public float curveMax = 1.0f; - #region Parameters - - public bool average = false; - - #endregion Parameters - public AnimationCurve GenerateCurve() { switch (this.curvePreset) { case CurvePresets.Linear: @@ -84,14 +78,6 @@ public class Neuron : Nucleus { } } - public virtual void Deserialize(Neuron nucleus) { } - - #endregion Serialization - - #region Runtime state (not serialized) - - #region Activation - public static class Presets { private const int samples = 32; public static AnimationCurve Linear(float weight) { @@ -136,9 +122,7 @@ public class Neuron : Nucleus { } } - #endregion Activation - - #endregion Runtime state + #endregion Serialization // this clone the nucleus without the synapses and receivers public override Nucleus ShallowCloneTo(Cluster newParent) { @@ -166,9 +150,7 @@ public class Neuron : Nucleus { clone.combinator = this.combinator; clone.curve = this.curve; clone.curvePreset = this.curvePreset; - Debug.Log($"clone preset {clone.name} = {clone.curvePreset}"); clone.curveMax = this.curveMax; - clone.average = this.average; } public static void Delete(Nucleus nucleus) { @@ -196,52 +178,20 @@ public class Neuron : Nucleus { } public override void UpdateStateIsolated() { - float3 result = CombinatorAction(); - this.outputValue = Activation(result); - // switch (this.type) { - // case Type.Neuron: - // UpdateSum(); - // break; - // case Type.Pulsar: - // UpdateProduct(); - // break; - // default: - // UpdateSum(); - // break; - // } - // Vector3 sum = this.bias; - // int n = 0; - - // //Applying the weight factgors - // foreach (Synapse synapse in this.synapses) { - // sum += synapse.weight * synapse.nucleus.outputValue; - - // // Perhaps synapses should be removed when the output value goes to 0.... - // if (lengthsq(synapse.nucleus.outputValue) != 0) { - // n++; - // this.stale = 0; - // } - // } - // if (this.average && n > 0) - // sum /= n; - - // // Activation function - // float3 result = Activation(sum); - // if (this.stale > staleValueForSleep) - // this.outputValue = new float3(0, 0, 0); - // else - // this.outputValue = result; + float3 result = Combinator(); + this.outputValue = Activator(result); } - private Func CombinatorAction => combinator switch { - CombinatorType.Sum => UpdateSum, - CombinatorType.Product => UpdateProduct, - CombinatorType.Max => UpdateMax, - _ => UpdateSum + #region Combinator + + protected Func Combinator => combinator switch { + CombinatorType.Sum => CombinatorSum, + CombinatorType.Product => CombinatorProduct, + CombinatorType.Max => CombinatorMax, + _ => CombinatorSum }; - - public float3 UpdateSum() { + public float3 CombinatorSum() { Vector3 sum = this.bias; foreach (Synapse synapse in this.synapses) sum += synapse.weight * synapse.nucleus.outputValue; @@ -249,7 +199,7 @@ public class Neuron : Nucleus { //this.outputValue = Activation(sum); } - public float3 UpdateProduct() { + public float3 CombinatorProduct() { float3 product = this.bias; foreach (Synapse synapse in this.synapses) product *= synapse.weight * synapse.nucleus.outputValue; @@ -257,7 +207,7 @@ public class Neuron : Nucleus { //this.outputValue = Activation(product); } - public float3 UpdateMax() { + public float3 CombinatorMax() { float3 max = this.bias; float maxSqrLength = lengthsq(max); @@ -274,32 +224,49 @@ public class Neuron : Nucleus { return max; } - protected float3 Activation(float3 input) { - float3 result = Vector3.zero; - switch (this.curvePreset) { - case CurvePresets.Linear: - result = input; - break; - case CurvePresets.Sqrt: - result = normalize(input) * System.MathF.Sqrt(length(input)); - break; - case CurvePresets.Power: - result = normalize(input) * System.MathF.Pow(length(input), 2); - break; - case CurvePresets.Reciprocal: { - float magnitude = length(input); - if (magnitude > 0) - result = normalize(input) * (1 / magnitude); - break; - } - default: - float activatedValue = this.curve.Evaluate(length(input)); - result = normalize(input) * activatedValue; - break; - } + #endregion Combinator + + #region Activator + + protected Func Activator => this.curvePreset switch { + CurvePresets.Linear => ActivatorLinear, + CurvePresets.Sqrt => ActivatorSqrt, + CurvePresets.Power => ActivatorPower, + CurvePresets.Reciprocal => ActivatorReciprocal, + _ => ActivatorCustom + }; + + protected float3 ActivatorLinear(float3 input) { + return input; + } + + protected float3 ActivatorSqrt(float3 input) { + float3 result = normalize(input) * System.MathF.Sqrt(length(input)); return result; } + protected float3 ActivatorPower(float3 input) { + float3 result = normalize(input) * System.MathF.Pow(length(input), 2); + return result; + } + + protected float3 ActivatorReciprocal(float3 input) { + float magnitude = length(input); + if (magnitude == 0) + return new float3(0, 0, 0); + + float3 result = normalize(input) * (1 / magnitude); + return result; + } + + protected float3 ActivatorCustom(float3 input) { + float activatedValue = this.curve.Evaluate(length(input)); + float3 result = normalize(input) * activatedValue; + return result; + } + + #endregion Activator + public virtual void ProcessStimulus(Vector3 inputValue, string thingName = null) { this.stale = 0; this.bias = inputValue; diff --git a/Pulsar.cs b/Pulsar.cs index 88e6251..8e2cc54 100644 --- a/Pulsar.cs +++ b/Pulsar.cs @@ -49,6 +49,6 @@ public class Pulsar : Neuron { } // Activation function - this.outputValue = Activation(product); + this.outputValue = Activator(product); } } \ No newline at end of file diff --git a/Selector.cs b/Selector.cs index ed7fae4..6364053 100644 --- a/Selector.cs +++ b/Selector.cs @@ -13,7 +13,6 @@ public class Selector : Neuron { curve = this.curve, curvePreset = this.curvePreset, curveMax = this.curveMax, - average = this.average }; return clone; }