Added Tanh Activation

This commit is contained in:
Pascal Serrarens 2026-04-15 09:50:36 +02:00
parent a99d40c5c9
commit 1c7b8e7940
3 changed files with 1133 additions and 504 deletions

File diff suppressed because it is too large Load Diff

View File

@ -10,16 +10,17 @@ namespace NanoBrain {
public class ClusterViewer : Editor { public class ClusterViewer : Editor {
public class GraphView : VisualElement { public class GraphView : VisualElement {
readonly ClusterPrefab prefab; protected readonly ClusterPrefab prefab;
SerializedObject serializedBrain; protected SerializedObject serializedBrain;
Nucleus currentNucleus; protected Nucleus currentNucleus;
GameObject gameObject; protected GameObject gameObject;
private List<NeuroidLayer> layers = new(); private List<NeuroidLayer> layers = new();
private readonly Dictionary<Nucleus, Vector2Int> neuroidPositions = new(); private readonly Dictionary<Nucleus, Vector2Int> neuroidPositions = new();
private bool expandArray = false; private bool expandArray = false;
ClusterPrefab prefabAsset; protected ClusterPrefab prefabAsset;
readonly PopupField<string> outputsField; protected VisualElement outputContainer;
protected readonly PopupField<string> outputsField;
public GraphView(ClusterPrefab prefab) { public GraphView(ClusterPrefab prefab) {
this.prefab = prefab; this.prefab = prefab;
@ -35,7 +36,7 @@ namespace NanoBrain {
graphContainer.focusable = true; graphContainer.focusable = true;
Add(graphContainer); Add(graphContainer);
VisualElement outputContainer = new() { outputContainer = new() {
style = { style = {
flexDirection = FlexDirection.Row, flexDirection = FlexDirection.Row,
alignItems = Align.Center, alignItems = Align.Center,
@ -108,7 +109,7 @@ namespace NanoBrain {
//DrawInspector(inspectorContainer); //DrawInspector(inspectorContainer);
} }
private void BuildLayers() { protected void BuildLayers() {
// A temporary list to track what's been added to layers // A temporary list to track what's been added to layers
this.layers = new(); this.layers = new();
int layerIx = 0; int layerIx = 0;
@ -534,4 +535,9 @@ namespace NanoBrain {
} }
} }
public class NeuroidLayer {
public int ix = 0;
public List<Nucleus> neuroids = new();
}
} }

View File

@ -61,16 +61,17 @@ namespace NanoBrain {
/// <summary> /// <summary>
/// The type of /// The type of
/// </summary> /// </summary>
public enum CurvePresets { public enum ActivationFunction {
Linear, Linear,
Power, Power,
Sqrt, Sqrt,
Reciprocal, Reciprocal,
Tanh,
Custom Custom
} }
[SerializeField] [SerializeField]
public CurvePresets _curvePreset; public ActivationFunction _curvePreset;
public CurvePresets curvePreset { public ActivationFunction curvePreset {
get { return _curvePreset; } get { return _curvePreset; }
set { set {
_curvePreset = value; _curvePreset = value;
@ -82,18 +83,21 @@ namespace NanoBrain {
public AnimationCurve GenerateCurve() { public AnimationCurve GenerateCurve() {
switch (this.curvePreset) { switch (this.curvePreset) {
case CurvePresets.Linear: case ActivationFunction.Linear:
this.curveMax = 1; this.curveMax = 1;
return Presets.Linear(1); return Presets.Linear(1);
case CurvePresets.Power: case ActivationFunction.Power:
this.curveMax = 1; this.curveMax = 1;
return Presets.Power(2.0f, 1); return Presets.Power(2.0f, 1);
case CurvePresets.Sqrt: case ActivationFunction.Sqrt:
this.curveMax = 1; this.curveMax = 1;
return Presets.Power(0.5f, 1); return Presets.Power(0.5f, 1);
case CurvePresets.Reciprocal: case ActivationFunction.Reciprocal:
this.curveMax = 1 / 0.01f * 1; this.curveMax = 1 / 0.01f * 1;
return Presets.Reciprocal(1); return Presets.Reciprocal(1);
case ActivationFunction.Tanh:
this.curveMax = 1;
return Presets.Tanh(1);
default: default:
this.curveMax = 1; this.curveMax = 1;
return this.curve; return this.curve;
@ -142,6 +146,25 @@ namespace NanoBrain {
} }
return curve; return curve;
} }
public static AnimationCurve Tanh(float weight) {
//int samples = 128;
float xMin = 0.001f;
float xMax = 1;
var keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float x = Mathf.Lerp(xMin, xMax, t);
float y = MathF.Tanh(x * weight);
keys[i] = new Keyframe(x, y);
}
var curve = new AnimationCurve(keys);
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
}
return curve;
}
} }
#endregion Serialization #endregion Serialization
@ -348,10 +371,11 @@ namespace NanoBrain {
#if UNITY_MATHEMATICS #if UNITY_MATHEMATICS
public Func<float3, float3> Activator => this.curvePreset switch { public Func<float3, float3> Activator => this.curvePreset switch {
CurvePresets.Linear => ActivatorLinear, ActivationFunction.Linear => ActivatorLinear,
CurvePresets.Sqrt => ActivatorSqrt, ActivationFunction.Sqrt => ActivatorSqrt,
CurvePresets.Power => ActivatorPower, ActivationFunction.Power => ActivatorPower,
CurvePresets.Reciprocal => ActivatorReciprocal, ActivationFunction.Reciprocal => ActivatorReciprocal,
ActivationFunction.Tanh => ActivatorTanh,
_ => ActivatorCustom _ => ActivatorCustom
}; };
@ -378,6 +402,12 @@ namespace NanoBrain {
return result; return result;
} }
protected float3 ActivatorTanh(float3 input) {
float magnitude = length(input);
float3 result = normalize(input) * MathF.Tanh(magnitude);
return result;
}
protected float3 ActivatorCustom(float3 input) { protected float3 ActivatorCustom(float3 input) {
float activatedValue = this.curve.Evaluate(length(input)); float activatedValue = this.curve.Evaluate(length(input));
float3 result = normalize(input) * activatedValue; float3 result = normalize(input) * activatedValue;