using UnityEngine; using UnityEditor; using System.Collections.Generic; using System.Linq; namespace NanoBrain { // Simple DAG data model [System.Serializable] public class DagNode { public int id; public string title; public Vector2 position; public float radius = 20f; // circle radius public Nucleus nucleus; } [System.Serializable] public class DagEdge { public int fromId; public int toId; } public class Dag { public List nodes = new(); public List edges = new(); } public class BrainEditorWindow : EditorWindow { Dag dag = new(); Vector2 pan = Vector2.zero; private readonly System.Type acceptedType = typeof(ClusterPrefab); [MenuItem("Window/Brain Viewer")] public static void ShowWindow() { var w = GetWindow("Brain Viewer"); w.minSize = new Vector2(500, 300); } void OnEnable() { // Register callback so window updates when selection changes Selection.selectionChanged += OnSelectionChanged; dag = RefreshSelection(); ComputeLayout(dag); Repaint(); } private void OnDisable() { Selection.selectionChanged -= OnSelectionChanged; } private void OnSelectionChanged() { dag = RefreshSelection(); ComputeLayout(dag); Repaint(); } private Dag RefreshSelection() { ClusterPrefab prefab = Selection.activeObject as ClusterPrefab; if (prefab != null && acceptedType.IsAssignableFrom(prefab.GetType())) return GenerateGraph(prefab); else return new Dag(); } public Dag GenerateGraph(ClusterPrefab prefab) { Dag dag = new(); int ix = 0; foreach (Nucleus nucleus in prefab.nuclei) { DagNode node = new() { id = ix, title = nucleus.name }; dag.nodes.Add(node); if (nucleus is Neuron neuron) { foreach (Nucleus receiver in neuron.receivers) { DagEdge edge = new() { fromId = ix, toId = prefab.GetNucleusIndex(receiver) }; dag.edges.Add(edge); } } ix++; } return dag; } void OnGUI() { HandleInput(); Rect rect = new(0, 0, position.width, position.height); EditorGUI.DrawRect(rect, new Color(0.11f, 0.11f, 0.11f)); // compute window center Vector2 windowCenter = new(position.width / 2f, position.height / 2f); // compute graph bounds center (in graph space) Rect bounds = GetGraphBounds(dag); Vector2 graphCenter = bounds.center; // compute autoPan that recenters the graph (does not modify node positions) Vector2 autoPan = -graphCenter; // moves graph center to origin // total translation = windowCenter + autoPan + user pan Matrix4x4 oldMatrix = GUI.matrix; GUI.matrix = Matrix4x4.TRS(windowCenter + autoPan + pan, Quaternion.identity, Vector3.one) * Matrix4x4.TRS(-windowCenter, Quaternion.identity, Vector3.one); // Draw edges first foreach (DagEdge e in dag.edges) { DagNode from = GetNodeById(dag, e.fromId); DagNode to = GetNodeById(dag, e.toId); if (from == null || to == null) continue; DrawEdgeCircleNodes(from, to); } // Draw nodes (circles) foreach (DagNode n in dag.nodes) DrawNucleus(n); GUI.matrix = oldMatrix; } void HandleInput() { Event e = Event.current; // Pan with middle or right+ctrl drag if (e.type == EventType.MouseDrag && (e.button == 2 || (e.button == 1 && e.control))) { pan += e.delta; e.Use(); } } public static DagNode GetNodeById(Dag dag, int id) => dag.nodes.FirstOrDefault(x => x.id == id); public static void DrawNucleus(DagNode n) { Vector3 position = n.position; Handles.color = Color.black * 0.9f; Handles.DrawSolidDisc(n.position, Vector3.forward, n.radius); Handles.color = Color.white; GUIStyle style = new(EditorStyles.label) { alignment = TextAnchor.UpperCenter, normal = { textColor = Color.white }, fontStyle = FontStyle.Bold, }; Vector3 labelPos = position - Vector3.down * (n.radius + 10f); // below disc along up axis Handles.Label(labelPos, n.title, style); } public static void DrawEdgeCircleNodes(DagNode from, DagNode to) { Vector2 a = from.position; Vector2 b = to.position; if (a == b) return; Handles.color = Color.white * 0.9f; Handles.DrawLine(from.position, to.position); } public static void ComputeLayout(Dag dag) { Dictionary> adjacency = dag.nodes.ToDictionary(n => n.id, n => new List()); Dictionary outdegree = dag.nodes.ToDictionary(node => node.id, n => 0); foreach (DagEdge edge in dag.edges) { if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) continue; adjacency[edge.fromId].Add(edge.toId); outdegree[edge.fromId]++; } // Kahn's algorithm to compute topological layers (horizontal layers) // build parent list (reverse adjacency) and parentIndegree = number of children each parent has Dictionary> parents = dag.nodes.ToDictionary(n => n.id, _ => new List()); Dictionary childCount = dag.nodes.ToDictionary(n => n.id, _ => 0); foreach (DagEdge edge in dag.edges) { if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) continue; adjacency[edge.fromId].Add(edge.toId); parents[edge.toId].Add(edge.fromId); // parent of 'to' is 'from' childCount[edge.fromId]++; // outdegree } Dictionary layer = new(); Queue queue = new(outdegree.Where(kv => kv.Value == 0).Select(kv => kv.Key)); foreach (int id in queue) layer[id] = 0; // process parents (reverse traversal) while (queue.Count > 0) { int u = queue.Dequeue(); int l = layer[u]; foreach (int p in parents[u]) { if (!layer.ContainsKey(p) || layer[p] < l + 1) layer[p] = l + 1; childCount[p]--; // decrement remaining unprocessed children if (childCount[p] == 0) queue.Enqueue(p); } } // Any unreachable nodes -> assign next layers int maxLayer = layer.Count > 0 ? layer.Values.Max() : 0; foreach (DagNode node in dag.nodes) { if (!layer.ContainsKey(node.id)) { maxLayer++; layer[node.id] = maxLayer; } } // Group nodes by layer (left to right) List> layers = layer. GroupBy(kv => kv.Value). OrderBy(g => g.Key). Select(g => g.Select(x => x.Key).ToList()). ToList(); // Same code without using Linq // Build layers dictionary: layerIndex -> List nodeIds // Dictionary> layersDict = new(); // foreach (KeyValuePair kv in layer) { // int nodeId = kv.Key; // int layerIndex = kv.Value; // if (!layersDict.TryGetValue(layerIndex, out List list)) { // list = new List(); // layersDict[layerIndex] = list; // } // list.Add(nodeId); // } // // Determine sorted layer indices // List layerIndices = new(layersDict.Keys); // layerIndices.Sort(); // ascending order // // Build final List> in sorted order // List> layers = new(); // foreach (int idx in layerIndices) { // layers.Add(layersDict[idx]); // } float hSpacing = 100f; float totalHeight = 400f; // Place nodes: x increases with layer index, y spaced within layer for (int layerIx = 0; layerIx < layers.Count; layerIx++) { List nodeList = layers[layerIx]; float spacing = totalHeight / nodeList.Count; float margin = 10 + spacing / 2; for (int i = 0; i < nodeList.Count; i++) { int index = nodeList[i]; DagNode node = GetNodeById(dag, index); if (node == null) continue; float x = hSpacing + layerIx * hSpacing; //float y = 400 - totalHeight / 2f + i * vSpacing; float y = margin + i * spacing; // Debug.Log($"({li}, {i}) -> {x}, {y}"); node.position = new Vector2(x, y); } } //Repaint(); } static Rect RectUnion(Rect a, Rect b) { float xMin = Mathf.Min(a.xMin, b.xMin); float xMax = Mathf.Max(a.xMax, b.xMax); float yMin = Mathf.Min(a.yMin, b.yMin); float yMax = Mathf.Max(a.yMax, b.yMax); return Rect.MinMaxRect(xMin, yMin, xMax, yMax); } Rect GetGraphBounds(Dag dag) { if (dag.nodes == null || dag.nodes.Count == 0) return new Rect(Vector2.zero, Vector2.one); Rect bounds = new( dag.nodes[0].position - Vector2.one * dag.nodes[0].radius, 2f * dag.nodes[0].radius * Vector2.one); foreach (var n in dag.nodes) bounds = RectUnion(bounds, new Rect(n.position - Vector2.one * n.radius, 2f * n.radius * Vector2.one)); return bounds; } } }