Added Tanh Activation

This commit is contained in:
Pascal Serrarens 2026-04-15 09:50:36 +02:00
parent a99d40c5c9
commit 1c7b8e7940
3 changed files with 1133 additions and 504 deletions

File diff suppressed because it is too large Load Diff

View File

@ -10,16 +10,17 @@ namespace NanoBrain {
public class ClusterViewer : Editor {
public class GraphView : VisualElement {
readonly ClusterPrefab prefab;
SerializedObject serializedBrain;
Nucleus currentNucleus;
GameObject gameObject;
protected readonly ClusterPrefab prefab;
protected SerializedObject serializedBrain;
protected Nucleus currentNucleus;
protected GameObject gameObject;
private List<NeuroidLayer> layers = new();
private readonly Dictionary<Nucleus, Vector2Int> neuroidPositions = new();
private bool expandArray = false;
ClusterPrefab prefabAsset;
readonly PopupField<string> outputsField;
protected ClusterPrefab prefabAsset;
protected VisualElement outputContainer;
protected readonly PopupField<string> outputsField;
public GraphView(ClusterPrefab prefab) {
this.prefab = prefab;
@ -35,7 +36,7 @@ namespace NanoBrain {
graphContainer.focusable = true;
Add(graphContainer);
VisualElement outputContainer = new() {
outputContainer = new() {
style = {
flexDirection = FlexDirection.Row,
alignItems = Align.Center,
@ -108,7 +109,7 @@ namespace NanoBrain {
//DrawInspector(inspectorContainer);
}
private void BuildLayers() {
protected void BuildLayers() {
// A temporary list to track what's been added to layers
this.layers = new();
int layerIx = 0;
@ -534,4 +535,9 @@ namespace NanoBrain {
}
}
public class NeuroidLayer {
public int ix = 0;
public List<Nucleus> neuroids = new();
}
}

View File

@ -61,16 +61,17 @@ namespace NanoBrain {
/// <summary>
/// The type of
/// </summary>
public enum CurvePresets {
public enum ActivationFunction {
Linear,
Power,
Sqrt,
Reciprocal,
Tanh,
Custom
}
[SerializeField]
public CurvePresets _curvePreset;
public CurvePresets curvePreset {
public ActivationFunction _curvePreset;
public ActivationFunction curvePreset {
get { return _curvePreset; }
set {
_curvePreset = value;
@ -82,18 +83,21 @@ namespace NanoBrain {
public AnimationCurve GenerateCurve() {
switch (this.curvePreset) {
case CurvePresets.Linear:
case ActivationFunction.Linear:
this.curveMax = 1;
return Presets.Linear(1);
case CurvePresets.Power:
case ActivationFunction.Power:
this.curveMax = 1;
return Presets.Power(2.0f, 1);
case CurvePresets.Sqrt:
case ActivationFunction.Sqrt:
this.curveMax = 1;
return Presets.Power(0.5f, 1);
case CurvePresets.Reciprocal:
case ActivationFunction.Reciprocal:
this.curveMax = 1 / 0.01f * 1;
return Presets.Reciprocal(1);
case ActivationFunction.Tanh:
this.curveMax = 1;
return Presets.Tanh(1);
default:
this.curveMax = 1;
return this.curve;
@ -142,6 +146,25 @@ namespace NanoBrain {
}
return curve;
}
public static AnimationCurve Tanh(float weight) {
//int samples = 128;
float xMin = 0.001f;
float xMax = 1;
var keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float x = Mathf.Lerp(xMin, xMax, t);
float y = MathF.Tanh(x * weight);
keys[i] = new Keyframe(x, y);
}
var curve = new AnimationCurve(keys);
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
}
return curve;
}
}
#endregion Serialization
@ -348,10 +371,11 @@ namespace NanoBrain {
#if UNITY_MATHEMATICS
public Func<float3, float3> Activator => this.curvePreset switch {
CurvePresets.Linear => ActivatorLinear,
CurvePresets.Sqrt => ActivatorSqrt,
CurvePresets.Power => ActivatorPower,
CurvePresets.Reciprocal => ActivatorReciprocal,
ActivationFunction.Linear => ActivatorLinear,
ActivationFunction.Sqrt => ActivatorSqrt,
ActivationFunction.Power => ActivatorPower,
ActivationFunction.Reciprocal => ActivatorReciprocal,
ActivationFunction.Tanh => ActivatorTanh,
_ => ActivatorCustom
};
@ -378,6 +402,12 @@ namespace NanoBrain {
return result;
}
protected float3 ActivatorTanh(float3 input) {
float magnitude = length(input);
float3 result = normalize(input) * MathF.Tanh(magnitude);
return result;
}
protected float3 ActivatorCustom(float3 input) {
float activatedValue = this.curve.Evaluate(length(input));
float3 result = normalize(input) * activatedValue;