NanoBrain-unitypackage/Editor/ClusterViewer.cs
Pascal Serrarens 4ae9a15fc6 Squashed 'NanoBrain/' changes from 832d849..cc9a845
cc9a845 Fix sleeping for product combinator
e4ba7f8 Better cross-cluster monitoring
4f8a6ab Improved (but not fixed) cross-cluster monitoring
b12616b Fix neuron output visualisation
96439cc Visualize all outputs
d583e67 WIP cluster references/instance
04bab92 Fix links to multiple cluster neurons & cleanup
e17a249 Cross-cluster editor links
0ab2d21 Migrating and cleaning up
b6630ad First steps to using instanceCount for clusters
8801fa2 Cluster reimport fixes
befb69d full graph with collapsed clusters
1a1919f Fix expansion of clsuter arrays
c708f4d Improved clusterarray support
c2e4e1b Fix Cluster array extension
02047a4 Adde full graph scrollbar
471ed36 Completed full graph integration
830e3e7 Added full graph view mode
249e888 Improve full graph view
308a6a1 The Entities are battling
75d9d1c Cleanup
c8f0f0c Fix aging of neurons
e2e169c small fixes
619ced6 Removed the use of Receptors
19f9296 Simplifications
bc0a796 Integrated clusterarray in cluster
e40dd23 Fixed clusterViewer for clusterarrays
b0f4b41 Status quo adding clusterArrays
1fc75a8 Added ClusterArray
0023920 Cover seeking(-ish) behaviour
1c7b8e7 Added Tanh Activation
a99d40c BrainViewer added
db43655 Pew pew!
18ef4cd Merge commit '89017475984bbbf1899fb38846c5bb0e7775dedd' into NanoBrain

git-subtree-dir: NanoBrain
git-subtree-split: cc9a845b643ffb4a9abe4f7da787ac5c5b14dae8
2026-04-23 15:22:02 +02:00

950 lines
41 KiB
C#

