cc9a845 Fix sleeping for product combinator e4ba7f8 Better cross-cluster monitoring 4f8a6ab Improved (but not fixed) cross-cluster monitoring b12616b Fix neuron output visualisation 96439cc Visualize all outputs d583e67 WIP cluster references/instance 04bab92 Fix links to multiple cluster neurons & cleanup e17a249 Cross-cluster editor links 0ab2d21 Migrating and cleaning up b6630ad First steps to using instanceCount for clusters 8801fa2 Cluster reimport fixes befb69d full graph with collapsed clusters 1a1919f Fix expansion of clsuter arrays c708f4d Improved clusterarray support c2e4e1b Fix Cluster array extension 02047a4 Adde full graph scrollbar 471ed36 Completed full graph integration 830e3e7 Added full graph view mode 249e888 Improve full graph view 308a6a1 The Entities are battling 75d9d1c Cleanup c8f0f0c Fix aging of neurons e2e169c small fixes 619ced6 Removed the use of Receptors 19f9296 Simplifications bc0a796 Integrated clusterarray in cluster e40dd23 Fixed clusterViewer for clusterarrays b0f4b41 Status quo adding clusterArrays 1fc75a8 Added ClusterArray 0023920 Cover seeking(-ish) behaviour 1c7b8e7 Added Tanh Activation a99d40c BrainViewer added db43655 Pew pew! 18ef4cd Merge commit '89017475984bbbf1899fb38846c5bb0e7775dedd' into NanoBrain git-subtree-dir: NanoBrain git-subtree-split: cc9a845b643ffb4a9abe4f7da787ac5c5b14dae8
950 lines
41 KiB
C#
950 lines
41 KiB
C#
using System.Collections.Generic;
|
|
using System.Linq;
|
|
|
|
using UnityEditor;
|
|
using UnityEngine;
|
|
using UnityEngine.UIElements;
|
|
|
|
namespace NanoBrain {
|
|
|
|
public class ClusterViewer : Editor {
|
|
|
|
public static ClusterPrefab previousPrefab;
|
|
|
|
public class GraphView : VisualElement {
|
|
//protected readonly ClusterPrefab prefab;
|
|
protected Cluster currentCluster;
|
|
protected SerializedObject serializedBrain;
|
|
protected Nucleus currentNucleus;
|
|
protected Nucleus selectedOutput;
|
|
|
|
protected GameObject gameObject;
|
|
private bool expandArray = false;
|
|
|
|
protected ClusterPrefab prefabAsset;
|
|
protected VisualElement topMenuContainer;
|
|
protected ScrollView scrollView;
|
|
protected IMGUIContainer graphContainer;
|
|
protected readonly PopupField<string> outputsPopup;
|
|
|
|
public enum Mode {
|
|
Focus,
|
|
Full
|
|
}
|
|
public Mode mode = Mode.Focus;
|
|
|
|
public GraphView(Cluster cluster) {
|
|
this.currentCluster = cluster;
|
|
|
|
name = "content";
|
|
style.flexGrow = 1;
|
|
|
|
topMenuContainer = new() {
|
|
style = {
|
|
flexDirection = FlexDirection.Row,
|
|
alignItems = Align.Center,
|
|
}
|
|
};
|
|
|
|
EnumField modePopup = new(mode);
|
|
modePopup.style.width = 80;
|
|
modePopup.RegisterValueChangedCallback(OnModeChange);
|
|
topMenuContainer.Add(modePopup);
|
|
|
|
scrollView = new(ScrollViewMode.Horizontal);
|
|
scrollView.style.position = Position.Absolute;
|
|
scrollView.style.left = 0; scrollView.style.top = 0;
|
|
scrollView.style.right = 0; scrollView.style.bottom = 0;
|
|
//scrollView.style.flexGrow = 1;
|
|
scrollView.horizontalScrollerVisibility = ScrollerVisibility.Auto; // Auto shows when needed
|
|
scrollView.verticalScrollerVisibility = ScrollerVisibility.Hidden;
|
|
|
|
graphContainer = new(OnIMGUI);
|
|
//graphContainer.style.position = Position.Relative; // or omit this line
|
|
//graphContainer.style.position = Position.Absolute;
|
|
// graphContainer.style.left = 0; graphContainer.style.top = 0;
|
|
// graphContainer.style.right = 0; graphContainer.style.bottom = 0;
|
|
graphContainer.pickingMode = PickingMode.Position;
|
|
graphContainer.focusable = true;
|
|
//graphContainer.style.width = 1200;
|
|
//graphContainer.style.width = new StyleLength(StyleKeyword.Null); // allow content to determine width
|
|
|
|
scrollView.contentContainer.Add(graphContainer);
|
|
Add(scrollView);
|
|
Add(topMenuContainer);
|
|
|
|
|
|
// Subscribe when added to panel (editor UI ready)
|
|
RegisterCallback<AttachToPanelEvent>(evt => Subscribe());
|
|
RegisterCallback<DetachFromPanelEvent>(evt => Unsubscribe());
|
|
}
|
|
|
|
protected virtual void OnModeChange(ChangeEvent<System.Enum> changeEvent) {
|
|
this.mode = (Mode)changeEvent.newValue;
|
|
}
|
|
|
|
bool subscribed = false;
|
|
void Subscribe() {
|
|
if (subscribed) return;
|
|
SceneView.duringSceneGui += OnSceneGUI;
|
|
subscribed = true;
|
|
SceneView.RepaintAll();
|
|
}
|
|
|
|
void Unsubscribe() {
|
|
if (!subscribed) return;
|
|
SceneView.duringSceneGui -= OnSceneGUI;
|
|
subscribed = false;
|
|
}
|
|
|
|
public void SetGraph(GameObject gameObject) {
|
|
this.gameObject = gameObject;
|
|
|
|
if (Application.isPlaying == false)
|
|
this.serializedBrain = new SerializedObject(this.currentCluster.prefab);
|
|
this.selectedOutput = this.currentCluster.outputs[0];
|
|
this.currentNucleus = this.selectedOutput;
|
|
Rebuild();
|
|
}
|
|
|
|
void Rebuild() {
|
|
if (this.currentNucleus == null)
|
|
return;
|
|
|
|
string path = AssetDatabase.GetAssetPath(this.currentCluster.prefab); // or known path
|
|
this.prefabAsset = AssetDatabase.LoadAssetAtPath<ClusterPrefab>(path);
|
|
if (this.prefabAsset == null) {
|
|
// create in memory save if it doesn't exist
|
|
this.prefabAsset = CreateInstance<ClusterPrefab>();
|
|
//Debug.LogError("Cluster Prefab is not found on disk");
|
|
}
|
|
}
|
|
|
|
public void OnIMGUI() {
|
|
if (Application.isPlaying == false)
|
|
serializedBrain.Update();
|
|
|
|
Handles.BeginGUI();
|
|
DrawGraph();
|
|
Handles.EndGUI();
|
|
}
|
|
|
|
#region Graph
|
|
|
|
protected virtual void DrawGraph() {
|
|
if (mode == Mode.Focus)
|
|
DrawFocusGraph();
|
|
else
|
|
DrawFullGraph();
|
|
}
|
|
|
|
#region Full Graph
|
|
|
|
protected void DrawFullGraph() {
|
|
//Dag dag = GenerateGraph(this.prefab);
|
|
Dag dag = GenerateGraph(this.selectedOutput);
|
|
Dag.ComputeLayout(dag);
|
|
// Draw edges
|
|
foreach (Dag.Edge e in dag.edges) {
|
|
Dag.Node from = dag.nodes.FirstOrDefault(x => x.id == e.fromId);
|
|
Dag.Node to = dag.nodes.FirstOrDefault(x => x.id == e.toId);
|
|
if (from == null || to == null)
|
|
continue;
|
|
|
|
Vector2 fromPosition = from.position;
|
|
Vector2 toPosition = to.position;
|
|
DrawEdge(fromPosition, toPosition);
|
|
}
|
|
|
|
// Draw nodes
|
|
foreach (Dag.Node n in dag.nodes)
|
|
DrawNucleus(n.nucleus, n.position, 1, n.radius);
|
|
|
|
// Determine graph width
|
|
float width = 0;
|
|
float currentNucleusPosition = 0;
|
|
foreach (Dag.Node node in dag.nodes) {
|
|
if (node.position.x > width)
|
|
width = node.position.x;
|
|
if (node.nucleus == currentNucleus)
|
|
currentNucleusPosition = node.position.x;
|
|
}
|
|
|
|
// Resize the graph container to the full graph width
|
|
float margin = 50f;
|
|
graphContainer.style.width = width + 2 * margin;
|
|
|
|
// Scroll to the current nucleus
|
|
float viewportWidth = scrollView.layout.width;
|
|
// center currentNucleus in viewport
|
|
float desiredScrollX = currentNucleusPosition - viewportWidth * 0.5f;
|
|
// clamp between 0 and maximum scrollable range
|
|
float maxScrollX = Mathf.Max(0f, graphContainer.resolvedStyle.width - viewportWidth);
|
|
desiredScrollX = Mathf.Clamp(desiredScrollX, 0f, maxScrollX);
|
|
|
|
Vector2 current = scrollView.scrollOffset;
|
|
scrollView.scrollOffset = new Vector2(desiredScrollX, current.y);
|
|
}
|
|
|
|
public Dag GenerateGraph(Nucleus rootNucleus) {
|
|
Dag dag = new();
|
|
if (rootNucleus == null)
|
|
return dag;
|
|
|
|
int ix = 0;
|
|
Dag.Node receiver = new() {
|
|
id = ix,
|
|
//title = nucleus.name,
|
|
nucleus = rootNucleus
|
|
};
|
|
dag.nodes.Add(receiver);
|
|
ix++;
|
|
DescendGraph(receiver, ref ix, dag);
|
|
return dag;
|
|
}
|
|
|
|
private void DescendGraph(Dag.Node receiver, ref int ix, Dag dag) {
|
|
foreach (Synapse synapse in receiver.nucleus.synapses) {
|
|
Nucleus nucleus = synapse.neuron;
|
|
if (nucleus.parent != null && nucleus.parent != currentNucleus.parent) {
|
|
nucleus = nucleus.parent;
|
|
}
|
|
string nucleusName = nucleus.name;
|
|
Dag.Node synapseNode = dag.FindNode(nucleusName);
|
|
if (synapseNode == null) {
|
|
synapseNode = new() {
|
|
id = ix,
|
|
nucleus = nucleus
|
|
};
|
|
dag.nodes.Add(synapseNode);
|
|
}
|
|
Dag.Edge edge = new() {
|
|
fromId = synapseNode.id,
|
|
toId = receiver.id
|
|
};
|
|
dag.edges.Add(edge);
|
|
ix++;
|
|
DescendGraph(synapseNode, ref ix, dag);
|
|
}
|
|
}
|
|
|
|
#endregion Full Graph
|
|
|
|
#region Focus Graph
|
|
|
|
protected void DrawFocusGraph() {
|
|
float size = 20;
|
|
Vector3 position = new(150, 210, 0);
|
|
|
|
if (this.currentNucleus != null) {
|
|
DrawReceivers(this.currentNucleus, position, size);
|
|
DrawSynapses(this.currentNucleus, position, size);
|
|
|
|
// Draw selected Nucleus
|
|
if (expandArray) {
|
|
float maxValue = 1;
|
|
|
|
if (this.currentNucleus is Cluster cluster) {
|
|
float spacing = 400f / cluster.instanceCount;
|
|
float margin = 10 + spacing / 2;
|
|
float xMin = 150 - size;
|
|
float xMax = 150 + size;
|
|
float yMin = 10 + margin - size / 2;
|
|
float yMax = 400 - margin + size;
|
|
Vector3[] verts = new Vector3[4] {
|
|
new(xMin, yMin, 0),
|
|
new(xMax, yMin, 0),
|
|
new(xMax, yMax, 0),
|
|
new(xMin, yMax, 0)
|
|
};
|
|
Handles.color = Color.black;
|
|
Handles.DrawAAConvexPolygon(verts);
|
|
int row = 0;
|
|
if (cluster.siblingClusters == null) {
|
|
Vector3 pos = new(150, margin + row * spacing, 0.0f);
|
|
Handles.color = Color.white;
|
|
// The selected sibling highlight ring
|
|
Handles.DrawSolidDisc(pos, Vector3.forward, size + 2);
|
|
DrawNucleus(cluster, pos, maxValue, size);
|
|
row++;
|
|
}
|
|
else {
|
|
foreach (Cluster sibling in cluster.siblingClusters) {
|
|
Vector3 pos = new(150, margin + row * spacing, 0.0f);
|
|
Handles.color = Color.white;
|
|
// The selected sibling highlight ring
|
|
Handles.DrawSolidDisc(pos, Vector3.forward, size + 2);
|
|
DrawNucleus(sibling, pos, maxValue, size);
|
|
row++;
|
|
}
|
|
}
|
|
GUIStyle style = new(EditorStyles.label) {
|
|
alignment = TextAnchor.UpperCenter,
|
|
normal = { textColor = Color.white },
|
|
fontStyle = FontStyle.Bold,
|
|
};
|
|
Vector3 labelPos = new(150, yMax + size + 5, 0);
|
|
string clusterName = cluster.name;
|
|
int colonPos = clusterName.IndexOf(":");
|
|
if (colonPos > 0) {
|
|
string baseName = clusterName[..colonPos];
|
|
Handles.Label(labelPos, baseName, style);
|
|
}
|
|
else
|
|
Handles.Label(labelPos, clusterName, style);
|
|
}
|
|
else {
|
|
if (this.currentNucleus is Neuron neuron)
|
|
maxValue = neuron.outputMagnitude;
|
|
|
|
DrawNucleus(this.currentNucleus, position, maxValue, 20);
|
|
|
|
}
|
|
}
|
|
else {
|
|
float maxValue = 1;
|
|
if (this.currentNucleus is Neuron neuron)
|
|
maxValue = neuron.outputMagnitude;
|
|
else if (this.currentNucleus is Cluster cluster)
|
|
maxValue = cluster.defaultOutput.outputMagnitude;
|
|
DrawNucleus(this.currentNucleus, position, maxValue, 20);
|
|
}
|
|
}
|
|
else {
|
|
DrawAllOutputs(position, size);
|
|
DrawOutputs(position, size);
|
|
}
|
|
graphContainer.style.width = 300;
|
|
}
|
|
|
|
protected void DrawReceivers(Nucleus nucleus, Vector3 parentPos, float size) {
|
|
List<Nucleus> receivers;
|
|
if (nucleus is Neuron neuron)
|
|
receivers = neuron.receivers;
|
|
else if (nucleus is Cluster cluster)
|
|
receivers = cluster.CollectReceivers();
|
|
else
|
|
return;
|
|
|
|
// For top-level nodes, add link to previous editor and/or 'Outputs'
|
|
int nodeCount = receivers.Count();
|
|
if (nucleus == this.selectedOutput) {
|
|
// Add link to 'Outpus'
|
|
nodeCount++;
|
|
if (ClusterViewer.previousPrefab != null)
|
|
// Add link to previous editor
|
|
nodeCount++;
|
|
}
|
|
|
|
// Determine the maximum value in this layer
|
|
// This is used to 'scale' the output value colors of the nuclei
|
|
float maxValue = 0;
|
|
foreach (Nucleus receiver in receivers) {
|
|
if (receiver is Neuron neuroid) {
|
|
float value = neuroid.outputMagnitude;
|
|
if (value > maxValue)
|
|
maxValue = value;
|
|
}
|
|
}
|
|
|
|
// Determine the spacing of the nuclei in the layer
|
|
float spacing = 400f / nodeCount;
|
|
float margin = 10 + spacing / 2;
|
|
|
|
int row = 0;
|
|
List<Nucleus[]> drawnArrays = new();
|
|
foreach (Nucleus receiver in receivers) {
|
|
Nucleus receiverNucleus = receiver;
|
|
if (receiverNucleus == null)
|
|
continue;
|
|
|
|
Vector3 pos = new(50, margin + row * spacing, 0.0f);
|
|
DrawEdge(parentPos, pos);
|
|
|
|
DrawNucleus(receiverNucleus, pos, maxValue, size);
|
|
row++;
|
|
}
|
|
if (nucleus == this.selectedOutput) {
|
|
Vector3 pos = new(50, margin + row * spacing, 0);
|
|
if (ClusterViewer.previousPrefab != null) {
|
|
DrawEdge(parentPos, pos);
|
|
DrawClusterPrefab(ClusterViewer.previousPrefab, pos, size);
|
|
row++;
|
|
}
|
|
pos = new(50, margin + row * spacing, 0);
|
|
DrawEdge(parentPos, pos);
|
|
DrawAllOutputs(pos, size);
|
|
}
|
|
}
|
|
|
|
protected void DrawSynapses(Nucleus nucleus, Vector3 parentPos, float size) {
|
|
if (nucleus == null)
|
|
return;
|
|
|
|
// Determine the maximum value in this layer
|
|
// This is used to 'scale' the output value colors of the nuclei
|
|
float maxValue = 0;
|
|
int neuronCount = 0;
|
|
List<Neuron> drawnNeurons = new();
|
|
foreach (Synapse synapse in nucleus.synapses) {
|
|
if (synapse.neuron == null)
|
|
continue;
|
|
|
|
// Count multiple synapses to the same neuron only once
|
|
if (drawnNeurons.Contains(synapse.neuron))
|
|
continue;
|
|
drawnNeurons.Add(synapse.neuron);
|
|
|
|
float value = synapse.neuron.outputMagnitude * synapse.weight;
|
|
if (value > maxValue)
|
|
maxValue = value;
|
|
|
|
neuronCount++;
|
|
}
|
|
|
|
// Determine the spacing of the nuclei in the layer
|
|
float spacing = 400f / neuronCount;
|
|
float margin = 10 + spacing / 2;
|
|
|
|
int row = 0;
|
|
drawnNeurons = new();
|
|
foreach (Synapse synapse in nucleus.synapses) {
|
|
if (synapse.neuron is null)
|
|
continue;
|
|
|
|
// Draw multiple synapses to the same neuron only once
|
|
if (drawnNeurons.Contains(synapse.neuron))
|
|
continue;
|
|
drawnNeurons.Add(synapse.neuron);
|
|
|
|
Vector3 pos = new(250, margin + row * spacing, 0.0f);
|
|
DrawEdge(parentPos, pos);
|
|
// Handles.color = Color.white;
|
|
// Handles.DrawLine(parentPos, pos);
|
|
Color color = Color.black;
|
|
if (Application.isPlaying) {
|
|
if (maxValue == 0 || !float.IsFinite(maxValue))
|
|
maxValue = 1;
|
|
float brightness = synapse.neuron.outputMagnitude * synapse.weight / maxValue;
|
|
color = new Color(brightness, brightness, brightness, 1f);
|
|
}
|
|
DrawNucleus(synapse.neuron, pos, size, color);
|
|
row++;
|
|
}
|
|
}
|
|
|
|
protected void DrawOutputs(Vector2 parentPos, float size) {
|
|
// Determine the maximum value in this layer
|
|
// This is used to 'scale' the output value colors of the nuclei
|
|
float maxValue = 0;
|
|
int neuronCount = 0;
|
|
List<Nucleus> drawnNuclei = new();
|
|
foreach (Nucleus nucleus in this.currentCluster.outputs) {
|
|
if (nucleus is not Neuron neuron)
|
|
continue;
|
|
|
|
// Draw multiple synapses to the same neuron only once
|
|
if (drawnNuclei.Contains(nucleus))
|
|
continue;
|
|
drawnNuclei.Add(nucleus);
|
|
|
|
float value = neuron.outputMagnitude;
|
|
if (value > maxValue)
|
|
maxValue = value;
|
|
|
|
neuronCount++;
|
|
}
|
|
|
|
// Determine the spacing of the nuclei in the layer
|
|
float spacing = 400f / neuronCount;
|
|
float margin = 10 + spacing / 2;
|
|
|
|
int row = 0;
|
|
drawnNuclei = new();
|
|
foreach (Nucleus nucleus in this.currentCluster.outputs) {
|
|
if (nucleus is not Neuron neuron)
|
|
continue;
|
|
|
|
// Draw multiple synapses to the same neuron only once
|
|
if (drawnNuclei.Contains(nucleus))
|
|
continue;
|
|
drawnNuclei.Add(nucleus);
|
|
|
|
Vector3 pos = new(250, margin + row * spacing, 0.0f);
|
|
DrawEdge(parentPos, pos);
|
|
Color color = Color.black;
|
|
if (Application.isPlaying) {
|
|
if (maxValue == 0 || !float.IsFinite(maxValue))
|
|
maxValue = 1;
|
|
float brightness = neuron.outputMagnitude / maxValue;
|
|
color = new Color(brightness, brightness, brightness, 1f);
|
|
}
|
|
DrawNucleus(nucleus, pos, size, color);
|
|
row++;
|
|
}
|
|
}
|
|
|
|
#endregion Focus Graph
|
|
|
|
protected void DrawNucleus(Nucleus nucleus, Vector3 position, float maxValue, float size) {
|
|
Color color;
|
|
if (Application.isPlaying) {
|
|
float brightness = 0;
|
|
if (nucleus is Neuron neuron)
|
|
brightness = neuron.outputMagnitude / maxValue;
|
|
color = new Color(brightness, brightness, brightness, 1f);
|
|
}
|
|
else
|
|
color = Color.black;
|
|
DrawNucleus(nucleus, position, size, color);
|
|
}
|
|
|
|
protected void DrawNucleus(Nucleus nucleus, Vector3 position, float size, Color color) {
|
|
if (nucleus == null)
|
|
return;
|
|
|
|
if (nucleus == this.currentNucleus) {
|
|
// The selected nucleus highlight ring
|
|
Handles.color = Color.white;
|
|
Handles.DrawSolidDisc(position, Vector3.forward, size + 2);
|
|
}
|
|
|
|
if (nucleus is MemoryCell) {
|
|
Handles.color = Color.white;
|
|
Handles.DrawWireDisc(position + Vector3.right * 10, Vector3.forward, size);
|
|
}
|
|
|
|
Handles.color = color;
|
|
Handles.DrawSolidDisc(position, Vector3.forward, size);
|
|
|
|
Handles.color = Color.white;
|
|
// Position the label in front of the disc
|
|
Vector3 labelPosition = position + (Vector3.forward * 0.1f);
|
|
|
|
GUIStyle style = new(EditorStyles.label) {
|
|
alignment = TextAnchor.MiddleCenter,
|
|
normal = { textColor = Color.white },
|
|
fontStyle = FontStyle.Bold,
|
|
};
|
|
|
|
if (nucleus.parent is Cluster parentCluster && currentNucleus != null && parentCluster != currentNucleus.parent)
|
|
DrawCluster(parentCluster, position, color, size);
|
|
else if (nucleus is Cluster cluster)
|
|
DrawCluster(cluster, position, color, size);
|
|
|
|
if (expandArray == false || nucleus != currentNucleus) {
|
|
// put name below nucleus
|
|
Vector3 labelPos = position - Vector3.down * (size + 5); // below neuron
|
|
style.alignment = TextAnchor.UpperCenter;
|
|
|
|
if (nucleus.parent != null && currentNucleus != null && nucleus.parent != currentNucleus.parent && nucleus.parent is Cluster parentCluster1) {
|
|
// This neuron is part of another cluster
|
|
parentCluster1.name ??= "";
|
|
string baseName = "";
|
|
int colonPos = parentCluster1.name.IndexOf(":");
|
|
if (colonPos > 0 && colonPos < parentCluster1.name.Length - 2)
|
|
baseName = parentCluster1.name[..colonPos] + ".";
|
|
else
|
|
baseName = parentCluster1.name + ".";
|
|
// if (colonPos > 0 && colonPos < parentCluster1.name.Length - 2) {
|
|
// // if it is an array, we should not show the :0 of the first element
|
|
// //baseName = baseName[..colonPos];
|
|
// Handles.Label(labelPos, baseName + nucleus.name, style);
|
|
// }
|
|
// else
|
|
Handles.Label(labelPos, baseName + nucleus.name, style);
|
|
}
|
|
else {
|
|
nucleus.name ??= "";
|
|
int colonPos = nucleus.name.IndexOf(":");
|
|
if (colonPos > 0 && colonPos < nucleus.name.Length - 2) {
|
|
// if it is an array, we should not show the :0 of the first element
|
|
string baseName = nucleus.name[..colonPos];
|
|
Handles.Label(labelPos, baseName, style);
|
|
}
|
|
else
|
|
Handles.Label(labelPos, nucleus.name, style);
|
|
}
|
|
}
|
|
|
|
// Tooltip
|
|
Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2);
|
|
|
|
int id = GUIUtility.GetControlID(FocusType.Passive);
|
|
Event e = Event.current;
|
|
EventType et = e.GetTypeForControl(id);
|
|
if (e != null && neuronRect.Contains(e.mousePosition)) {
|
|
// Process Hover
|
|
HandleMouseHover(nucleus, neuronRect);
|
|
// Process click
|
|
if (e.type == EventType.MouseDown && e.button == 0) {
|
|
// Consume the event so the scene doesn't also handle it
|
|
e.Use();
|
|
if (nucleus is Cluster parentCluster2)
|
|
OnNeuronClick(parentCluster2);
|
|
else
|
|
OnNeuronClick(nucleus);
|
|
}
|
|
}
|
|
}
|
|
|
|
protected void DrawCluster(Cluster cluster, Vector3 position, Color color, float size) {
|
|
GUIStyle labelTextStyle = new(EditorStyles.label) {
|
|
normal = { textColor = Color.white },
|
|
fontStyle = FontStyle.Bold,
|
|
};
|
|
|
|
if (expandArray) {
|
|
// Put array indices above the discs
|
|
labelTextStyle.alignment = TextAnchor.LowerCenter;
|
|
Vector3 labelPosition = position + Vector3.down * (size + 5); // below disc
|
|
|
|
// Strip the instance number in the name
|
|
int colonPos1 = cluster.name.IndexOf(":");
|
|
if (colonPos1 > 0) {
|
|
string extName = cluster.name[(colonPos1 + 2)..];
|
|
Handles.Label(labelPosition, extName, labelTextStyle);
|
|
}
|
|
else
|
|
Handles.Label(labelPosition, "0", labelTextStyle);
|
|
}
|
|
else {
|
|
// Put instance count inside the disc
|
|
labelTextStyle.alignment = TextAnchor.MiddleCenter;
|
|
Vector3 labelPosition = position + (Vector3.forward * 0.1f);
|
|
|
|
// Adjust text color based on disc color
|
|
if (color.grayscale > 0.5f)
|
|
labelTextStyle.normal.textColor = Color.black;
|
|
else
|
|
labelTextStyle.normal.textColor = Color.white;
|
|
|
|
if (cluster.instanceCount > 1) {
|
|
Handles.Label(labelPosition, cluster.instanceCount.ToString(), labelTextStyle);
|
|
labelTextStyle.normal.textColor = Color.white;
|
|
}
|
|
else if (cluster.siblingClusters != null && cluster.siblingClusters.Length > 1) {
|
|
Handles.Label(labelPosition, cluster.siblingClusters.Length.ToString(), labelTextStyle);
|
|
labelTextStyle.normal.textColor = Color.white;
|
|
}
|
|
}
|
|
|
|
// Draw a circle around the disc to indicate this is a Cluster
|
|
Handles.color = Color.white;
|
|
Handles.DrawWireDisc(position, Vector3.forward, size + 5);
|
|
}
|
|
|
|
protected void DrawClusterPrefab(ClusterPrefab prefab, Vector2 position, float size) {
|
|
Handles.color = Color.black;
|
|
Handles.DrawSolidDisc(position, Vector3.forward, size);
|
|
// Draw a circle around the disc to indicate this is a Cluster
|
|
Handles.color = Color.white;
|
|
Handles.DrawWireDisc(position, Vector3.forward, size + 5);
|
|
|
|
// put name below nucleus
|
|
GUIStyle style = new(EditorStyles.label) {
|
|
alignment = TextAnchor.MiddleCenter,
|
|
normal = { textColor = Color.white },
|
|
fontStyle = FontStyle.Bold,
|
|
};
|
|
Vector2 labelPos = position - Vector2.down * (size + 5); // below neuron
|
|
style.alignment = TextAnchor.UpperCenter;
|
|
Handles.Label(labelPos, prefab.name, style);
|
|
|
|
Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2);
|
|
int id = GUIUtility.GetControlID(FocusType.Passive);
|
|
Event e = Event.current;
|
|
EventType et = e.GetTypeForControl(id);
|
|
if (e != null && neuronRect.Contains(e.mousePosition)) {
|
|
// Process click
|
|
if (e.type == EventType.MouseDown && e.button == 0) {
|
|
// Consume the event so the scene doesn't also handle it
|
|
e.Use();
|
|
Selection.activeObject = prefab;
|
|
EditorGUIUtility.PingObject(prefab);
|
|
ClusterViewer.previousPrefab = null;
|
|
CreateEditor(prefab);
|
|
}
|
|
}
|
|
}
|
|
|
|
protected void DrawAllOutputs(Vector2 position, float size) {
|
|
GUIStyle labelTextStyle = new(EditorStyles.label) {
|
|
normal = { textColor = Color.white },
|
|
fontStyle = FontStyle.Bold,
|
|
alignment = TextAnchor.MiddleCenter,
|
|
};
|
|
Handles.Label(position, "Outputs", labelTextStyle);
|
|
|
|
Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2);
|
|
Event e = Event.current;
|
|
if (e != null && neuronRect.Contains(e.mousePosition)) {
|
|
// Process click
|
|
if (e.type == EventType.MouseDown && e.button == 0) {
|
|
// Consume the event so the scene doesn't also handle it
|
|
e.Use();
|
|
OnAllOutputsClick();
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
protected void DrawEdge(Vector2 from, Vector2 to, float radius = 20) {
|
|
Handles.color = Color.white;
|
|
// Handles.DrawLine(from, to);
|
|
|
|
Vector2 dir = to - from;
|
|
float len = dir.magnitude;
|
|
if (len <= 2f * radius || len <= Mathf.Epsilon)
|
|
// line too short
|
|
return;
|
|
|
|
Vector2 n = dir / len; // normalized
|
|
Vector2 a = from + n * radius;
|
|
Vector2 b = to - n * radius;
|
|
Handles.DrawLine(a, b);
|
|
}
|
|
|
|
protected void HandleMouseHover(Nucleus nucleus, Rect rect) {
|
|
GUIContent tooltip;
|
|
if (nucleus is Neuron neuron) {
|
|
tooltip = new(
|
|
$"{nucleus.name}" +
|
|
$"\nValue: {neuron.outputMagnitude}");
|
|
}
|
|
else
|
|
tooltip = new($"{nucleus.name}");
|
|
|
|
Vector2 mousePosition = Event.current.mousePosition;
|
|
|
|
// Display tooltip with some offset
|
|
Vector2 tooltipSize = GUI.skin.box.CalcSize(tooltip);
|
|
Rect tooltipRect = new Rect(mousePosition.x + 10, mousePosition.y + 10, tooltipSize.x, tooltipSize.y);
|
|
|
|
GUI.Box(tooltipRect, tooltip);
|
|
}
|
|
|
|
protected void OnNeuronClick(Nucleus nucleus) {
|
|
if (nucleus == this.currentNucleus) {
|
|
if (Application.isPlaying) {
|
|
if (nucleus is Cluster)
|
|
expandArray = !expandArray;
|
|
else
|
|
expandArray = false;
|
|
}
|
|
else {
|
|
if (nucleus is Cluster cluster)
|
|
OnClusterClick(cluster);
|
|
}
|
|
}
|
|
else if (nucleus.parent != null && this.currentNucleus != null && nucleus.parent != this.currentNucleus.parent) {
|
|
// We go to a different cluster
|
|
if (Application.isPlaying) {
|
|
this.currentNucleus = nucleus;
|
|
if (this.currentNucleus is Neuron neuron && neuron.receivers.Count == 0)
|
|
this.selectedOutput = this.currentNucleus;
|
|
expandArray = false;
|
|
}
|
|
else {
|
|
// select the cluster, not the neuron in the cluster
|
|
this.currentNucleus = nucleus.parent;
|
|
expandArray = false;
|
|
}
|
|
}
|
|
else {
|
|
this.currentNucleus = nucleus;
|
|
if (this.currentNucleus is Neuron neuron && neuron.receivers.Count == 0)
|
|
this.selectedOutput = this.currentNucleus;
|
|
expandArray = false;
|
|
}
|
|
}
|
|
|
|
protected void OnClusterClick(Cluster subCluster) {
|
|
// May be used with storedPrefab...
|
|
Selection.activeObject = subCluster.prefab;
|
|
EditorGUIUtility.PingObject(subCluster.prefab);
|
|
ClusterViewer.previousPrefab = this.currentCluster.prefab;
|
|
ClusterEditor newEditor = CreateEditor(subCluster.prefab) as ClusterEditor;
|
|
}
|
|
|
|
protected void OnAllOutputsClick() {
|
|
this.currentNucleus = null;
|
|
this.selectedOutput = null;
|
|
expandArray = false;
|
|
}
|
|
|
|
#endregion Graph
|
|
|
|
void OnSceneGUI(SceneView sceneView) {
|
|
if (this.gameObject != null) {
|
|
// if (this.currentNucleus is IReceptor receptor) {
|
|
// foreach (Nucleus nucleus in receptor.nucleiArray) {
|
|
// if (nucleus is Neuron neuron) {
|
|
// Vector3 worldVector = this.gameObject.transform.TransformVector(neuron.outputValue);
|
|
// Handles.color = Color.yellow;
|
|
// Handles.DrawLine(this.gameObject.transform.position, this.gameObject.transform.position + worldVector);
|
|
// }
|
|
// }
|
|
// }
|
|
// else {
|
|
if (this.currentNucleus is Neuron currentNeuron) {
|
|
Vector3 worldVector = this.gameObject.transform.TransformVector(currentNeuron.outputValue);
|
|
Handles.color = Color.yellow;
|
|
Handles.DrawLine(this.gameObject.transform.position, this.gameObject.transform.position + worldVector);
|
|
}
|
|
// }
|
|
}
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
public class NeuroidLayer {
|
|
public int ix = 0;
|
|
public List<Nucleus> neuroids = new();
|
|
}
|
|
|
|
public class Dag {
|
|
|
|
public class Node {
|
|
public int id;
|
|
public Vector2 position;
|
|
public float radius = 20f; // circle radius
|
|
public Nucleus nucleus;
|
|
}
|
|
|
|
public class Edge {
|
|
public int fromId;
|
|
public int toId;
|
|
}
|
|
|
|
public List<Node> nodes = new();
|
|
public List<Edge> edges = new();
|
|
|
|
public Node FindNode(string name, bool justBaseName = true) {
|
|
if (justBaseName) {
|
|
int colonPos = name.IndexOf(":");
|
|
if (colonPos > 0)
|
|
name = name[..colonPos];
|
|
}
|
|
foreach (Node node in this.nodes) {
|
|
string nodeName = node.nucleus.name;
|
|
if (justBaseName) {
|
|
int colonPos = nodeName.IndexOf(":");
|
|
if (colonPos > 0)
|
|
nodeName = nodeName[..colonPos];
|
|
}
|
|
if (nodeName == name)
|
|
return node;
|
|
}
|
|
return null;
|
|
}
|
|
|
|
public static Node GetNodeById(Dag dag, int id) => dag.nodes.FirstOrDefault(x => x.id == id);
|
|
|
|
public static void ComputeLayout(Dag dag) {
|
|
Dictionary<int, List<int>> adjacency = dag.nodes.ToDictionary(n => n.id, n => new List<int>());
|
|
Dictionary<int, int> outdegree = dag.nodes.ToDictionary(node => node.id, n => 0);
|
|
foreach (Edge edge in dag.edges) {
|
|
if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId))
|
|
continue;
|
|
adjacency[edge.fromId].Add(edge.toId);
|
|
outdegree[edge.fromId]++;
|
|
}
|
|
|
|
// Kahn's algorithm to compute topological layers (horizontal layers)
|
|
// build parent list (reverse adjacency) and parentIndegree = number of children each parent has
|
|
Dictionary<int, List<int>> parents = dag.nodes.ToDictionary(n => n.id, _ => new List<int>());
|
|
Dictionary<int, int> childCount = dag.nodes.ToDictionary(n => n.id, _ => 0);
|
|
|
|
foreach (Edge edge in dag.edges) {
|
|
if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) continue;
|
|
adjacency[edge.fromId].Add(edge.toId);
|
|
parents[edge.toId].Add(edge.fromId); // parent of 'to' is 'from'
|
|
childCount[edge.fromId]++; // outdegree
|
|
}
|
|
|
|
Dictionary<int, int> layer = new();
|
|
Queue<int> queue = new(outdegree.Where(kv => kv.Value == 0).Select(kv => kv.Key));
|
|
foreach (int id in queue)
|
|
layer[id] = 0;
|
|
|
|
// process parents (reverse traversal)
|
|
while (queue.Count > 0) {
|
|
int u = queue.Dequeue();
|
|
int l = layer[u];
|
|
foreach (int p in parents[u]) {
|
|
if (!layer.ContainsKey(p) || layer[p] < l + 1)
|
|
layer[p] = l + 1;
|
|
childCount[p]--; // decrement remaining unprocessed children
|
|
if (childCount[p] == 0)
|
|
queue.Enqueue(p);
|
|
}
|
|
}
|
|
|
|
// Any unreachable nodes -> assign next layers
|
|
int maxLayer = layer.Count > 0 ? layer.Values.Max() : 0;
|
|
foreach (Node node in dag.nodes) {
|
|
if (!layer.ContainsKey(node.id)) {
|
|
maxLayer++;
|
|
layer[node.id] = maxLayer;
|
|
}
|
|
}
|
|
|
|
// Group nodes by layer (left to right)
|
|
List<List<int>> layers =
|
|
layer.
|
|
GroupBy(kv => kv.Value).
|
|
OrderBy(g => g.Key).
|
|
Select(g => g.Select(x => x.Key).ToList()).
|
|
ToList();
|
|
|
|
// Same code without using Linq
|
|
// Build layers dictionary: layerIndex -> List<int> nodeIds
|
|
// Dictionary<int, List<int>> layersDict = new();
|
|
// foreach (KeyValuePair<int, int> kv in layer) {
|
|
// int nodeId = kv.Key;
|
|
// int layerIndex = kv.Value;
|
|
// if (!layersDict.TryGetValue(layerIndex, out List<int> list)) {
|
|
// list = new List<int>();
|
|
// layersDict[layerIndex] = list;
|
|
// }
|
|
// list.Add(nodeId);
|
|
// }
|
|
|
|
// // Determine sorted layer indices
|
|
// List<int> layerIndices = new(layersDict.Keys);
|
|
// layerIndices.Sort(); // ascending order
|
|
|
|
// // Build final List<List<int>> in sorted order
|
|
// List<List<int>> layers = new();
|
|
// foreach (int idx in layerIndices) {
|
|
// layers.Add(layersDict[idx]);
|
|
// }
|
|
|
|
float hSpacing = 100f;
|
|
float totalHeight = 400f;
|
|
|
|
// Place nodes: x increases with layer index, y spaced within layer
|
|
for (int layerIx = 0; layerIx < layers.Count; layerIx++) {
|
|
List<int> nodeList = layers[layerIx];
|
|
float spacing = totalHeight / nodeList.Count;
|
|
float margin = 10 + spacing / 2;
|
|
for (int i = 0; i < nodeList.Count; i++) {
|
|
int index = nodeList[i];
|
|
Node node = GetNodeById(dag, index);
|
|
if (node == null)
|
|
continue;
|
|
float x = hSpacing + layerIx * hSpacing;
|
|
//float y = 400 - totalHeight / 2f + i * vSpacing;
|
|
float y = margin + i * spacing;
|
|
// Debug.Log($"({li}, {i}) -> {x}, {y}");
|
|
node.position = new Vector2(x, y);
|
|
}
|
|
}
|
|
|
|
//Repaint();
|
|
}
|
|
}
|
|
|
|
} |