using UnityEngine; using UnityEditor; using System.Collections.Generic; using System.Linq; // Simple DAG data model [System.Serializable] public class DagNode { public int id; public string title; public Vector2 position; public float radius = 20f; // circle radius } [System.Serializable] public class DagEdge { public int fromId; public int toId; } public class BrainEditorWindow : EditorWindow { readonly List nodes = new(); readonly List edges = new(); Vector2 pan = Vector2.zero; float zoom = 1.0f; const float minZoom = 0.5f; const float maxZoom = 2.0f; // Vector2 dragStart; // bool draggingNode = false; // int draggingNodeId = -1; private readonly System.Type acceptedType = typeof(ClusterPrefab); [MenuItem("Window/Brain Viewer")] public static void ShowWindow() { var w = GetWindow("Brain Viewer"); w.minSize = new Vector2(500, 300); } void OnEnable() { // if (nodes.Count == 0) // CreateSampleGraph(); // Register callback so window updates when selection changes Selection.selectionChanged += OnSelectionChanged; RefreshSelection(); ComputeLeftToRightLayout(); } private void OnDisable() { Selection.selectionChanged -= OnSelectionChanged; } private void OnSelectionChanged() { RefreshSelection(); ComputeLeftToRightLayout(); Repaint(); } private void RefreshSelection() { ClusterPrefab prefab = Selection.activeObject as ClusterPrefab; if (prefab != null && acceptedType.IsAssignableFrom(prefab.GetType())) { GenerateGraph(prefab); } } private void GenerateGraph(ClusterPrefab prefab) { nodes.Clear(); edges.Clear(); int ix = 0; foreach (Nucleus nucleus in prefab.nuclei) { nodes.Add(new DagNode() { id = ix, title = nucleus.name }); if (nucleus is Neuron neuron) { foreach (Nucleus receiver in neuron.receivers) { int receiverIx = prefab.GetNucleusIndex(receiver); edges.Add(new DagEdge() { fromId = ix, toId = receiverIx }); } } ix++; } } // void CreateSampleGraph() { // nodes.Clear(); // edges.Clear(); // nodes.Add(new DagNode() { id = 0, title = "In1" }); // nodes.Add(new DagNode() { id = 1, title = "In2" }); // nodes.Add(new DagNode() { id = 2, title = "A" }); // nodes.Add(new DagNode() { id = 3, title = "B" }); // nodes.Add(new DagNode() { id = 4, title = "C" }); // nodes.Add(new DagNode() { id = 5, title = "Out1" }); // nodes.Add(new DagNode() { id = 6, title = "Out2" }); // edges.Add(new DagEdge() { fromId = 0, toId = 2 }); // edges.Add(new DagEdge() { fromId = 1, toId = 2 }); // edges.Add(new DagEdge() { fromId = 2, toId = 3 }); // edges.Add(new DagEdge() { fromId = 2, toId = 4 }); // edges.Add(new DagEdge() { fromId = 3, toId = 5 }); // edges.Add(new DagEdge() { fromId = 4, toId = 6 }); // } void OnGUI() { HandleInput(); Rect rect = new(0, 0, position.width, position.height); EditorGUI.DrawRect(rect, new Color(0.11f, 0.11f, 0.11f)); // compute window center Vector2 windowCenter = new(position.width / 2f, position.height / 2f); // compute graph bounds center (in graph space) Rect bounds = GetGraphBounds(); Vector2 graphCenter = bounds.center; // compute autoPan that recenters the graph (does not modify node positions) Vector2 autoPan = -graphCenter; // moves graph center to origin // total translation = windowCenter + autoPan + user pan Matrix4x4 oldMatrix = GUI.matrix; GUI.matrix = Matrix4x4.TRS(windowCenter + autoPan + pan, Quaternion.identity, Vector3.one * zoom) * Matrix4x4.TRS(-windowCenter, Quaternion.identity, Vector3.one); // Draw edges first foreach (DagEdge e in edges) { DagNode from = GetNodeById(e.fromId); DagNode to = GetNodeById(e.toId); if (from == null || to == null) continue; DrawEdgeCircleNodes(from, to); } // Draw nodes (circles) foreach (DagNode n in nodes) DrawNucleus(n); GUI.matrix = oldMatrix; // Footer toolbar GUILayout.FlexibleSpace(); EditorGUILayout.BeginHorizontal(EditorStyles.toolbar); if (GUILayout.Button("Fit", EditorStyles.toolbarButton)) FitToView(); if (GUILayout.Button("Layout LR", EditorStyles.toolbarButton)) ComputeLeftToRightLayout(); EditorGUILayout.EndHorizontal(); } void HandleInput() { Event e = Event.current; // Zoom with scroll if (e.type == EventType.ScrollWheel) { float oldZoom = zoom; float delta = -e.delta.y * 0.01f; zoom = Mathf.Clamp(zoom + delta, minZoom, maxZoom); Vector2 mouse = e.mousePosition; pan += (mouse - new Vector2(position.width / 2, position.height / 2)) * (1 - zoom / oldZoom); e.Use(); } // Pan with middle or right+ctrl drag if (e.type == EventType.MouseDrag && (e.button == 2 || (e.button == 1 && e.control))) { pan += e.delta; e.Use(); } } DagNode GetNodeById(int id) => nodes.FirstOrDefault(x => x.id == id); List GetIncomingEdges(DagNode node) { List incoming = new(); foreach (DagEdge e in edges) { if (e.toId == node.id) incoming.Add(e); } return incoming; } List GetOutgoingEdges(DagNode node) { List outgoing = new(); foreach (DagEdge e in edges) { if (e.fromId == node.id) outgoing.Add(e); } return outgoing; } void DrawNucleus(DagNode n) { Vector3 position = n.position; Handles.color = Color.white * 0.9f; Handles.DrawSolidDisc(n.position, Vector3.forward, n.radius); if (GetIncomingEdges(n).Count == 0) DrawArrowHead(n.position - new Vector2(n.radius + 10, 0), n.position - new Vector2(n.radius + 5, 0), 10f / zoom, 12f / zoom, Color.white); if (GetOutgoingEdges(n).Count == 0) DrawArrowHead(n.position + new Vector2(n.radius + 10, 0), n.position + new Vector2(n.radius + 15, 0), 10f / zoom, 12f / zoom, Color.white); Handles.color = Color.white; GUIStyle style = new(EditorStyles.label) { alignment = TextAnchor.UpperCenter, normal = { textColor = Color.white }, fontStyle = FontStyle.Bold, }; Vector3 labelPos = position - Vector3.down * (n.radius + 10f); // below disc along up axis Handles.Label(labelPos, n.title, style); } void DrawEdgeCircleNodes(DagNode from, DagNode to) { Vector2 a = from.position; Vector2 b = to.position; if (a == b) return; Handles.color = Color.white * 0.9f; Handles.DrawLine(from.position, to.position); // Vector2 dir = (b - a).normalized; // Vector2 start = a + dir * from.radius; // Vector2 end = b - dir * to.radius; //DrawArrowHead(end - dir * 2f, end, 10f / zoom, 12f / zoom, Color.white); } void DrawArrowHead(Vector2 from, Vector2 to, float headWidth, float headLength, Color color) { Vector2 dir = (to - from).normalized; if (dir == Vector2.zero) return; Vector2 right = new Vector2(-dir.y, dir.x); Vector3 p1 = to; Vector3 p2 = to - dir * headLength + right * headWidth * 0.5f; Vector3 p3 = to - dir * headLength - right * headWidth * 0.5f; Handles.color = color; Handles.DrawAAConvexPolygon(p1, p2, p3); } // Left-to-right layered layout (sources on the left, sinks on the right) void ComputeLeftToRightLayout() { // build adjacency and indegree var adj = nodes.ToDictionary(n => n.id, n => new List()); var indeg = nodes.ToDictionary(n => n.id, n => 0); foreach (var e in edges) { if (!adj.ContainsKey(e.fromId) || !adj.ContainsKey(e.toId)) continue; adj[e.fromId].Add(e.toId); indeg[e.toId]++; } // Kahn's algorithm to compute topological layers (horizontal layers) Dictionary layer = new(); Queue q = new(indeg.Where(kv => kv.Value == 0).Select(kv => kv.Key)); foreach (var id in q) layer[id] = 0; while (q.Count > 0) { int u = q.Dequeue(); int l = layer[u]; foreach (var v in adj[u]) { // prefer placing v at least one layer after u if (!layer.ContainsKey(v) || layer[v] < l + 1) layer[v] = l + 1; indeg[v]--; if (indeg[v] == 0) q.Enqueue(v); } } // Any unreachable nodes -> assign next layers int maxLayer = layer.Count > 0 ? layer.Values.Max() : 0; foreach (var n in nodes) { if (!layer.ContainsKey(n.id)) { maxLayer++; layer[n.id] = maxLayer; } } // Group nodes by layer (left to right) var layers = layer.GroupBy(kv => kv.Value).OrderBy(g => g.Key).Select(g => g.Select(x => x.Key).ToList()).ToList(); // Layout parameters (horizontal spacing drives left->right) float hSpacing = 150f; float vSpacing = 100f; // Place nodes: x increases with layer index, y spaced within layer for (int li = 0; li < layers.Count; li++) { var lst = layers[li]; float totalHeight = (lst.Count - 1) * vSpacing; for (int i = 0; i < lst.Count; i++) { int id = lst[i]; var n = GetNodeById(id); if (n == null) continue; float x = hSpacing + li * hSpacing; float y = 400 - totalHeight / 2f + i * vSpacing; // Debug.Log($"({li}, {i}) -> {x}, {y}"); n.position = new Vector2(x, y); } } Repaint(); } void FitToView() { if (nodes.Count == 0) return; // compute bounds including radii Rect bounds = new Rect(nodes[0].position - Vector2.one * nodes[0].radius, Vector2.one * nodes[0].radius * 2f); foreach (var n in nodes) bounds = RectUnion(bounds, new Rect(n.position - Vector2.one * n.radius, Vector2.one * n.radius * 2f)); // center graph at origin (0,0) then set pan so it appears centered in window Vector2 graphCenter = bounds.center; // move nodes so center is at origin for (int i = 0; i < nodes.Count; i++) nodes[i].position -= graphCenter; // reset pan/zoom so centered pan = Vector2.zero; zoom = 1.0f; Repaint(); } static Rect RectUnion(Rect a, Rect b) { float xMin = Mathf.Min(a.xMin, b.xMin); float xMax = Mathf.Max(a.xMax, b.xMax); float yMin = Mathf.Min(a.yMin, b.yMin); float yMax = Mathf.Max(a.yMax, b.yMax); return Rect.MinMaxRect(xMin, yMin, xMax, yMax); } Vector2 ScreenToGraph_old(Vector2 screenPos) { Vector2 origin = new Vector2(position.width / 2, position.height / 2); // invert the GUI.matrix transform (approx for current simple transforms) return (screenPos - (origin + pan)) / zoom + origin * (1 - 1 / zoom); } Vector2 ScreenToGraph(Vector2 screenPos) { Vector2 windowCenter = new Vector2(position.width / 2f, position.height / 2f); Rect bounds = GetGraphBounds(); Vector2 graphCenter = bounds.center; Vector2 autoPan = -graphCenter; // inverse of: screen -> translate by -(windowCenter+autoPan+pan), scale by 1/zoom, translate by windowCenter return (screenPos - (windowCenter + autoPan + pan)) / zoom + windowCenter; } Rect GetGraphBounds() { if (nodes == null || nodes.Count == 0) return new Rect(Vector2.zero, Vector2.one); Rect bounds = new( nodes[0].position - Vector2.one * nodes[0].radius, 2f * nodes[0].radius * Vector2.one); foreach (var n in nodes) bounds = RectUnion(bounds, new Rect(n.position - Vector2.one * n.radius, 2f * n.radius * Vector2.one)); return bounds; } int HitTestNode(Vector2 graphPos) { // returns node id under point or -1 for (int i = nodes.Count - 1; i >= 0; i--) { var n = nodes[i]; if ((graphPos - n.position).sqrMagnitude <= n.radius * n.radius) return n.id; } return -1; } }