diff --git a/Editor/BrainEditorWindow.cs b/Editor/BrainEditorWindow.cs deleted file mode 100644 index 078095c..0000000 --- a/Editor/BrainEditorWindow.cs +++ /dev/null @@ -1,290 +0,0 @@ -using UnityEngine; -using UnityEditor; -using System.Collections.Generic; -using System.Linq; - -namespace NanoBrain { - - // Simple DAG data model - [System.Serializable] - public class DagNode { - public int id; - public string title; - public Vector2 position; - public float radius = 20f; // circle radius - public Nucleus nucleus; - } - - [System.Serializable] - public class DagEdge { - public int fromId; - public int toId; - } - public class Dag { - public List nodes = new(); - public List edges = new(); - } - - public class BrainEditorWindow : EditorWindow { - Dag dag = new(); - - Vector2 pan = Vector2.zero; - - private readonly System.Type acceptedType = typeof(ClusterPrefab); - - [MenuItem("Window/Brain Viewer")] - public static void ShowWindow() { - var w = GetWindow("Brain Viewer"); - w.minSize = new Vector2(500, 300); - } - - void OnEnable() { - // Register callback so window updates when selection changes - Selection.selectionChanged += OnSelectionChanged; - dag = RefreshSelection(); - ComputeLayout(dag); - Repaint(); - } - - private void OnDisable() { - Selection.selectionChanged -= OnSelectionChanged; - } - - private void OnSelectionChanged() { - dag = RefreshSelection(); - ComputeLayout(dag); - Repaint(); - } - - private Dag RefreshSelection() { - ClusterPrefab prefab = Selection.activeObject as ClusterPrefab; - if (prefab != null && acceptedType.IsAssignableFrom(prefab.GetType())) - return GenerateGraph(prefab); - else - return new Dag(); - } - - public Dag GenerateGraph(ClusterPrefab prefab) { - Dag dag = new(); - - int ix = 0; - foreach (Nucleus nucleus in prefab.nuclei) { - DagNode node = new() { - id = ix, - title = nucleus.name - }; - dag.nodes.Add(node); - if (nucleus is Neuron neuron) { - foreach (Nucleus receiver in neuron.receivers) { - DagEdge edge = new() { - fromId = ix, - toId = prefab.GetNucleusIndex(receiver) - }; - dag.edges.Add(edge); - } - } - ix++; - } - return dag; - } - - void OnGUI() { - HandleInput(); - - Rect rect = new(0, 0, position.width, position.height); - EditorGUI.DrawRect(rect, new Color(0.11f, 0.11f, 0.11f)); - - // compute window center - Vector2 windowCenter = new(position.width / 2f, position.height / 2f); - - // compute graph bounds center (in graph space) - Rect bounds = GetGraphBounds(dag); - Vector2 graphCenter = bounds.center; - - // compute autoPan that recenters the graph (does not modify node positions) - Vector2 autoPan = -graphCenter; // moves graph center to origin - // total translation = windowCenter + autoPan + user pan - Matrix4x4 oldMatrix = GUI.matrix; - GUI.matrix = Matrix4x4.TRS(windowCenter + autoPan + pan, Quaternion.identity, Vector3.one) * - Matrix4x4.TRS(-windowCenter, Quaternion.identity, Vector3.one); - - - // Draw edges first - foreach (DagEdge e in dag.edges) { - DagNode from = GetNodeById(dag, e.fromId); - DagNode to = GetNodeById(dag, e.toId); - if (from == null || to == null) - continue; - DrawEdgeCircleNodes(from, to); - } - - // Draw nodes (circles) - foreach (DagNode n in dag.nodes) - DrawNucleus(n); - - GUI.matrix = oldMatrix; - } - - void HandleInput() { - Event e = Event.current; - - // Pan with middle or right+ctrl drag - if (e.type == EventType.MouseDrag && (e.button == 2 || (e.button == 1 && e.control))) { - pan += e.delta; - e.Use(); - } - } - - public static DagNode GetNodeById(Dag dag, int id) => dag.nodes.FirstOrDefault(x => x.id == id); - - public static void DrawNucleus(DagNode n) { - Vector3 position = n.position; - - Handles.color = Color.black * 0.9f; - Handles.DrawSolidDisc(n.position, Vector3.forward, n.radius); - - Handles.color = Color.white; - GUIStyle style = new(EditorStyles.label) { - alignment = TextAnchor.UpperCenter, - normal = { textColor = Color.white }, - fontStyle = FontStyle.Bold, - }; - Vector3 labelPos = position - Vector3.down * (n.radius + 10f); // below disc along up axis - Handles.Label(labelPos, n.title, style); - } - - public static void DrawEdgeCircleNodes(DagNode from, DagNode to) { - Vector2 a = from.position; - Vector2 b = to.position; - if (a == b) return; - - Handles.color = Color.white * 0.9f; - Handles.DrawLine(from.position, to.position); - } - - public static void ComputeLayout(Dag dag) { - Dictionary> adjacency = dag.nodes.ToDictionary(n => n.id, n => new List()); - Dictionary outdegree = dag.nodes.ToDictionary(node => node.id, n => 0); - foreach (DagEdge edge in dag.edges) { - if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) - continue; - adjacency[edge.fromId].Add(edge.toId); - outdegree[edge.fromId]++; - } - - // Kahn's algorithm to compute topological layers (horizontal layers) - // build parent list (reverse adjacency) and parentIndegree = number of children each parent has - Dictionary> parents = dag.nodes.ToDictionary(n => n.id, _ => new List()); - Dictionary childCount = dag.nodes.ToDictionary(n => n.id, _ => 0); - - foreach (DagEdge edge in dag.edges) { - if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) continue; - adjacency[edge.fromId].Add(edge.toId); - parents[edge.toId].Add(edge.fromId); // parent of 'to' is 'from' - childCount[edge.fromId]++; // outdegree - } - - Dictionary layer = new(); - Queue queue = new(outdegree.Where(kv => kv.Value == 0).Select(kv => kv.Key)); - foreach (int id in queue) - layer[id] = 0; - - // process parents (reverse traversal) - while (queue.Count > 0) { - int u = queue.Dequeue(); - int l = layer[u]; - foreach (int p in parents[u]) { - if (!layer.ContainsKey(p) || layer[p] < l + 1) - layer[p] = l + 1; - childCount[p]--; // decrement remaining unprocessed children - if (childCount[p] == 0) - queue.Enqueue(p); - } - } - - // Any unreachable nodes -> assign next layers - int maxLayer = layer.Count > 0 ? layer.Values.Max() : 0; - foreach (DagNode node in dag.nodes) { - if (!layer.ContainsKey(node.id)) { - maxLayer++; - layer[node.id] = maxLayer; - } - } - - // Group nodes by layer (left to right) - List> layers = - layer. - GroupBy(kv => kv.Value). - OrderBy(g => g.Key). - Select(g => g.Select(x => x.Key).ToList()). - ToList(); - - // Same code without using Linq - // Build layers dictionary: layerIndex -> List nodeIds - // Dictionary> layersDict = new(); - // foreach (KeyValuePair kv in layer) { - // int nodeId = kv.Key; - // int layerIndex = kv.Value; - // if (!layersDict.TryGetValue(layerIndex, out List list)) { - // list = new List(); - // layersDict[layerIndex] = list; - // } - // list.Add(nodeId); - // } - - // // Determine sorted layer indices - // List layerIndices = new(layersDict.Keys); - // layerIndices.Sort(); // ascending order - - // // Build final List> in sorted order - // List> layers = new(); - // foreach (int idx in layerIndices) { - // layers.Add(layersDict[idx]); - // } - - float hSpacing = 100f; - float totalHeight = 400f; - - // Place nodes: x increases with layer index, y spaced within layer - for (int layerIx = 0; layerIx < layers.Count; layerIx++) { - List nodeList = layers[layerIx]; - float spacing = totalHeight / nodeList.Count; - float margin = 10 + spacing / 2; - for (int i = 0; i < nodeList.Count; i++) { - int index = nodeList[i]; - DagNode node = GetNodeById(dag, index); - if (node == null) - continue; - float x = hSpacing + layerIx * hSpacing; - //float y = 400 - totalHeight / 2f + i * vSpacing; - float y = margin + i * spacing; - // Debug.Log($"({li}, {i}) -> {x}, {y}"); - node.position = new Vector2(x, y); - } - } - - //Repaint(); - } - - static Rect RectUnion(Rect a, Rect b) { - float xMin = Mathf.Min(a.xMin, b.xMin); - float xMax = Mathf.Max(a.xMax, b.xMax); - float yMin = Mathf.Min(a.yMin, b.yMin); - float yMax = Mathf.Max(a.yMax, b.yMax); - return Rect.MinMaxRect(xMin, yMin, xMax, yMax); - } - - Rect GetGraphBounds(Dag dag) { - if (dag.nodes == null || dag.nodes.Count == 0) - return new Rect(Vector2.zero, Vector2.one); - Rect bounds = new( - dag.nodes[0].position - Vector2.one * dag.nodes[0].radius, - 2f * dag.nodes[0].radius * Vector2.one); - foreach (var n in dag.nodes) - bounds = RectUnion(bounds, - new Rect(n.position - Vector2.one * n.radius, 2f * n.radius * Vector2.one)); - return bounds; - } - } - -} \ No newline at end of file diff --git a/Editor/BrainEditorWindow.cs.meta b/Editor/BrainEditorWindow.cs.meta deleted file mode 100644 index 5d8b61f..0000000 --- a/Editor/BrainEditorWindow.cs.meta +++ /dev/null @@ -1,2 +0,0 @@ -fileFormatVersion: 2 -guid: f041740900808273ab006e7d276a78e9 diff --git a/Editor/ClusterInspector.cs b/Editor/ClusterInspector.cs index 06d9db2..556f6a2 100644 --- a/Editor/ClusterInspector.cs +++ b/Editor/ClusterInspector.cs @@ -62,18 +62,7 @@ namespace NanoBrain { public class GraphEditor : GraphView { - public enum Mode { - Focus, - Full - } - public Mode mode = Mode.Focus; - public GraphEditor(ClusterPrefab prefab) : base(prefab) { - // create an EnumField for Mode - EnumField enumField = new(mode); - enumField.style.width = 80; - enumField.RegisterValueChangedCallback(OnModeChange); - outputContainer.Insert(0, enumField); Button addButton = new(() => OnAddClusterOutput()) { text = "Add" @@ -83,22 +72,6 @@ namespace NanoBrain { Add(outputContainer); } - private void OnModeChange(ChangeEvent evt) { - mode = (Mode)evt.newValue; - - Debug.Log("Mode changed to: " + mode); - } - - Nucleus selectedOutput; - protected override void OnOutputChanged(string outputName) { - if (this.currentNucleus.parent != null) - // Get nucleus in the parent instance - this.selectedOutput = this.currentNucleus.parent.GetNucleus(outputName); - else - // Get nucleus in the prefab - this.selectedOutput = this.prefab.GetNucleus(outputName); - } - void OnAddClusterOutput() { Nucleus newOutput = new Neuron(this.prefab, "New Output"); this.prefab.RefreshOutputs(); @@ -115,6 +88,7 @@ namespace NanoBrain { this.serializedBrain = new SerializedObject(this.prefab); this.currentNucleus = nucleus; Rebuild(inspectorContainer); + OnOutputChanged(outputsField.choices[0]); } void Rebuild(VisualElement inspectorContainer) { @@ -147,58 +121,6 @@ namespace NanoBrain { inspectorContainer.Add(inspectorIMGUIContainer); } - protected override void DrawGraph() { - if (mode == Mode.Focus) - DrawFocusGraph(); - else - DrawFullGraph(); - } - - protected void DrawFullGraph() { - Dag dag = GenerateGraph(this.prefab); - BrainEditorWindow.ComputeLayout(dag); - // Draw edges - foreach (DagEdge e in dag.edges) { - DagNode from = dag.nodes.FirstOrDefault(x => x.id == e.fromId); - DagNode to = dag.nodes.FirstOrDefault(x => x.id == e.toId); - if (from == null || to == null) - continue; - - Vector2 fromPosition = from.position; - Vector2 toPosition = to.position; - DrawEdge(fromPosition, toPosition); - } - - // Draw nodes - foreach (DagNode n in dag.nodes) - DrawNucleus(n.nucleus, n.position, 1, n.radius); - } - - public Dag GenerateGraph(ClusterPrefab prefab) { - Dag dag = new(); - - int ix = 0; - foreach (Nucleus nucleus in prefab.nuclei) { - DagNode node = new() { - id = ix, - title = nucleus.name, - nucleus = nucleus - }; - dag.nodes.Add(node); - if (nucleus is Neuron neuron) { - foreach (Nucleus receiver in neuron.receivers) { - DagEdge edge = new() { - fromId = ix, - toId = prefab.GetNucleusIndex(receiver) - }; - dag.edges.Add(edge); - } - } - ix++; - } - return dag; - } - #region Inspector private VisualElement inspectorIMGUIContainer; diff --git a/Editor/ClusterViewer.cs b/Editor/ClusterViewer.cs index 241abf7..972738c 100644 --- a/Editor/ClusterViewer.cs +++ b/Editor/ClusterViewer.cs @@ -22,6 +22,12 @@ namespace NanoBrain { protected VisualElement outputContainer; protected readonly PopupField outputsField; + public enum Mode { + Focus, + Full + } + public Mode mode = Mode.Focus; + public GraphView(ClusterPrefab prefab) { this.prefab = prefab; @@ -43,6 +49,12 @@ namespace NanoBrain { } }; + EnumField enumField = new(mode); + enumField.style.width = 80; + enumField.RegisterValueChangedCallback(OnModeChange); + outputContainer.Add(enumField); + + List names = this.prefab.outputs.Select(output => output.name).ToList(); if (names.Count > 0 && names.First() != null) { outputsField = new(names, names.First()) { @@ -59,13 +71,19 @@ namespace NanoBrain { RegisterCallback(evt => Unsubscribe()); } + protected virtual void OnModeChange(ChangeEvent evt) { + mode = (Mode)evt.newValue; + } + + protected Nucleus selectedOutput; protected virtual void OnOutputChanged(string outputName) { if (this.currentNucleus.parent != null) // Get nucleus in the parent instance - this.currentNucleus = this.currentNucleus.parent.GetNucleus(outputName); + this.selectedOutput = this.currentNucleus.parent.GetNucleus(outputName); else // Get nucleus in the prefab - this.currentNucleus = this.prefab.GetNucleus(outputName); + this.selectedOutput = this.prefab.GetNucleus(outputName); + this.currentNucleus = this.selectedOutput; } bool subscribed = false; @@ -82,22 +100,21 @@ namespace NanoBrain { subscribed = false; } - public void SetGraph(GameObject gameObject, Nucleus nucleus) { //}, VisualElement inspectorContainer) { + public void SetGraph(GameObject gameObject, Nucleus nucleus) { this.gameObject = gameObject; - //this.cluster = brain; if (Application.isPlaying == false) this.serializedBrain = new SerializedObject(this.prefab); this.currentNucleus = nucleus; Rebuild(); //inspectorContainer); + OnOutputChanged(outputsField.choices[0]); + } - void Rebuild() { //VisualElement inspectorContainer) { + void Rebuild() { BuildLayers(); - if (this.currentNucleus == null) { - // inspectorContainer.Clear(); + if (this.currentNucleus == null) return; - } string path = AssetDatabase.GetAssetPath(this.prefab); // or known path this.prefabAsset = AssetDatabase.LoadAssetAtPath(path); @@ -106,7 +123,6 @@ namespace NanoBrain { this.prefabAsset = CreateInstance(); //Debug.LogError("Cluster Prefab is not found on disk"); } - //DrawInspector(inspectorContainer); } protected void BuildLayers() { @@ -178,7 +194,68 @@ namespace NanoBrain { } protected virtual void DrawGraph() { - DrawFocusGraph(); + if (mode == Mode.Focus) + DrawFocusGraph(); + else + DrawFullGraph(); + } + + protected void DrawFullGraph() { + //Dag dag = GenerateGraph(this.prefab); + Dag dag = GenerateGraph(this.selectedOutput); + Dag.ComputeLayout(dag); + // Draw edges + foreach (DagEdge e in dag.edges) { + DagNode from = dag.nodes.FirstOrDefault(x => x.id == e.fromId); + DagNode to = dag.nodes.FirstOrDefault(x => x.id == e.toId); + if (from == null || to == null) + continue; + + Vector2 fromPosition = from.position; + Vector2 toPosition = to.position; + DrawEdge(fromPosition, toPosition); + } + + // Draw nodes + foreach (DagNode n in dag.nodes) + DrawNucleus(n.nucleus, n.position, 1, n.radius); + } + + public Dag GenerateGraph(Nucleus rootNucleus) { + Dag dag = new(); + if (rootNucleus == null) + return dag; + + int ix = 0; + DagNode receiver = new() { + id = ix, + //title = nucleus.name, + nucleus = rootNucleus + }; + dag.nodes.Add(receiver); + ix++; + DescendGraph(receiver, ref ix, dag); + return dag; + } + + private void DescendGraph(DagNode receiver, ref int ix, Dag dag) { + foreach (Synapse synapse in receiver.nucleus.synapses) { + DagNode synapseNode = dag.FindNode(synapse.neuron.name); + if (synapseNode == null) { + synapseNode = new() { + id = ix, + nucleus = synapse.neuron + }; + dag.nodes.Add(synapseNode); + } + DagEdge edge = new() { + fromId = synapseNode.id, + toId = receiver.id + }; + dag.edges.Add(edge); + ix++; + DescendGraph(synapseNode, ref ix, dag); + } } protected void DrawFocusGraph() { @@ -416,7 +493,7 @@ namespace NanoBrain { fontStyle = FontStyle.Bold, }; - if (nucleus.parent != null && nucleus.parent is Cluster parentCluster) { + if (nucleus.parent is Cluster parentCluster) { if (expandArray) { // Put array indices above elements style.alignment = TextAnchor.LowerCenter; @@ -463,12 +540,12 @@ namespace NanoBrain { } } - if (expandArray == false) {// || nucleus is not IReceptor) { + if (expandArray == false) { // put name below nucleus Vector3 labelPos = position - Vector3.down * (size + 5); // below neuron style.alignment = TextAnchor.UpperCenter; - if (nucleus.parent != null && nucleus.parent is Cluster parentCluster1) { + if (nucleus.parent != currentNucleus.parent && nucleus.parent is Cluster parentCluster1) { // This neuron is part of another cluster parentCluster1.name ??= ""; string baseName = ""; @@ -612,4 +689,137 @@ namespace NanoBrain { public int ix = 0; public List neuroids = new(); } + + [System.Serializable] + public class DagNode { + public int id; + public Vector2 position; + public float radius = 20f; // circle radius + public Nucleus nucleus; + } + + [System.Serializable] + public class DagEdge { + public int fromId; + public int toId; + } + public class Dag { + public List nodes = new(); + public List edges = new(); + + public DagNode FindNode(string name) { + foreach (DagNode node in this.nodes) { + if (node.nucleus.name == name) + return node; + } + return null; + } + + public static DagNode GetNodeById(Dag dag, int id) => dag.nodes.FirstOrDefault(x => x.id == id); + + public static void ComputeLayout(Dag dag) { + Dictionary> adjacency = dag.nodes.ToDictionary(n => n.id, n => new List()); + Dictionary outdegree = dag.nodes.ToDictionary(node => node.id, n => 0); + foreach (DagEdge edge in dag.edges) { + if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) + continue; + adjacency[edge.fromId].Add(edge.toId); + outdegree[edge.fromId]++; + } + + // Kahn's algorithm to compute topological layers (horizontal layers) + // build parent list (reverse adjacency) and parentIndegree = number of children each parent has + Dictionary> parents = dag.nodes.ToDictionary(n => n.id, _ => new List()); + Dictionary childCount = dag.nodes.ToDictionary(n => n.id, _ => 0); + + foreach (DagEdge edge in dag.edges) { + if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) continue; + adjacency[edge.fromId].Add(edge.toId); + parents[edge.toId].Add(edge.fromId); // parent of 'to' is 'from' + childCount[edge.fromId]++; // outdegree + } + + Dictionary layer = new(); + Queue queue = new(outdegree.Where(kv => kv.Value == 0).Select(kv => kv.Key)); + foreach (int id in queue) + layer[id] = 0; + + // process parents (reverse traversal) + while (queue.Count > 0) { + int u = queue.Dequeue(); + int l = layer[u]; + foreach (int p in parents[u]) { + if (!layer.ContainsKey(p) || layer[p] < l + 1) + layer[p] = l + 1; + childCount[p]--; // decrement remaining unprocessed children + if (childCount[p] == 0) + queue.Enqueue(p); + } + } + + // Any unreachable nodes -> assign next layers + int maxLayer = layer.Count > 0 ? layer.Values.Max() : 0; + foreach (DagNode node in dag.nodes) { + if (!layer.ContainsKey(node.id)) { + maxLayer++; + layer[node.id] = maxLayer; + } + } + + // Group nodes by layer (left to right) + List> layers = + layer. + GroupBy(kv => kv.Value). + OrderBy(g => g.Key). + Select(g => g.Select(x => x.Key).ToList()). + ToList(); + + // Same code without using Linq + // Build layers dictionary: layerIndex -> List nodeIds + // Dictionary> layersDict = new(); + // foreach (KeyValuePair kv in layer) { + // int nodeId = kv.Key; + // int layerIndex = kv.Value; + // if (!layersDict.TryGetValue(layerIndex, out List list)) { + // list = new List(); + // layersDict[layerIndex] = list; + // } + // list.Add(nodeId); + // } + + // // Determine sorted layer indices + // List layerIndices = new(layersDict.Keys); + // layerIndices.Sort(); // ascending order + + // // Build final List> in sorted order + // List> layers = new(); + // foreach (int idx in layerIndices) { + // layers.Add(layersDict[idx]); + // } + + float hSpacing = 100f; + float totalHeight = 400f; + + // Place nodes: x increases with layer index, y spaced within layer + for (int layerIx = 0; layerIx < layers.Count; layerIx++) { + List nodeList = layers[layerIx]; + float spacing = totalHeight / nodeList.Count; + float margin = 10 + spacing / 2; + for (int i = 0; i < nodeList.Count; i++) { + int index = nodeList[i]; + DagNode node = GetNodeById(dag, index); + if (node == null) + continue; + float x = hSpacing + layerIx * hSpacing; + //float y = 400 - totalHeight / 2f + i * vSpacing; + float y = margin + i * spacing; + // Debug.Log($"({li}, {i}) -> {x}, {y}"); + node.position = new Vector2(x, y); + } + } + + //Repaint(); + } + } + } \ No newline at end of file