Added Tanh Activation
This commit is contained in:
parent
a99d40c5c9
commit
1c7b8e7940
File diff suppressed because it is too large
Load Diff
@ -10,16 +10,17 @@ namespace NanoBrain {
|
|||||||
public class ClusterViewer : Editor {
|
public class ClusterViewer : Editor {
|
||||||
|
|
||||||
public class GraphView : VisualElement {
|
public class GraphView : VisualElement {
|
||||||
readonly ClusterPrefab prefab;
|
protected readonly ClusterPrefab prefab;
|
||||||
SerializedObject serializedBrain;
|
protected SerializedObject serializedBrain;
|
||||||
Nucleus currentNucleus;
|
protected Nucleus currentNucleus;
|
||||||
GameObject gameObject;
|
protected GameObject gameObject;
|
||||||
private List<NeuroidLayer> layers = new();
|
private List<NeuroidLayer> layers = new();
|
||||||
private readonly Dictionary<Nucleus, Vector2Int> neuroidPositions = new();
|
private readonly Dictionary<Nucleus, Vector2Int> neuroidPositions = new();
|
||||||
private bool expandArray = false;
|
private bool expandArray = false;
|
||||||
|
|
||||||
ClusterPrefab prefabAsset;
|
protected ClusterPrefab prefabAsset;
|
||||||
readonly PopupField<string> outputsField;
|
protected VisualElement outputContainer;
|
||||||
|
protected readonly PopupField<string> outputsField;
|
||||||
|
|
||||||
public GraphView(ClusterPrefab prefab) {
|
public GraphView(ClusterPrefab prefab) {
|
||||||
this.prefab = prefab;
|
this.prefab = prefab;
|
||||||
@ -35,7 +36,7 @@ namespace NanoBrain {
|
|||||||
graphContainer.focusable = true;
|
graphContainer.focusable = true;
|
||||||
Add(graphContainer);
|
Add(graphContainer);
|
||||||
|
|
||||||
VisualElement outputContainer = new() {
|
outputContainer = new() {
|
||||||
style = {
|
style = {
|
||||||
flexDirection = FlexDirection.Row,
|
flexDirection = FlexDirection.Row,
|
||||||
alignItems = Align.Center,
|
alignItems = Align.Center,
|
||||||
@ -108,7 +109,7 @@ namespace NanoBrain {
|
|||||||
//DrawInspector(inspectorContainer);
|
//DrawInspector(inspectorContainer);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void BuildLayers() {
|
protected void BuildLayers() {
|
||||||
// A temporary list to track what's been added to layers
|
// A temporary list to track what's been added to layers
|
||||||
this.layers = new();
|
this.layers = new();
|
||||||
int layerIx = 0;
|
int layerIx = 0;
|
||||||
@ -534,4 +535,9 @@ namespace NanoBrain {
|
|||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public class NeuroidLayer {
|
||||||
|
public int ix = 0;
|
||||||
|
public List<Nucleus> neuroids = new();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
@ -61,16 +61,17 @@ namespace NanoBrain {
|
|||||||
/// <summary>
|
/// <summary>
|
||||||
/// The type of
|
/// The type of
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public enum CurvePresets {
|
public enum ActivationFunction {
|
||||||
Linear,
|
Linear,
|
||||||
Power,
|
Power,
|
||||||
Sqrt,
|
Sqrt,
|
||||||
Reciprocal,
|
Reciprocal,
|
||||||
|
Tanh,
|
||||||
Custom
|
Custom
|
||||||
}
|
}
|
||||||
[SerializeField]
|
[SerializeField]
|
||||||
public CurvePresets _curvePreset;
|
public ActivationFunction _curvePreset;
|
||||||
public CurvePresets curvePreset {
|
public ActivationFunction curvePreset {
|
||||||
get { return _curvePreset; }
|
get { return _curvePreset; }
|
||||||
set {
|
set {
|
||||||
_curvePreset = value;
|
_curvePreset = value;
|
||||||
@ -82,18 +83,21 @@ namespace NanoBrain {
|
|||||||
|
|
||||||
public AnimationCurve GenerateCurve() {
|
public AnimationCurve GenerateCurve() {
|
||||||
switch (this.curvePreset) {
|
switch (this.curvePreset) {
|
||||||
case CurvePresets.Linear:
|
case ActivationFunction.Linear:
|
||||||
this.curveMax = 1;
|
this.curveMax = 1;
|
||||||
return Presets.Linear(1);
|
return Presets.Linear(1);
|
||||||
case CurvePresets.Power:
|
case ActivationFunction.Power:
|
||||||
this.curveMax = 1;
|
this.curveMax = 1;
|
||||||
return Presets.Power(2.0f, 1);
|
return Presets.Power(2.0f, 1);
|
||||||
case CurvePresets.Sqrt:
|
case ActivationFunction.Sqrt:
|
||||||
this.curveMax = 1;
|
this.curveMax = 1;
|
||||||
return Presets.Power(0.5f, 1);
|
return Presets.Power(0.5f, 1);
|
||||||
case CurvePresets.Reciprocal:
|
case ActivationFunction.Reciprocal:
|
||||||
this.curveMax = 1 / 0.01f * 1;
|
this.curveMax = 1 / 0.01f * 1;
|
||||||
return Presets.Reciprocal(1);
|
return Presets.Reciprocal(1);
|
||||||
|
case ActivationFunction.Tanh:
|
||||||
|
this.curveMax = 1;
|
||||||
|
return Presets.Tanh(1);
|
||||||
default:
|
default:
|
||||||
this.curveMax = 1;
|
this.curveMax = 1;
|
||||||
return this.curve;
|
return this.curve;
|
||||||
@ -142,6 +146,25 @@ namespace NanoBrain {
|
|||||||
}
|
}
|
||||||
return curve;
|
return curve;
|
||||||
}
|
}
|
||||||
|
public static AnimationCurve Tanh(float weight) {
|
||||||
|
//int samples = 128;
|
||||||
|
float xMin = 0.001f;
|
||||||
|
float xMax = 1;
|
||||||
|
var keys = new Keyframe[samples];
|
||||||
|
for (int i = 0; i < samples; i++) {
|
||||||
|
float t = i / (float)(samples - 1);
|
||||||
|
float x = Mathf.Lerp(xMin, xMax, t);
|
||||||
|
float y = MathF.Tanh(x * weight);
|
||||||
|
keys[i] = new Keyframe(x, y);
|
||||||
|
}
|
||||||
|
var curve = new AnimationCurve(keys);
|
||||||
|
for (int i = 0; i < curve.length; i++) {
|
||||||
|
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
||||||
|
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
||||||
|
}
|
||||||
|
return curve;
|
||||||
|
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#endregion Serialization
|
#endregion Serialization
|
||||||
@ -348,10 +371,11 @@ namespace NanoBrain {
|
|||||||
#if UNITY_MATHEMATICS
|
#if UNITY_MATHEMATICS
|
||||||
|
|
||||||
public Func<float3, float3> Activator => this.curvePreset switch {
|
public Func<float3, float3> Activator => this.curvePreset switch {
|
||||||
CurvePresets.Linear => ActivatorLinear,
|
ActivationFunction.Linear => ActivatorLinear,
|
||||||
CurvePresets.Sqrt => ActivatorSqrt,
|
ActivationFunction.Sqrt => ActivatorSqrt,
|
||||||
CurvePresets.Power => ActivatorPower,
|
ActivationFunction.Power => ActivatorPower,
|
||||||
CurvePresets.Reciprocal => ActivatorReciprocal,
|
ActivationFunction.Reciprocal => ActivatorReciprocal,
|
||||||
|
ActivationFunction.Tanh => ActivatorTanh,
|
||||||
_ => ActivatorCustom
|
_ => ActivatorCustom
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -378,6 +402,12 @@ namespace NanoBrain {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected float3 ActivatorTanh(float3 input) {
|
||||||
|
float magnitude = length(input);
|
||||||
|
float3 result = normalize(input) * MathF.Tanh(magnitude);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
protected float3 ActivatorCustom(float3 input) {
|
protected float3 ActivatorCustom(float3 input) {
|
||||||
float activatedValue = this.curve.Evaluate(length(input));
|
float activatedValue = this.curve.Evaluate(length(input));
|
||||||
float3 result = normalize(input) * activatedValue;
|
float3 result = normalize(input) * activatedValue;
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user