From ebef711981e3d6e53f0eeb9f8f3aa3821132610d Mon Sep 17 00:00:00 2001 From: Pascal Serrarens Date: Tue, 19 May 2026 12:45:11 +0200 Subject: [PATCH] Migrated ClusterEditor to ClusterView --- Editor/Brain_Editor.cs | 4 +- Editor/ClusterEditor.cs | 1154 ++++++++++++++------------- Editor/ClusterPrefab_Drawer.cs | 62 +- Editor/ClusterView.cs | 32 +- Editor/ClusterViewer.cs | 1374 ++++++++++++++++---------------- Runtime/Scripts/Brain.cs | 5 +- 6 files changed, 1399 insertions(+), 1232 deletions(-) diff --git a/Editor/Brain_Editor.cs b/Editor/Brain_Editor.cs index af4eddc..47e3a36 100644 --- a/Editor/Brain_Editor.cs +++ b/Editor/Brain_Editor.cs @@ -1,3 +1,4 @@ +/* using UnityEditor; using UnityEditor.UIElements; @@ -70,4 +71,5 @@ namespace NanoBrain.Unity { } } -} \ No newline at end of file +} +*/ \ No newline at end of file diff --git a/Editor/ClusterEditor.cs b/Editor/ClusterEditor.cs index 31cc63b..6124bf8 100644 --- a/Editor/ClusterEditor.cs +++ b/Editor/ClusterEditor.cs @@ -8,585 +8,663 @@ using UnityEngine.UIElements; namespace NanoBrain.Unity { [CustomEditor(typeof(ClusterPrefab))] - public class ClusterEditor : ClusterViewer { + public class ClusterEditor : Editor { + const float drawAreaWidth = 300f; // adjust as needed + const float padding = 6f; + ClusterPrefab clusterPrefab; + Nucleus currentNucleus { + get { return clusterView.currentNucleus; } + set { clusterView.currentNucleus = value; } + } + Cluster currentCluster => clusterView.currentCluster; + protected Nucleus selectedOutput; + ClusterView clusterView; - public override VisualElement CreateInspectorGUI() { - ClusterPrefab prefab = target as ClusterPrefab; - if (prefab != null) - prefab.EnsureInitialization(); - - serializedObject.Update(); - - VisualElement root = new(); - CreateEditor(root, prefab, null); - - serializedObject.ApplyModifiedProperties(); - return root; + void OnEnable() { + clusterPrefab = (ClusterPrefab)target; + clusterView = ClusterView.GetClusterView(serializedObject); + clusterView.currentCluster ??= clusterPrefab.cluster; + clusterView.currentNucleus = clusterPrefab.cluster.defaultOutput; } - public GraphView CreateEditor(VisualElement root, ClusterPrefab cluster, GameObject gameObject) { - root.style.paddingLeft = 0; - root.style.paddingRight = 0; - root.style.paddingTop = 0; - root.style.paddingBottom = 0; + public override void OnInspectorGUI() { + // Begin horizontal split + EditorGUILayout.BeginHorizontal(); - root.styleSheets.Add(Resources.Load("GraphStyles")); + // Left: fixed-width drawing area + GUILayoutOption[] leftOptions = { GUILayout.Width(drawAreaWidth) }; + Rect drawRect = GUILayoutUtility.GetRect(drawAreaWidth, 450f, leftOptions); // height adjustable - VisualElement mainContainer = new() { - style = { - flexDirection = FlexDirection.Row, - } - }; - GraphEditor graphContainer = new(cluster); - graphContainer.style.flexShrink = 0; - graphContainer.style.width = 300; - graphContainer.style.overflow = Overflow.Hidden; + // add padding inside rect + Rect innerRect = new(drawRect.x + padding, drawRect.y + padding, + drawRect.width - padding * 2, drawRect.height - padding * 2); - VisualElement inspectorContainer = new() { - name = "inspector", - style = { - minHeight = 450, - width = 300, - flexGrow = 0, - flexDirection = FlexDirection.Row, - } - }; + clusterView.Render(innerRect); - mainContainer.Add(graphContainer); - mainContainer.Add(inspectorContainer); - root.Add(mainContainer); + // Right: info panel (takes remaining width) + EditorGUILayout.BeginVertical(GUILayout.ExpandWidth(true)); + float prevLabelWidth = EditorGUIUtility.labelWidth; + EditorGUIUtility.labelWidth = 100f; // smaller labels -> larger fields - graphContainer.SetGraph(gameObject, inspectorContainer); + InspectorHandler(serializedObject); - return graphContainer; + EditorGUIUtility.labelWidth = prevLabelWidth; + EditorGUILayout.EndVertical(); // end right column + EditorGUILayout.EndHorizontal(); // end split } - public class GraphEditor : GraphView { + // public override void OnInspectorGUI() { + // float totalWidth = EditorGUIUtility.currentViewWidth; + // float leftW = drawAreaWidth; + // float rightW = Mathf.Max(80f, totalWidth - leftW - padding); - protected ClusterPrefab prefab; - //protected Nucleus currentPrefabNucleus; + // Rect row = GUILayoutUtility.GetRect(totalWidth, 450f); // request full width + // Rect leftRect = new Rect(row.x, row.y, leftW, row.height); + // Rect rightRect = new Rect(row.x + leftW + padding, row.y, rightW, 450f); - protected override Nucleus currentNucleus { - get => base.currentNucleus; - set { - base.currentNucleus = value; - // this.currentPrefabNucleus = value != null ? this.prefab.GetNucleus(value.name) : null; - } - } + // Rect innerLeft = new Rect(leftRect.x + padding, leftRect.y + padding, + // leftRect.width - padding*2, leftRect.height - padding*2); + // clusterView.Render(innerLeft); - public GraphEditor(ClusterPrefab prefab) : base(prefab.cluster.defaultOutput.parent) { - this.prefab = prefab; + // GUILayout.BeginArea(rightRect); + // float prev = EditorGUIUtility.labelWidth; + // EditorGUIUtility.labelWidth = 100f; + // InspectorHandler(serializedObject); + // EditorGUIUtility.labelWidth = prev; + // GUILayout.EndArea(); + // } + //} + /* + public class ClusterEditor : ClusterViewer { - // In a Prefab editor, no instance exists but we need it for the ClusterViewer. - // So we create a temporary instance - //this.currentCluster = new(prefab); - this.currentCluster = prefab.cluster; - this.currentCluster.Refresh(); - } - - public void SetGraph(GameObject gameObject, VisualElement inspectorContainer) { - this.gameObject = gameObject; - - if (Application.isPlaying == false) - this.serializedBrain = new SerializedObject(this.prefab); - this.selectedOutput = this.currentCluster.defaultOutput; - this.currentNucleus = this.selectedOutput; - //this.currentCluster = this.currentNucleus.parent; - Rebuild(inspectorContainer); - // if (outputsPopup != null) - // OnOutputChanged(outputsPopup.choices[0]); - } - - private void Rebuild(VisualElement inspectorContainer) { - - if (this.currentNucleus == null) { - inspectorContainer.Clear(); - return; - } - - string path = AssetDatabase.GetAssetPath(this.prefab); // or known path - this.prefabAsset = AssetDatabase.LoadAssetAtPath(path); - if (this.prefabAsset == null) { - // create in memory save if it doesn't exist - this.prefabAsset = CreateInstance(); - //Debug.LogError("Cluster Prefab is not found on disk"); - } - // DrawInspector(inspectorContainer); - if (inspectorContainer == null) - return; - - inspectorContainer.Clear(); - if (this.currentNucleus == null) - return; - - // create a SerializedObject wrapper so Unity inspector controls work (and Undo) - SerializedObject so = new(prefabAsset); - - // foreach (Nucleus nucleus in this.prefab.cluster.nuclei) { - // nucleus.Initialize(); - // } - - this.inspectorIMGUIContainer = new IMGUIContainer(() => InspectorHandler(so)); - - inspectorContainer.Add(inspectorIMGUIContainer); - } - - #region Inspector - - private VisualElement inspectorIMGUIContainer; - private bool showSynapses = true; - private bool showActivation = true; - protected bool breakOnWake = false; - protected bool trace = false; - - void InspectorHandler(SerializedObject serializedObject) { - bool anythingChanged = false; - - if (serializedObject == null || serializedObject.targetObject == null) - return; + public override VisualElement CreateInspectorGUI() { + ClusterPrefab prefab = target as ClusterPrefab; + if (prefab != null) + prefab.EnsureInitialization(); serializedObject.Update(); - GUIStyle boldTextFieldStyle = new(EditorStyles.textField) { - fontStyle = FontStyle.Bold - }; - - if (this.currentNucleus == null) { - OutputsInspector(ref anythingChanged); - return; - } - else { - GUIStyle headerStyle = new(EditorStyles.boldLabel) { - alignment = TextAnchor.MiddleLeft, - margin = new RectOffset(10, 0, 4, 4) - }; - // Nucleus type - string nucleusType = this.currentNucleus.GetType().Name; - GUILayout.Label(nucleusType, headerStyle); - - // Nucleus name - string newName = EditorGUILayout.TextField(this.currentNucleus.name, boldTextFieldStyle); - if (newName != this.currentNucleus.name) { - this.currentNucleus.name = newName; - anythingChanged = true; - } - - // Current output value - if (Application.isPlaying) { - if (currentNucleus is Neuron currentNeuron1) { - GUIContent nameLabel = new("Output", currentNeuron1.outputValue.ToString()); - EditorGUILayout.FloatField(nameLabel, currentNeuron1.outputMagnitude); - } - else - EditorGUILayout.LabelField(" "); - } - else - EditorGUILayout.LabelField(" "); - - // Memory cell - if (this.currentNucleus is MemoryCell memory) - MemoryCellInspector(memory, ref anythingChanged); - // Cluster - else if (this.currentNucleus is Cluster cluster) - ClusterInspector(cluster, ref anythingChanged); - // Other - else - NucleusInspector(this.currentNucleus, ref anythingChanged); - - if (GUILayout.Button("Delete")) - DeleteNucleus(this.currentNucleus); - } + VisualElement root = new(); + CreateEditor(root, prefab, null); serializedObject.ApplyModifiedProperties(); - if (anythingChanged) { - EditorUtility.SetDirty(prefabAsset); - AssetDatabase.SaveAssets(); - } + return root; } - protected void OutputsInspector(ref bool anythingChanged) { + public GraphView CreateEditor(VisualElement root, ClusterPrefab cluster, GameObject gameObject) { + root.style.paddingLeft = 0; + root.style.paddingRight = 0; + root.style.paddingTop = 0; + root.style.paddingBottom = 0; + + root.styleSheets.Add(Resources.Load("GraphStyles")); + + VisualElement mainContainer = new() { + style = { + flexDirection = FlexDirection.Row, + } + }; + GraphEditor graphContainer = new(cluster); + graphContainer.style.flexShrink = 0; + graphContainer.style.width = 300; + graphContainer.style.overflow = Overflow.Hidden; + + VisualElement inspectorContainer = new() { + name = "inspector", + style = { + minHeight = 450, + width = 300, + flexGrow = 0, + flexDirection = FlexDirection.Row, + } + }; + + mainContainer.Add(graphContainer); + mainContainer.Add(inspectorContainer); + root.Add(mainContainer); + + graphContainer.SetGraph(gameObject, inspectorContainer); + + return graphContainer; + } + + public class GraphEditor : GraphView { + + protected ClusterPrefab prefab; + //protected Nucleus currentPrefabNucleus; + + protected override Nucleus currentNucleus { + get => base.currentNucleus; + set { + base.currentNucleus = value; + // this.currentPrefabNucleus = value != null ? this.prefab.GetNucleus(value.name) : null; + } + } + + public GraphEditor(ClusterPrefab prefab) : base(prefab.cluster.defaultOutput.parent) { + this.prefab = prefab; + + // In a Prefab editor, no instance exists but we need it for the ClusterViewer. + // So we create a temporary instance + //this.currentCluster = new(prefab); + this.currentCluster = prefab.cluster; + this.currentCluster.Refresh(); + } + + public void SetGraph(GameObject gameObject, VisualElement inspectorContainer) { + this.gameObject = gameObject; + + if (Application.isPlaying == false) + this.serializedBrain = new SerializedObject(this.prefab); + this.selectedOutput = this.currentCluster.defaultOutput; + this.currentNucleus = this.selectedOutput; + //this.currentCluster = this.currentNucleus.parent; + Rebuild(inspectorContainer); + // if (outputsPopup != null) + // OnOutputChanged(outputsPopup.choices[0]); + } + + private void Rebuild(VisualElement inspectorContainer) { + + if (this.currentNucleus == null) { + inspectorContainer.Clear(); + return; + } + + string path = AssetDatabase.GetAssetPath(this.prefab); // or known path + this.prefabAsset = AssetDatabase.LoadAssetAtPath(path); + if (this.prefabAsset == null) { + // create in memory save if it doesn't exist + this.prefabAsset = CreateInstance(); + //Debug.LogError("Cluster Prefab is not found on disk"); + } + // DrawInspector(inspectorContainer); + if (inspectorContainer == null) + return; + + inspectorContainer.Clear(); + if (this.currentNucleus == null) + return; + + // create a SerializedObject wrapper so Unity inspector controls work (and Undo) + SerializedObject so = new(prefabAsset); + + // foreach (Nucleus nucleus in this.prefab.cluster.nuclei) { + // nucleus.Initialize(); + // } + + this.inspectorIMGUIContainer = new IMGUIContainer(() => InspectorHandler(so)); + + inspectorContainer.Add(inspectorIMGUIContainer); + } + */ + #region Inspector + + //private VisualElement inspectorIMGUIContainer; + private bool showSynapses = true; + private bool showActivation = true; + protected bool breakOnWake = false; + protected bool trace = false; + + void InspectorHandler(SerializedObject serializedObject) { + bool anythingChanged = false; + + if (serializedObject == null || serializedObject.targetObject == null) + return; + + serializedObject.Update(); + + GUIStyle boldTextFieldStyle = new(EditorStyles.textField) { + fontStyle = FontStyle.Bold + }; + + if (this.currentNucleus == null) { + OutputsInspector(ref anythingChanged); + return; + } + else { GUIStyle headerStyle = new(EditorStyles.boldLabel) { alignment = TextAnchor.MiddleLeft, margin = new RectOffset(10, 0, 4, 4) }; - GUILayout.Label("Outputs", headerStyle); + // Nucleus type + string nucleusType = this.currentNucleus.GetType().Name; + GUILayout.Label(nucleusType, headerStyle); - bool connecting = GUILayout.Button("Add Output Neuron"); - if (connecting) { - Nucleus newOutput = new Neuron(this.currentCluster, "New Output"); - this.currentCluster.Refresh(); - this.currentNucleus = newOutput; - this.selectedOutput = this.currentNucleus; + // Nucleus name + string newName = EditorGUILayout.TextField(this.currentNucleus.name, boldTextFieldStyle); + if (newName != this.currentNucleus.name) { + this.currentNucleus.name = newName; + anythingChanged = true; } - } - protected void MemoryCellInspector(MemoryCell memoryCell, ref bool anythingChanged) { - //memoryCell.staticMemory = EditorGUILayout.Toggle("Static Memory", memoryCell.staticMemory); - NucleusInspector(memoryCell, ref anythingChanged); - } - - protected void ClusterInspector(Cluster cluster, ref bool anythingChanged) { - EditorGUILayout.BeginHorizontal(); - - int instanceCount = cluster.instanceCount; - if (instanceCount <= 1) { - if (cluster.instances != null && cluster.instances.Length > 1) - instanceCount = cluster.instances.Count(); + // Current output value + if (Application.isPlaying) { + if (currentNucleus is Neuron currentNeuron1) { + GUIContent nameLabel = new("Output", currentNeuron1.outputValue.ToString()); + EditorGUILayout.FloatField(nameLabel, currentNeuron1.outputMagnitude); + } else - instanceCount = 1; + EditorGUILayout.LabelField(" "); } - EditorGUILayout.IntField("Instances", instanceCount, GUILayout.MinWidth(150)); + else + EditorGUILayout.LabelField(" "); - if (GUILayout.Button("Add")) { - Undo.RecordObject(prefabAsset, "Array add " + prefabAsset.name); - cluster.AddInstance(); - anythingChanged = true; - } - if (GUILayout.Button("Del")) { - Undo.RecordObject(prefabAsset, "Array delete " + prefabAsset.name); - cluster.RemoveInstance(); - anythingChanged = true; - } - EditorGUILayout.EndHorizontal(); + // Memory cell + if (this.currentNucleus is MemoryCell memory) + MemoryCellInspector(memory, ref anythingChanged); + // Cluster + else if (this.currentNucleus is Cluster cluster) + ClusterInspector(cluster, ref anythingChanged); + // Other + else + NucleusInspector(this.currentNucleus, ref anythingChanged); - // if (GUILayout.Button("Reimport Cluster")) - // ReimportCluster(cluster); + if (GUILayout.Button("Delete")) + DeleteNucleus(this.currentNucleus); } - protected void NucleusInspector(Nucleus nucleus, ref bool anythingChanged) { - SynapsesInspector(ref anythingChanged); - ActivationInspector(ref anythingChanged); - - EditorGUILayout.Space(); - breakOnWake = EditorGUILayout.Toggle("Break on wake", breakOnWake); - if (breakOnWake && this.currentNucleus is Neuron currentNeuron) { - if (currentNeuron.isSleeping == false) - Debug.Break(); - // trace = EditorGUILayout.Toggle("Trace", trace); - // currentNeuron.trace = trace; - } + serializedObject.ApplyModifiedProperties(); + if (anythingChanged) { + EditorUtility.SetDirty(clusterPrefab); + AssetDatabase.SaveAssets(); } + } - protected void SynapsesInspector(ref bool anythingChanged) { - showSynapses = EditorGUILayout.BeginFoldoutHeaderGroup(showSynapses, "Synapses"); - if (showSynapses) { - if (this.currentNucleus is Neuron neuron2) { - Neuron.CombinatorType newCombinator = (Neuron.CombinatorType)EditorGUILayout.EnumPopup("Combinator", neuron2.combinator); - anythingChanged |= newCombinator != neuron2.combinator; - neuron2.combinator = newCombinator; + protected void OutputsInspector(ref bool anythingChanged) { + GUIStyle headerStyle = new(EditorStyles.boldLabel) { + alignment = TextAnchor.MiddleLeft, + margin = new RectOffset(10, 0, 4, 4) + }; + GUILayout.Label("Outputs", headerStyle); - EditorGUIUtility.wideMode = true; - float previousLabelWidth = EditorGUIUtility.labelWidth; - EditorGUIUtility.labelWidth = 100; - - Vector3 newBias = EditorGUILayout.Vector3Field("Bias", neuron2.bias); - if (newBias != neuron2.bias) { - anythingChanged |= newBias != neuron2.bias; - neuron2.bias = newBias; - } - EditorGUIUtility.labelWidth = previousLabelWidth; - } - - Nucleus[] array = null; - int elementIx = -1; - if (this.currentNucleus is Neuron currentNeuron && currentNeuron.synapses.Count > 0) { - Synapse[] synapses = currentNeuron.synapses.ToArray(); - foreach (Synapse synapse in synapses) { - if (synapse.neuron == null) - continue; - - if (array != null) { - if (synapse.neuron.parent is Cluster iCluster && elementIx > 0) { - int thisElementIx = Cluster.GetNucleusIndex(iCluster.nuclei, synapse.neuron); - if (thisElementIx == elementIx) - continue; - else - elementIx = thisElementIx; - } - if (array.Contains(synapse.neuron)) - continue; - else if (array.Contains(synapse.neuron.parent)) - continue; - } - else { - if (synapse.neuron.parent is Cluster iReceptor) { - array = iReceptor.instances; - if (iReceptor is Cluster iCluster) - elementIx = Cluster.GetNucleusIndex(iCluster.nuclei, synapse.neuron); - } - } - - EditorGUILayout.Space(); - - if (Application.isPlaying) { - if (synapse.neuron is Neuron synapseNeuron) { - Vector3 value = synapseNeuron.outputValue * synapse.weight; - GUIContent synapseValueLabel = new(synapse.neuron.name, synapseNeuron.outputValue.ToString()); - EditorGUILayout.FloatField(synapseValueLabel, synapseNeuron.outputMagnitude); - } - } - else { - EditorGUILayout.BeginHorizontal(); - - if (synapse.neuron.parent != this.currentNucleus.parent) { - // If it is a different cluster - GUIStyle labelStyle = new(GUI.skin.label); - float labelWidth = 200; - if (synapse.neuron.parent != null) { - labelWidth = labelStyle.CalcSize(new GUIContent($"{synapse.neuron.parent.name}.")).x; - GUILayout.Label($"{synapse.neuron.parent.name}", GUILayout.Width(labelWidth)); - } - string[] options = synapse.neuron.parent.nuclei.Select(n => n.name).ToArray(); - int selectedIndex = System.Array.IndexOf(options, synapse.neuron.name); - int newIndex = EditorGUILayout.Popup(selectedIndex, options); - if (newIndex != selectedIndex) { - Neuron newNeuron = synapse.neuron.parent.nuclei[newIndex] as Neuron; - ChangeSynapse(synapse, newNeuron); - } - } - else - GUILayout.Label(synapse.neuron.name); - - bool disconnecting = GUILayout.Button("Disconnect", GUILayout.Width(80)); - if (disconnecting) { - synapse.neuron.RemoveReceiver(this.currentNucleus); - this.currentCluster.Refresh(); - anythingChanged = true; - } - EditorGUILayout.EndHorizontal(); - } - - EditorGUI.indentLevel++; - float newWeight = EditorGUILayout.FloatField("Weight", synapse.weight); - if (newWeight != synapse.weight) { - synapse.weight = newWeight; - anythingChanged = true; - } - EditorGUI.indentLevel--; - } - } - - EditorGUILayout.Space(); - anythingChanged |= ConnectNucleus(this.prefab, this.currentNucleus); - anythingChanged |= AddSynapse(this.prefab, this.currentNucleus); - } - EditorGUILayout.EndFoldoutHeaderGroup(); - } - - protected void ActivationInspector(ref bool anythingChanged) { - EditorGUILayout.Space(); - showActivation = EditorGUILayout.BeginFoldoutHeaderGroup(showActivation, "Activation"); - if (showActivation) { - if (this.currentNucleus is Neuron neuron) { - if (this.currentNucleus is not MemoryCell) { - EditorGUILayout.BeginHorizontal(); - EditorGUILayout.LabelField("Activation Curve", GUILayout.MinWidth(60)); - if (neuron.curveMax > 0) - EditorGUILayout.CurveField(neuron.curve, Color.cyan, new Rect(0, 0, 1, neuron.curveMax), GUILayout.Width(40)); - else - EditorGUILayout.CurveField(neuron.curve, Color.cyan, new Rect(0, neuron.curveMax, 1, -neuron.curveMax), GUILayout.Width(40)); - Neuron.ActivationType newPreset = (Neuron.ActivationType)EditorGUILayout.EnumPopup(neuron.activator, GUILayout.MinWidth(50)); - anythingChanged |= newPreset != neuron.activator; - neuron.activator = newPreset; - EditorGUILayout.EndHorizontal(); - } - // if (neuron is Receptor receptor2) { - // if (receptor2.nucleiArray == null || receptor2.nucleiArray.Count() == 0) - // receptor2.array = new NucleusArray(neuron); - // } - } - - EditorGUILayout.Space(); - } - EditorGUILayout.EndFoldoutHeaderGroup(); - } - - #region Synapses - - protected virtual void AddInput(Nucleus.Type selectedType, Nucleus nucleus) { - switch (selectedType) { - case Nucleus.Type.Neuron: - AddNeuronInput(nucleus); - break; - case Nucleus.Type.MemoryCell: - AddMemoryCellInput(nucleus); - break; - case Nucleus.Type.Cluster: - AddClusterInput(nucleus); - break; - default: - break; - } - } - - protected virtual void AddNeuronInput(Nucleus nucleus) { - Neuron newNeuron = new(this.currentCluster, "New Neuron"); - //Neuron newNeuroid = new(this.prefab.cluster, "New neuron"); - newNeuron.AddReceiver(nucleus); - this.currentNucleus = newNeuron; - } - - protected virtual void AddMemoryCellInput(Nucleus nucleus) { - MemoryCell newMemory = new(this.prefab.cluster, "New memory cell"); - newMemory.AddReceiver(nucleus); - this.currentNucleus = newMemory; - } - - protected virtual void AddClusterInput(Nucleus nucleus) { - ClusterPickerWindow.ShowPicker(brain => OnClusterPicked(nucleus, brain), "Select Cluster"); - } - private void OnClusterPicked(Nucleus nucleus, ClusterPrefab selectedPrefab) { - Cluster subclusterInstance = new(selectedPrefab, this.currentCluster); - subclusterInstance.defaultOutput.AddReceiver(nucleus); - } - - // private void ReimportCluster(Cluster subCluster) { - // if (subCluster.siblingClusters == null || subCluster.siblingClusters.Length <= 0) { - // Cluster reimportedCluster = new(subCluster.prefab, this.prefab); - // subCluster.MoveReceivers(reimportedCluster); - // // subcluster should be garbage now... - // this.currentNucleus = reimportedCluster; - // } - // else { - // this.currentNucleus = null; - // List newSiblingsList = new(); - // foreach (Cluster sibling in subCluster.siblingClusters) { - // Cluster reimportedCluster = new(sibling.prefab, this.prefab) { - // name = sibling.name - // }; - // sibling.MoveReceivers(reimportedCluster); - // newSiblingsList.Add(reimportedCluster); - // // make the first reimportedCluster the new current nucleus - // this.currentNucleus ??= reimportedCluster; - // } - // Cluster[] newSiblings = newSiblingsList.ToArray(); - // foreach (Cluster sibling in newSiblings) - // sibling.siblingClusters = newSiblings; - // } - // } - - int selectedConnectNucleus = -1; - // Connect to another nucleus - protected virtual bool ConnectNucleus(ClusterPrefab cluster, Nucleus nucleusToConnect) { - if (cluster == null) - return false; - - Neuron currentNeuron = this.currentNucleus as Neuron; - IEnumerable synapseNuclei = currentNeuron.synapses - .Where(synapse => synapse.neuron != null) - .Select(synapse => synapse.neuron); - - IEnumerable nuclei = cluster.cluster.nuclei - .Except(synapseNuclei); - IEnumerable nucleiNames = nuclei - .Select(n => { - int idx = n.name.IndexOf(':'); - return idx < 0 ? n.name : n.name[..idx]; - }) - .Distinct(); - - string[] names = nucleiNames.ToArray(); - EditorGUILayout.BeginHorizontal(); - selectedConnectNucleus = EditorGUILayout.Popup(selectedConnectNucleus, names); - bool connecting = GUILayout.Button("Connect", GUILayout.Width(80)); - EditorGUILayout.EndHorizontal(); - if (connecting) { - Nucleus nucleus = nuclei.ElementAt(selectedConnectNucleus); - // if (nucleus is Cluster subCluster) { - // subCluster.AddArrayReceiver(this.currentNucleus); - // } - // else - if (nucleus is Neuron neuron) - neuron.AddReceiver(this.currentNucleus); - this.currentCluster.Refresh(); - } - return connecting; - } - - protected virtual void DeleteNucleus(Nucleus nucleus) { - if (nucleus == null) - return; - - if (nucleus is Neuron neuron) { - foreach (Nucleus receiver in neuron.receivers) { - if (receiver != null) { - this.currentNucleus = receiver; - break; - } - } - } - this.currentCluster.DeleteNucleus(nucleus);//clusterNuclei.Remove(nucleus); - - // this.prefab.nuclei.Remove(nucleus); - // Neuron.Delete(nucleus); - this.prefab.cluster.RefreshOutputs(); - - - this.currentNucleus = this.prefab.cluster.defaultOutput; + bool connecting = GUILayout.Button("Add Output Neuron"); + if (connecting) { + Nucleus newOutput = new Neuron(this.currentCluster, "New Output"); + this.currentCluster.Refresh(); + this.currentNucleus = newOutput; this.selectedOutput = this.currentNucleus; } - - Nucleus.Type selectedType = Nucleus.Type.None; - protected virtual bool AddSynapse(ClusterPrefab cluster, Nucleus nucleus) { - if (cluster == null) - return false; - - EditorGUILayout.BeginHorizontal(); - selectedType = (Nucleus.Type)EditorGUILayout.EnumPopup(selectedType); - bool connecting = GUILayout.Button("Add", GUILayout.Width(80)); - EditorGUILayout.EndHorizontal(); - - if (connecting) { - AddInput(selectedType, this.currentNucleus); - } - return connecting; - } - - protected virtual void ChangeSynapse(Synapse synapse, Neuron newNucleus) { - Neuron synapseNeuron = synapse.neuron; - if (synapse.neuron.parent is Cluster subCluster && subCluster.prefab != this.prefab) { - // if (synapse.neuron.parent is ClusterReceptor receptor) { - // // the new nucleus is part of a (cluster) receptor, - // // so we have to change all synapses to this nucleus array elements - // int oldNucleusIx = Cluster.GetNucleusIndex(subCluster.clusterNuclei, synapse.neuron); - // int newNucleusIx = Cluster.GetNucleusIndex(subCluster.clusterNuclei, newNucleus); - // foreach (Nucleus element in receptor.nucleiArray) { - // if (element is not ClusterReceptor clusterReceptor) - // continue; - // // Get the same neuron as the synapse.nucleus in a different element - // // of the ClusterReceptor array - // Nucleus oldElementNucleus = clusterReceptor.clusterNuclei[oldNucleusIx]; - // if (oldElementNucleus is not Neuron oldElementNeuron) - // continue; - // // Get the same neuron as newNucleus in a different element - // // of the ClusterReceptor array - // Nucleus newElementNucleus = clusterReceptor.clusterNuclei[newNucleusIx]; - // if (newElementNucleus is not Neuron newElementNeuron) - // continue; - - // oldElementNeuron.RemoveReceiver(this.currentNucleus); - // newElementNeuron.AddReceiver(this.currentNucleus); - // // Now find the synapse which pointed to the old Neuron - // // Synapse synapseForUpdate = this.currentNucleus.GetSynapse(oldElementNeuron); - // // synapseForUpdate.nucleus = newElementNeuron; - // } - // } - // else { - // it is a neuron in a subcluster - synapseNeuron.RemoveReceiver(this.currentNucleus); - newNucleus.AddReceiver(this.currentNucleus); - // } - } - else { - synapseNeuron.RemoveReceiver(this.currentNucleus); - newNucleus.AddReceiver(this.currentNucleus); - } - } - - #endregion Synapses - - #endregion Inspector } + + protected void MemoryCellInspector(MemoryCell memoryCell, ref bool anythingChanged) { + //memoryCell.staticMemory = EditorGUILayout.Toggle("Static Memory", memoryCell.staticMemory); + NucleusInspector(memoryCell, ref anythingChanged); + } + + protected void ClusterInspector(Cluster cluster, ref bool anythingChanged) { + EditorGUILayout.BeginHorizontal(); + + int instanceCount = cluster.instanceCount; + if (instanceCount <= 1) { + if (cluster.instances != null && cluster.instances.Length > 1) + instanceCount = cluster.instances.Count(); + else + instanceCount = 1; + } + EditorGUILayout.IntField("Instances", instanceCount, GUILayout.MinWidth(150)); + + if (GUILayout.Button("Add")) { + Undo.RecordObject(clusterPrefab, "Array add " + clusterPrefab.name); + cluster.AddInstance(); + anythingChanged = true; + } + if (GUILayout.Button("Del")) { + Undo.RecordObject(clusterPrefab, "Array delete " + clusterPrefab.name); + cluster.RemoveInstance(); + anythingChanged = true; + } + EditorGUILayout.EndHorizontal(); + + // if (GUILayout.Button("Reimport Cluster")) + // ReimportCluster(cluster); + } + + protected void NucleusInspector(Nucleus nucleus, ref bool anythingChanged) { + SynapsesInspector(ref anythingChanged); + ActivationInspector(ref anythingChanged); + + EditorGUILayout.Space(); + breakOnWake = EditorGUILayout.Toggle("Break on wake", breakOnWake); + if (breakOnWake && this.currentNucleus is Neuron currentNeuron) { + if (currentNeuron.isSleeping == false) + Debug.Break(); + // trace = EditorGUILayout.Toggle("Trace", trace); + // currentNeuron.trace = trace; + } + } + + protected void SynapsesInspector(ref bool anythingChanged) { + EditorGUI.indentLevel++; + //showSynapses = EditorGUILayout.BeginFoldoutHeaderGroup(showSynapses, "Synapses"); + showSynapses = EditorGUILayout.Foldout(showSynapses, "Synapses"); + if (showSynapses) { + if (this.currentNucleus is Neuron neuron2) { + Neuron.CombinatorType newCombinator = (Neuron.CombinatorType)EditorGUILayout.EnumPopup("Combinator", neuron2.combinator); + anythingChanged |= newCombinator != neuron2.combinator; + neuron2.combinator = newCombinator; + + EditorGUIUtility.wideMode = true; + float previousLabelWidth = EditorGUIUtility.labelWidth; + EditorGUIUtility.labelWidth = 100; + + Vector3 newBias = EditorGUILayout.Vector3Field("Bias", neuron2.bias); + if (newBias != neuron2.bias) { + anythingChanged |= newBias != neuron2.bias; + neuron2.bias = newBias; + } + EditorGUIUtility.labelWidth = previousLabelWidth; + } + + Nucleus[] array = null; + int elementIx = -1; + if (this.currentNucleus is Neuron currentNeuron && currentNeuron.synapses.Count > 0) { + Synapse[] synapses = currentNeuron.synapses.ToArray(); + foreach (Synapse synapse in synapses) { + if (synapse.neuron == null) + continue; + + if (array != null) { + if (synapse.neuron.parent is Cluster iCluster && elementIx > 0) { + int thisElementIx = Cluster.GetNucleusIndex(iCluster.nuclei, synapse.neuron); + if (thisElementIx == elementIx) + continue; + else + elementIx = thisElementIx; + } + if (array.Contains(synapse.neuron)) + continue; + else if (array.Contains(synapse.neuron.parent)) + continue; + } + else { + if (synapse.neuron.parent is Cluster iReceptor) { + array = iReceptor.instances; + if (iReceptor is Cluster iCluster) + elementIx = Cluster.GetNucleusIndex(iCluster.nuclei, synapse.neuron); + } + } + + EditorGUILayout.Space(); + + if (Application.isPlaying) { + if (synapse.neuron is Neuron synapseNeuron) { + Vector3 value = synapseNeuron.outputValue * synapse.weight; + GUIContent synapseValueLabel = new(synapse.neuron.name, synapseNeuron.outputValue.ToString()); + EditorGUILayout.FloatField(synapseValueLabel, synapseNeuron.outputMagnitude); + } + } + else { + float indentPx = EditorGUI.indentLevel * EditorGUIUtility.singleLineHeight; + EditorGUILayout.BeginHorizontal(); + GUILayout.Space(indentPx); + + if (synapse.neuron.parent != this.currentNucleus.parent) { + // If it is a different cluster + GUIStyle labelStyle = new(GUI.skin.label); + float labelWidth = 200; + if (synapse.neuron.parent != null) { + labelWidth = labelStyle.CalcSize(new GUIContent($"{synapse.neuron.parent.name}.")).x; + GUILayout.Label($"{synapse.neuron.parent.name}", GUILayout.Width(labelWidth)); + } + string[] options = synapse.neuron.parent.nuclei.Select(n => n.name).ToArray(); + int selectedIndex = System.Array.IndexOf(options, synapse.neuron.name); + int newIndex = EditorGUILayout.Popup(selectedIndex, options); + if (newIndex != selectedIndex) { + Neuron newNeuron = synapse.neuron.parent.nuclei[newIndex] as Neuron; + ChangeSynapse(synapse, newNeuron); + } + } + else + GUILayout.Label(synapse.neuron.name); + + bool disconnecting = GUILayout.Button("Disconnect", GUILayout.Width(80)); + if (disconnecting) { + synapse.neuron.RemoveReceiver(this.currentNucleus); + this.currentCluster.Refresh(); + anythingChanged = true; + } + EditorGUILayout.EndHorizontal(); + } + + EditorGUI.indentLevel++; + float newWeight = EditorGUILayout.FloatField("Weight", synapse.weight); + if (newWeight != synapse.weight) { + synapse.weight = newWeight; + anythingChanged = true; + } + EditorGUI.indentLevel--; + } + } + + EditorGUILayout.Space(); + anythingChanged |= ConnectNucleus(this.clusterPrefab, this.currentNucleus); + anythingChanged |= AddSynapse(this.clusterPrefab, this.currentNucleus); + } + //EditorGUILayout.EndFoldoutHeaderGroup(); + EditorGUI.indentLevel--; + + } + + protected void ActivationInspector(ref bool anythingChanged) { + EditorGUILayout.Space(); + EditorGUI.indentLevel++; + showActivation = EditorGUILayout.Foldout(showActivation, "Activation"); + if (showActivation) { + if (this.currentNucleus is Neuron neuron) { + if (this.currentNucleus is not MemoryCell) { + EditorGUILayout.BeginHorizontal(); + EditorGUILayout.LabelField("Activation Curve", GUILayout.MinWidth(60)); + if (neuron.curveMax > 0) + EditorGUILayout.CurveField(neuron.curve, Color.cyan, new Rect(0, 0, 1, neuron.curveMax), GUILayout.Width(40)); + else + EditorGUILayout.CurveField(neuron.curve, Color.cyan, new Rect(0, neuron.curveMax, 1, -neuron.curveMax), GUILayout.Width(40)); + Neuron.ActivationType newPreset = (Neuron.ActivationType)EditorGUILayout.EnumPopup(neuron.activator, GUILayout.MinWidth(50)); + anythingChanged |= newPreset != neuron.activator; + neuron.activator = newPreset; + EditorGUILayout.EndHorizontal(); + } + // if (neuron is Receptor receptor2) { + // if (receptor2.nucleiArray == null || receptor2.nucleiArray.Count() == 0) + // receptor2.array = new NucleusArray(neuron); + // } + } + + EditorGUILayout.Space(); + } + //EditorGUILayout.EndFoldoutHeaderGroup(); + EditorGUI.indentLevel--; + } + + #region Synapses + + protected virtual void AddInput(Nucleus.Type selectedType, Nucleus nucleus) { + switch (selectedType) { + case Nucleus.Type.Neuron: + AddNeuronInput(nucleus); + break; + case Nucleus.Type.MemoryCell: + AddMemoryCellInput(nucleus); + break; + case Nucleus.Type.Cluster: + AddClusterInput(nucleus); + break; + default: + break; + } + } + + protected virtual void AddNeuronInput(Nucleus nucleus) { + Neuron newNeuron = new(this.currentCluster, "New Neuron"); + //Neuron newNeuroid = new(this.prefab.cluster, "New neuron"); + newNeuron.AddReceiver(nucleus); + this.currentNucleus = newNeuron; + } + + protected virtual void AddMemoryCellInput(Nucleus nucleus) { + MemoryCell newMemory = new(this.clusterPrefab.cluster, "New memory cell"); + newMemory.AddReceiver(nucleus); + this.currentNucleus = newMemory; + } + + protected virtual void AddClusterInput(Nucleus nucleus) { + ClusterPickerWindow.ShowPicker(brain => OnClusterPicked(nucleus, brain), "Select Cluster"); + } + private void OnClusterPicked(Nucleus nucleus, ClusterPrefab selectedPrefab) { + Cluster subclusterInstance = new(selectedPrefab, this.currentCluster); + subclusterInstance.defaultOutput.AddReceiver(nucleus); + } + + // private void ReimportCluster(Cluster subCluster) { + // if (subCluster.siblingClusters == null || subCluster.siblingClusters.Length <= 0) { + // Cluster reimportedCluster = new(subCluster.prefab, this.prefab); + // subCluster.MoveReceivers(reimportedCluster); + // // subcluster should be garbage now... + // this.currentNucleus = reimportedCluster; + // } + // else { + // this.currentNucleus = null; + // List newSiblingsList = new(); + // foreach (Cluster sibling in subCluster.siblingClusters) { + // Cluster reimportedCluster = new(sibling.prefab, this.prefab) { + // name = sibling.name + // }; + // sibling.MoveReceivers(reimportedCluster); + // newSiblingsList.Add(reimportedCluster); + // // make the first reimportedCluster the new current nucleus + // this.currentNucleus ??= reimportedCluster; + // } + // Cluster[] newSiblings = newSiblingsList.ToArray(); + // foreach (Cluster sibling in newSiblings) + // sibling.siblingClusters = newSiblings; + // } + // } + + int selectedConnectNucleus = -1; + // Connect to another nucleus + protected virtual bool ConnectNucleus(ClusterPrefab cluster, Nucleus nucleusToConnect) { + if (cluster == null) + return false; + + Neuron currentNeuron = this.currentNucleus as Neuron; + IEnumerable synapseNuclei = currentNeuron.synapses + .Where(synapse => synapse.neuron != null) + .Select(synapse => synapse.neuron); + + IEnumerable nuclei = cluster.cluster.nuclei + .Except(synapseNuclei); + IEnumerable nucleiNames = nuclei + .Select(n => { + int idx = n.name.IndexOf(':'); + return idx < 0 ? n.name : n.name[..idx]; + }) + .Distinct(); + + string[] names = nucleiNames.ToArray(); + EditorGUILayout.BeginHorizontal(); + selectedConnectNucleus = EditorGUILayout.Popup(selectedConnectNucleus, names); + bool connecting = GUILayout.Button("Connect", GUILayout.Width(80)); + EditorGUILayout.EndHorizontal(); + if (connecting) { + Nucleus nucleus = nuclei.ElementAt(selectedConnectNucleus); + // if (nucleus is Cluster subCluster) { + // subCluster.AddArrayReceiver(this.currentNucleus); + // } + // else + if (nucleus is Neuron neuron) + neuron.AddReceiver(this.currentNucleus); + this.currentCluster.Refresh(); + } + return connecting; + } + + protected virtual void DeleteNucleus(Nucleus nucleus) { + if (nucleus == null) + return; + + if (nucleus is Neuron neuron) { + foreach (Nucleus receiver in neuron.receivers) { + if (receiver != null) { + this.currentNucleus = receiver; + break; + } + } + } + this.currentCluster.DeleteNucleus(nucleus);//clusterNuclei.Remove(nucleus); + + // this.prefab.nuclei.Remove(nucleus); + // Neuron.Delete(nucleus); + this.clusterPrefab.cluster.RefreshOutputs(); + + + this.currentNucleus = this.clusterPrefab.cluster.defaultOutput; + this.selectedOutput = this.currentNucleus; + } + + Nucleus.Type selectedType = Nucleus.Type.None; + protected virtual bool AddSynapse(ClusterPrefab cluster, Nucleus nucleus) { + if (cluster == null) + return false; + + EditorGUILayout.BeginHorizontal(); + selectedType = (Nucleus.Type)EditorGUILayout.EnumPopup(selectedType); + bool connecting = GUILayout.Button("Add", GUILayout.Width(80)); + EditorGUILayout.EndHorizontal(); + + if (connecting) { + AddInput(selectedType, this.currentNucleus); + } + return connecting; + } + + protected virtual void ChangeSynapse(Synapse synapse, Neuron newNucleus) { + Neuron synapseNeuron = synapse.neuron; + if (synapse.neuron.parent is Cluster subCluster && subCluster.prefab != this.clusterPrefab) { + // if (synapse.neuron.parent is ClusterReceptor receptor) { + // // the new nucleus is part of a (cluster) receptor, + // // so we have to change all synapses to this nucleus array elements + // int oldNucleusIx = Cluster.GetNucleusIndex(subCluster.clusterNuclei, synapse.neuron); + // int newNucleusIx = Cluster.GetNucleusIndex(subCluster.clusterNuclei, newNucleus); + // foreach (Nucleus element in receptor.nucleiArray) { + // if (element is not ClusterReceptor clusterReceptor) + // continue; + // // Get the same neuron as the synapse.nucleus in a different element + // // of the ClusterReceptor array + // Nucleus oldElementNucleus = clusterReceptor.clusterNuclei[oldNucleusIx]; + // if (oldElementNucleus is not Neuron oldElementNeuron) + // continue; + // // Get the same neuron as newNucleus in a different element + // // of the ClusterReceptor array + // Nucleus newElementNucleus = clusterReceptor.clusterNuclei[newNucleusIx]; + // if (newElementNucleus is not Neuron newElementNeuron) + // continue; + + // oldElementNeuron.RemoveReceiver(this.currentNucleus); + // newElementNeuron.AddReceiver(this.currentNucleus); + // // Now find the synapse which pointed to the old Neuron + // // Synapse synapseForUpdate = this.currentNucleus.GetSynapse(oldElementNeuron); + // // synapseForUpdate.nucleus = newElementNeuron; + // } + // } + // else { + // it is a neuron in a subcluster + synapseNeuron.RemoveReceiver(this.currentNucleus); + newNucleus.AddReceiver(this.currentNucleus); + // } + } + else { + synapseNeuron.RemoveReceiver(this.currentNucleus); + newNucleus.AddReceiver(this.currentNucleus); + } + } + + #endregion Synapses + + #endregion Inspector + /* + } + */ } + } \ No newline at end of file diff --git a/Editor/ClusterPrefab_Drawer.cs b/Editor/ClusterPrefab_Drawer.cs index a75c747..05dc396 100644 --- a/Editor/ClusterPrefab_Drawer.cs +++ b/Editor/ClusterPrefab_Drawer.cs @@ -1,7 +1,10 @@ +using System.Linq; using System.Collections.Generic; using UnityEngine; using UnityEngine.UIElements; using UnityEditor; +using System; +using System.Reflection; namespace NanoBrain.Unity { @@ -58,7 +61,6 @@ namespace NanoBrain.Unity { // content rect below header Rect drawRect = new(fieldRect.x, headerRect.yMax + 2f, fieldRect.width, 450f); - // IMGUIContainer should be inserted here ClusterView.Render(drawRect, prefab.cluster, property); } } @@ -86,4 +88,62 @@ namespace NanoBrain.Unity { // } } +/* + [InitializeOnLoad] + static class ClusterPrefabInspectorRepaints { + static ClusterPrefabInspectorRepaints() { + EditorApplication.update += OnEditorUpdate; + } + + static double lastRepaint = 0; + const double repaintInterval = 1.0 / 15.0; // up to 15 FPS in inspector + + static void OnEditorUpdate() { + if (!Application.isPlaying) return; + + // throttle repaint frequency + if (EditorApplication.timeSinceStartup - lastRepaint < repaintInterval) return; + lastRepaint = EditorApplication.timeSinceStartup; + + // Find all open inspectors (Editors) that target objects containing ClusterPrefab fields + var editors = Resources.FindObjectsOfTypeAll(); + foreach (var ed in editors) { + var targets = ed.targets; + if (targets == null) + continue; + bool shouldRepaint = targets.Any(t => ObjectHasClusterPrefabField(t)); + if (shouldRepaint) { + try { + ed.Repaint(); + } + catch { + // ignore + } + } + } + } + + static bool ObjectHasClusterPrefabField(UnityEngine.Object obj) { + if (obj == null) + return false; + Type type = obj.GetType(); + // search fields (instance, non-public/public) + FieldInfo[] fields = type.GetFields(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic); + foreach (FieldInfo f in fields) { + if (f.FieldType == typeof(ClusterPrefab)) + return true; + // also handle arrays/lists of ClusterPrefab or serializable wrappers: + if (f.FieldType.IsArray && f.FieldType.GetElementType() == typeof(ClusterPrefab)) + return true; + if (f.FieldType.IsGenericType) { + Type[] gen = f.FieldType.GetGenericArguments(); + if (gen.Length == 1 && gen[0] == typeof(ClusterPrefab)) + return true; + } + } + return false; + } + } +*/ + } \ No newline at end of file diff --git a/Editor/ClusterView.cs b/Editor/ClusterView.cs index 751e223..6424483 100644 --- a/Editor/ClusterView.cs +++ b/Editor/ClusterView.cs @@ -9,32 +9,46 @@ namespace NanoBrain.Unity { private static readonly float discRadius = 20; static readonly Dictionary viewStates = new(); - private static ClusterView GetClusterView(SerializedProperty property) { + public static ClusterView GetClusterView(SerializedProperty property) { string key = property.propertyPath + "_" + property.serializedObject.targetObject.GetEntityId(); if (!viewStates.TryGetValue(key, out ClusterView state)) state = new() { key = key }; return state; } + public static ClusterView GetClusterView(SerializedObject serializedObject) { + string key = serializedObject.targetObject.GetEntityId().ToString(); + if (!viewStates.TryGetValue(key, out ClusterView state)) + state = new() { key = key }; + return state; + } - private static void UpdateViewState(ClusterView viewState) { - viewStates[viewState.key] = viewState; + private void UpdateViewState() { + viewStates[this.key] = this; } public static void Render(Rect drawRect, Cluster cluster, SerializedProperty property) { ClusterView clusterView = GetClusterView(property); clusterView.currentCluster ??= cluster; + clusterView.Render(drawRect); + } + public static void Render(Rect drawRect, Cluster cluster, SerializedObject obj) { + ClusterView clusterView = GetClusterView(obj); + clusterView.currentCluster ??= cluster; + clusterView.Render(drawRect); + } + public void Render(Rect drawRect) { // background EditorGUI.DrawRect(drawRect, Color.black); const float contentWidth = 1000f; - Rect contentRect = new Rect(0f, 0f, contentWidth, drawRect.height); + Rect contentRect = new(0f, 0f, contentWidth, drawRect.height - 20); // Begin horizontal-only scroll view - clusterView.scrollPos = GUI.BeginScrollView(drawRect, clusterView.scrollPos, contentRect, true, false); + this.scrollPos = GUI.BeginScrollView(drawRect, this.scrollPos, contentRect, false, false); // Local content group: draw GUI content using content-local coords (0..contentWidth) - GUI.BeginGroup(new Rect(-clusterView.scrollPos.x, 0f, contentWidth, drawRect.height)); + GUI.BeginGroup(new Rect(-this.scrollPos.x, 0f, contentWidth, drawRect.height)); EditorGUI.DrawRect(new Rect(0f, 0f, contentWidth, drawRect.height), new Color(0.08f, 0.08f, 0.08f, 1f)); GUI.EndGroup(); GUI.EndScrollView(); @@ -43,16 +57,16 @@ namespace NanoBrain.Unity { GUI.BeginGroup(drawRect); // Inner group positions content origin so local coords match content space and respect scroll - GUI.BeginGroup(new Rect(-clusterView.scrollPos.x, 0f, contentWidth, drawRect.height)); + GUI.BeginGroup(new Rect(-this.scrollPos.x, 0f, contentWidth, drawRect.height)); Handles.BeginGUI(); - clusterView.DrawFocusGraph(); + this.DrawFocusGraph(); Handles.EndGUI(); GUI.EndGroup(); // end inner group GUI.EndGroup(); // end clipping group - UpdateViewState(clusterView); + UpdateViewState(); } public string key = null; diff --git a/Editor/ClusterViewer.cs b/Editor/ClusterViewer.cs index 1396890..d7da0a9 100644 --- a/Editor/ClusterViewer.cs +++ b/Editor/ClusterViewer.cs @@ -73,7 +73,6 @@ namespace NanoBrain.Unity { Add(scrollView); Add(topMenuContainer); - // Subscribe when added to panel (editor UI ready) RegisterCallback(evt => Subscribe()); RegisterCallback(evt => Unsubscribe()); @@ -83,16 +82,18 @@ namespace NanoBrain.Unity { this.mode = (Mode)changeEvent.newValue; } - bool subscribed = false; + private bool subscribed = false; void Subscribe() { - if (subscribed) return; + if (subscribed) + return; SceneView.duringSceneGui += OnSceneGUI; subscribed = true; SceneView.RepaintAll(); } void Unsubscribe() { - if (!subscribed) return; + if (!subscribed) + return; SceneView.duringSceneGui -= OnSceneGUI; subscribed = false; } @@ -123,725 +124,736 @@ namespace NanoBrain.Unity { } public void OnIMGUI() { + var rect = graphContainer.layout; // container local size + int id = GUIUtility.GetControlID(123456, FocusType.Passive, new Rect(0, 0, rect.width, rect.height)); + + //int id = GUIUtility.GetControlID(FocusType.Passive); + if (Application.isPlaying == false && serializedBrain != null) serializedBrain.Update(); - Handles.BeginGUI(); - DrawGraph(); - Handles.EndGUI(); + Rect r = new Rect(0, 0, rect.width, rect.height); + ClusterView.Render(r, currentCluster, serializedBrain); + // ClusterView clusterView = ClusterView.GetClusterView(serializedBrain); + // clusterView.currentCluster ??= currentCluster; + // clusterView.DrawGraph(id); + + // Handles.BeginGUI(); + // DrawGraph(); + // Handles.EndGUI(); } #region Graph - public virtual void DrawGraph() { - if (mode == Mode.Focus) - DrawFocusGraph(); - else - DrawFullGraph(); - } + // public virtual void DrawGraph() { + // if (mode == Mode.Focus) + // DrawFocusGraph(); + // // else + // // DrawFullGraph(); + // } #region Full Graph + /* + protected void DrawFullGraph() { + //Dag dag = GenerateGraph(this.prefab); + Dag dag = GenerateGraph(this.selectedOutput); + Dag.ComputeLayout(dag); + // Draw edges + foreach (Dag.Edge e in dag.edges) { + Dag.Node from = dag.nodes.FirstOrDefault(x => x.id == e.fromId); + Dag.Node to = dag.nodes.FirstOrDefault(x => x.id == e.toId); + if (from == null || to == null) + continue; - protected void DrawFullGraph() { - //Dag dag = GenerateGraph(this.prefab); - Dag dag = GenerateGraph(this.selectedOutput); - Dag.ComputeLayout(dag); - // Draw edges - foreach (Dag.Edge e in dag.edges) { - Dag.Node from = dag.nodes.FirstOrDefault(x => x.id == e.fromId); - Dag.Node to = dag.nodes.FirstOrDefault(x => x.id == e.toId); - if (from == null || to == null) - continue; + Vector2 fromPosition = from.position; + Vector2 toPosition = to.position; + DrawEdge(fromPosition, toPosition); + } - Vector2 fromPosition = from.position; - Vector2 toPosition = to.position; - DrawEdge(fromPosition, toPosition); - } + // Draw nodes + foreach (Dag.Node n in dag.nodes) + DrawNucleus(n.nucleus, n.position, 1, n.radius); - // Draw nodes - foreach (Dag.Node n in dag.nodes) - DrawNucleus(n.nucleus, n.position, 1, n.radius); + // Determine graph width + float width = 0; + float currentNucleusPosition = 0; + foreach (Dag.Node node in dag.nodes) { + if (node.position.x > width) + width = node.position.x; + if (node.nucleus == currentNucleus) + currentNucleusPosition = node.position.x; + } - // Determine graph width - float width = 0; - float currentNucleusPosition = 0; - foreach (Dag.Node node in dag.nodes) { - if (node.position.x > width) - width = node.position.x; - if (node.nucleus == currentNucleus) - currentNucleusPosition = node.position.x; - } + // Resize the graph container to the full graph width + float margin = 50f; + graphContainer.style.width = width + 2 * margin; - // Resize the graph container to the full graph width - float margin = 50f; - graphContainer.style.width = width + 2 * margin; + // Scroll to the current nucleus + float viewportWidth = scrollView.layout.width; + // center currentNucleus in viewport + float desiredScrollX = currentNucleusPosition - viewportWidth * 0.5f; + // clamp between 0 and maximum scrollable range + float maxScrollX = Mathf.Max(0f, graphContainer.resolvedStyle.width - viewportWidth); + desiredScrollX = Mathf.Clamp(desiredScrollX, 0f, maxScrollX); - // Scroll to the current nucleus - float viewportWidth = scrollView.layout.width; - // center currentNucleus in viewport - float desiredScrollX = currentNucleusPosition - viewportWidth * 0.5f; - // clamp between 0 and maximum scrollable range - float maxScrollX = Mathf.Max(0f, graphContainer.resolvedStyle.width - viewportWidth); - desiredScrollX = Mathf.Clamp(desiredScrollX, 0f, maxScrollX); + Vector2 current = scrollView.scrollOffset; + scrollView.scrollOffset = new Vector2(desiredScrollX, current.y); + } - Vector2 current = scrollView.scrollOffset; - scrollView.scrollOffset = new Vector2(desiredScrollX, current.y); - } + public Dag GenerateGraph(Nucleus rootNucleus) { + Dag dag = new(); + if (rootNucleus == null) + return dag; - public Dag GenerateGraph(Nucleus rootNucleus) { - Dag dag = new(); - if (rootNucleus == null) - return dag; - - int ix = 0; - Dag.Node receiver = new() { - id = ix, - //title = nucleus.name, - nucleus = rootNucleus - }; - dag.nodes.Add(receiver); - ix++; - DescendGraph(receiver, ref ix, dag); - return dag; - } - - private void DescendGraph(Dag.Node receiver, ref int ix, Dag dag) { - Neuron receiverNeuron = receiver.nucleus as Neuron; - foreach (Synapse synapse in receiverNeuron.synapses) { - Nucleus nucleus = synapse.neuron; - if (nucleus.parent != null && nucleus.parent != currentNucleus.parent) { - nucleus = nucleus.parent; - } - string nucleusName = nucleus.name; - Dag.Node synapseNode = dag.FindNode(nucleusName); - if (synapseNode == null) { - synapseNode = new() { - id = ix, - nucleus = nucleus - }; - dag.nodes.Add(synapseNode); - } - Dag.Edge edge = new() { - fromId = synapseNode.id, - toId = receiver.id - }; - dag.edges.Add(edge); - ix++; - DescendGraph(synapseNode, ref ix, dag); - } - } + int ix = 0; + Dag.Node receiver = new() { + id = ix, + //title = nucleus.name, + nucleus = rootNucleus + }; + dag.nodes.Add(receiver); + ix++; + DescendGraph(receiver, ref ix, dag); + return dag; + } + private void DescendGraph(Dag.Node receiver, ref int ix, Dag dag) { + Neuron receiverNeuron = receiver.nucleus as Neuron; + foreach (Synapse synapse in receiverNeuron.synapses) { + Nucleus nucleus = synapse.neuron; + if (nucleus.parent != null && nucleus.parent != currentNucleus.parent) { + nucleus = nucleus.parent; + } + string nucleusName = nucleus.name; + Dag.Node synapseNode = dag.FindNode(nucleusName); + if (synapseNode == null) { + synapseNode = new() { + id = ix, + nucleus = nucleus + }; + dag.nodes.Add(synapseNode); + } + Dag.Edge edge = new() { + fromId = synapseNode.id, + toId = receiver.id + }; + dag.edges.Add(edge); + ix++; + DescendGraph(synapseNode, ref ix, dag); + } + } + */ #endregion Full Graph #region Focus Graph + /* + protected void DrawFocusGraph() { + float size = 20; + Vector3 position = new(150, 210, 0); - protected void DrawFocusGraph() { - float size = 20; - Vector3 position = new(150, 210, 0); + if (this.currentNucleus != null) { + DrawReceivers(this.currentNucleus, position, size); + DrawSynapses(this.currentNucleus, position, size); - if (this.currentNucleus != null) { - DrawReceivers(this.currentNucleus, position, size); - DrawSynapses(this.currentNucleus, position, size); + // Draw selected Nucleus + if (expandArray) { + float maxValue = 1; - // Draw selected Nucleus - if (expandArray) { - float maxValue = 1; + if (this.currentNucleus is Cluster cluster) { + float spacing = 400f / cluster.instanceCount; + float margin = 10 + spacing / 2; + float xMin = 150 - size; + float xMax = 150 + size; + float yMin = 10 + margin - size / 2; + float yMax = 400 - margin + size; + Vector3[] verts = new Vector3[4] { + new(xMin, yMin, 0), + new(xMax, yMin, 0), + new(xMax, yMax, 0), + new(xMin, yMax, 0) + }; + Handles.color = Color.black; + Handles.DrawAAConvexPolygon(verts); + int row = 0; + if (cluster.instances == null) { + Vector3 pos = new(150, margin + row * spacing, 0.0f); + Handles.color = Color.white; + // The selected sibling highlight ring + Handles.DrawSolidDisc(pos, Vector3.forward, size + 2); + DrawNucleus(cluster, pos, maxValue, size); + row++; + } + else { + foreach (Cluster sibling in cluster.instances) { + Vector3 pos = new(150, margin + row * spacing, 0.0f); + Handles.color = Color.white; + // The selected sibling highlight ring + Handles.DrawSolidDisc(pos, Vector3.forward, size + 2); + DrawNucleus(sibling, pos, maxValue, size); + row++; + } + } + GUIStyle style = new(EditorStyles.label) { + alignment = TextAnchor.UpperCenter, + normal = { textColor = Color.white }, + fontStyle = FontStyle.Bold, + }; + Vector3 labelPos = new(150, yMax + size + 5, 0); + string clusterName = cluster.name; + int colonPos = clusterName.IndexOf(":"); + if (colonPos > 0) { + string baseName = clusterName[..colonPos]; + Handles.Label(labelPos, baseName, style); + } + else + Handles.Label(labelPos, clusterName, style); + } + else { + if (this.currentNucleus is Neuron neuron) + maxValue = neuron.outputMagnitude; - if (this.currentNucleus is Cluster cluster) { - float spacing = 400f / cluster.instanceCount; - float margin = 10 + spacing / 2; - float xMin = 150 - size; - float xMax = 150 + size; - float yMin = 10 + margin - size / 2; - float yMax = 400 - margin + size; - Vector3[] verts = new Vector3[4] { - new(xMin, yMin, 0), - new(xMax, yMin, 0), - new(xMax, yMax, 0), - new(xMin, yMax, 0) - }; - Handles.color = Color.black; - Handles.DrawAAConvexPolygon(verts); - int row = 0; - if (cluster.instances == null) { - Vector3 pos = new(150, margin + row * spacing, 0.0f); - Handles.color = Color.white; - // The selected sibling highlight ring - Handles.DrawSolidDisc(pos, Vector3.forward, size + 2); - DrawNucleus(cluster, pos, maxValue, size); - row++; - } - else { - foreach (Cluster sibling in cluster.instances) { - Vector3 pos = new(150, margin + row * spacing, 0.0f); - Handles.color = Color.white; - // The selected sibling highlight ring - Handles.DrawSolidDisc(pos, Vector3.forward, size + 2); - DrawNucleus(sibling, pos, maxValue, size); - row++; + DrawNucleus(this.currentNucleus, position, maxValue, 20); + + } + } + else { + float maxValue = 1; + if (this.currentNucleus is Neuron neuron) + maxValue = neuron.outputMagnitude; + else if (this.currentNucleus is Cluster cluster) + maxValue = cluster.defaultOutput.outputMagnitude; + DrawNucleus(this.currentNucleus, position, maxValue, 20); } } + else { + DrawAllOutputs(position, size); + DrawOutputs(position, size); + } + graphContainer.style.width = 300; + } + + protected void DrawReceivers(Nucleus nucleus, Vector3 parentPos, float size) { + List receivers; + if (nucleus is Neuron neuron) + receivers = neuron.receivers; + else if (nucleus is Cluster cluster) + receivers = cluster.CollectReceivers(true); + else + return; + + // For top-level nodes, add link to previous editor and/or 'Outputs' + int nodeCount = receivers.Count(); + if (nucleus == this.selectedOutput) { + // Add link to 'Outpus' + nodeCount++; + if (ClusterViewer.previousPrefab != null) + // Add link to previous editor + nodeCount++; + } + + // Determine the maximum value in this layer + // This is used to 'scale' the output value colors of the nuclei + float maxValue = 0; + foreach (Nucleus receiver in receivers) { + if (receiver is Neuron neuroid) { + float value = neuroid.outputMagnitude; + if (value > maxValue) + maxValue = value; + } + } + + // Determine the spacing of the nuclei in the layer + float spacing = 400f / nodeCount; + float margin = 10 + spacing / 2; + + int row = 0; + foreach (Nucleus receiver in receivers) { + Nucleus receiverNucleus = receiver; + if (receiverNucleus == null) + continue; + + Vector3 pos = new(50, margin + row * spacing, 0.0f); + DrawEdge(parentPos, pos); + + DrawNucleus(receiverNucleus, pos, maxValue, size); + row++; + } + if (nucleus == this.selectedOutput) { + Vector3 pos = new(50, margin + row * spacing, 0); + if (ClusterViewer.previousPrefab != null) { + DrawEdge(parentPos, pos); + DrawClusterPrefab(ClusterViewer.previousPrefab, pos, size); + row++; + } + pos = new(50, margin + row * spacing, 0); + DrawEdge(parentPos, pos); + DrawAllOutputs(pos, size); + } + } + + protected void DrawSynapses(Nucleus nucleus, Vector3 parentPos, float size) { + if (nucleus is not Neuron neuron) + return; + + if (this.selectedSynapseNeuron != null) { + DrawClusterSynapses(this.selectedSynapseNeuron, parentPos, size); + return; + } + if (nucleus == null) + return; + + // Determine the maximum value in this layer + // This is used to 'scale' the output value colors of the nuclei + float maxValue = 0; + int neuronCount = 0; + List drawnNeuronNames = new(); + foreach (Synapse synapse in neuron.synapses) { + if (synapse.neuron == null) + continue; + + // Count multiple synapses to the same neuron only once + string neuronName = synapse.neuron.name; + if (synapse.neuron.parent != null) + neuronName = synapse.neuron.parent.baseName + "." + neuronName; + + if (drawnNeuronNames.Contains(neuronName)) + continue; + drawnNeuronNames.Add(neuronName); + + float value = synapse.neuron.outputMagnitude * synapse.weight; + if (value > maxValue) + maxValue = value; + + neuronCount++; + } + + // Determine the spacing of the nuclei in the layer + float spacing = 400f / neuronCount; + float margin = 10 + spacing / 2; + + int row = 0; + //List drawnNeurons = new(); + drawnNeuronNames = new(); + foreach (Synapse synapse in neuron.synapses) { + if (synapse.neuron is null) + continue; + + // Draw multiple synapses to the same neuron only once + string neuronName = synapse.neuron.name; + if (synapse.neuron.parent != null) + neuronName = synapse.neuron.parent.baseName + "." + neuronName; + + if (drawnNeuronNames.Contains(neuronName)) + continue; + drawnNeuronNames.Add(neuronName); + + Vector3 pos = new(250, margin + row * spacing, 0.0f); + DrawEdge(parentPos, pos); + // Handles.color = Color.white; + // Handles.DrawLine(parentPos, pos); + Color color = Color.black; + if (Application.isPlaying) { + if (maxValue == 0 || !float.IsFinite(maxValue)) + maxValue = 1; + float brightness = synapse.neuron.outputMagnitude * synapse.weight / maxValue; + color = new Color(brightness, brightness, brightness, 1f); + } + DrawNucleus(synapse.neuron, pos, size, color); + row++; + } + } + + protected void DrawClusterSynapses(Nucleus nucleus, Vector3 parentPos, float size) { + if (nucleus == null || nucleus.parent == null || nucleus.parent.instances == null) + return; + + // Hack to disable showing labels + expandArray = true; + + float maxValue = 0; + foreach (Cluster sibling in nucleus.parent.instances) { + Neuron siblingNeuron = sibling.GetNucleus(nucleus.name) as Neuron; + float value = siblingNeuron.outputMagnitude; // no need to add weight as they are all the same + if (value > maxValue) + maxValue = value; + } + + // Determine the spacing of the nuclei in the layer + float spacing = 400f / nucleus.parent.instanceCount; + float margin = 10 + spacing / 2; + + int row = 0; + foreach (Cluster sibling in nucleus.parent.instances) { + Neuron siblingNeuron = sibling.GetNucleus(nucleus.name) as Neuron; + Vector3 position = new(250, margin + row * spacing, 0.0f); + DrawEdge(parentPos, position); + Color color = Color.black; + if (Application.isPlaying) { + if (maxValue == 0 || !float.IsFinite(maxValue)) + maxValue = 1; + float brightness = siblingNeuron.outputMagnitude / maxValue; + color = new Color(brightness, brightness, brightness, 1f); + } + DrawNucleus(siblingNeuron, position, size, color); + GUIStyle style = new(EditorStyles.label) { + alignment = TextAnchor.UpperCenter, + normal = { textColor = Color.white }, + fontStyle = FontStyle.Bold, + }; + Vector3 labelPos = position - Vector3.down * (size + 5); // below neuron + string name = $"{sibling.baseName}\n{nucleus.name}"; + Handles.Label(labelPos, name, style); + row++; + } + expandArray = false; + } + + protected void DrawOutputs(Vector2 parentPos, float size) { + if (this.currentCluster == null) + return; + + // Determine the maximum value in this layer + // This is used to 'scale' the output value colors of the nuclei + float maxValue = 0; + int neuronCount = 0; + List drawnNuclei = new(); + foreach (Nucleus nucleus in this.currentCluster.outputs) { + if (nucleus is not Neuron neuron) + continue; + + // Draw multiple synapses to the same neuron only once + if (drawnNuclei.Contains(nucleus)) + continue; + drawnNuclei.Add(nucleus); + + float value = neuron.outputMagnitude; + if (value > maxValue) + maxValue = value; + + neuronCount++; + } + + // Determine the spacing of the nuclei in the layer + float spacing = 400f / neuronCount; + float margin = 10 + spacing / 2; + + int row = 0; + drawnNuclei = new(); + foreach (Nucleus nucleus in this.currentCluster.outputs) { + if (nucleus is not Neuron neuron) + continue; + + // Draw multiple synapses to the same neuron only once + if (drawnNuclei.Contains(nucleus)) + continue; + drawnNuclei.Add(nucleus); + + Vector3 pos = new(250, margin + row * spacing, 0.0f); + DrawEdge(parentPos, pos); + Color color = Color.black; + if (Application.isPlaying) { + if (maxValue == 0 || !float.IsFinite(maxValue)) + maxValue = 1; + float brightness = neuron.outputMagnitude / maxValue; + color = new Color(brightness, brightness, brightness, 1f); + } + DrawNucleus(nucleus, pos, size, color); + row++; + } + } + */ + #endregion Focus Graph + /* + protected void DrawNucleus(Nucleus nucleus, Vector3 position, float maxValue, float size) { + Color color; + if (Application.isPlaying) { + float brightness = 0; + if (nucleus is Neuron neuron) + brightness = neuron.outputMagnitude / maxValue; + color = new Color(brightness, brightness, brightness, 1f); + } + else + color = Color.black; + DrawNucleus(nucleus, position, size, color); + } + + protected void DrawNucleus(Nucleus nucleus, Vector3 position, float size, Color color) { + if (nucleus == null) + return; + + if (nucleus == this.currentNucleus) { + // The selected nucleus highlight ring + Handles.color = Color.white; + Handles.DrawSolidDisc(position, Vector3.forward, size + 2); + } + + if (nucleus is MemoryCell) { + Handles.color = Color.white; + Handles.DrawWireDisc(position + Vector3.right * 10, Vector3.forward, size); + } + + Handles.color = color; + Handles.DrawSolidDisc(position, Vector3.forward, size); + + Handles.color = Color.white; + // Position the label in front of the disc + Vector3 labelPosition = position + (Vector3.forward * 0.1f); + GUIStyle style = new(EditorStyles.label) { - alignment = TextAnchor.UpperCenter, + alignment = TextAnchor.MiddleCenter, normal = { textColor = Color.white }, fontStyle = FontStyle.Bold, }; - Vector3 labelPos = new(150, yMax + size + 5, 0); - string clusterName = cluster.name; - int colonPos = clusterName.IndexOf(":"); - if (colonPos > 0) { - string baseName = clusterName[..colonPos]; - Handles.Label(labelPos, baseName, style); + + if (nucleus.parent is Cluster parentCluster && currentNucleus != null && parentCluster != currentNucleus.parent) + DrawCluster(parentCluster, position, color, size); + else if (nucleus is Cluster cluster) + DrawCluster(cluster, position, color, size); + + if (expandArray == false) {// || nucleus != currentNucleus) { + // put name below nucleus + Vector3 labelPos = position - Vector3.down * (size + 5); // below neuron + style.alignment = TextAnchor.UpperCenter; + + if (nucleus.parent != null && currentNucleus != null && nucleus.parent != currentNucleus.parent && nucleus.parent is Cluster parentCluster1) { + // This neuron is part of another cluster + parentCluster1.name ??= ""; + int colonPos = parentCluster1.name.IndexOf(":"); + string baseName; + if (colonPos > 0 && colonPos < parentCluster1.name.Length - 2) + baseName = parentCluster1.name[..colonPos] + "\n"; + else + baseName = parentCluster1.name + "\n"; + Handles.Label(labelPos, baseName + nucleus.name, style); + } + else { + nucleus.name ??= ""; + int colonPos = nucleus.name.IndexOf(":"); + if (colonPos > 0 && colonPos < nucleus.name.Length - 2) { + // if it is an array, we should not show the :0 of the first element + string baseName = nucleus.name[..colonPos]; + Handles.Label(labelPos, baseName, style); + } + else + Handles.Label(labelPos, nucleus.name, style); + } + } + + // Tooltip + Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2); + + int id = GUIUtility.GetControlID(FocusType.Passive); + Event e = Event.current; + EventType et = e.GetTypeForControl(id); + if (e != null && neuronRect.Contains(e.mousePosition)) { + // Process Hover + HandleMouseHover(nucleus, neuronRect); + // Process click + if (e.type == EventType.MouseDown && e.button == 0) { + // Consume the event so the scene doesn't also handle it + e.Use(); + if (nucleus is Cluster parentCluster2) + OnNeuronClick(parentCluster2); + else + OnNeuronClick(nucleus); + } + } + } + + protected void DrawCluster(Cluster cluster, Vector3 position, Color color, float size) { + GUIStyle labelTextStyle = new(EditorStyles.label) { + normal = { textColor = Color.white }, + fontStyle = FontStyle.Bold, + }; + + if (expandArray) { + // Put array indices above the discs + labelTextStyle.alignment = TextAnchor.LowerCenter; + Vector3 labelPosition = position + Vector3.down * (size + 5); // below disc + + // Strip the instance number in the name + int colonPos1 = cluster.name.IndexOf(":"); + if (colonPos1 > 0) { + string extName = cluster.name[(colonPos1 + 2)..]; + Handles.Label(labelPosition, extName, labelTextStyle); + } + else + Handles.Label(labelPosition, "0", labelTextStyle); + } + else { + // Put instance count inside the disc + labelTextStyle.alignment = TextAnchor.MiddleCenter; + Vector3 labelPosition = position + (Vector3.forward * 0.1f); + + // Adjust text color based on disc color + if (color.grayscale > 0.5f) + labelTextStyle.normal.textColor = Color.black; + else + labelTextStyle.normal.textColor = Color.white; + + if (cluster.instanceCount > 1) { + Handles.Label(labelPosition, cluster.instanceCount.ToString(), labelTextStyle); + labelTextStyle.normal.textColor = Color.white; + } + else if (cluster.instances != null && cluster.instances.Length > 1) { + Handles.Label(labelPosition, cluster.instances.Length.ToString(), labelTextStyle); + labelTextStyle.normal.textColor = Color.white; + } + } + + // Draw a circle around the disc to indicate this is a Cluster + Handles.color = Color.white; + Handles.DrawWireDisc(position, Vector3.forward, size + 5); + } + + protected void DrawClusterPrefab(ClusterPrefab prefab, Vector2 position, float size) { + Handles.color = Color.black; + Handles.DrawSolidDisc(position, Vector3.forward, size); + // Draw a circle around the disc to indicate this is a Cluster + Handles.color = Color.white; + Handles.DrawWireDisc(position, Vector3.forward, size + 5); + + // put name below nucleus + GUIStyle style = new(EditorStyles.label) { + alignment = TextAnchor.MiddleCenter, + normal = { textColor = Color.white }, + fontStyle = FontStyle.Bold, + }; + Vector2 labelPos = position - Vector2.down * (size + 5); // below neuron + style.alignment = TextAnchor.UpperCenter; + Handles.Label(labelPos, prefab.name, style); + + Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2); + int id = GUIUtility.GetControlID(FocusType.Passive); + Event e = Event.current; + EventType et = e.GetTypeForControl(id); + if (e != null && neuronRect.Contains(e.mousePosition)) { + // Process click + if (e.type == EventType.MouseDown && e.button == 0) { + // Consume the event so the scene doesn't also handle it + e.Use(); + Selection.activeObject = prefab; + EditorGUIUtility.PingObject(prefab); + ClusterViewer.previousPrefab = null; + CreateEditor(prefab); + } + } + } + + protected void DrawAllOutputs(Vector2 position, float size) { + GUIStyle labelTextStyle = new(EditorStyles.label) { + normal = { textColor = Color.white }, + fontStyle = FontStyle.Bold, + alignment = TextAnchor.MiddleCenter, + }; + Handles.Label(position, "Outputs", labelTextStyle); + + Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2); + Event e = Event.current; + if (e != null && neuronRect.Contains(e.mousePosition)) { + // Process click + if (e.type == EventType.MouseDown && e.button == 0) { + // Consume the event so the scene doesn't also handle it + e.Use(); + OnAllOutputsClick(); + } + } + + } + + protected void DrawEdge(Vector2 from, Vector2 to, float radius = 20) { + Handles.color = Color.white; + // Handles.DrawLine(from, to); + + Vector2 dir = to - from; + float len = dir.magnitude; + if (len <= 2f * radius || len <= Mathf.Epsilon) + // line too short + return; + + Vector2 n = dir / len; // normalized + Vector2 a = from + n * radius; + Vector2 b = to - n * radius; + Handles.DrawLine(a, b); + } + + protected void HandleMouseHover(Nucleus nucleus, Rect rect) { + GUIContent tooltip; + if (nucleus is Neuron neuron) { + tooltip = new( + $"{nucleus.name}" + + $"\nValue: {neuron.outputMagnitude}"); } else - Handles.Label(labelPos, clusterName, style); + tooltip = new($"{nucleus.name}"); + + Vector2 mousePosition = Event.current.mousePosition; + + // Display tooltip with some offset + Vector2 tooltipSize = GUI.skin.box.CalcSize(tooltip); + Rect tooltipRect = new Rect(mousePosition.x + 10, mousePosition.y + 10, tooltipSize.x, tooltipSize.y); + + GUI.Box(tooltipRect, tooltip); } - else { - if (this.currentNucleus is Neuron neuron) - maxValue = neuron.outputMagnitude; - DrawNucleus(this.currentNucleus, position, maxValue, 20); + protected void OnNeuronClick(Nucleus nucleus) { + if (nucleus == this.currentNucleus) { + this.selectedSynapseNeuron = null; + // if (Application.isPlaying) { + // if (nucleus is Cluster) + // expandArray = !expandArray; + // else + // expandArray = false; + // } + // else { + if (nucleus is Cluster cluster) + OnClusterClick(cluster); + // } + } + else if (nucleus.parent != null && this.currentNucleus != null && nucleus.parent != this.currentNucleus.parent) { + // We go to a different cluster + if (Application.isPlaying) { + if (this.selectedSynapseNeuron == null && nucleus.parent.instanceCount > 1) { + this.selectedSynapseNeuron = nucleus; + expandArray = false; + } + else { + this.currentNucleus = nucleus; + if (this.currentNucleus is Neuron neuron && neuron.receivers.Count == 0) + this.selectedOutput = this.currentNucleus; + this.selectedSynapseNeuron = null; + expandArray = false; + } + } + else { + // select the cluster, not the neuron in the cluster + this.currentNucleus = nucleus.parent; + expandArray = false; + } + } + else { + this.currentNucleus = nucleus; + if (this.currentNucleus is Neuron neuron && neuron.receivers.Count == 0) + this.selectedOutput = this.currentNucleus; + expandArray = false; + } } - } - else { - float maxValue = 1; - if (this.currentNucleus is Neuron neuron) - maxValue = neuron.outputMagnitude; - else if (this.currentNucleus is Cluster cluster) - maxValue = cluster.defaultOutput.outputMagnitude; - DrawNucleus(this.currentNucleus, position, maxValue, 20); - } - } - else { - DrawAllOutputs(position, size); - DrawOutputs(position, size); - } - graphContainer.style.width = 300; - } - protected void DrawReceivers(Nucleus nucleus, Vector3 parentPos, float size) { - List receivers; - if (nucleus is Neuron neuron) - receivers = neuron.receivers; - else if (nucleus is Cluster cluster) - receivers = cluster.CollectReceivers(true); - else - return; - - // For top-level nodes, add link to previous editor and/or 'Outputs' - int nodeCount = receivers.Count(); - if (nucleus == this.selectedOutput) { - // Add link to 'Outpus' - nodeCount++; - if (ClusterViewer.previousPrefab != null) - // Add link to previous editor - nodeCount++; - } - - // Determine the maximum value in this layer - // This is used to 'scale' the output value colors of the nuclei - float maxValue = 0; - foreach (Nucleus receiver in receivers) { - if (receiver is Neuron neuroid) { - float value = neuroid.outputMagnitude; - if (value > maxValue) - maxValue = value; - } - } - - // Determine the spacing of the nuclei in the layer - float spacing = 400f / nodeCount; - float margin = 10 + spacing / 2; - - int row = 0; - foreach (Nucleus receiver in receivers) { - Nucleus receiverNucleus = receiver; - if (receiverNucleus == null) - continue; - - Vector3 pos = new(50, margin + row * spacing, 0.0f); - DrawEdge(parentPos, pos); - - DrawNucleus(receiverNucleus, pos, maxValue, size); - row++; - } - if (nucleus == this.selectedOutput) { - Vector3 pos = new(50, margin + row * spacing, 0); - if (ClusterViewer.previousPrefab != null) { - DrawEdge(parentPos, pos); - DrawClusterPrefab(ClusterViewer.previousPrefab, pos, size); - row++; - } - pos = new(50, margin + row * spacing, 0); - DrawEdge(parentPos, pos); - DrawAllOutputs(pos, size); - } - } - - protected void DrawSynapses(Nucleus nucleus, Vector3 parentPos, float size) { - if (nucleus is not Neuron neuron) - return; - - if (this.selectedSynapseNeuron != null) { - DrawClusterSynapses(this.selectedSynapseNeuron, parentPos, size); - return; - } - if (nucleus == null) - return; - - // Determine the maximum value in this layer - // This is used to 'scale' the output value colors of the nuclei - float maxValue = 0; - int neuronCount = 0; - List drawnNeuronNames = new(); - foreach (Synapse synapse in neuron.synapses) { - if (synapse.neuron == null) - continue; - - // Count multiple synapses to the same neuron only once - string neuronName = synapse.neuron.name; - if (synapse.neuron.parent != null) - neuronName = synapse.neuron.parent.baseName + "." + neuronName; - - if (drawnNeuronNames.Contains(neuronName)) - continue; - drawnNeuronNames.Add(neuronName); - - float value = synapse.neuron.outputMagnitude * synapse.weight; - if (value > maxValue) - maxValue = value; - - neuronCount++; - } - - // Determine the spacing of the nuclei in the layer - float spacing = 400f / neuronCount; - float margin = 10 + spacing / 2; - - int row = 0; - //List drawnNeurons = new(); - drawnNeuronNames = new(); - foreach (Synapse synapse in neuron.synapses) { - if (synapse.neuron is null) - continue; - - // Draw multiple synapses to the same neuron only once - string neuronName = synapse.neuron.name; - if (synapse.neuron.parent != null) - neuronName = synapse.neuron.parent.baseName + "." + neuronName; - - if (drawnNeuronNames.Contains(neuronName)) - continue; - drawnNeuronNames.Add(neuronName); - - Vector3 pos = new(250, margin + row * spacing, 0.0f); - DrawEdge(parentPos, pos); - // Handles.color = Color.white; - // Handles.DrawLine(parentPos, pos); - Color color = Color.black; - if (Application.isPlaying) { - if (maxValue == 0 || !float.IsFinite(maxValue)) - maxValue = 1; - float brightness = synapse.neuron.outputMagnitude * synapse.weight / maxValue; - color = new Color(brightness, brightness, brightness, 1f); - } - DrawNucleus(synapse.neuron, pos, size, color); - row++; - } - } - - protected void DrawClusterSynapses(Nucleus nucleus, Vector3 parentPos, float size) { - if (nucleus == null || nucleus.parent == null || nucleus.parent.instances == null) - return; - - // Hack to disable showing labels - expandArray = true; - - float maxValue = 0; - foreach (Cluster sibling in nucleus.parent.instances) { - Neuron siblingNeuron = sibling.GetNucleus(nucleus.name) as Neuron; - float value = siblingNeuron.outputMagnitude; // no need to add weight as they are all the same - if (value > maxValue) - maxValue = value; - } - - // Determine the spacing of the nuclei in the layer - float spacing = 400f / nucleus.parent.instanceCount; - float margin = 10 + spacing / 2; - - int row = 0; - foreach (Cluster sibling in nucleus.parent.instances) { - Neuron siblingNeuron = sibling.GetNucleus(nucleus.name) as Neuron; - Vector3 position = new(250, margin + row * spacing, 0.0f); - DrawEdge(parentPos, position); - Color color = Color.black; - if (Application.isPlaying) { - if (maxValue == 0 || !float.IsFinite(maxValue)) - maxValue = 1; - float brightness = siblingNeuron.outputMagnitude / maxValue; - color = new Color(brightness, brightness, brightness, 1f); - } - DrawNucleus(siblingNeuron, position, size, color); - GUIStyle style = new(EditorStyles.label) { - alignment = TextAnchor.UpperCenter, - normal = { textColor = Color.white }, - fontStyle = FontStyle.Bold, - }; - Vector3 labelPos = position - Vector3.down * (size + 5); // below neuron - string name = $"{sibling.baseName}\n{nucleus.name}"; - Handles.Label(labelPos, name, style); - row++; - } - expandArray = false; - } - - protected void DrawOutputs(Vector2 parentPos, float size) { - if (this.currentCluster == null) - return; - - // Determine the maximum value in this layer - // This is used to 'scale' the output value colors of the nuclei - float maxValue = 0; - int neuronCount = 0; - List drawnNuclei = new(); - foreach (Nucleus nucleus in this.currentCluster.outputs) { - if (nucleus is not Neuron neuron) - continue; - - // Draw multiple synapses to the same neuron only once - if (drawnNuclei.Contains(nucleus)) - continue; - drawnNuclei.Add(nucleus); - - float value = neuron.outputMagnitude; - if (value > maxValue) - maxValue = value; - - neuronCount++; - } - - // Determine the spacing of the nuclei in the layer - float spacing = 400f / neuronCount; - float margin = 10 + spacing / 2; - - int row = 0; - drawnNuclei = new(); - foreach (Nucleus nucleus in this.currentCluster.outputs) { - if (nucleus is not Neuron neuron) - continue; - - // Draw multiple synapses to the same neuron only once - if (drawnNuclei.Contains(nucleus)) - continue; - drawnNuclei.Add(nucleus); - - Vector3 pos = new(250, margin + row * spacing, 0.0f); - DrawEdge(parentPos, pos); - Color color = Color.black; - if (Application.isPlaying) { - if (maxValue == 0 || !float.IsFinite(maxValue)) - maxValue = 1; - float brightness = neuron.outputMagnitude / maxValue; - color = new Color(brightness, brightness, brightness, 1f); - } - DrawNucleus(nucleus, pos, size, color); - row++; - } - } - - #endregion Focus Graph - - protected void DrawNucleus(Nucleus nucleus, Vector3 position, float maxValue, float size) { - Color color; - if (Application.isPlaying) { - float brightness = 0; - if (nucleus is Neuron neuron) - brightness = neuron.outputMagnitude / maxValue; - color = new Color(brightness, brightness, brightness, 1f); - } - else - color = Color.black; - DrawNucleus(nucleus, position, size, color); - } - - protected void DrawNucleus(Nucleus nucleus, Vector3 position, float size, Color color) { - if (nucleus == null) - return; - - if (nucleus == this.currentNucleus) { - // The selected nucleus highlight ring - Handles.color = Color.white; - Handles.DrawSolidDisc(position, Vector3.forward, size + 2); - } - - if (nucleus is MemoryCell) { - Handles.color = Color.white; - Handles.DrawWireDisc(position + Vector3.right * 10, Vector3.forward, size); - } - - Handles.color = color; - Handles.DrawSolidDisc(position, Vector3.forward, size); - - Handles.color = Color.white; - // Position the label in front of the disc - Vector3 labelPosition = position + (Vector3.forward * 0.1f); - - GUIStyle style = new(EditorStyles.label) { - alignment = TextAnchor.MiddleCenter, - normal = { textColor = Color.white }, - fontStyle = FontStyle.Bold, - }; - - if (nucleus.parent is Cluster parentCluster && currentNucleus != null && parentCluster != currentNucleus.parent) - DrawCluster(parentCluster, position, color, size); - else if (nucleus is Cluster cluster) - DrawCluster(cluster, position, color, size); - - if (expandArray == false) {// || nucleus != currentNucleus) { - // put name below nucleus - Vector3 labelPos = position - Vector3.down * (size + 5); // below neuron - style.alignment = TextAnchor.UpperCenter; - - if (nucleus.parent != null && currentNucleus != null && nucleus.parent != currentNucleus.parent && nucleus.parent is Cluster parentCluster1) { - // This neuron is part of another cluster - parentCluster1.name ??= ""; - int colonPos = parentCluster1.name.IndexOf(":"); - string baseName; - if (colonPos > 0 && colonPos < parentCluster1.name.Length - 2) - baseName = parentCluster1.name[..colonPos] + "\n"; - else - baseName = parentCluster1.name + "\n"; - Handles.Label(labelPos, baseName + nucleus.name, style); - } - else { - nucleus.name ??= ""; - int colonPos = nucleus.name.IndexOf(":"); - if (colonPos > 0 && colonPos < nucleus.name.Length - 2) { - // if it is an array, we should not show the :0 of the first element - string baseName = nucleus.name[..colonPos]; - Handles.Label(labelPos, baseName, style); + protected void OnClusterClick(Cluster subCluster) { + // May be used with storedPrefab... + Selection.activeObject = subCluster.prefab; + EditorGUIUtility.PingObject(subCluster.prefab); + ClusterViewer.previousPrefab = this.currentCluster.prefab; + ClusterEditor newEditor = CreateEditor(subCluster.prefab) as ClusterEditor; } - else - Handles.Label(labelPos, nucleus.name, style); - } - } - // Tooltip - Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2); - - int id = GUIUtility.GetControlID(FocusType.Passive); - Event e = Event.current; - EventType et = e.GetTypeForControl(id); - if (e != null && neuronRect.Contains(e.mousePosition)) { - // Process Hover - HandleMouseHover(nucleus, neuronRect); - // Process click - if (e.type == EventType.MouseDown && e.button == 0) { - // Consume the event so the scene doesn't also handle it - e.Use(); - if (nucleus is Cluster parentCluster2) - OnNeuronClick(parentCluster2); - else - OnNeuronClick(nucleus); - } - } - } - - protected void DrawCluster(Cluster cluster, Vector3 position, Color color, float size) { - GUIStyle labelTextStyle = new(EditorStyles.label) { - normal = { textColor = Color.white }, - fontStyle = FontStyle.Bold, - }; - - if (expandArray) { - // Put array indices above the discs - labelTextStyle.alignment = TextAnchor.LowerCenter; - Vector3 labelPosition = position + Vector3.down * (size + 5); // below disc - - // Strip the instance number in the name - int colonPos1 = cluster.name.IndexOf(":"); - if (colonPos1 > 0) { - string extName = cluster.name[(colonPos1 + 2)..]; - Handles.Label(labelPosition, extName, labelTextStyle); - } - else - Handles.Label(labelPosition, "0", labelTextStyle); - } - else { - // Put instance count inside the disc - labelTextStyle.alignment = TextAnchor.MiddleCenter; - Vector3 labelPosition = position + (Vector3.forward * 0.1f); - - // Adjust text color based on disc color - if (color.grayscale > 0.5f) - labelTextStyle.normal.textColor = Color.black; - else - labelTextStyle.normal.textColor = Color.white; - - if (cluster.instanceCount > 1) { - Handles.Label(labelPosition, cluster.instanceCount.ToString(), labelTextStyle); - labelTextStyle.normal.textColor = Color.white; - } - else if (cluster.instances != null && cluster.instances.Length > 1) { - Handles.Label(labelPosition, cluster.instances.Length.ToString(), labelTextStyle); - labelTextStyle.normal.textColor = Color.white; - } - } - - // Draw a circle around the disc to indicate this is a Cluster - Handles.color = Color.white; - Handles.DrawWireDisc(position, Vector3.forward, size + 5); - } - - protected void DrawClusterPrefab(ClusterPrefab prefab, Vector2 position, float size) { - Handles.color = Color.black; - Handles.DrawSolidDisc(position, Vector3.forward, size); - // Draw a circle around the disc to indicate this is a Cluster - Handles.color = Color.white; - Handles.DrawWireDisc(position, Vector3.forward, size + 5); - - // put name below nucleus - GUIStyle style = new(EditorStyles.label) { - alignment = TextAnchor.MiddleCenter, - normal = { textColor = Color.white }, - fontStyle = FontStyle.Bold, - }; - Vector2 labelPos = position - Vector2.down * (size + 5); // below neuron - style.alignment = TextAnchor.UpperCenter; - Handles.Label(labelPos, prefab.name, style); - - Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2); - int id = GUIUtility.GetControlID(FocusType.Passive); - Event e = Event.current; - EventType et = e.GetTypeForControl(id); - if (e != null && neuronRect.Contains(e.mousePosition)) { - // Process click - if (e.type == EventType.MouseDown && e.button == 0) { - // Consume the event so the scene doesn't also handle it - e.Use(); - Selection.activeObject = prefab; - EditorGUIUtility.PingObject(prefab); - ClusterViewer.previousPrefab = null; - CreateEditor(prefab); - } - } - } - - protected void DrawAllOutputs(Vector2 position, float size) { - GUIStyle labelTextStyle = new(EditorStyles.label) { - normal = { textColor = Color.white }, - fontStyle = FontStyle.Bold, - alignment = TextAnchor.MiddleCenter, - }; - Handles.Label(position, "Outputs", labelTextStyle); - - Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2); - Event e = Event.current; - if (e != null && neuronRect.Contains(e.mousePosition)) { - // Process click - if (e.type == EventType.MouseDown && e.button == 0) { - // Consume the event so the scene doesn't also handle it - e.Use(); - OnAllOutputsClick(); - } - } - - } - - protected void DrawEdge(Vector2 from, Vector2 to, float radius = 20) { - Handles.color = Color.white; - // Handles.DrawLine(from, to); - - Vector2 dir = to - from; - float len = dir.magnitude; - if (len <= 2f * radius || len <= Mathf.Epsilon) - // line too short - return; - - Vector2 n = dir / len; // normalized - Vector2 a = from + n * radius; - Vector2 b = to - n * radius; - Handles.DrawLine(a, b); - } - - protected void HandleMouseHover(Nucleus nucleus, Rect rect) { - GUIContent tooltip; - if (nucleus is Neuron neuron) { - tooltip = new( - $"{nucleus.name}" + - $"\nValue: {neuron.outputMagnitude}"); - } - else - tooltip = new($"{nucleus.name}"); - - Vector2 mousePosition = Event.current.mousePosition; - - // Display tooltip with some offset - Vector2 tooltipSize = GUI.skin.box.CalcSize(tooltip); - Rect tooltipRect = new Rect(mousePosition.x + 10, mousePosition.y + 10, tooltipSize.x, tooltipSize.y); - - GUI.Box(tooltipRect, tooltip); - } - - protected void OnNeuronClick(Nucleus nucleus) { - if (nucleus == this.currentNucleus) { - this.selectedSynapseNeuron = null; - // if (Application.isPlaying) { - // if (nucleus is Cluster) - // expandArray = !expandArray; - // else - // expandArray = false; - // } - // else { - if (nucleus is Cluster cluster) - OnClusterClick(cluster); - // } - } - else if (nucleus.parent != null && this.currentNucleus != null && nucleus.parent != this.currentNucleus.parent) { - // We go to a different cluster - if (Application.isPlaying) { - if (this.selectedSynapseNeuron == null && nucleus.parent.instanceCount > 1) { - this.selectedSynapseNeuron = nucleus; + protected void OnAllOutputsClick() { + this.currentNucleus = null; + this.selectedOutput = null; expandArray = false; } - else { - this.currentNucleus = nucleus; - if (this.currentNucleus is Neuron neuron && neuron.receivers.Count == 0) - this.selectedOutput = this.currentNucleus; - this.selectedSynapseNeuron = null; - expandArray = false; - } - - } - else { - // select the cluster, not the neuron in the cluster - this.currentNucleus = nucleus.parent; - expandArray = false; - } - } - else { - this.currentNucleus = nucleus; - if (this.currentNucleus is Neuron neuron && neuron.receivers.Count == 0) - this.selectedOutput = this.currentNucleus; - expandArray = false; - } - } - - protected void OnClusterClick(Cluster subCluster) { - // May be used with storedPrefab... - Selection.activeObject = subCluster.prefab; - EditorGUIUtility.PingObject(subCluster.prefab); - ClusterViewer.previousPrefab = this.currentCluster.prefab; - ClusterEditor newEditor = CreateEditor(subCluster.prefab) as ClusterEditor; - } - - protected void OnAllOutputsClick() { - this.currentNucleus = null; - this.selectedOutput = null; - expandArray = false; - } - + */ #endregion Graph void OnSceneGUI(SceneView sceneView) { diff --git a/Runtime/Scripts/Brain.cs b/Runtime/Scripts/Brain.cs index 5ac2d82..fe12ba5 100644 --- a/Runtime/Scripts/Brain.cs +++ b/Runtime/Scripts/Brain.cs @@ -1,6 +1,6 @@ using System; using UnityEngine; - +/* namespace NanoBrain.Unity { /// @@ -65,4 +65,5 @@ namespace NanoBrain.Unity { } } -} \ No newline at end of file +} +*/ \ No newline at end of file