using System.Collections.Generic; using System.Linq; using UnityEditor; using UnityEngine; using UnityEngine.UIElements; using Unity.Mathematics; using static Unity.Mathematics.math; [CustomEditor(typeof(ClusterPrefab))] public class ClusterInspector : Editor { protected static VisualElement mainContainer; protected static VisualElement inspectorContainer; protected bool breakOnWake = false; #region Start public override VisualElement CreateInspectorGUI() { ClusterPrefab cluster = target as ClusterPrefab; if (cluster != null) cluster.EnsureInitialization(); serializedObject.Update(); VisualElement root = new(); CreateInspector(root, cluster); serializedObject.ApplyModifiedProperties(); return root; } public static GraphView CreateInspector(VisualElement root, ClusterPrefab cluster) { root.style.paddingLeft = 0; root.style.paddingRight = 0; root.style.paddingTop = 0; root.style.paddingBottom = 0; root.styleSheets.Add(Resources.Load("GraphStyles")); // does the main container have added value? // is just is like the root mainContainer = new() { style = { height = 450, flexDirection = FlexDirection.Row } }; GraphView graph = new(cluster); graph.style.flexGrow = 1; inspectorContainer = new VisualElement { name = "inspector", style = { width = 300, flexGrow = 0 } }; mainContainer.Add(graph); mainContainer.Add(inspectorContainer); root.Add(mainContainer); graph.SetGraph(null, cluster.output, inspectorContainer); return graph; } public class GraphView : VisualElement { readonly ClusterPrefab cluster; SerializedObject serializedBrain; Nucleus currentNucleus; GameObject gameObject; private List layers = new(); private readonly Dictionary neuroidPositions = new(); private bool expandArray = false; ClusterWrapper currentWrapper; PopupField outputsField; public GraphView(ClusterPrefab prefab) { this.cluster = prefab; name = "content"; style.flexGrow = 1; IMGUIContainer graphContainer = new(OnIMGUI); graphContainer.style.position = Position.Absolute; graphContainer.style.left = 0; graphContainer.style.top = 0; graphContainer.style.right = 0; graphContainer.style.bottom = 0; graphContainer.pickingMode = PickingMode.Position; graphContainer.focusable = true; Add(graphContainer); VisualElement outputContainer = new() { style = { flexDirection = FlexDirection.Row, alignItems = Align.Center, } }; List names = this.cluster.outputs.Select(output => output.name).ToList(); outputsField = new(names, names.First()) { style = { flexGrow = 1 } }; outputsField.RegisterValueChangedCallback(evt => OnOutputChanged(evt.newValue)); outputContainer.Add(outputsField); Button addButton = new(() => OnAddClusterOutput()) { text = "Add" }; outputContainer.Add(addButton); Add(outputContainer); // Subscribe when added to panel (editor UI ready) RegisterCallback(evt => Subscribe()); RegisterCallback(evt => Unsubscribe()); } void OnOutputChanged(string outputName) { this.currentNucleus = this.cluster.GetNucleus(outputName); } void OnAddClusterOutput() { Nucleus newOutput = new Neuron(this.cluster, "Output 2"); outputsField.choices = this.cluster.outputs.Select(output => output.name).ToList(); outputsField.value = newOutput.name; this.currentNucleus = newOutput; } bool subscribed = false; void Subscribe() { if (subscribed) return; SceneView.duringSceneGui += OnSceneGUI; subscribed = true; SceneView.RepaintAll(); } void Unsubscribe() { if (!subscribed) return; SceneView.duringSceneGui -= OnSceneGUI; subscribed = false; } public void SetGraph(GameObject gameObject, Nucleus nucleus, VisualElement inspectorContainer) { this.gameObject = gameObject; //this.cluster = brain; if (Application.isPlaying == false) this.serializedBrain = new SerializedObject(this.cluster); this.currentNucleus = nucleus; Rebuild(inspectorContainer); } void Rebuild(VisualElement inspectorContainer) { BuildLayers(); if (this.currentNucleus == null) { inspectorContainer.Clear(); return; } if (currentWrapper != null) DestroyImmediate(currentWrapper); currentWrapper = CreateInstance().Init(this.currentNucleus, cluster); DrawInspector(inspectorContainer); } private void BuildLayers() { // A temporary list to track what's been added to layers this.layers = new(); int layerIx = 0; Nucleus selectedNucleus = this.currentNucleus; if (selectedNucleus == null) return; NeuroidLayer currentLayer = new() { ix = layerIx }; if (selectedNucleus.receivers != null) { foreach (Nucleus receiver in selectedNucleus.receivers) { Nucleus outputNeuroid = receiver; if (outputNeuroid != null) { AddToLayer(currentLayer, outputNeuroid); // Debug.Log($"layer {layerIx} nucleus {outputNeuroid.name}"); } } } if (currentLayer.neuroids.Count > 0) { this.layers.Add(currentLayer); layerIx++; currentLayer = new() { ix = layerIx }; } AddToLayer(currentLayer, selectedNucleus); this.layers.Add(currentLayer); // Debug.Log($"layer {layerIx} nucleus {selectedNucleus.name}"); layerIx++; currentLayer = new() { ix = layerIx }; if (selectedNucleus.synapses != null) { foreach (Synapse synapse in selectedNucleus.synapses) { Nucleus input = synapse.nucleus; AddToLayer(currentLayer, input); // Debug.Log($"layer {layerIx} nucleus {input.name}"); } } if (currentLayer.neuroids.Count > 0) { this.layers.Add(currentLayer); } } private void AddToLayer(NeuroidLayer layer, Nucleus nucleus) { if (nucleus == null) return; layer.neuroids.Add(nucleus); //nucleus.layerIx = layer.ix; // Store its position Vector2Int neuroidPosition = new(layer.ix, layer.neuroids.Count - 1); neuroidPositions[nucleus] = neuroidPosition; } public void OnIMGUI() { if (currentNucleus == null) return; if (Application.isPlaying == false) serializedBrain.Update(); Handles.BeginGUI(); DrawGraph(); Handles.EndGUI(); } private void DrawGraph() { float size = 20; Vector3 position = new(150, 210, 0); DrawReceivers(this.currentNucleus, position, size); DrawSynapses(this.currentNucleus, position, size); // Draw selected Nucleus if (expandArray) { float maxValue = 0; foreach (Nucleus nucleus in this.currentNucleus.array.nuclei) { float value = length(nucleus.outputValue); if (value > maxValue) maxValue = value; } float spacing = 400f / this.currentNucleus.array.nuclei.Count(); float margin = 10 + spacing / 2; float xMin = 150 - size; float xMax = 150 + size; float yMin = 10 + margin - size / 2; float yMax = 400 - margin + size; Vector3[] verts = new Vector3[4] { new(xMin, yMin, 0), new(xMax, yMin, 0), new(xMax, yMax, 0), new(xMin, yMax, 0) }; Handles.color = Color.black; Handles.DrawAAConvexPolygon(verts); int row = 0; foreach (Nucleus nucleus in this.currentNucleus.array.nuclei) { Vector3 pos = new(150, margin + row * spacing, 0.0f); Handles.color = Color.white; // The selected nucleus highlight ring Handles.DrawSolidDisc(pos, Vector3.forward, size + 2); DrawNucleus(nucleus, pos, maxValue, size); row++; } // GUIStyle style = new(EditorStyles.label) { // alignment = TextAnchor.UpperCenter, // normal = { textColor = Color.white }, // fontStyle = FontStyle.Bold, // }; // Vector3 labelPos = new Vector3(150, yMax, 0) - Vector3.down * (size + 25); // below disc along up axis // Handles.Label(labelPos, this.currentNucleus.name, style); } else { Handles.color = Color.white; // The selected nucleus highlight ring Handles.DrawSolidDisc(position, Vector3.forward, size + 2); DrawNucleus(this.currentNucleus, position, length(this.currentNucleus.outputValue), 20); } } private void DrawReceivers(Nucleus nucleus, Vector3 parentPos, float size) { int nodeCount = nucleus.receivers.Count(); // Determine the maximum value in this layer // This is used to 'scale' the output value colors of the nuclei float maxValue = 0; foreach (Nucleus receiver in nucleus.receivers) { if (receiver is Neuron neuroid) { float value = length(neuroid.outputValue); if (value > maxValue) maxValue = value; } } // Determine the spacing of the nuclei in the layer float spacing = 400f / nodeCount; float margin = 10 + spacing / 2; int row = 0; List drawnArrays = new(); foreach (Nucleus receiver in nucleus.receivers) { if (drawnArrays.Contains(receiver.array)) continue; drawnArrays.Add(receiver.array); Nucleus receiverNucleus = receiver; if (receiverNucleus == null) continue; Vector3 pos = new(50, margin + row * spacing, 0.0f); Handles.color = Color.white; Handles.DrawLine(parentPos, pos); DrawNucleus(receiverNucleus, pos, maxValue, size); row++; } } private void DrawSynapses(Nucleus nucleus, Vector3 parentPos, float size) { int nodeCount = nucleus.synapses.Count; // Determine the maximum value in this layer // This is used to 'scale' the output value colors of the nuclei float maxValue = 0; int neuronCount = 0; List drawnArrays = new(); foreach (Synapse synapse in nucleus.synapses) { if (synapse.nucleus is Neuron neuroid) { if (drawnArrays.Contains(neuroid.array)) continue; drawnArrays.Add(neuroid.array); } float value = length(synapse.nucleus.outputValue) * synapse.weight; // Debug.Log($"{synapse.nucleus.name}: {value} {length(synapse.nucleus.outputValue)} {synapse.weight}"); if (value > maxValue) maxValue = value; neuronCount++; } // Determine the spacing of the nuclei in the layer float spacing = 400f / neuronCount; float margin = 10 + spacing / 2; int row = 0; drawnArrays = new(); foreach (Synapse synapse in nucleus.synapses) { if (synapse.nucleus is Neuron neuron) { if (drawnArrays.Contains(neuron.array)) continue; drawnArrays.Add(neuron.array); } Vector3 pos = new(250, margin + row * spacing, 0.0f); Handles.color = Color.white; Handles.DrawLine(parentPos, pos); if (synapse.nucleus != null) { Color color = Color.black; if (synapse.nucleus.isSleeping) color = Color.darkRed; else if (Application.isPlaying) { float brightness = length(synapse.nucleus.outputValue) * synapse.weight / maxValue; color = new Color(brightness, brightness, brightness, 1f); } DrawNucleus(synapse.nucleus, pos, maxValue, size, color); } row++; } } private void DrawNucleus(Nucleus nucleus, Vector3 position, float maxValue, float size) { Color color; if (nucleus.isSleeping) color = Color.darkRed; else { if (Application.isPlaying) { float brightness = length(nucleus.outputValue) / maxValue; color = new Color(brightness, brightness, brightness, 1f); } else color = Color.black; } DrawNucleus(nucleus, position, maxValue, size, color); } private void DrawNucleus(Nucleus nucleus, Vector3 position, float maxValue, float size, Color color) { if (nucleus is MemoryCell memory) { Handles.color = Color.white; Handles.DrawWireDisc(position + Vector3.right * 10, Vector3.forward, size); } Handles.color = color; Handles.DrawSolidDisc(position, Vector3.forward, size); Handles.color = Color.white; // Position the label in front of the disc Vector3 labelPosition = position + (Vector3.forward * 0.1f); GUIStyle style = new(EditorStyles.label) { alignment = TextAnchor.MiddleCenter, normal = { textColor = Color.white }, fontStyle = FontStyle.Bold, }; if (nucleus is Nucleus neuron) { if (neuron.array == null || neuron.array.nuclei == null || neuron.array.nuclei.Count() == 0) neuron.array = new NucleusArray(neuron); if ((!expandArray || neuron.array.nuclei.First() != this.currentNucleus) && neuron.array.nuclei.Count() > 1) { Handles.Label(labelPosition, neuron.array.nuclei.Count().ToString(), style); } if (expandArray && neuron.array.nuclei.First() == this.currentNucleus) { int arrayIx = 0; foreach (Nucleus n in neuron.array.nuclei) { if (n == neuron) break; arrayIx++; } Handles.Label(labelPosition, $"[{arrayIx}]", style); } else { style.alignment = TextAnchor.UpperCenter; Vector3 labelPos = position - Vector3.down * (size + 10f); // below disc along up axis int colonPos = nucleus.name.IndexOf(":"); if (colonPos > 0) { string baseName = nucleus.name[..colonPos]; Handles.Label(labelPos, baseName, style); } else Handles.Label(labelPos, nucleus.name, style); } if (nucleus is Cluster cluster) { Handles.color = Color.white; Handles.DrawWireDisc(position, Vector3.forward, size + 10); } } else { style.alignment = TextAnchor.UpperCenter; Vector3 labelPos = position - Vector3.down * (size + 10); // below disc along up axis Handles.Label(labelPos, nucleus.name, style); } Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2); int id = GUIUtility.GetControlID(FocusType.Passive); Event e = Event.current; EventType et = e.GetTypeForControl(id); if (e != null && neuronRect.Contains(e.mousePosition)) { // Process Hover HandleMouseHover(nucleus, neuronRect); // Process click if (e.type == EventType.MouseDown && e.button == 0) { // Consume the event so the scene doesn't also handle it e.Use(); HandleClicked(nucleus); } } } private void HandleMouseHover(Nucleus nucleus, Rect rect) { GUIContent tooltip; tooltip = new( $"{nucleus.name}" + $"\nValue: {length(nucleus.outputValue)}"); Vector2 mousePosition = Event.current.mousePosition; // Display tooltip with some offset Vector2 tooltipSize = GUI.skin.box.CalcSize(tooltip); Rect tooltipRect = new Rect(mousePosition.x + 10, mousePosition.y + 10, tooltipSize.x, tooltipSize.y); GUI.Box(tooltipRect, tooltip); } private void HandleClicked(Nucleus nucleus) { if (nucleus == this.currentNucleus) { if (nucleus is Nucleus n) { expandArray = !expandArray; return; } } else if (nucleus is Nucleus n) { this.currentNucleus = n; BuildLayers(); } } private int selectedInputType = 0; void DrawInspector(VisualElement inspectorContainer) { if (inspectorContainer == null) return; inspectorContainer.Clear(); if (this.currentNucleus == null) return; // create a SerializedObject wrapper so Unity inspector controls work (and Undo) SerializedObject so = new(currentWrapper); IMGUIContainer container = new(() => { if (so.targetObject == null) return; so.Update(); if (this.currentNucleus == null) return; GUIStyle headerStyle = new GUIStyle(EditorStyles.boldLabel) { alignment = TextAnchor.MiddleLeft, margin = new RectOffset(10, 0, 4, 4) }; //GUI.backgroundColor = EditorGUIUtility.isProSkin ? new Color(0.15f, 0.15f, 0.15f) : new Color(0.85f, 0.85f, 0.85f); //GUILayout.BeginVertical("box"); GUIStyle boldTextFieldStyle = new GUIStyle(EditorStyles.textField) { fontStyle = FontStyle.Bold }; GUILayout.Label(this.currentNucleus.GetType().ToString(), headerStyle); //GUILayout.EndVertical(); //GUI.backgroundColor = Color.white; // Reset background color this.currentNucleus.name = EditorGUILayout.TextField(this.currentNucleus.name, boldTextFieldStyle); if (this.currentNucleus is Neuron neuroid) { if (this.currentNucleus is MemoryCell memory) { } else { EditorGUILayout.BeginHorizontal(); EditorGUILayout.LabelField("Activation Curve", GUILayout.Width(150)); if (neuroid.curveMax > 0) EditorGUILayout.CurveField(neuroid.curve, Color.cyan, new Rect(0, 0, 1, neuroid.curveMax)); else EditorGUILayout.CurveField(neuroid.curve, Color.cyan, new Rect(0, neuroid.curveMax, 1, -neuroid.curveMax)); neuroid.curvePreset = (Neuron.CurvePresets)EditorGUILayout.EnumPopup(neuroid.curvePreset, GUILayout.Width(100)); EditorGUILayout.EndHorizontal(); } if (neuroid.array == null || neuroid.array.nuclei == null || neuroid.array.nuclei.Count() == 0) neuroid.array = new NucleusArray(neuroid); EditorGUILayout.BeginHorizontal(); EditorGUILayout.IntField("Array size", neuroid.array.nuclei.Count()); if (GUILayout.Button("Add")) neuroid.array.AddNucleus(this.cluster); if (GUILayout.Button("Del")) neuroid.array.RemoveNucleus(); EditorGUILayout.EndHorizontal(); } if (Application.isPlaying) EditorGUILayout.FloatField("Output", length(this.currentNucleus.outputValue)); else EditorGUILayout.LabelField(" "); if (this.currentNucleus.synapses.Count > 0) { EditorGUILayout.LabelField("Synapses"); Synapse[] synapses = this.currentNucleus.synapses.ToArray(); foreach (Synapse synapse in synapses) { if (synapse.nucleus != null) { EditorGUILayout.Space(); //EditorGUI.BeginDisabledGroup(synapse.nucleus.isSleeping); if (Application.isPlaying) EditorGUILayout.FloatField(synapse.nucleus.name, length(synapse.nucleus.outputValue) * synapse.weight); else { EditorGUILayout.BeginHorizontal(); EditorGUILayout.LabelField(synapse.nucleus.name); if (GUILayout.Button("Disconnect")) synapse.nucleus.RemoveReceiver(this.currentNucleus); EditorGUILayout.EndHorizontal(); } EditorGUI.indentLevel++; synapse.weight = EditorGUILayout.FloatField("Weight", synapse.weight); EditorGUI.indentLevel--; //EditorGUI.EndDisabledGroup(); } } } EditorGUILayout.Space(); ConnectNucleus(this.cluster, this.currentNucleus); EditorGUILayout.BeginHorizontal(); string[] options = { "Neuron", "MemoryCell", "Selector", "Cluster" }; selectedInputType = EditorGUILayout.Popup(selectedInputType, options); if (GUILayout.Button("Add Input")) AddInput(selectedInputType, this.currentNucleus); EditorGUILayout.EndHorizontal(); // if (GUILayout.Button("Add Input Neuron")) // AddInputNeuron(this.currentNucleus); // if (GUILayout.Button("Add Input MemoryCell")) // AddInputMemoryCell(this.currentNucleus); // if (GUILayout.Button("Add Input Cluster")) // AddCluster(this.currentNucleus); EditorGUILayout.Space(); if (GUILayout.Button("Delete this neuron")) DeleteNeuron(this.currentNucleus); if (this.currentNucleus is Cluster subCluster) { if (GUILayout.Button("Edit Cluster")) EditCluster(subCluster); } // if (this.gameObject != null) { // Vector3 worldVector = this.gameObject.transform.TransformVector(this.currentNucleus.outputValue); // //Debug.DrawRay(this.gameObject.transform.position, worldVector, Color.yellow); // Handles.color = Color.yellow; // Handles.DrawLine(this.gameObject.transform.position, this.gameObject.transform.position + worldVector); // } }); inspectorContainer.Add(container); } void OnSceneGUI(SceneView sceneView) { if (this.gameObject != null) { Vector3 worldVector = this.gameObject.transform.TransformVector(this.currentNucleus.outputValue); Handles.color = Color.yellow; Handles.DrawLine(this.gameObject.transform.position, this.gameObject.transform.position + worldVector); } } protected virtual void AddInput(int selectedInputType, Nucleus nucleus) { switch (selectedInputType) { case 0: // Neuron AddInputNeuron(nucleus); break; case 1: // MemoryCell AddInputMemoryCell(nucleus); break; case 2: // Selector AddSelectorInput(nucleus); break; case 3: // Cluster AddCluster(nucleus); break; } } protected virtual void AddInputNeuron(Nucleus nucleus) { //Neuron newNeuroid = new(this.cluster, "New neuron"); Neuron newNeuroid = new(this.cluster, "New neuron"); newNeuroid.AddReceiver(nucleus); this.currentNucleus = newNeuroid; BuildLayers(); } protected virtual void DeleteNeuron(Nucleus nucleus) { if (nucleus == null) return; if (nucleus.cluster != null) this.currentNucleus = nucleus.cluster.output; foreach (Nucleus receiver in nucleus.receivers) { if (receiver != null) { this.currentNucleus = receiver; break; } } Neuron.Delete(nucleus); BuildLayers(); } protected void AddSelectorInput(Nucleus nucleus) { Selector newSelector = new(this.cluster, "New Selector"); newSelector.AddReceiver(nucleus); this.currentNucleus = newSelector; BuildLayers(); } protected virtual void AddInputMemoryCell(Nucleus nucleus) { MemoryCell newMemory = new(this.cluster, "New memory cell"); newMemory.AddReceiver(nucleus); this.currentNucleus = newMemory; BuildLayers(); } protected virtual void AddCluster(Nucleus nucleus) { ClusterPickerWindow.ShowPicker(brain => OnClusterPicked(nucleus, brain), "Select Cluster"); } private void OnClusterPicked(Nucleus nucleus, ClusterPrefab prefab) { Cluster subclusterInstance = new(prefab, this.cluster); subclusterInstance.AddReceiver(nucleus); // This does not work somehow // this.currentNucleus = subclusterInstance; // BuildLayers(); } private void EditCluster(Cluster subCluster) { // May be used with storedPrefab... Selection.activeObject = subCluster.prefab; EditorGUIUtility.PingObject(subCluster.prefab); var editor = Editor.CreateEditor(subCluster.prefab); } // Connect to another nucleus in the same cluster protected virtual void ConnectNucleus(ClusterPrefab cluster, Nucleus nucleus) { if (cluster == null) return; IEnumerable synapseNuclei = this.currentNucleus.synapses .Where(synapse => synapse.nucleus != null) .Select(synapse => synapse.nucleus); IEnumerable nuclei = cluster.nuclei .Except(synapseNuclei); IEnumerable nucleiNames = nuclei.Select(n => n.name); string[] names = nucleiNames.ToArray(); int selectedIndex = -1; selectedIndex = EditorGUILayout.Popup("Connect to", selectedIndex, names); if (selectedIndex >= 0) { Nucleus receptor = nuclei.ElementAt(selectedIndex); receptor.AddReceiver(this.currentNucleus); } } protected virtual void DisconnectNucleus(Neuron nucleus) { if (this.currentNucleus.cluster == null) return; string[] names = this.currentNucleus.synapses.Select(synapse => synapse.nucleus.name).ToArray(); int selectedIndex = -1; selectedIndex = EditorGUILayout.Popup("Disconnect from", selectedIndex, names); //if (selectedIndex >= 0 && selectedIndex < this.currentNucleus.brain.perceptei.Count) { if (selectedIndex >= 0 && selectedIndex < this.currentNucleus.cluster.nuclei.Count) { Synapse synapse = this.currentNucleus.synapses[selectedIndex]; synapse.nucleus.RemoveReceiver(this.currentNucleus); } } } #endregion Start } public class NeuroidLayer { public int ix = 0; public List neuroids = new(); } public class ClusterWrapper : ScriptableObject { // expose fields that map to GraphNode //public string title; public Vector2 position; Nucleus node; ClusterPrefab graph; // needed to write back and mark dirty public ClusterWrapper Init(Nucleus node, ClusterPrefab graphAsset) { this.node = node; this.graph = graphAsset; //this.title = " A " + node.name; //position = node.position; return this; } void OnValidate() { if (node != null) { //node.name = title; //node.position = position; #if UNITY_EDITOR if (graph != null) UnityEditor.EditorUtility.SetDirty(graph); #endif } } }