Improve persisting changes

This commit is contained in:
Pascal Serrarens 2026-02-11 11:49:03 +01:00
parent 8c8d5a5a66
commit c1bf54a1cc
2 changed files with 253 additions and 183 deletions

View File

@ -18,9 +18,17 @@ public class ClusterInspector : Editor {
public override VisualElement CreateInspectorGUI() { public override VisualElement CreateInspectorGUI() {
ClusterPrefab prefab = target as ClusterPrefab; ClusterPrefab prefab = target as ClusterPrefab;
string path = AssetDatabase.GetAssetPath(prefab); // or known path
Debug.Log($"{path}");
ClusterPrefab currentWrapper = AssetDatabase.LoadAssetAtPath<ClusterPrefab>(path);
if (currentWrapper == null)
Debug.LogError("CreateInspectorGUI: Cluster Prefab is not found on disk");
if (prefab != null) if (prefab != null)
prefab.EnsureInitialization(); prefab.EnsureInitialization();
serializedObject.Update(); serializedObject.Update();
VisualElement root = new(); VisualElement root = new();
@ -76,7 +84,7 @@ public class ClusterInspector : Editor {
private readonly Dictionary<Nucleus, Vector2Int> neuroidPositions = new(); private readonly Dictionary<Nucleus, Vector2Int> neuroidPositions = new();
private bool expandArray = false; private bool expandArray = false;
ClusterWrapper currentWrapper; ClusterPrefab prefabAsset;
readonly PopupField<string> outputsField; readonly PopupField<string> outputsField;
public GraphView(ClusterPrefab prefab) { public GraphView(ClusterPrefab prefab) {
@ -168,9 +176,20 @@ public class ClusterInspector : Editor {
return; return;
} }
if (currentWrapper != null) // if (currentWrapper != null)
DestroyImmediate(currentWrapper); // DestroyImmediate(currentWrapper);
currentWrapper = CreateInstance<ClusterWrapper>().Init(this.currentNucleus, prefab); // currentWrapper = CreateInstance<ClusterWrapper>().Init(this.currentNucleus, prefab);
string path = AssetDatabase.GetAssetPath(this.prefab); // or known path
this.prefabAsset = AssetDatabase.LoadAssetAtPath<ClusterPrefab>(path);
if (this.prefabAsset == null) {
// create and save if it doesn't exist
this.prefabAsset = CreateInstance<ClusterPrefab>();
// AssetDatabase.CreateAsset(currentWrapper, "Assets/ClusterPrefab.asset");
// AssetDatabase.SaveAssets();
Debug.LogError("Cluster Prefab is not found on disk");
}
//currentWrapper.Init(this.currentNucleus, prefab);
DrawInspector(inspectorContainer); DrawInspector(inspectorContainer);
} }
@ -510,15 +529,23 @@ public class ClusterInspector : Editor {
return; return;
// create a SerializedObject wrapper so Unity inspector controls work (and Undo) // create a SerializedObject wrapper so Unity inspector controls work (and Undo)
SerializedObject so = new(currentWrapper); SerializedObject so = new(prefabAsset);
IMGUIContainer container = new(() => { IMGUIContainer container = new(() => InspectorHandler(so));
if (so.targetObject == null)
inspectorContainer.Add(container);
}
void InspectorHandler(SerializedObject serializedObject) {
bool anythingChanged = false;
if (serializedObject == null || serializedObject.targetObject == null)
return; return;
so.Update();
if (this.currentNucleus == null) if (this.currentNucleus == null)
return; return;
serializedObject.Update();
GUIStyle headerStyle = new(EditorStyles.boldLabel) { GUIStyle headerStyle = new(EditorStyles.boldLabel) {
alignment = TextAnchor.MiddleLeft, alignment = TextAnchor.MiddleLeft,
margin = new RectOffset(10, 0, 4, 4) margin = new RectOffset(10, 0, 4, 4)
@ -528,9 +555,6 @@ public class ClusterInspector : Editor {
}; };
GUILayout.Label(this.currentNucleus.GetType().ToString(), headerStyle); GUILayout.Label(this.currentNucleus.GetType().ToString(), headerStyle);
if (this.currentNucleus is Neuron neuron1) {
neuron1.type = (Nucleus.Type)EditorGUILayout.EnumPopup(neuron1.type);
}
string newName = EditorGUILayout.TextField(this.currentNucleus.name, boldTextFieldStyle); string newName = EditorGUILayout.TextField(this.currentNucleus.name, boldTextFieldStyle);
if (newName != this.currentNucleus.name) { if (newName != this.currentNucleus.name) {
this.currentNucleus.name = newName; this.currentNucleus.name = newName;
@ -554,8 +578,13 @@ public class ClusterInspector : Editor {
showSynapses = EditorGUILayout.BeginFoldoutHeaderGroup(showSynapses, "Synapses"); showSynapses = EditorGUILayout.BeginFoldoutHeaderGroup(showSynapses, "Synapses");
if (showSynapses) { if (showSynapses) {
ConnectNucleus(this.prefab, this.currentNucleus); ConnectNucleus(this.prefab, this.currentNucleus);
AddSynapse(this.prefab, this.currentNucleus); AddSynapse(this.prefab, this.currentNucleus);
EditorGUILayout.Space();
if (this.currentNucleus is Neuron neuron2)
neuron2.combinator = (Neuron.CombinatorType)EditorGUILayout.EnumPopup("Combinator", neuron2.combinator);
this.currentNucleus.bias = EditorGUILayout.Vector3Field("Bias", this.currentNucleus.bias); this.currentNucleus.bias = EditorGUILayout.Vector3Field("Bias", this.currentNucleus.bias);
@ -597,10 +626,11 @@ public class ClusterInspector : Editor {
} }
EditorGUILayout.EndFoldoutHeaderGroup(); EditorGUILayout.EndFoldoutHeaderGroup();
// Activation
EditorGUILayout.Space(); EditorGUILayout.Space();
showActivation = EditorGUILayout.BeginFoldoutHeaderGroup(showActivation, "Activation"); showActivation = EditorGUILayout.BeginFoldoutHeaderGroup(showActivation, "Activation");
if (showActivation) { if (showActivation) {
if (this.currentNucleus is Neuron neuron) { if (this.currentNucleus is Neuron neuron) {
if (this.currentNucleus is not MemoryCell) { if (this.currentNucleus is not MemoryCell) {
EditorGUILayout.BeginHorizontal(); EditorGUILayout.BeginHorizontal();
@ -616,10 +646,16 @@ public class ClusterInspector : Editor {
neuron.array = new NucleusArray(neuron); neuron.array = new NucleusArray(neuron);
EditorGUILayout.BeginHorizontal(); EditorGUILayout.BeginHorizontal();
EditorGUILayout.IntField("Array size", neuron.array.nuclei.Count()); EditorGUILayout.IntField("Array size", neuron.array.nuclei.Count());
if (GUILayout.Button("Add")) if (GUILayout.Button("Add")) {
Undo.RecordObject(prefabAsset, "Array add " + prefabAsset.name);
neuron.array.AddNucleus(this.prefab); neuron.array.AddNucleus(this.prefab);
if (GUILayout.Button("Del")) anythingChanged = true;
}
if (GUILayout.Button("Del")) {
Undo.RecordObject(prefabAsset, "Array delete " + prefabAsset.name);
neuron.array.RemoveNucleus(); neuron.array.RemoveNucleus();
anythingChanged = true;
}
EditorGUILayout.EndHorizontal(); EditorGUILayout.EndHorizontal();
} }
@ -645,11 +681,12 @@ public class ClusterInspector : Editor {
trace = EditorGUILayout.Toggle("Trace", trace); trace = EditorGUILayout.Toggle("Trace", trace);
this.currentNucleus.trace = trace; this.currentNucleus.trace = trace;
}); serializedObject.ApplyModifiedProperties();
if (anythingChanged) {
inspectorContainer.Add(container); EditorUtility.SetDirty(prefabAsset);
AssetDatabase.SaveAssets();
}
} }
void OnSceneGUI(SceneView sceneView) { void OnSceneGUI(SceneView sceneView) {
if (this.gameObject != null) { if (this.gameObject != null) {
@ -829,28 +866,28 @@ public class NeuroidLayer {
public List<Nucleus> neuroids = new(); public List<Nucleus> neuroids = new();
} }
public class ClusterWrapper : ScriptableObject { // public class ClusterWrapper : ScriptableObject {
// expose fields that map to GraphNode // // expose fields that map to GraphNode
//public string title; // //public string title;
public Vector2 position; // public Vector2 position;
Nucleus node; // Nucleus node;
ClusterPrefab graph; // needed to write back and mark dirty // ClusterPrefab graph; // needed to write back and mark dirty
public ClusterWrapper Init(Nucleus node, ClusterPrefab graphAsset) { // public ClusterWrapper Init(Nucleus node, ClusterPrefab graphAsset) {
this.node = node; // this.node = node;
this.graph = graphAsset; // this.graph = graphAsset;
//this.title = " A " + node.name; // //this.title = " A " + node.name;
//position = node.position; // //position = node.position;
return this; // return this;
} // }
void OnValidate() { // void OnValidate() {
if (node != null) { // if (node != null) {
//node.name = title; // //node.name = title;
//node.position = position; // //node.position = position;
#if UNITY_EDITOR // #if UNITY_EDITOR
if (graph != null) // if (graph != null)
UnityEditor.EditorUtility.SetDirty(graph); // UnityEditor.EditorUtility.SetDirty(graph);
#endif // #endif
} // }
} // }
} // }

View File

@ -26,7 +26,13 @@ public class Neuron : Nucleus {
#region Serialization #region Serialization
public Type type = Type.Neuron; //public Type type = Type.Neuron;
public enum CombinatorType {
Sum,
Product,
Max
}
public CombinatorType combinator = CombinatorType.Sum;
public enum CurvePresets { public enum CurvePresets {
Linear, Linear,
@ -150,9 +156,9 @@ public class Neuron : Nucleus {
} }
protected virtual void CloneFields(Neuron clone) { protected virtual void CloneFields(Neuron clone) {
clone.array = null; clone.array = this.array;
clone.bias = this.bias; clone.bias = this.bias;
clone.type = this.type; clone.combinator = this.combinator;
clone.curve = this.curve; clone.curve = this.curve;
clone.curvePreset = this.curvePreset; clone.curvePreset = this.curvePreset;
clone.curveMax = this.curveMax; clone.curveMax = this.curveMax;
@ -184,17 +190,19 @@ public class Neuron : Nucleus {
} }
public override void UpdateStateIsolated() { public override void UpdateStateIsolated() {
switch (this.type) { float3 result = CombinatorAction();
case Type.Neuron: this.outputValue = Activation(result);
UpdateSum(); // switch (this.type) {
break; // case Type.Neuron:
case Type.Pulsar: // UpdateSum();
UpdateProduct(); // break;
break; // case Type.Pulsar:
default: // UpdateProduct();
UpdateSum(); // break;
break; // default:
} // UpdateSum();
// break;
// }
// Vector3 sum = this.bias; // Vector3 sum = this.bias;
// int n = 0; // int n = 0;
@ -219,20 +227,45 @@ public class Neuron : Nucleus {
// this.outputValue = result; // this.outputValue = result;
} }
public void UpdateSum() { private Func<float3> CombinatorAction => combinator switch {
CombinatorType.Sum => UpdateSum,
CombinatorType.Product => UpdateProduct,
CombinatorType.Max => UpdateMax,
_ => UpdateSum
};
public float3 UpdateSum() {
Vector3 sum = this.bias; Vector3 sum = this.bias;
foreach (Synapse synapse in this.synapses) foreach (Synapse synapse in this.synapses)
sum += synapse.weight * synapse.nucleus.outputValue; sum += synapse.weight * synapse.nucleus.outputValue;
return sum;
this.outputValue = Activation(sum); //this.outputValue = Activation(sum);
} }
public void UpdateProduct() { public float3 UpdateProduct() {
float3 product = this.bias; float3 product = this.bias;
foreach (Synapse synapse in this.synapses) foreach (Synapse synapse in this.synapses)
product *= synapse.weight * synapse.nucleus.outputValue; product *= synapse.weight * synapse.nucleus.outputValue;
return product;
//this.outputValue = Activation(product);
}
this.outputValue = Activation(product); public float3 UpdateMax() {
float3 max = this.bias;
float maxSqrLength = lengthsq(max);
//Applying the weight factors
foreach (Synapse synapse in this.synapses) {
float3 input = synapse.weight * synapse.nucleus.outputValue;
float inputSqrlength = lengthsq(input);
if (inputSqrlength > maxSqrLength) {
max = input;
maxSqrLength = inputSqrlength;
}
}
return max;
} }
protected float3 Activation(float3 input) { protected float3 Activation(float3 input) {