using System.Collections.Generic;
using System.Linq;
using UnityEditor;
using UnityEngine;
using UnityEngine.UIElements;
namespace NanoBrain {
public class ClusterViewer : Editor {
public static ClusterPrefab previousPrefab;
public class GraphView : VisualElement {
//protected readonly ClusterPrefab prefab;
protected Cluster currentCluster;
protected SerializedObject serializedBrain;
protected Nucleus currentNucleus;
protected Nucleus selectedOutput;
protected GameObject gameObject;
private bool expandArray = false;
protected ClusterPrefab prefabAsset;
protected VisualElement topMenuContainer;
protected ScrollView scrollView;
protected IMGUIContainer graphContainer;
protected readonly PopupField<string> outputsPopup;
public enum Mode {
Focus,
Full
}
public Mode mode = Mode.Focus;
public GraphView(Cluster cluster) {
this.currentCluster = cluster;
name = "content";
style.flexGrow = 1;
topMenuContainer = new() {
style = {
flexDirection = FlexDirection.Row,
alignItems = Align.Center,
}
};
EnumField modePopup = new(mode);
modePopup.style.width = 80;
modePopup.RegisterValueChangedCallback(OnModeChange);
topMenuContainer.Add(modePopup);
scrollView = new(ScrollViewMode.Horizontal);
scrollView.style.position = Position.Absolute;
scrollView.style.left = 0; scrollView.style.top = 0;
scrollView.style.right = 0; scrollView.style.bottom = 0;
//scrollView.style.flexGrow = 1;
scrollView.horizontalScrollerVisibility = ScrollerVisibility.Auto; // Auto shows when needed
scrollView.verticalScrollerVisibility = ScrollerVisibility.Hidden;
graphContainer = new(OnIMGUI);
//graphContainer.style.position = Position.Relative; // or omit this line
//graphContainer.style.position = Position.Absolute;
// graphContainer.style.left = 0; graphContainer.style.top = 0;
// graphContainer.style.right = 0; graphContainer.style.bottom = 0;
graphContainer.pickingMode = PickingMode.Position;
graphContainer.focusable = true;
//graphContainer.style.width = 1200;
//graphContainer.style.width = new StyleLength(StyleKeyword.Null); // allow content to determine width
scrollView.contentContainer.Add(graphContainer);
Add(scrollView);
Add(topMenuContainer);
// Subscribe when added to panel (editor UI ready)
RegisterCallback<AttachToPanelEvent>(evt => Subscribe());
RegisterCallback<DetachFromPanelEvent>(evt => Unsubscribe());
}
protected virtual void OnModeChange(ChangeEvent<System.Enum> changeEvent) {
this.mode = (Mode)changeEvent.newValue;
}
bool subscribed = false;
void Subscribe() {
if (subscribed) return;
SceneView.duringSceneGui += OnSceneGUI;
subscribed = true;
SceneView.RepaintAll();
}
void Unsubscribe() {
if (!subscribed) return;
SceneView.duringSceneGui -= OnSceneGUI;
subscribed = false;
}
public void SetGraph(GameObject gameObject) {
this.gameObject = gameObject;
if (Application.isPlaying == false)
this.serializedBrain = new SerializedObject(this.currentCluster.prefab);
this.selectedOutput = this.currentCluster.outputs[0];
this.currentNucleus = this.selectedOutput;
Rebuild();
}
void Rebuild() {
if (this.currentNucleus == null)
return;
string path = AssetDatabase.GetAssetPath(this.currentCluster.prefab); // or known path
this.prefabAsset = AssetDatabase.LoadAssetAtPath<ClusterPrefab>(path);
if (this.prefabAsset == null) {
// create in memory save if it doesn't exist
this.prefabAsset = CreateInstance<ClusterPrefab>();
//Debug.LogError("Cluster Prefab is not found on disk");
}
}
public void OnIMGUI() {
if (Application.isPlaying == false)
serializedBrain.Update();
Handles.BeginGUI();
DrawGraph();
Handles.EndGUI();
}
#region Graph
protected virtual void DrawGraph() {
if (mode == Mode.Focus)
DrawFocusGraph();
else
DrawFullGraph();
}
#region Full Graph
protected void DrawFullGraph() {
//Dag dag = GenerateGraph(this.prefab);
Dag dag = GenerateGraph(this.selectedOutput);
Dag.ComputeLayout(dag);
// Draw edges
foreach (Dag.Edge e in dag.edges) {
Dag.Node from = dag.nodes.FirstOrDefault(x => x.id == e.fromId);
Dag.Node to = dag.nodes.FirstOrDefault(x => x.id == e.toId);
if (from == null || to == null)
continue;
Vector2 fromPosition = from.position;
Vector2 toPosition = to.position;
DrawEdge(fromPosition, toPosition);
}
// Draw nodes
foreach (Dag.Node n in dag.nodes)
DrawNucleus(n.nucleus, n.position, 1, n.radius);
// Determine graph width
float width = 0;
float currentNucleusPosition = 0;
foreach (Dag.Node node in dag.nodes) {
if (node.position.x > width)
width = node.position.x;
if (node.nucleus == currentNucleus)
currentNucleusPosition = node.position.x;
}
// Resize the graph container to the full graph width
float margin = 50f;
graphContainer.style.width = width + 2 * margin;
// Scroll to the current nucleus
float viewportWidth = scrollView.layout.width;
// center currentNucleus in viewport
float desiredScrollX = currentNucleusPosition - viewportWidth * 0.5f;
// clamp between 0 and maximum scrollable range
float maxScrollX = Mathf.Max(0f, graphContainer.resolvedStyle.width - viewportWidth);
desiredScrollX = Mathf.Clamp(desiredScrollX, 0f, maxScrollX);
Vector2 current = scrollView.scrollOffset;
scrollView.scrollOffset = new Vector2(desiredScrollX, current.y);
}
public Dag GenerateGraph(Nucleus rootNucleus) {
Dag dag = new();
if (rootNucleus == null)
return dag;
int ix = 0;
Dag.Node receiver = new() {
id = ix,
//title = nucleus.name,
nucleus = rootNucleus
};
dag.nodes.Add(receiver);
ix++;
DescendGraph(receiver, ref ix, dag);
return dag;
}
private void DescendGraph(Dag.Node receiver, ref int ix, Dag dag) {
foreach (Synapse synapse in receiver.nucleus.synapses) {
Nucleus nucleus = synapse.neuron;
if (nucleus.parent != null && nucleus.parent != currentNucleus.parent) {
nucleus = nucleus.parent;
}
string nucleusName = nucleus.name;
Dag.Node synapseNode = dag.FindNode(nucleusName);
if (synapseNode == null) {
synapseNode = new() {
id = ix,
nucleus = nucleus
};
dag.nodes.Add(synapseNode);
}
Dag.Edge edge = new() {
fromId = synapseNode.id,
toId = receiver.id
};
dag.edges.Add(edge);
ix++;
DescendGraph(synapseNode, ref ix, dag);
}
}
#endregion Full Graph
#region Focus Graph
protected void DrawFocusGraph() {
float size = 20;
Vector3 position = new(150, 210, 0);
if (this.currentNucleus != null) {
DrawReceivers(this.currentNucleus, position, size);
DrawSynapses(this.currentNucleus, position, size);
// Draw selected Nucleus
if (expandArray) {
float maxValue = 1;
if (this.currentNucleus is Cluster cluster) {
float spacing = 400f / cluster.instanceCount;
float margin = 10 + spacing / 2;
float xMin = 150 - size;
float xMax = 150 + size;
float yMin = 10 + margin - size / 2;
float yMax = 400 - margin + size;
Vector3[] verts = new Vector3[4] {
new(xMin, yMin, 0),
new(xMax, yMin, 0),
new(xMax, yMax, 0),
new(xMin, yMax, 0)
};
Handles.color = Color.black;
Handles.DrawAAConvexPolygon(verts);
int row = 0;
if (cluster.siblingClusters == null) {
Vector3 pos = new(150, margin + row * spacing, 0.0f);
Handles.color = Color.white;
// The selected sibling highlight ring
Handles.DrawSolidDisc(pos, Vector3.forward, size + 2);
DrawNucleus(cluster, pos, maxValue, size);
row++;
}
else {
foreach (Cluster sibling in cluster.siblingClusters) {
Vector3 pos = new(150, margin + row * spacing, 0.0f);
Handles.color = Color.white;
// The selected sibling highlight ring
Handles.DrawSolidDisc(pos, Vector3.forward, size + 2);
DrawNucleus(sibling, pos, maxValue, size);
row++;
}
}
GUIStyle style = new(EditorStyles.label) {
alignment = TextAnchor.UpperCenter,
normal = { textColor = Color.white },
fontStyle = FontStyle.Bold,
};
Vector3 labelPos = new(150, yMax + size + 5, 0);
string clusterName = cluster.name;
int colonPos = clusterName.IndexOf(":");
if (colonPos > 0) {
string baseName = clusterName[..colonPos];
Handles.Label(labelPos, baseName, style);
}
else
Handles.Label(labelPos, clusterName, style);
}
else {
if (this.currentNucleus is Neuron neuron)
maxValue = neuron.outputMagnitude;
DrawNucleus(this.currentNucleus, position, maxValue, 20);
}
}
else {
float maxValue = 1;
if (this.currentNucleus is Neuron neuron)
maxValue = neuron.outputMagnitude;
else if (this.currentNucleus is Cluster cluster)
maxValue = cluster.defaultOutput.outputMagnitude;
DrawNucleus(this.currentNucleus, position, maxValue, 20);
}
}
else {
DrawAllOutputs(position, size);
DrawOutputs(position, size);
}
graphContainer.style.width = 300;
}
protected void DrawReceivers(Nucleus nucleus, Vector3 parentPos, float size) {
List<Nucleus> receivers;
if (nucleus is Neuron neuron)
receivers = neuron.receivers;
else if (nucleus is Cluster cluster)
receivers = cluster.CollectReceivers();
else
return;
// For top-level nodes, add link to previous editor and/or 'Outputs'
int nodeCount = receivers.Count();
if (nucleus == this.selectedOutput) {
// Add link to 'Outpus'
nodeCount++;
if (ClusterViewer.previousPrefab != null)
// Add link to previous editor
nodeCount++;
}
// Determine the maximum value in this layer
// This is used to 'scale' the output value colors of the nuclei
float maxValue = 0;
foreach (Nucleus receiver in receivers) {
if (receiver is Neuron neuroid) {
float value = neuroid.outputMagnitude;
if (value > maxValue)
maxValue = value;
}
}
// Determine the spacing of the nuclei in the layer
float spacing = 400f / nodeCount;
float margin = 10 + spacing / 2;
int row = 0;
List<Nucleus[]> drawnArrays = new();
foreach (Nucleus receiver in receivers) {
Nucleus receiverNucleus = receiver;
if (receiverNucleus == null)
continue;
Vector3 pos = new(50, margin + row * spacing, 0.0f);
DrawEdge(parentPos, pos);
DrawNucleus(receiverNucleus, pos, maxValue, size);
row++;
}
if (nucleus == this.selectedOutput) {
Vector3 pos = new(50, margin + row * spacing, 0);
if (ClusterViewer.previousPrefab != null) {
DrawEdge(parentPos, pos);
DrawClusterPrefab(ClusterViewer.previousPrefab, pos, size);
row++;
}
pos = new(50, margin + row * spacing, 0);
DrawEdge(parentPos, pos);
DrawAllOutputs(pos, size);
}
}
protected void DrawSynapses(Nucleus nucleus, Vector3 parentPos, float size) {
if (nucleus == null)
return;
// Determine the maximum value in this layer
// This is used to 'scale' the output value colors of the nuclei
float maxValue = 0;
int neuronCount = 0;
List<Neuron> drawnNeurons = new();
foreach (Synapse synapse in nucleus.synapses) {
if (synapse.neuron == null)
continue;
// Count multiple synapses to the same neuron only once
if (drawnNeurons.Contains(synapse.neuron))
continue;
drawnNeurons.Add(synapse.neuron);
float value = synapse.neuron.outputMagnitude * synapse.weight;
if (value > maxValue)
maxValue = value;
neuronCount++;
}
// Determine the spacing of the nuclei in the layer
float spacing = 400f / neuronCount;
float margin = 10 + spacing / 2;
int row = 0;
drawnNeurons = new();
foreach (Synapse synapse in nucleus.synapses) {
if (synapse.neuron is null)
continue;
// Draw multiple synapses to the same neuron only once
if (drawnNeurons.Contains(synapse.neuron))
continue;
drawnNeurons.Add(synapse.neuron);
Vector3 pos = new(250, margin + row * spacing, 0.0f);
DrawEdge(parentPos, pos);
// Handles.color = Color.white;
// Handles.DrawLine(parentPos, pos);
Color color = Color.black;
if (Application.isPlaying) {
if (maxValue == 0 || !float.IsFinite(maxValue))
maxValue = 1;
float brightness = synapse.neuron.outputMagnitude * synapse.weight / maxValue;
color = new Color(brightness, brightness, brightness, 1f);
}
DrawNucleus(synapse.neuron, pos, size, color);
row++;
}
}
protected void DrawOutputs(Vector2 parentPos, float size) {
// Determine the maximum value in this layer
// This is used to 'scale' the output value colors of the nuclei
float maxValue = 0;
int neuronCount = 0;
List<Nucleus> drawnNuclei = new();
foreach (Nucleus nucleus in this.currentCluster.outputs) {
if (nucleus is not Neuron neuron)
continue;
// Draw multiple synapses to the same neuron only once
if (drawnNuclei.Contains(nucleus))
continue;
drawnNuclei.Add(nucleus);
float value = neuron.outputMagnitude;
if (value > maxValue)
maxValue = value;
neuronCount++;
}
// Determine the spacing of the nuclei in the layer
float spacing = 400f / neuronCount;
float margin = 10 + spacing / 2;
int row = 0;
drawnNuclei = new();
foreach (Nucleus nucleus in this.currentCluster.outputs) {
if (nucleus is not Neuron neuron)
continue;
// Draw multiple synapses to the same neuron only once
if (drawnNuclei.Contains(nucleus))
continue;
drawnNuclei.Add(nucleus);
Vector3 pos = new(250, margin + row * spacing, 0.0f);
DrawEdge(parentPos, pos);
Color color = Color.black;
if (Application.isPlaying) {
if (maxValue == 0 || !float.IsFinite(maxValue))
maxValue = 1;
float brightness = neuron.outputMagnitude / maxValue;
color = new Color(brightness, brightness, brightness, 1f);
}
DrawNucleus(nucleus, pos, size, color);
row++;
}
}
#endregion Focus Graph
protected void DrawNucleus(Nucleus nucleus, Vector3 position, float maxValue, float size) {
Color color;
if (Application.isPlaying) {
float brightness = 0;
if (nucleus is Neuron neuron)
brightness = neuron.outputMagnitude / maxValue;
color = new Color(brightness, brightness, brightness, 1f);
}
else
color = Color.black;
DrawNucleus(nucleus, position, size, color);
}
protected void DrawNucleus(Nucleus nucleus, Vector3 position, float size, Color color) {
if (nucleus == null)
return;
if (nucleus == this.currentNucleus) {
// The selected nucleus highlight ring
Handles.color = Color.white;
Handles.DrawSolidDisc(position, Vector3.forward, size + 2);
}
if (nucleus is MemoryCell) {
Handles.color = Color.white;
Handles.DrawWireDisc(position + Vector3.right * 10, Vector3.forward, size);
}
Handles.color = color;
Handles.DrawSolidDisc(position, Vector3.forward, size);
Handles.color = Color.white;
// Position the label in front of the disc
Vector3 labelPosition = position + (Vector3.forward * 0.1f);
GUIStyle style = new(EditorStyles.label) {
alignment = TextAnchor.MiddleCenter,
normal = { textColor = Color.white },
fontStyle = FontStyle.Bold,
};
if (nucleus.parent is Cluster parentCluster && currentNucleus != null && parentCluster != currentNucleus.parent)
DrawCluster(parentCluster, position, color, size);
else if (nucleus is Cluster cluster)
DrawCluster(cluster, position, color, size);
if (expandArray == false || nucleus != currentNucleus) {
// put name below nucleus
Vector3 labelPos = position - Vector3.down * (size + 5); // below neuron
style.alignment = TextAnchor.UpperCenter;
if (nucleus.parent != null && currentNucleus != null && nucleus.parent != currentNucleus.parent && nucleus.parent is Cluster parentCluster1) {
// This neuron is part of another cluster
parentCluster1.name ??= "";
string baseName = "";
int colonPos = parentCluster1.name.IndexOf(":");
if (colonPos > 0 && colonPos < parentCluster1.name.Length - 2)
baseName = parentCluster1.name[..colonPos] + ".";
else
baseName = parentCluster1.name + ".";
// if (colonPos > 0 && colonPos < parentCluster1.name.Length - 2) {
// // if it is an array, we should not show the :0 of the first element
// //baseName = baseName[..colonPos];
// Handles.Label(labelPos, baseName + nucleus.name, style);
// }
// else
Handles.Label(labelPos, baseName + nucleus.name, style);
}
else {
nucleus.name ??= "";
int colonPos = nucleus.name.IndexOf(":");
if (colonPos > 0 && colonPos < nucleus.name.Length - 2) {
// if it is an array, we should not show the :0 of the first element
string baseName = nucleus.name[..colonPos];
Handles.Label(labelPos, baseName, style);
}
else
Handles.Label(labelPos, nucleus.name, style);
}
}
// Tooltip
Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2);
int id = GUIUtility.GetControlID(FocusType.Passive);
Event e = Event.current;
EventType et = e.GetTypeForControl(id);
if (e != null && neuronRect.Contains(e.mousePosition)) {
// Process Hover
HandleMouseHover(nucleus, neuronRect);
// Process click
if (e.type == EventType.MouseDown && e.button == 0) {
// Consume the event so the scene doesn't also handle it
e.Use();
if (nucleus is Cluster parentCluster2)
OnNeuronClick(parentCluster2);
else
OnNeuronClick(nucleus);
}
}
}
protected void DrawCluster(Cluster cluster, Vector3 position, Color color, float size) {
GUIStyle labelTextStyle = new(EditorStyles.label) {
normal = { textColor = Color.white },
fontStyle = FontStyle.Bold,
};
if (expandArray) {
// Put array indices above the discs
labelTextStyle.alignment = TextAnchor.LowerCenter;
Vector3 labelPosition = position + Vector3.down * (size + 5); // below disc
// Strip the instance number in the name
int colonPos1 = cluster.name.IndexOf(":");
if (colonPos1 > 0) {
string extName = cluster.name[(colonPos1 + 2)..];
Handles.Label(labelPosition, extName, labelTextStyle);
}
else
Handles.Label(labelPosition, "0", labelTextStyle);
}
else {
// Put instance count inside the disc
labelTextStyle.alignment = TextAnchor.MiddleCenter;
Vector3 labelPosition = position + (Vector3.forward * 0.1f);
// Adjust text color based on disc color
if (color.grayscale > 0.5f)
labelTextStyle.normal.textColor = Color.black;
else
labelTextStyle.normal.textColor = Color.white;
if (cluster.instanceCount > 1) {
Handles.Label(labelPosition, cluster.instanceCount.ToString(), labelTextStyle);
labelTextStyle.normal.textColor = Color.white;
}
else if (cluster.siblingClusters != null && cluster.siblingClusters.Length > 1) {
Handles.Label(labelPosition, cluster.siblingClusters.Length.ToString(), labelTextStyle);
labelTextStyle.normal.textColor = Color.white;
}
}
// Draw a circle around the disc to indicate this is a Cluster
Handles.color = Color.white;
Handles.DrawWireDisc(position, Vector3.forward, size + 5);
}
protected void DrawClusterPrefab(ClusterPrefab prefab, Vector2 position, float size) {
Handles.color = Color.black;
Handles.DrawSolidDisc(position, Vector3.forward, size);
// Draw a circle around the disc to indicate this is a Cluster
Handles.color = Color.white;
Handles.DrawWireDisc(position, Vector3.forward, size + 5);
// put name below nucleus
GUIStyle style = new(EditorStyles.label) {
alignment = TextAnchor.MiddleCenter,
normal = { textColor = Color.white },
fontStyle = FontStyle.Bold,
};
Vector2 labelPos = position - Vector2.down * (size + 5); // below neuron
style.alignment = TextAnchor.UpperCenter;
Handles.Label(labelPos, prefab.name, style);
Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2);
int id = GUIUtility.GetControlID(FocusType.Passive);
Event e = Event.current;
EventType et = e.GetTypeForControl(id);
if (e != null && neuronRect.Contains(e.mousePosition)) {
// Process click
if (e.type == EventType.MouseDown && e.button == 0) {
// Consume the event so the scene doesn't also handle it
e.Use();
Selection.activeObject = prefab;
EditorGUIUtility.PingObject(prefab);
ClusterViewer.previousPrefab = null;
CreateEditor(prefab);
}
}
}
protected void DrawAllOutputs(Vector2 position, float size) {
GUIStyle labelTextStyle = new(EditorStyles.label) {
normal = { textColor = Color.white },
fontStyle = FontStyle.Bold,
alignment = TextAnchor.MiddleCenter,
};
Handles.Label(position, "Outputs", labelTextStyle);
Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2);
Event e = Event.current;
if (e != null && neuronRect.Contains(e.mousePosition)) {
// Process click
if (e.type == EventType.MouseDown && e.button == 0) {
// Consume the event so the scene doesn't also handle it
e.Use();
OnAllOutputsClick();
}
}
}
protected void DrawEdge(Vector2 from, Vector2 to, float radius = 20) {
Handles.color = Color.white;
// Handles.DrawLine(from, to);
Vector2 dir = to - from;
float len = dir.magnitude;
if (len <= 2f * radius || len <= Mathf.Epsilon)
// line too short
return;
Vector2 n = dir / len; // normalized
Vector2 a = from + n * radius;
Vector2 b = to - n * radius;
Handles.DrawLine(a, b);
}
protected void HandleMouseHover(Nucleus nucleus, Rect rect) {
GUIContent tooltip;
if (nucleus is Neuron neuron) {
tooltip = new(
$"{nucleus.name}" +
$"\nValue: {neuron.outputMagnitude}");
}
else
tooltip = new($"{nucleus.name}");
Vector2 mousePosition = Event.current.mousePosition;
// Display tooltip with some offset
Vector2 tooltipSize = GUI.skin.box.CalcSize(tooltip);
Rect tooltipRect = new Rect(mousePosition.x + 10, mousePosition.y + 10, tooltipSize.x, tooltipSize.y);
GUI.Box(tooltipRect, tooltip);
}
protected void OnNeuronClick(Nucleus nucleus) {
if (nucleus == this.currentNucleus) {
if (Application.isPlaying) {
if (nucleus is Cluster)
expandArray = !expandArray;
else
expandArray = false;
}
else {
if (nucleus is Cluster cluster)
OnClusterClick(cluster);
}
}
else if (nucleus.parent != null && this.currentNucleus != null && nucleus.parent != this.currentNucleus.parent) {
// We go to a different cluster
if (Application.isPlaying) {
this.currentNucleus = nucleus;
if (this.currentNucleus is Neuron neuron && neuron.receivers.Count == 0)
this.selectedOutput = this.currentNucleus;
expandArray = false;
}
else {
// select the cluster, not the neuron in the cluster
this.currentNucleus = nucleus.parent;
expandArray = false;
}
}
else {
this.currentNucleus = nucleus;
if (this.currentNucleus is Neuron neuron && neuron.receivers.Count == 0)
this.selectedOutput = this.currentNucleus;
expandArray = false;
}
}
protected void OnClusterClick(Cluster subCluster) {
// May be used with storedPrefab...
Selection.activeObject = subCluster.prefab;
EditorGUIUtility.PingObject(subCluster.prefab);
ClusterViewer.previousPrefab = this.currentCluster.prefab;
ClusterEditor newEditor = CreateEditor(subCluster.prefab) as ClusterEditor;
}
protected void OnAllOutputsClick() {
this.currentNucleus = null;
this.selectedOutput = null;
expandArray = false;
}
#endregion Graph
void OnSceneGUI(SceneView sceneView) {
if (this.gameObject != null) {
// if (this.currentNucleus is IReceptor receptor) {
// foreach (Nucleus nucleus in receptor.nucleiArray) {
// if (nucleus is Neuron neuron) {
// Vector3 worldVector = this.gameObject.transform.TransformVector(neuron.outputValue);
// Handles.color = Color.yellow;
// Handles.DrawLine(this.gameObject.transform.position, this.gameObject.transform.position + worldVector);
// }
// }
// }
// else {
if (this.currentNucleus is Neuron currentNeuron) {
Vector3 worldVector = this.gameObject.transform.TransformVector(currentNeuron.outputValue);
Handles.color = Color.yellow;
Handles.DrawLine(this.gameObject.transform.position, this.gameObject.transform.position + worldVector);
}
// }
}
}
}
}
public class NeuroidLayer {
public int ix = 0;
public List<Nucleus> neuroids = new();
}
public class Dag {
public class Node {
public int id;
public Vector2 position;
public float radius = 20f; // circle radius
public Nucleus nucleus;
}
public class Edge {
public int fromId;
public int toId;
}
public List<Node> nodes = new();
public List<Edge> edges = new();
public Node FindNode(string name, bool justBaseName = true) {
if (justBaseName) {
int colonPos = name.IndexOf(":");
if (colonPos > 0)
name = name[..colonPos];
}
foreach (Node node in this.nodes) {
string nodeName = node.nucleus.name;
if (justBaseName) {
int colonPos = nodeName.IndexOf(":");
if (colonPos > 0)
nodeName = nodeName[..colonPos];
}
if (nodeName == name)
return node;
}
return null;
}
public static Node GetNodeById(Dag dag, int id) => dag.nodes.FirstOrDefault(x => x.id == id);
public static void ComputeLayout(Dag dag) {
Dictionary<int, List<int>> adjacency = dag.nodes.ToDictionary(n => n.id, n => new List<int>());
Dictionary<int, int> outdegree = dag.nodes.ToDictionary(node => node.id, n => 0);
foreach (Edge edge in dag.edges) {
if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId))
continue;
adjacency[edge.fromId].Add(edge.toId);
outdegree[edge.fromId]++;
}
// Kahn's algorithm to compute topological layers (horizontal layers)
// build parent list (reverse adjacency) and parentIndegree = number of children each parent has
Dictionary<int, List<int>> parents = dag.nodes.ToDictionary(n => n.id, _ => new List<int>());
Dictionary<int, int> childCount = dag.nodes.ToDictionary(n => n.id, _ => 0);
foreach (Edge edge in dag.edges) {
if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) continue;
adjacency[edge.fromId].Add(edge.toId);
parents[edge.toId].Add(edge.fromId); // parent of 'to' is 'from'
childCount[edge.fromId]++; // outdegree
}
Dictionary<int, int> layer = new();
Queue<int> queue = new(outdegree.Where(kv => kv.Value == 0).Select(kv => kv.Key));
foreach (int id in queue)
layer[id] = 0;
// process parents (reverse traversal)
while (queue.Count > 0) {
int u = queue.Dequeue();
int l = layer[u];
foreach (int p in parents[u]) {
if (!layer.ContainsKey(p) || layer[p] < l + 1)
layer[p] = l + 1;
childCount[p]--; // decrement remaining unprocessed children
if (childCount[p] == 0)
queue.Enqueue(p);
}
}
// Any unreachable nodes -> assign next layers
int maxLayer = layer.Count > 0 ? layer.Values.Max() : 0;
foreach (Node node in dag.nodes) {
if (!layer.ContainsKey(node.id)) {
maxLayer++;
layer[node.id] = maxLayer;
}
}
// Group nodes by layer (left to right)
List<List<int>> layers =
layer.
GroupBy(kv => kv.Value).
OrderBy(g => g.Key).
Select(g => g.Select(x => x.Key).ToList()).
ToList();
// Same code without using Linq
// Build layers dictionary: layerIndex -> List<int> nodeIds
// Dictionary<int, List<int>> layersDict = new();
// foreach (KeyValuePair<int, int> kv in layer) {
// int nodeId = kv.Key;
// int layerIndex = kv.Value;
// if (!layersDict.TryGetValue(layerIndex, out List<int> list)) {
// list = new List<int>();
// layersDict[layerIndex] = list;
// }
// list.Add(nodeId);
// }
// // Determine sorted layer indices
// List<int> layerIndices = new(layersDict.Keys);
// layerIndices.Sort(); // ascending order
// // Build final List<List<int>> in sorted order
// List<List<int>> layers = new();
// foreach (int idx in layerIndices) {
// layers.Add(layersDict[idx]);
// }
float hSpacing = 100f;
float totalHeight = 400f;
// Place nodes: x increases with layer index, y spaced within layer
for (int layerIx = 0; layerIx < layers.Count; layerIx++) {
List<int> nodeList = layers[layerIx];
float spacing = totalHeight / nodeList.Count;
float margin = 10 + spacing / 2;
for (int i = 0; i < nodeList.Count; i++) {
int index = nodeList[i];
Node node = GetNodeById(dag, index);
if (node == null)
continue;
float x = hSpacing + layerIx * hSpacing;
//float y = 400 - totalHeight / 2f + i * vSpacing;
float y = margin + i * spacing;
// Debug.Log($"({li}, {i}) -> {x}, {y}");
node.position = new Vector2(x, y);
}
}
//Repaint();
}
}
}