using UnityEngine; using UnityEditor; using System.Collections.Generic; using System.Linq; namespace NanoBrain { // Simple DAG data model [System.Serializable] public class DagNode { public int id; public string title; public Vector2 position; public float radius = 20f; // circle radius } [System.Serializable] public class DagEdge { public int fromId; public int toId; } public class BrainEditorWindow : EditorWindow { readonly List nodes = new(); readonly List edges = new(); Vector2 pan = Vector2.zero; float zoom = 1.0f; const float minZoom = 0.5f; const float maxZoom = 2.0f; private readonly System.Type acceptedType = typeof(ClusterPrefab); [MenuItem("Window/Brain Viewer")] public static void ShowWindow() { var w = GetWindow("Brain Viewer"); w.minSize = new Vector2(500, 300); } void OnEnable() { // Register callback so window updates when selection changes Selection.selectionChanged += OnSelectionChanged; RefreshSelection(); ComputeLayout(); } private void OnDisable() { Selection.selectionChanged -= OnSelectionChanged; } private void OnSelectionChanged() { RefreshSelection(); ComputeLayout(); Repaint(); } private void RefreshSelection() { ClusterPrefab prefab = Selection.activeObject as ClusterPrefab; if (prefab != null && acceptedType.IsAssignableFrom(prefab.GetType())) { GenerateGraph(prefab); } } private void GenerateGraph(ClusterPrefab prefab) { nodes.Clear(); edges.Clear(); int ix = 0; foreach (Nucleus nucleus in prefab.nuclei) { nodes.Add(new DagNode() { id = ix, title = nucleus.name }); if (nucleus is Neuron neuron) { foreach (Nucleus receiver in neuron.receivers) { int receiverIx = prefab.GetNucleusIndex(receiver); edges.Add(new DagEdge() { fromId = ix, toId = receiverIx }); } } ix++; } } void OnGUI() { HandleInput(); Rect rect = new(0, 0, position.width, position.height); EditorGUI.DrawRect(rect, new Color(0.11f, 0.11f, 0.11f)); // compute window center Vector2 windowCenter = new(position.width / 2f, position.height / 2f); // compute graph bounds center (in graph space) Rect bounds = GetGraphBounds(); Vector2 graphCenter = bounds.center; // compute autoPan that recenters the graph (does not modify node positions) Vector2 autoPan = -graphCenter; // moves graph center to origin // total translation = windowCenter + autoPan + user pan Matrix4x4 oldMatrix = GUI.matrix; GUI.matrix = Matrix4x4.TRS(windowCenter + autoPan + pan, Quaternion.identity, Vector3.one * zoom) * Matrix4x4.TRS(-windowCenter, Quaternion.identity, Vector3.one); // Draw edges first foreach (DagEdge e in edges) { DagNode from = GetNodeById(e.fromId); DagNode to = GetNodeById(e.toId); if (from == null || to == null) continue; DrawEdgeCircleNodes(from, to); } // Draw nodes (circles) foreach (DagNode n in nodes) DrawNucleus(n); GUI.matrix = oldMatrix; } void HandleInput() { Event e = Event.current; // Zoom with scroll if (e.type == EventType.ScrollWheel) { float oldZoom = zoom; float delta = -e.delta.y * 0.01f; zoom = Mathf.Clamp(zoom + delta, minZoom, maxZoom); Vector2 mouse = e.mousePosition; pan += (mouse - new Vector2(position.width / 2, position.height / 2)) * (1 - zoom / oldZoom); e.Use(); } // Pan with middle or right+ctrl drag if (e.type == EventType.MouseDrag && (e.button == 2 || (e.button == 1 && e.control))) { pan += e.delta; e.Use(); } } DagNode GetNodeById(int id) => nodes.FirstOrDefault(x => x.id == id); void DrawNucleus(DagNode n) { Vector3 position = n.position; Handles.color = Color.white * 0.9f; Handles.DrawSolidDisc(n.position, Vector3.forward, n.radius); Handles.color = Color.white; GUIStyle style = new(EditorStyles.label) { alignment = TextAnchor.UpperCenter, normal = { textColor = Color.white }, fontStyle = FontStyle.Bold, }; Vector3 labelPos = position - Vector3.down * (n.radius + 10f); // below disc along up axis Handles.Label(labelPos, n.title, style); } void DrawEdgeCircleNodes(DagNode from, DagNode to) { Vector2 a = from.position; Vector2 b = to.position; if (a == b) return; Handles.color = Color.white * 0.9f; Handles.DrawLine(from.position, to.position); } // Right-to-left layered layout (sources on the right, sinks on the left) void ComputeLayout() { // build adjacency and indegree Dictionary> adjacency = nodes.ToDictionary(n => n.id, n => new List()); Dictionary indegree = nodes.ToDictionary(n => n.id, n => 0); foreach (DagEdge edge in edges) { if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) continue; adjacency[edge.fromId].Add(edge.toId); indegree[edge.toId]++; } Dictionary outdegree = nodes.ToDictionary(node => node.id, n => 0); foreach (DagEdge edge in edges) { if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) continue; adjacency[edge.fromId].Add(edge.toId); outdegree[edge.fromId]++; } // Kahn's algorithm to compute topological layers (horizontal layers) // build parent list (reverse adjacency) and parentIndegree = number of children each parent has Dictionary> parents = nodes.ToDictionary(n => n.id, _ => new List()); Dictionary childCount = nodes.ToDictionary(n => n.id, _ => 0); foreach (DagEdge edge in edges) { if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) continue; adjacency[edge.fromId].Add(edge.toId); parents[edge.toId].Add(edge.fromId); // parent of 'to' is 'from' childCount[edge.fromId]++; // outdegree } Dictionary layer = new(); Queue queue = new(outdegree.Where(kv => kv.Value == 0).Select(kv => kv.Key)); foreach (int id in queue) layer[id] = 0; // process parents (reverse traversal) while (queue.Count > 0) { int u = queue.Dequeue(); int l = layer[u]; foreach (int p in parents[u]) { if (!layer.ContainsKey(p) || layer[p] < l + 1) layer[p] = l + 1; childCount[p]--; // decrement remaining unprocessed children if (childCount[p] == 0) queue.Enqueue(p); } } // Any unreachable nodes -> assign next layers int maxLayer = layer.Count > 0 ? layer.Values.Max() : 0; foreach (DagNode node in nodes) { if (!layer.ContainsKey(node.id)) { maxLayer++; layer[node.id] = maxLayer; } } // Group nodes by layer (left to right) List> layers = layer.GroupBy(kv => kv.Value).OrderBy(g => g.Key).Select(g => g.Select(x => x.Key).ToList()).ToList(); // Same code without using Linq // Build layers dictionary: layerIndex -> List nodeIds // Dictionary> layersDict = new(); // foreach (KeyValuePair kv in layer) { // int nodeId = kv.Key; // int layerIndex = kv.Value; // if (!layersDict.TryGetValue(layerIndex, out List list)) { // list = new List(); // layersDict[layerIndex] = list; // } // list.Add(nodeId); // } // // Determine sorted layer indices // List layerIndices = new(layersDict.Keys); // layerIndices.Sort(); // ascending order // // Build final List> in sorted order // List> layers = new(); // foreach (int idx in layerIndices) { // layers.Add(layersDict[idx]); // } float hSpacing = 100f; float vSpacing = 100f; // Place nodes: x increases with layer index, y spaced within layer for (int layerIx = 0; layerIx < layers.Count; layerIx++) { List nodeList = layers[layerIx]; float totalHeight = (nodeList.Count - 1) * vSpacing; for (int i = 0; i < nodeList.Count; i++) { int index = nodeList[i]; DagNode node = GetNodeById(index); if (node == null) continue; float x = hSpacing + layerIx * hSpacing; float y = 400 - totalHeight / 2f + i * vSpacing; // Debug.Log($"({li}, {i}) -> {x}, {y}"); node.position = new Vector2(x, y); } } Repaint(); } static Rect RectUnion(Rect a, Rect b) { float xMin = Mathf.Min(a.xMin, b.xMin); float xMax = Mathf.Max(a.xMax, b.xMax); float yMin = Mathf.Min(a.yMin, b.yMin); float yMax = Mathf.Max(a.yMax, b.yMax); return Rect.MinMaxRect(xMin, yMin, xMax, yMax); } Rect GetGraphBounds() { if (nodes == null || nodes.Count == 0) return new Rect(Vector2.zero, Vector2.one); Rect bounds = new( nodes[0].position - Vector2.one * nodes[0].radius, 2f * nodes[0].radius * Vector2.one); foreach (var n in nodes) bounds = RectUnion(bounds, new Rect(n.position - Vector2.one * n.radius, 2f * n.radius * Vector2.one)); return bounds; } } }