Added Tanh Activation
This commit is contained in:
parent
a99d40c5c9
commit
1c7b8e7940
File diff suppressed because it is too large
Load Diff
@ -10,16 +10,17 @@ namespace NanoBrain {
|
||||
public class ClusterViewer : Editor {
|
||||
|
||||
public class GraphView : VisualElement {
|
||||
readonly ClusterPrefab prefab;
|
||||
SerializedObject serializedBrain;
|
||||
Nucleus currentNucleus;
|
||||
GameObject gameObject;
|
||||
protected readonly ClusterPrefab prefab;
|
||||
protected SerializedObject serializedBrain;
|
||||
protected Nucleus currentNucleus;
|
||||
protected GameObject gameObject;
|
||||
private List<NeuroidLayer> layers = new();
|
||||
private readonly Dictionary<Nucleus, Vector2Int> neuroidPositions = new();
|
||||
private bool expandArray = false;
|
||||
|
||||
ClusterPrefab prefabAsset;
|
||||
readonly PopupField<string> outputsField;
|
||||
protected ClusterPrefab prefabAsset;
|
||||
protected VisualElement outputContainer;
|
||||
protected readonly PopupField<string> outputsField;
|
||||
|
||||
public GraphView(ClusterPrefab prefab) {
|
||||
this.prefab = prefab;
|
||||
@ -35,7 +36,7 @@ namespace NanoBrain {
|
||||
graphContainer.focusable = true;
|
||||
Add(graphContainer);
|
||||
|
||||
VisualElement outputContainer = new() {
|
||||
outputContainer = new() {
|
||||
style = {
|
||||
flexDirection = FlexDirection.Row,
|
||||
alignItems = Align.Center,
|
||||
@ -108,7 +109,7 @@ namespace NanoBrain {
|
||||
//DrawInspector(inspectorContainer);
|
||||
}
|
||||
|
||||
private void BuildLayers() {
|
||||
protected void BuildLayers() {
|
||||
// A temporary list to track what's been added to layers
|
||||
this.layers = new();
|
||||
int layerIx = 0;
|
||||
@ -534,4 +535,9 @@ namespace NanoBrain {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
public class NeuroidLayer {
|
||||
public int ix = 0;
|
||||
public List<Nucleus> neuroids = new();
|
||||
}
|
||||
}
|
||||
@ -61,16 +61,17 @@ namespace NanoBrain {
|
||||
/// <summary>
|
||||
/// The type of
|
||||
/// </summary>
|
||||
public enum CurvePresets {
|
||||
public enum ActivationFunction {
|
||||
Linear,
|
||||
Power,
|
||||
Sqrt,
|
||||
Reciprocal,
|
||||
Tanh,
|
||||
Custom
|
||||
}
|
||||
[SerializeField]
|
||||
public CurvePresets _curvePreset;
|
||||
public CurvePresets curvePreset {
|
||||
public ActivationFunction _curvePreset;
|
||||
public ActivationFunction curvePreset {
|
||||
get { return _curvePreset; }
|
||||
set {
|
||||
_curvePreset = value;
|
||||
@ -82,18 +83,21 @@ namespace NanoBrain {
|
||||
|
||||
public AnimationCurve GenerateCurve() {
|
||||
switch (this.curvePreset) {
|
||||
case CurvePresets.Linear:
|
||||
case ActivationFunction.Linear:
|
||||
this.curveMax = 1;
|
||||
return Presets.Linear(1);
|
||||
case CurvePresets.Power:
|
||||
case ActivationFunction.Power:
|
||||
this.curveMax = 1;
|
||||
return Presets.Power(2.0f, 1);
|
||||
case CurvePresets.Sqrt:
|
||||
case ActivationFunction.Sqrt:
|
||||
this.curveMax = 1;
|
||||
return Presets.Power(0.5f, 1);
|
||||
case CurvePresets.Reciprocal:
|
||||
case ActivationFunction.Reciprocal:
|
||||
this.curveMax = 1 / 0.01f * 1;
|
||||
return Presets.Reciprocal(1);
|
||||
case ActivationFunction.Tanh:
|
||||
this.curveMax = 1;
|
||||
return Presets.Tanh(1);
|
||||
default:
|
||||
this.curveMax = 1;
|
||||
return this.curve;
|
||||
@ -142,6 +146,25 @@ namespace NanoBrain {
|
||||
}
|
||||
return curve;
|
||||
}
|
||||
public static AnimationCurve Tanh(float weight) {
|
||||
//int samples = 128;
|
||||
float xMin = 0.001f;
|
||||
float xMax = 1;
|
||||
var keys = new Keyframe[samples];
|
||||
for (int i = 0; i < samples; i++) {
|
||||
float t = i / (float)(samples - 1);
|
||||
float x = Mathf.Lerp(xMin, xMax, t);
|
||||
float y = MathF.Tanh(x * weight);
|
||||
keys[i] = new Keyframe(x, y);
|
||||
}
|
||||
var curve = new AnimationCurve(keys);
|
||||
for (int i = 0; i < curve.length; i++) {
|
||||
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
||||
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
|
||||
}
|
||||
return curve;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endregion Serialization
|
||||
@ -348,10 +371,11 @@ namespace NanoBrain {
|
||||
#if UNITY_MATHEMATICS
|
||||
|
||||
public Func<float3, float3> Activator => this.curvePreset switch {
|
||||
CurvePresets.Linear => ActivatorLinear,
|
||||
CurvePresets.Sqrt => ActivatorSqrt,
|
||||
CurvePresets.Power => ActivatorPower,
|
||||
CurvePresets.Reciprocal => ActivatorReciprocal,
|
||||
ActivationFunction.Linear => ActivatorLinear,
|
||||
ActivationFunction.Sqrt => ActivatorSqrt,
|
||||
ActivationFunction.Power => ActivatorPower,
|
||||
ActivationFunction.Reciprocal => ActivatorReciprocal,
|
||||
ActivationFunction.Tanh => ActivatorTanh,
|
||||
_ => ActivatorCustom
|
||||
};
|
||||
|
||||
@ -378,6 +402,12 @@ namespace NanoBrain {
|
||||
return result;
|
||||
}
|
||||
|
||||
protected float3 ActivatorTanh(float3 input) {
|
||||
float magnitude = length(input);
|
||||
float3 result = normalize(input) * MathF.Tanh(magnitude);
|
||||
return result;
|
||||
}
|
||||
|
||||
protected float3 ActivatorCustom(float3 input) {
|
||||
float activatedValue = this.curve.Evaluate(length(input));
|
||||
float3 result = normalize(input) * activatedValue;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user