using neuron types

This commit is contained in:
Pascal Serrarens 2026-02-11 09:45:21 +01:00
parent ed0a95b4d8
commit 8c8d5a5a66
3 changed files with 91 additions and 49 deletions

View File

@ -4,7 +4,6 @@ using UnityEngine;
[CreateAssetMenu(menuName = "Passer/Cluster")] [CreateAssetMenu(menuName = "Passer/Cluster")]
public class ClusterPrefab : ScriptableObject { public class ClusterPrefab : ScriptableObject {
// The ScriptableObject asset from which the runtime object has been created // The ScriptableObject asset from which the runtime object has been created
public string hello = "hello";
[SerializeReference] [SerializeReference]
public List<Nucleus> nuclei = new(); public List<Nucleus> nuclei = new();

View File

@ -77,7 +77,7 @@ public class ClusterInspector : Editor {
private bool expandArray = false; private bool expandArray = false;
ClusterWrapper currentWrapper; ClusterWrapper currentWrapper;
PopupField<string> outputsField; readonly PopupField<string> outputsField;
public GraphView(ClusterPrefab prefab) { public GraphView(ClusterPrefab prefab) {
this.prefab = prefab; this.prefab = prefab;
@ -414,37 +414,37 @@ public class ClusterInspector : Editor {
fontStyle = FontStyle.Bold, fontStyle = FontStyle.Bold,
}; };
//if (nucleus is Nucleus neuron) { //if (nucleus is Nucleus neuron) {
if (nucleus.array == null || nucleus.array.nuclei == null || nucleus.array.nuclei.Count() == 0) if (nucleus.array == null || nucleus.array.nuclei == null || nucleus.array.nuclei.Count() == 0)
nucleus.array = new NucleusArray(nucleus); nucleus.array = new NucleusArray(nucleus);
if ((!expandArray || nucleus.array.nuclei.First() != this.currentNucleus) && nucleus.array.nuclei.Count() > 1) { if ((!expandArray || nucleus.array.nuclei.First() != this.currentNucleus) && nucleus.array.nuclei.Count() > 1) {
Handles.Label(labelPosition, nucleus.array.nuclei.Count().ToString(), style); Handles.Label(labelPosition, nucleus.array.nuclei.Count().ToString(), style);
}
if (expandArray && nucleus.array.nuclei.First() == this.currentNucleus) {
int arrayIx = 0;
foreach (Nucleus n in nucleus.array.nuclei) {
if (n == nucleus)
break;
arrayIx++;
} }
if (expandArray && nucleus.array.nuclei.First() == this.currentNucleus) { Handles.Label(labelPosition, $"[{arrayIx}]", style);
int arrayIx = 0; }
foreach (Nucleus n in nucleus.array.nuclei) { else {
if (n == nucleus) style.alignment = TextAnchor.UpperCenter;
break; Vector3 labelPos = position - Vector3.down * (size + 10f); // below disc along up axis
arrayIx++; int colonPos = nucleus.name.IndexOf(":");
} if (colonPos > 0) {
Handles.Label(labelPosition, $"[{arrayIx}]", style); string baseName = nucleus.name[..colonPos];
} Handles.Label(labelPos, baseName, style);
else {
style.alignment = TextAnchor.UpperCenter;
Vector3 labelPos = position - Vector3.down * (size + 10f); // below disc along up axis
int colonPos = nucleus.name.IndexOf(":");
if (colonPos > 0) {
string baseName = nucleus.name[..colonPos];
Handles.Label(labelPos, baseName, style);
}
else
Handles.Label(labelPos, nucleus.name, style);
} }
else
Handles.Label(labelPos, nucleus.name, style);
}
if (nucleus is Cluster) { if (nucleus is Cluster) {
Handles.color = Color.white; Handles.color = Color.white;
Handles.DrawWireDisc(position, Vector3.forward, size + 10); Handles.DrawWireDisc(position, Vector3.forward, size + 10);
} }
// } // }
// else { // else {
// style.alignment = TextAnchor.UpperCenter; // style.alignment = TextAnchor.UpperCenter;
@ -528,6 +528,9 @@ public class ClusterInspector : Editor {
}; };
GUILayout.Label(this.currentNucleus.GetType().ToString(), headerStyle); GUILayout.Label(this.currentNucleus.GetType().ToString(), headerStyle);
if (this.currentNucleus is Neuron neuron1) {
neuron1.type = (Nucleus.Type)EditorGUILayout.EnumPopup(neuron1.type);
}
string newName = EditorGUILayout.TextField(this.currentNucleus.name, boldTextFieldStyle); string newName = EditorGUILayout.TextField(this.currentNucleus.name, boldTextFieldStyle);
if (newName != this.currentNucleus.name) { if (newName != this.currentNucleus.name) {
this.currentNucleus.name = newName; this.currentNucleus.name = newName;
@ -707,15 +710,24 @@ public class ClusterInspector : Editor {
protected virtual void DeleteNeuron(Nucleus nucleus) { protected virtual void DeleteNeuron(Nucleus nucleus) {
if (nucleus == null) if (nucleus == null)
return; return;
if (nucleus.cluster != null)
this.currentNucleus = nucleus.cluster.output;
foreach (Nucleus receiver in nucleus.receivers) { foreach (Nucleus receiver in nucleus.receivers) {
if (receiver != null) { if (receiver != null) {
this.currentNucleus = receiver; this.currentNucleus = receiver;
break; break;
} }
} }
this.prefab.nuclei.Remove(nucleus);
if (outputsField.value == nucleus.name) {
this.prefab.RefreshOutputs();
outputsField.choices = this.prefab.outputs.Select(output => output.name).ToList();
outputsField.index = 0;
}
Neuron.Delete(nucleus); Neuron.Delete(nucleus);
this.currentNucleus = this.prefab.output;
BuildLayers(); BuildLayers();
} }

View File

@ -26,6 +26,8 @@ public class Neuron : Nucleus {
#region Serialization #region Serialization
public Type type = Type.Neuron;
public enum CurvePresets { public enum CurvePresets {
Linear, Linear,
Power, Power,
@ -150,6 +152,7 @@ public class Neuron : Nucleus {
protected virtual void CloneFields(Neuron clone) { protected virtual void CloneFields(Neuron clone) {
clone.array = null; clone.array = null;
clone.bias = this.bias; clone.bias = this.bias;
clone.type = this.type;
clone.curve = this.curve; clone.curve = this.curve;
clone.curvePreset = this.curvePreset; clone.curvePreset = this.curvePreset;
clone.curveMax = this.curveMax; clone.curveMax = this.curveMax;
@ -181,28 +184,55 @@ public class Neuron : Nucleus {
} }
public override void UpdateStateIsolated() { public override void UpdateStateIsolated() {
Vector3 sum = this.bias; switch (this.type) {
int n = 0; case Type.Neuron:
UpdateSum();
break;
case Type.Pulsar:
UpdateProduct();
break;
default:
UpdateSum();
break;
}
// Vector3 sum = this.bias;
// int n = 0;
//Applying the weight factgors // //Applying the weight factgors
foreach (Synapse synapse in this.synapses) { // foreach (Synapse synapse in this.synapses) {
// sum += synapse.weight * synapse.nucleus.outputValue;
// // Perhaps synapses should be removed when the output value goes to 0....
// if (lengthsq(synapse.nucleus.outputValue) != 0) {
// n++;
// this.stale = 0;
// }
// }
// if (this.average && n > 0)
// sum /= n;
// // Activation function
// float3 result = Activation(sum);
// if (this.stale > staleValueForSleep)
// this.outputValue = new float3(0, 0, 0);
// else
// this.outputValue = result;
}
public void UpdateSum() {
Vector3 sum = this.bias;
foreach (Synapse synapse in this.synapses)
sum += synapse.weight * synapse.nucleus.outputValue; sum += synapse.weight * synapse.nucleus.outputValue;
// Perhaps synapses should be removed when the output value goes to 0.... this.outputValue = Activation(sum);
if (lengthsq(synapse.nucleus.outputValue) != 0) { }
n++;
this.stale = 0;
}
}
if (this.average && n > 0)
sum /= n;
// Activation function public void UpdateProduct() {
float3 result = Activation(sum); float3 product = this.bias;
if (this.stale > staleValueForSleep) foreach (Synapse synapse in this.synapses)
this.outputValue = new float3(0, 0, 0); product *= synapse.weight * synapse.nucleus.outputValue;
else
this.outputValue = result; this.outputValue = Activation(product);
} }
protected float3 Activation(float3 input) { protected float3 Activation(float3 input) {
@ -236,4 +266,5 @@ public class Neuron : Nucleus {
this.bias = inputValue; this.bias = inputValue;
this.parent.UpdateFromNucleus(this); this.parent.UpdateFromNucleus(this);
} }
} }