Squashed 'NanoBrain/' changes from 832d849..cc9a845

cc9a845 Fix sleeping for product combinator
e4ba7f8 Better cross-cluster monitoring
4f8a6ab Improved (but not fixed) cross-cluster monitoring
b12616b Fix neuron output visualisation
96439cc Visualize all outputs
d583e67 WIP cluster references/instance
04bab92 Fix links to multiple cluster neurons & cleanup
e17a249 Cross-cluster editor links
0ab2d21 Migrating and cleaning up
b6630ad First steps to using instanceCount for clusters
8801fa2 Cluster reimport fixes
befb69d full graph with collapsed clusters
1a1919f Fix expansion of clsuter arrays
c708f4d Improved clusterarray support
c2e4e1b Fix Cluster array extension
02047a4 Adde full graph scrollbar
471ed36 Completed full graph integration
830e3e7 Added full graph view mode
249e888 Improve full graph view
308a6a1 The Entities are battling
75d9d1c Cleanup
c8f0f0c Fix aging of neurons
e2e169c small fixes
619ced6 Removed the use of Receptors
19f9296 Simplifications
bc0a796 Integrated clusterarray in cluster
e40dd23 Fixed clusterViewer for clusterarrays
b0f4b41 Status quo adding clusterArrays
1fc75a8 Added ClusterArray
0023920 Cover seeking(-ish) behaviour
1c7b8e7 Added Tanh Activation
a99d40c BrainViewer added
db43655 Pew pew!
18ef4cd Merge commit '89017475984bbbf1899fb38846c5bb0e7775dedd' into NanoBrain

git-subtree-dir: NanoBrain
git-subtree-split: cc9a845b643ffb4a9abe4f7da787ac5c5b14dae8
This commit is contained in:
Pascal Serrarens 2026-04-23 15:22:02 +02:00
parent 05fd588f9b
commit 4ae9a15fc6
29 changed files with 2562 additions and 3122 deletions

View File

@ -1,369 +0,0 @@
using UnityEngine;
using UnityEditor;
using System.Collections.Generic;
using System.Linq;
namespace NanoBrain {
// Simple DAG data model
[System.Serializable]
public class DagNode {
public int id;
public string title;
public Vector2 position;
public float radius = 20f; // circle radius
}
[System.Serializable]
public class DagEdge {
public int fromId;
public int toId;
}
public class BrainEditorWindow : EditorWindow {
readonly List<DagNode> nodes = new();
readonly List<DagEdge> edges = new();
Vector2 pan = Vector2.zero;
float zoom = 1.0f;
const float minZoom = 0.5f;
const float maxZoom = 2.0f;
// Vector2 dragStart;
// bool draggingNode = false;
// int draggingNodeId = -1;
private readonly System.Type acceptedType = typeof(ClusterPrefab);
[MenuItem("Window/Brain Viewer")]
public static void ShowWindow() {
var w = GetWindow<BrainEditorWindow>("Brain Viewer");
w.minSize = new Vector2(500, 300);
}
void OnEnable() {
// if (nodes.Count == 0)
// CreateSampleGraph();
// Register callback so window updates when selection changes
Selection.selectionChanged += OnSelectionChanged;
RefreshSelection();
ComputeLeftToRightLayout();
}
private void OnDisable() {
Selection.selectionChanged -= OnSelectionChanged;
}
private void OnSelectionChanged() {
RefreshSelection();
ComputeLeftToRightLayout();
Repaint();
}
private void RefreshSelection() {
ClusterPrefab prefab = Selection.activeObject as ClusterPrefab;
if (prefab != null && acceptedType.IsAssignableFrom(prefab.GetType())) {
GenerateGraph(prefab);
}
}
private void GenerateGraph(ClusterPrefab prefab) {
nodes.Clear();
edges.Clear();
int ix = 0;
foreach (Nucleus nucleus in prefab.nuclei) {
nodes.Add(new DagNode() { id = ix, title = nucleus.name });
if (nucleus is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers) {
int receiverIx = prefab.GetNucleusIndex(receiver);
edges.Add(new DagEdge() { fromId = ix, toId = receiverIx });
}
}
ix++;
}
}
// void CreateSampleGraph() {
// nodes.Clear();
// edges.Clear();
// nodes.Add(new DagNode() { id = 0, title = "In1" });
// nodes.Add(new DagNode() { id = 1, title = "In2" });
// nodes.Add(new DagNode() { id = 2, title = "A" });
// nodes.Add(new DagNode() { id = 3, title = "B" });
// nodes.Add(new DagNode() { id = 4, title = "C" });
// nodes.Add(new DagNode() { id = 5, title = "Out1" });
// nodes.Add(new DagNode() { id = 6, title = "Out2" });
// edges.Add(new DagEdge() { fromId = 0, toId = 2 });
// edges.Add(new DagEdge() { fromId = 1, toId = 2 });
// edges.Add(new DagEdge() { fromId = 2, toId = 3 });
// edges.Add(new DagEdge() { fromId = 2, toId = 4 });
// edges.Add(new DagEdge() { fromId = 3, toId = 5 });
// edges.Add(new DagEdge() { fromId = 4, toId = 6 });
// }
void OnGUI() {
HandleInput();
Rect rect = new(0, 0, position.width, position.height);
EditorGUI.DrawRect(rect, new Color(0.11f, 0.11f, 0.11f));
// compute window center
Vector2 windowCenter = new(position.width / 2f, position.height / 2f);
// compute graph bounds center (in graph space)
Rect bounds = GetGraphBounds();
Vector2 graphCenter = bounds.center;
// compute autoPan that recenters the graph (does not modify node positions)
Vector2 autoPan = -graphCenter; // moves graph center to origin
// total translation = windowCenter + autoPan + user pan
Matrix4x4 oldMatrix = GUI.matrix;
GUI.matrix = Matrix4x4.TRS(windowCenter + autoPan + pan, Quaternion.identity, Vector3.one * zoom) *
Matrix4x4.TRS(-windowCenter, Quaternion.identity, Vector3.one);
// Draw edges first
foreach (DagEdge e in edges) {
DagNode from = GetNodeById(e.fromId);
DagNode to = GetNodeById(e.toId);
if (from == null || to == null) continue;
DrawEdgeCircleNodes(from, to);
}
// Draw nodes (circles)
foreach (DagNode n in nodes)
DrawNucleus(n);
GUI.matrix = oldMatrix;
// Footer toolbar
GUILayout.FlexibleSpace();
EditorGUILayout.BeginHorizontal(EditorStyles.toolbar);
if (GUILayout.Button("Fit", EditorStyles.toolbarButton)) FitToView();
if (GUILayout.Button("Layout LR", EditorStyles.toolbarButton)) ComputeLeftToRightLayout();
EditorGUILayout.EndHorizontal();
}
void HandleInput() {
Event e = Event.current;
// Zoom with scroll
if (e.type == EventType.ScrollWheel) {
float oldZoom = zoom;
float delta = -e.delta.y * 0.01f;
zoom = Mathf.Clamp(zoom + delta, minZoom, maxZoom);
Vector2 mouse = e.mousePosition;
pan += (mouse - new Vector2(position.width / 2, position.height / 2)) * (1 - zoom / oldZoom);
e.Use();
}
// Pan with middle or right+ctrl drag
if (e.type == EventType.MouseDrag && (e.button == 2 || (e.button == 1 && e.control))) {
pan += e.delta;
e.Use();
}
}
DagNode GetNodeById(int id) => nodes.FirstOrDefault(x => x.id == id);
List<DagEdge> GetIncomingEdges(DagNode node) {
List<DagEdge> incoming = new();
foreach (DagEdge e in edges) {
if (e.toId == node.id)
incoming.Add(e);
}
return incoming;
}
List<DagEdge> GetOutgoingEdges(DagNode node) {
List<DagEdge> outgoing = new();
foreach (DagEdge e in edges) {
if (e.fromId == node.id)
outgoing.Add(e);
}
return outgoing;
}
void DrawNucleus(DagNode n) {
Vector3 position = n.position;
Handles.color = Color.white * 0.9f;
Handles.DrawSolidDisc(n.position, Vector3.forward, n.radius);
if (GetIncomingEdges(n).Count == 0)
DrawArrowHead(n.position - new Vector2(n.radius + 10, 0), n.position - new Vector2(n.radius + 5, 0), 10f / zoom, 12f / zoom, Color.white);
if (GetOutgoingEdges(n).Count == 0)
DrawArrowHead(n.position + new Vector2(n.radius + 10, 0), n.position + new Vector2(n.radius + 15, 0), 10f / zoom, 12f / zoom, Color.white);
Handles.color = Color.white;
GUIStyle style = new(EditorStyles.label) {
alignment = TextAnchor.UpperCenter,
normal = { textColor = Color.white },
fontStyle = FontStyle.Bold,
};
Vector3 labelPos = position - Vector3.down * (n.radius + 10f); // below disc along up axis
Handles.Label(labelPos, n.title, style);
}
void DrawEdgeCircleNodes(DagNode from, DagNode to) {
Vector2 a = from.position;
Vector2 b = to.position;
if (a == b) return;
Handles.color = Color.white * 0.9f;
Handles.DrawLine(from.position, to.position);
// Vector2 dir = (b - a).normalized;
// Vector2 start = a + dir * from.radius;
// Vector2 end = b - dir * to.radius;
//DrawArrowHead(end - dir * 2f, end, 10f / zoom, 12f / zoom, Color.white);
}
void DrawArrowHead(Vector2 from, Vector2 to, float headWidth, float headLength, Color color) {
Vector2 dir = (to - from).normalized;
if (dir == Vector2.zero) return;
Vector2 right = new Vector2(-dir.y, dir.x);
Vector3 p1 = to;
Vector3 p2 = to - dir * headLength + right * headWidth * 0.5f;
Vector3 p3 = to - dir * headLength - right * headWidth * 0.5f;
Handles.color = color;
Handles.DrawAAConvexPolygon(p1, p2, p3);
}
// Left-to-right layered layout (sources on the left, sinks on the right)
void ComputeLeftToRightLayout() {
// build adjacency and indegree
var adj = nodes.ToDictionary(n => n.id, n => new List<int>());
var indeg = nodes.ToDictionary(n => n.id, n => 0);
foreach (var e in edges) {
if (!adj.ContainsKey(e.fromId) || !adj.ContainsKey(e.toId)) continue;
adj[e.fromId].Add(e.toId);
indeg[e.toId]++;
}
// Kahn's algorithm to compute topological layers (horizontal layers)
Dictionary<int, int> layer = new();
Queue<int> q = new(indeg.Where(kv => kv.Value == 0).Select(kv => kv.Key));
foreach (var id in q) layer[id] = 0;
while (q.Count > 0) {
int u = q.Dequeue();
int l = layer[u];
foreach (var v in adj[u]) {
// prefer placing v at least one layer after u
if (!layer.ContainsKey(v) || layer[v] < l + 1) layer[v] = l + 1;
indeg[v]--;
if (indeg[v] == 0) q.Enqueue(v);
}
}
// Any unreachable nodes -> assign next layers
int maxLayer = layer.Count > 0 ? layer.Values.Max() : 0;
foreach (var n in nodes) {
if (!layer.ContainsKey(n.id)) {
maxLayer++;
layer[n.id] = maxLayer;
}
}
// Group nodes by layer (left to right)
var layers = layer.GroupBy(kv => kv.Value).OrderBy(g => g.Key).Select(g => g.Select(x => x.Key).ToList()).ToList();
// Layout parameters (horizontal spacing drives left->right)
float hSpacing = 150f;
float vSpacing = 100f;
// Place nodes: x increases with layer index, y spaced within layer
for (int li = 0; li < layers.Count; li++) {
var lst = layers[li];
float totalHeight = (lst.Count - 1) * vSpacing;
for (int i = 0; i < lst.Count; i++) {
int id = lst[i];
var n = GetNodeById(id);
if (n == null) continue;
float x = hSpacing + li * hSpacing;
float y = 400 - totalHeight / 2f + i * vSpacing;
// Debug.Log($"({li}, {i}) -> {x}, {y}");
n.position = new Vector2(x, y);
}
}
Repaint();
}
void FitToView() {
if (nodes.Count == 0) return;
// compute bounds including radii
Rect bounds = new Rect(nodes[0].position - Vector2.one * nodes[0].radius, Vector2.one * nodes[0].radius * 2f);
foreach (var n in nodes)
bounds = RectUnion(bounds, new Rect(n.position - Vector2.one * n.radius, Vector2.one * n.radius * 2f));
// center graph at origin (0,0) then set pan so it appears centered in window
Vector2 graphCenter = bounds.center;
// move nodes so center is at origin
for (int i = 0; i < nodes.Count; i++)
nodes[i].position -= graphCenter;
// reset pan/zoom so centered
pan = Vector2.zero;
zoom = 1.0f;
Repaint();
}
static Rect RectUnion(Rect a, Rect b) {
float xMin = Mathf.Min(a.xMin, b.xMin);
float xMax = Mathf.Max(a.xMax, b.xMax);
float yMin = Mathf.Min(a.yMin, b.yMin);
float yMax = Mathf.Max(a.yMax, b.yMax);
return Rect.MinMaxRect(xMin, yMin, xMax, yMax);
}
Vector2 ScreenToGraph_old(Vector2 screenPos) {
Vector2 origin = new Vector2(position.width / 2, position.height / 2);
// invert the GUI.matrix transform (approx for current simple transforms)
return (screenPos - (origin + pan)) / zoom + origin * (1 - 1 / zoom);
}
Vector2 ScreenToGraph(Vector2 screenPos) {
Vector2 windowCenter = new Vector2(position.width / 2f, position.height / 2f);
Rect bounds = GetGraphBounds();
Vector2 graphCenter = bounds.center;
Vector2 autoPan = -graphCenter;
// inverse of: screen -> translate by -(windowCenter+autoPan+pan), scale by 1/zoom, translate by windowCenter
return (screenPos - (windowCenter + autoPan + pan)) / zoom + windowCenter;
}
Rect GetGraphBounds() {
if (nodes == null || nodes.Count == 0) return new Rect(Vector2.zero, Vector2.one);
Rect bounds = new(
nodes[0].position - Vector2.one * nodes[0].radius,
2f * nodes[0].radius * Vector2.one);
foreach (var n in nodes)
bounds = RectUnion(bounds,
new Rect(n.position - Vector2.one * n.radius, 2f * n.radius * Vector2.one));
return bounds;
}
int HitTestNode(Vector2 graphPos) {
// returns node id under point or -1
for (int i = nodes.Count - 1; i >= 0; i--) {
var n = nodes[i];
if ((graphPos - n.position).sqrMagnitude <= n.radius * n.radius) return n.id;
}
return -1;
}
}
}

View File

@ -1,2 +0,0 @@
fileFormatVersion: 2
guid: f041740900808273ab006e7d276a78e9

73
Editor/Brain_Editor.cs Normal file
View File

@ -0,0 +1,73 @@
using UnityEditor;
using UnityEditor.UIElements;
using UnityEngine;
using UnityEngine.UIElements;
namespace NanoBrain {
[CustomEditor(typeof(Brain))]
public class Brain_Editor : Editor {
protected static VisualElement mainContainer;
protected static VisualElement inspectorContainer;
public Brain component;
private SerializedProperty brainProp;
public void OnEnable() {
component = target as Brain;
if (Application.isPlaying == false && serializedObject != null) {
string propertyName = nameof(Brain.brainPrefab);
brainProp = serializedObject.FindProperty(propertyName);
}
}
public override VisualElement CreateInspectorGUI() {
if (Application.isPlaying == false)
serializedObject.Update();
VisualElement root = new() {
style = {
paddingLeft = 0,
paddingRight = 0,
paddingTop = 0,
paddingBottom = 0
}
};
root.styleSheets.Add(Resources.Load<StyleSheet>("GraphStyles"));
PropertyField brainField = new(brainProp) {
label = "Cluster Prefab"
};
root.Add(brainField);
CreateViewer(root, component.brain, component.gameObject);
if (Application.isPlaying == false)
serializedObject.ApplyModifiedProperties();
return root;
}
public ClusterViewer.GraphView CreateViewer(VisualElement root, Cluster cluster, GameObject gameObject) {
VisualElement mainContainer = new() {
style = {
flexDirection = FlexDirection.Row,
minHeight = 450
}
};
ClusterViewer.GraphView graph = new(cluster);
graph.style.flexGrow = 1;
mainContainer.Add(graph);
root.Add(mainContainer);
graph.SetGraph(gameObject);
return graph;
}
}
}

632
Editor/ClusterEditor.cs Normal file
View File

@ -0,0 +1,632 @@
using System.Collections.Generic;
using System.Linq;
using UnityEditor;
using UnityEngine;
using UnityEngine.UIElements;
namespace NanoBrain {
[CustomEditor(typeof(ClusterPrefab))]
public class ClusterEditor : ClusterViewer {
public override VisualElement CreateInspectorGUI() {
ClusterPrefab prefab = target as ClusterPrefab;
if (prefab != null)
prefab.EnsureInitialization();
serializedObject.Update();
VisualElement root = new();
CreateEditor(root, prefab, null);
serializedObject.ApplyModifiedProperties();
return root;
}
public GraphView CreateEditor(VisualElement root, ClusterPrefab cluster, GameObject gameObject) {
root.style.paddingLeft = 0;
root.style.paddingRight = 0;
root.style.paddingTop = 0;
root.style.paddingBottom = 0;
root.styleSheets.Add(Resources.Load<StyleSheet>("GraphStyles"));
VisualElement mainContainer = new() {
style = {
flexDirection = FlexDirection.Row,
}
};
GraphEditor graphContainer = new(cluster);
graphContainer.style.flexShrink = 0;
graphContainer.style.width = 300;
graphContainer.style.overflow = Overflow.Hidden;
VisualElement inspectorContainer = new() {
name = "inspector",
style = {
minHeight = 450,
width = 300,
flexGrow = 0,
flexDirection = FlexDirection.Row,
}
};
mainContainer.Add(graphContainer);
mainContainer.Add(inspectorContainer);
root.Add(mainContainer);
graphContainer.SetGraph(gameObject, inspectorContainer);
return graphContainer;
}
public class GraphEditor : GraphView {
protected ClusterPrefab prefab;
public GraphEditor(ClusterPrefab prefab) : base(prefab.output.parent) {
this.prefab = prefab;
// In a Prefab editor, no instance exists but we need it for the ClusterViewer.
// So we create a temporary instance
Cluster cluster = new(prefab);
this.currentCluster = cluster;
Button addButton = new(() => OnAddClusterOutput()) {
text = "Add"
};
topMenuContainer?.Add(addButton);
Add(topMenuContainer);
}
void OnAddClusterOutput() {
Nucleus newOutput = new Neuron(this.prefab, "New Output");
this.prefab.RefreshOutputs();
outputsPopup.choices = this.prefab.outputs.Select(output => output.name).ToList();
outputsPopup.value = newOutput.name;
this.currentNucleus = newOutput;
}
public void SetGraph(GameObject gameObject, VisualElement inspectorContainer) {
this.gameObject = gameObject;
if (Application.isPlaying == false)
this.serializedBrain = new SerializedObject(this.prefab);
this.selectedOutput = this.currentCluster.outputs[0];
this.currentNucleus = this.selectedOutput;
//this.currentCluster = this.currentNucleus.parent;
Rebuild(inspectorContainer);
// if (outputsPopup != null)
// OnOutputChanged(outputsPopup.choices[0]);
}
private void Rebuild(VisualElement inspectorContainer) {
if (this.currentNucleus == null) {
inspectorContainer.Clear();
return;
}
string path = AssetDatabase.GetAssetPath(this.prefab); // or known path
this.prefabAsset = AssetDatabase.LoadAssetAtPath<ClusterPrefab>(path);
if (this.prefabAsset == null) {
// create in memory save if it doesn't exist
this.prefabAsset = CreateInstance<ClusterPrefab>();
//Debug.LogError("Cluster Prefab is not found on disk");
}
// DrawInspector(inspectorContainer);
if (inspectorContainer == null)
return;
inspectorContainer.Clear();
if (this.currentNucleus == null)
return;
// create a SerializedObject wrapper so Unity inspector controls work (and Undo)
SerializedObject so = new(prefabAsset);
foreach (Nucleus nucleus in this.prefab.nuclei) {
nucleus.Initialize();
}
this.inspectorIMGUIContainer = new IMGUIContainer(() => InspectorHandler(so));
inspectorContainer.Add(inspectorIMGUIContainer);
}
#region Inspector
private VisualElement inspectorIMGUIContainer;
private bool showSynapses = true;
private bool showActivation = true;
protected bool breakOnWake = false;
protected bool trace = false;
void InspectorHandler(SerializedObject serializedObject) {
bool anythingChanged = false;
if (serializedObject == null || serializedObject.targetObject == null)
return;
if (this.currentNucleus == null)
return;
serializedObject.Update();
GUIStyle headerStyle = new(EditorStyles.boldLabel) {
alignment = TextAnchor.MiddleLeft,
margin = new RectOffset(10, 0, 4, 4)
};
GUIStyle boldTextFieldStyle = new(EditorStyles.textField) {
fontStyle = FontStyle.Bold
};
// Nucleus type
string nucleusType = this.currentNucleus.GetType().Name;
GUILayout.Label(nucleusType, headerStyle);
// Nucleus name
if (this.currentNucleus.parent is Cluster parentCluster) {
EditorGUILayout.BeginHorizontal();
if (GUILayout.Button(this.currentNucleus.parent.name))
OnClusterClick(parentCluster);
EditorGUI.BeginDisabledGroup(true);
EditorGUILayout.TextField(this.currentNucleus.name, boldTextFieldStyle);
EditorGUI.EndDisabledGroup();
if (GUILayout.Button("Reimport"))
ReimportCluster(parentCluster);
EditorGUILayout.EndHorizontal();
}
else {
string newName = EditorGUILayout.TextField(this.currentNucleus.name, boldTextFieldStyle);
if (newName != this.currentNucleus.name) {
this.currentNucleus.name = newName;
this.prefab.RefreshOutputs();
outputsPopup.choices = this.prefab.outputs.Select(output => output.name).ToList();
anythingChanged = true;
}
}
// Current output value
if (Application.isPlaying) {
if (currentNucleus is Neuron currentNeuron1) {
GUIContent nameLabel = new("Output", currentNeuron1.outputValue.ToString());
EditorGUILayout.FloatField(nameLabel, currentNeuron1.outputMagnitude);
}
else
EditorGUILayout.LabelField(" ");
}
else
EditorGUILayout.LabelField(" ");
// Memory cell
if (this.currentNucleus is MemoryCell memory)
MemoryCellInspector(memory, ref anythingChanged);
// Cluster
else if (this.currentNucleus is Cluster cluster)
ClusterInspector(cluster, ref anythingChanged);
// Other
else
NucleusInspector(this.currentNucleus, ref anythingChanged);
if (GUILayout.Button("Delete"))
DeleteNucleus(this.currentNucleus);
serializedObject.ApplyModifiedProperties();
if (anythingChanged) {
EditorUtility.SetDirty(prefabAsset);
AssetDatabase.SaveAssets();
}
}
protected void MemoryCellInspector(MemoryCell memoryCell, ref bool anythingChanged) {
memoryCell.staticMemory = EditorGUILayout.Toggle("Static Memory", memoryCell.staticMemory);
NucleusInspector(memoryCell, ref anythingChanged);
}
protected void ClusterInspector(Cluster cluster, ref bool anythingChanged) {
EditorGUILayout.BeginHorizontal();
int instanceCount = cluster.instanceCount;
if (instanceCount <= 1) {
if (cluster.siblingClusters != null && cluster.siblingClusters.Length > 1)
instanceCount = cluster.siblingClusters.Count();
else
instanceCount = 1;
}
EditorGUILayout.IntField("Instances", instanceCount, GUILayout.MinWidth(150));
if (GUILayout.Button("Add")) {
Undo.RecordObject(prefabAsset, "Array add " + prefabAsset.name);
//cluster.AddInstance(this.prefab);
cluster.AddInstance();
anythingChanged = true;
}
if (GUILayout.Button("Del")) {
Undo.RecordObject(prefabAsset, "Array delete " + prefabAsset.name);
cluster.RemoveInstance();
anythingChanged = true;
}
EditorGUILayout.EndHorizontal();
if (GUILayout.Button("Reimport Cluster"))
ReimportCluster(cluster);
}
protected void NucleusInspector(Nucleus nucleus, ref bool anythingChanged) {
SynapsesInspector(ref anythingChanged);
ActivationInspector(ref anythingChanged);
EditorGUILayout.Space();
breakOnWake = EditorGUILayout.Toggle("Break on wake", breakOnWake);
if (breakOnWake && this.currentNucleus is Neuron currentNeuron) {
if (currentNeuron.isSleeping == false)
Debug.Break();
}
trace = EditorGUILayout.Toggle("Trace", trace);
this.currentNucleus.trace = trace;
}
protected void SynapsesInspector(ref bool anythingChanged) {
showSynapses = EditorGUILayout.BeginFoldoutHeaderGroup(showSynapses, "Synapses");
if (showSynapses) {
if (this.currentNucleus is Neuron neuron2) {
Neuron.CombinatorType newCombinator = (Neuron.CombinatorType)EditorGUILayout.EnumPopup("Combinator", neuron2.combinator);
anythingChanged |= newCombinator != neuron2.combinator;
neuron2.combinator = newCombinator;
}
EditorGUIUtility.wideMode = true;
float previousLabelWidth = EditorGUIUtility.labelWidth;
EditorGUIUtility.labelWidth = 100;
Vector3 newBias = EditorGUILayout.Vector3Field("Bias", this.currentNucleus.bias);
anythingChanged |= newBias != this.currentNucleus.bias;
this.currentNucleus.bias = newBias;
EditorGUIUtility.labelWidth = previousLabelWidth;
Nucleus[] array = null;
int elementIx = -1;
if (this.currentNucleus.synapses.Count > 0) {
Synapse[] synapses = this.currentNucleus.synapses.ToArray();
foreach (Synapse synapse in synapses) {
if (synapse.neuron == null)
continue;
if (array != null) {
if (synapse.neuron.parent is Cluster iCluster && elementIx > 0) {
int thisElementIx = Cluster.GetNucleusIndex(iCluster.clusterNuclei, synapse.neuron);
if (thisElementIx == elementIx)
continue;
else
elementIx = thisElementIx;
}
if (array.Contains(synapse.neuron))
continue;
else if (array.Contains(synapse.neuron.parent))
continue;
}
else {
if (synapse.neuron.parent is Cluster iReceptor) {
array = iReceptor.siblingClusters;
if (iReceptor is Cluster iCluster)
elementIx = Cluster.GetNucleusIndex(iCluster.clusterNuclei, synapse.neuron);
}
}
EditorGUILayout.Space();
if (Application.isPlaying) {
if (synapse.neuron is Neuron synapseNeuron) {
Vector3 value = synapseNeuron.outputValue * synapse.weight;
GUIContent synapseValueLabel = new(synapse.neuron.name, synapseNeuron.outputValue.ToString());
EditorGUILayout.FloatField(synapseValueLabel, synapseNeuron.outputMagnitude);
}
}
else {
EditorGUILayout.BeginHorizontal();
if (synapse.neuron.clusterPrefab != this.currentNucleus.clusterPrefab) {
// If it is a cluster
GUIStyle labelStyle = new(GUI.skin.label);
float labelWidth = 200;
if (synapse.neuron.clusterPrefab != null) {
labelWidth = labelStyle.CalcSize(new GUIContent($"{synapse.neuron.clusterPrefab.name}.")).x;
GUILayout.Label($"{synapse.neuron.clusterPrefab.name}", GUILayout.Width(labelWidth));
}
//string[] options = synapse.neuron.parent.clusterNuclei.Select(n => n.name).ToArray();
string[] options = synapse.neuron.clusterPrefab.nuclei.Select(n => n.name).ToArray();
int selectedIndex = System.Array.IndexOf(options, synapse.neuron.name);
int newIndex = EditorGUILayout.Popup(selectedIndex, options);
// if (newIndex != selectedIndex && synapse.neuron.clusterPrefab.nuclei[newIndex] is Neuron newNeuron)
// ChangeSynapse(synapse, newNeuron);
if (newIndex != selectedIndex) {
// It shall be ensured that the parent.clusterNuclei and
// clusterPrefab.nuclei contain the same neurons in the same order....
Nucleus selectedNucleus = synapse.neuron.parent.clusterNuclei[newIndex];
Neuron newNeuron = selectedNucleus as Neuron;
ChangeSynapse(synapse, newNeuron);
}
}
else
GUILayout.Label(synapse.neuron.name);
bool disconnecting = GUILayout.Button("Disconnect", GUILayout.Width(80));
if (disconnecting && synapse.neuron is Neuron synapseNeuron) {
synapseNeuron.RemoveReceiver(this.currentNucleus);
this.prefab.GarbageCollection();
anythingChanged = true;
}
EditorGUILayout.EndHorizontal();
}
EditorGUI.indentLevel++;
float newWeight = EditorGUILayout.FloatField("Weight", synapse.weight);
if (newWeight != synapse.weight) {
// if (synapse.neuron.parent is IReceptor receptor) {
// Nucleus[] receptorArray = receptor.nucleiArray;
// foreach (Synapse s in this.currentNucleus.synapses) {
// if (s.neuron.parent is IReceptor r && r.nucleiArray == receptorArray)
// s.weight = newWeight;
// }
// }
// else
synapse.weight = newWeight;
anythingChanged = true;
}
EditorGUI.indentLevel--;
}
}
EditorGUILayout.Space();
anythingChanged |= ConnectNucleus(this.prefab, this.currentNucleus);
anythingChanged |= AddSynapse(this.prefab, this.currentNucleus);
}
EditorGUILayout.EndFoldoutHeaderGroup();
}
protected void ActivationInspector(ref bool anythingChanged) {
EditorGUILayout.Space();
showActivation = EditorGUILayout.BeginFoldoutHeaderGroup(showActivation, "Activation");
if (showActivation) {
if (this.currentNucleus is Neuron neuron) {
if (this.currentNucleus is not MemoryCell) {
EditorGUILayout.BeginHorizontal();
EditorGUILayout.LabelField("Activation Curve", GUILayout.MinWidth(60));
if (neuron.curveMax > 0)
EditorGUILayout.CurveField(neuron.curve, Color.cyan, new Rect(0, 0, 1, neuron.curveMax), GUILayout.Width(40));
else
EditorGUILayout.CurveField(neuron.curve, Color.cyan, new Rect(0, neuron.curveMax, 1, -neuron.curveMax), GUILayout.Width(40));
Neuron.ActivationType newPreset = (Neuron.ActivationType)EditorGUILayout.EnumPopup(neuron.curvePreset, GUILayout.MinWidth(50));
anythingChanged |= newPreset != neuron.curvePreset;
neuron.curvePreset = newPreset;
EditorGUILayout.EndHorizontal();
}
// if (neuron is Receptor receptor2) {
// if (receptor2.nucleiArray == null || receptor2.nucleiArray.Count() == 0)
// receptor2.array = new NucleusArray(neuron);
// }
}
EditorGUILayout.Space();
}
EditorGUILayout.EndFoldoutHeaderGroup();
}
#region Synapses
protected virtual void AddInput(Nucleus.Type selectedType, Nucleus nucleus) {
switch (selectedType) {
case Nucleus.Type.Neuron:
AddNeuronInput(nucleus);
break;
case Nucleus.Type.MemoryCell:
AddMemoryCellInput(nucleus);
break;
case Nucleus.Type.Cluster:
AddClusterInput(nucleus);
break;
// case Nucleus.Type.Receptor:
// AddReceptorInput(nucleus);
// break;
// case Nucleus.Type.ClusterReceptor:
// AddClusterReceptorInput(nucleus);
// break;
// case Nucleus.Type.ClusterArray:
// AddClusterArrayInput(nucleus);
// break;
default:
break;
}
}
protected virtual void AddNeuronInput(Nucleus nucleus) {
Neuron newNeuroid = new(this.prefab, "New neuron");
newNeuroid.AddReceiver(nucleus);
this.currentNucleus = newNeuroid;
}
protected virtual void AddMemoryCellInput(Nucleus nucleus) {
MemoryCell newMemory = new(this.prefab, "New memory cell");
newMemory.AddReceiver(nucleus);
this.currentNucleus = newMemory;
}
protected virtual void AddClusterInput(Nucleus nucleus) {
ClusterPickerWindow.ShowPicker(brain => OnClusterPicked(nucleus, brain), "Select Cluster");
}
private void OnClusterPicked(Nucleus nucleus, ClusterPrefab selectedPrefab) {
Cluster subclusterInstance = new(selectedPrefab, this.prefab);
subclusterInstance.defaultOutput.AddReceiver(nucleus);
}
private void ReimportCluster(Cluster subCluster) {
if (subCluster.siblingClusters == null || subCluster.siblingClusters.Length <= 0) {
Cluster reimportedCluster = new(subCluster.prefab, this.prefab);
subCluster.MoveReceivers(reimportedCluster);
// subcluster should be garbage now...
this.currentNucleus = reimportedCluster;
}
else {
this.currentNucleus = null;
List<Cluster> newSiblingsList = new();
foreach (Cluster sibling in subCluster.siblingClusters) {
Cluster reimportedCluster = new(sibling.prefab, this.prefab) {
name = sibling.name
};
sibling.MoveReceivers(reimportedCluster);
newSiblingsList.Add(reimportedCluster);
// make the first reimportedCluster the new current nucleus
this.currentNucleus ??= reimportedCluster;
}
Cluster[] newSiblings = newSiblingsList.ToArray();
foreach (Cluster sibling in newSiblings)
sibling.siblingClusters = newSiblings;
}
}
int selectedConnectNucleus = -1;
// Connect to another nucleus
protected virtual bool ConnectNucleus(ClusterPrefab cluster, Nucleus nucleusToConnect) {
if (cluster == null)
return false;
IEnumerable<Nucleus> synapseNuclei = this.currentNucleus.synapses
.Where(synapse => synapse.neuron != null)
.Select(synapse => synapse.neuron);
IEnumerable<Nucleus> nuclei = cluster.nuclei
.Except(synapseNuclei);
IEnumerable<string> nucleiNames = nuclei
.Select(n => {
int idx = n.name.IndexOf(':');
return idx < 0 ? n.name : n.name[..idx];
})
.Distinct();
string[] names = nucleiNames.ToArray();
EditorGUILayout.BeginHorizontal();
selectedConnectNucleus = EditorGUILayout.Popup(selectedConnectNucleus, names);
bool connecting = GUILayout.Button("Connect", GUILayout.Width(80));
EditorGUILayout.EndHorizontal();
if (connecting) {
Nucleus nucleus = nuclei.ElementAt(selectedConnectNucleus);
if (nucleus is Cluster subCluster)
subCluster.AddArrayReceiver(this.currentNucleus);
else if (nucleus is Neuron neuron)
neuron.AddReceiver(this.currentNucleus);
}
return connecting;
}
protected virtual void DeleteNucleus(Nucleus nucleus) {
if (nucleus == null)
return;
if (nucleus is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers) {
if (receiver != null) {
this.currentNucleus = receiver;
break;
}
}
}
this.prefab.nuclei.Remove(nucleus);
if (outputsPopup.value == nucleus.name) {
this.prefab.RefreshOutputs();
outputsPopup.choices = this.prefab.outputs.Select(output => output.name).ToList();
outputsPopup.index = 0;
}
Neuron.Delete(nucleus);
this.currentNucleus = this.prefab.output;
}
Nucleus.Type selectedType = Nucleus.Type.None;
protected virtual bool AddSynapse(ClusterPrefab cluster, Nucleus nucleus) {
if (cluster == null)
return false;
EditorGUILayout.BeginHorizontal();
selectedType = (Nucleus.Type)EditorGUILayout.EnumPopup(selectedType);
bool connecting = GUILayout.Button("Add", GUILayout.Width(80));
EditorGUILayout.EndHorizontal();
if (connecting) {
AddInput(selectedType, this.currentNucleus);
}
return connecting;
// if (selectedType == Nucleus.Type.None)
// return false;
// AddInput(selectedType, this.currentNucleus);
// return true;
}
protected virtual void ChangeSynapse(Synapse synapse, Neuron newNucleus) {
Neuron synapseNeuron = synapse.neuron as Neuron;
if (synapse.neuron.parent is Cluster subCluster && subCluster.prefab != this.prefab) {
// if (synapse.neuron.parent is ClusterReceptor receptor) {
// // the new nucleus is part of a (cluster) receptor,
// // so we have to change all synapses to this nucleus array elements
// int oldNucleusIx = Cluster.GetNucleusIndex(subCluster.clusterNuclei, synapse.neuron);
// int newNucleusIx = Cluster.GetNucleusIndex(subCluster.clusterNuclei, newNucleus);
// foreach (Nucleus element in receptor.nucleiArray) {
// if (element is not ClusterReceptor clusterReceptor)
// continue;
// // Get the same neuron as the synapse.nucleus in a different element
// // of the ClusterReceptor array
// Nucleus oldElementNucleus = clusterReceptor.clusterNuclei[oldNucleusIx];
// if (oldElementNucleus is not Neuron oldElementNeuron)
// continue;
// // Get the same neuron as newNucleus in a different element
// // of the ClusterReceptor array
// Nucleus newElementNucleus = clusterReceptor.clusterNuclei[newNucleusIx];
// if (newElementNucleus is not Neuron newElementNeuron)
// continue;
// oldElementNeuron.RemoveReceiver(this.currentNucleus);
// newElementNeuron.AddReceiver(this.currentNucleus);
// // Now find the synapse which pointed to the old Neuron
// // Synapse synapseForUpdate = this.currentNucleus.GetSynapse(oldElementNeuron);
// // synapseForUpdate.nucleus = newElementNeuron;
// }
// }
// else {
// it is a neuron in a subcluster
synapseNeuron.RemoveReceiver(this.currentNucleus);
newNucleus.AddReceiver(this.currentNucleus);
// }
}
else {
synapseNeuron.RemoveReceiver(this.currentNucleus);
newNucleus.AddReceiver(this.currentNucleus);
}
}
protected virtual void DisconnectNucleus(Neuron nucleus) {
if (this.currentNucleus.clusterPrefab == null)
return;
string[] names = this.currentNucleus.synapses.Select(synapse => synapse.neuron.name).ToArray();
int selectedIndex = -1;
selectedIndex = EditorGUILayout.Popup("Disconnect from", selectedIndex, names);
if (selectedIndex >= 0 && selectedIndex < this.currentNucleus.clusterPrefab.nuclei.Count) {
Synapse synapse = this.currentNucleus.synapses[selectedIndex];
Neuron synapseNeuron = synapse.neuron as Neuron;
synapseNeuron.RemoveReceiver(this.currentNucleus);
}
}
#endregion Synapses
#endregion Inspector
}
}
}

File diff suppressed because it is too large Load Diff

950
Editor/ClusterViewer.cs Normal file
View File

@ -0,0 +1,950 @@
using System.Collections.Generic;
using System.Linq;
using UnityEditor;
using UnityEngine;
using UnityEngine.UIElements;
namespace NanoBrain {
public class ClusterViewer : Editor {
public static ClusterPrefab previousPrefab;
public class GraphView : VisualElement {
//protected readonly ClusterPrefab prefab;
protected Cluster currentCluster;
protected SerializedObject serializedBrain;
protected Nucleus currentNucleus;
protected Nucleus selectedOutput;
protected GameObject gameObject;
private bool expandArray = false;
protected ClusterPrefab prefabAsset;
protected VisualElement topMenuContainer;
protected ScrollView scrollView;
protected IMGUIContainer graphContainer;
protected readonly PopupField<string> outputsPopup;
public enum Mode {
Focus,
Full
}
public Mode mode = Mode.Focus;
public GraphView(Cluster cluster) {
this.currentCluster = cluster;
name = "content";
style.flexGrow = 1;
topMenuContainer = new() {
style = {
flexDirection = FlexDirection.Row,
alignItems = Align.Center,
}
};
EnumField modePopup = new(mode);
modePopup.style.width = 80;
modePopup.RegisterValueChangedCallback(OnModeChange);
topMenuContainer.Add(modePopup);
scrollView = new(ScrollViewMode.Horizontal);
scrollView.style.position = Position.Absolute;
scrollView.style.left = 0; scrollView.style.top = 0;
scrollView.style.right = 0; scrollView.style.bottom = 0;
//scrollView.style.flexGrow = 1;
scrollView.horizontalScrollerVisibility = ScrollerVisibility.Auto; // Auto shows when needed
scrollView.verticalScrollerVisibility = ScrollerVisibility.Hidden;
graphContainer = new(OnIMGUI);
//graphContainer.style.position = Position.Relative; // or omit this line
//graphContainer.style.position = Position.Absolute;
// graphContainer.style.left = 0; graphContainer.style.top = 0;
// graphContainer.style.right = 0; graphContainer.style.bottom = 0;
graphContainer.pickingMode = PickingMode.Position;
graphContainer.focusable = true;
//graphContainer.style.width = 1200;
//graphContainer.style.width = new StyleLength(StyleKeyword.Null); // allow content to determine width
scrollView.contentContainer.Add(graphContainer);
Add(scrollView);
Add(topMenuContainer);
// Subscribe when added to panel (editor UI ready)
RegisterCallback<AttachToPanelEvent>(evt => Subscribe());
RegisterCallback<DetachFromPanelEvent>(evt => Unsubscribe());
}
protected virtual void OnModeChange(ChangeEvent<System.Enum> changeEvent) {
this.mode = (Mode)changeEvent.newValue;
}
bool subscribed = false;
void Subscribe() {
if (subscribed) return;
SceneView.duringSceneGui += OnSceneGUI;
subscribed = true;
SceneView.RepaintAll();
}
void Unsubscribe() {
if (!subscribed) return;
SceneView.duringSceneGui -= OnSceneGUI;
subscribed = false;
}
public void SetGraph(GameObject gameObject) {
this.gameObject = gameObject;
if (Application.isPlaying == false)
this.serializedBrain = new SerializedObject(this.currentCluster.prefab);
this.selectedOutput = this.currentCluster.outputs[0];
this.currentNucleus = this.selectedOutput;
Rebuild();
}
void Rebuild() {
if (this.currentNucleus == null)
return;
string path = AssetDatabase.GetAssetPath(this.currentCluster.prefab); // or known path
this.prefabAsset = AssetDatabase.LoadAssetAtPath<ClusterPrefab>(path);
if (this.prefabAsset == null) {
// create in memory save if it doesn't exist
this.prefabAsset = CreateInstance<ClusterPrefab>();
//Debug.LogError("Cluster Prefab is not found on disk");
}
}
public void OnIMGUI() {
if (Application.isPlaying == false)
serializedBrain.Update();
Handles.BeginGUI();
DrawGraph();
Handles.EndGUI();
}
#region Graph
protected virtual void DrawGraph() {
if (mode == Mode.Focus)
DrawFocusGraph();
else
DrawFullGraph();
}
#region Full Graph
protected void DrawFullGraph() {
//Dag dag = GenerateGraph(this.prefab);
Dag dag = GenerateGraph(this.selectedOutput);
Dag.ComputeLayout(dag);
// Draw edges
foreach (Dag.Edge e in dag.edges) {
Dag.Node from = dag.nodes.FirstOrDefault(x => x.id == e.fromId);
Dag.Node to = dag.nodes.FirstOrDefault(x => x.id == e.toId);
if (from == null || to == null)
continue;
Vector2 fromPosition = from.position;
Vector2 toPosition = to.position;
DrawEdge(fromPosition, toPosition);
}
// Draw nodes
foreach (Dag.Node n in dag.nodes)
DrawNucleus(n.nucleus, n.position, 1, n.radius);
// Determine graph width
float width = 0;
float currentNucleusPosition = 0;
foreach (Dag.Node node in dag.nodes) {
if (node.position.x > width)
width = node.position.x;
if (node.nucleus == currentNucleus)
currentNucleusPosition = node.position.x;
}
// Resize the graph container to the full graph width
float margin = 50f;
graphContainer.style.width = width + 2 * margin;
// Scroll to the current nucleus
float viewportWidth = scrollView.layout.width;
// center currentNucleus in viewport
float desiredScrollX = currentNucleusPosition - viewportWidth * 0.5f;
// clamp between 0 and maximum scrollable range
float maxScrollX = Mathf.Max(0f, graphContainer.resolvedStyle.width - viewportWidth);
desiredScrollX = Mathf.Clamp(desiredScrollX, 0f, maxScrollX);
Vector2 current = scrollView.scrollOffset;
scrollView.scrollOffset = new Vector2(desiredScrollX, current.y);
}
public Dag GenerateGraph(Nucleus rootNucleus) {
Dag dag = new();
if (rootNucleus == null)
return dag;
int ix = 0;
Dag.Node receiver = new() {
id = ix,
//title = nucleus.name,
nucleus = rootNucleus
};
dag.nodes.Add(receiver);
ix++;
DescendGraph(receiver, ref ix, dag);
return dag;
}
private void DescendGraph(Dag.Node receiver, ref int ix, Dag dag) {
foreach (Synapse synapse in receiver.nucleus.synapses) {
Nucleus nucleus = synapse.neuron;
if (nucleus.parent != null && nucleus.parent != currentNucleus.parent) {
nucleus = nucleus.parent;
}
string nucleusName = nucleus.name;
Dag.Node synapseNode = dag.FindNode(nucleusName);
if (synapseNode == null) {
synapseNode = new() {
id = ix,
nucleus = nucleus
};
dag.nodes.Add(synapseNode);
}
Dag.Edge edge = new() {
fromId = synapseNode.id,
toId = receiver.id
};
dag.edges.Add(edge);
ix++;
DescendGraph(synapseNode, ref ix, dag);
}
}
#endregion Full Graph
#region Focus Graph
protected void DrawFocusGraph() {
float size = 20;
Vector3 position = new(150, 210, 0);
if (this.currentNucleus != null) {
DrawReceivers(this.currentNucleus, position, size);
DrawSynapses(this.currentNucleus, position, size);
// Draw selected Nucleus
if (expandArray) {
float maxValue = 1;
if (this.currentNucleus is Cluster cluster) {
float spacing = 400f / cluster.instanceCount;
float margin = 10 + spacing / 2;
float xMin = 150 - size;
float xMax = 150 + size;
float yMin = 10 + margin - size / 2;
float yMax = 400 - margin + size;
Vector3[] verts = new Vector3[4] {
new(xMin, yMin, 0),
new(xMax, yMin, 0),
new(xMax, yMax, 0),
new(xMin, yMax, 0)
};
Handles.color = Color.black;
Handles.DrawAAConvexPolygon(verts);
int row = 0;
if (cluster.siblingClusters == null) {
Vector3 pos = new(150, margin + row * spacing, 0.0f);
Handles.color = Color.white;
// The selected sibling highlight ring
Handles.DrawSolidDisc(pos, Vector3.forward, size + 2);
DrawNucleus(cluster, pos, maxValue, size);
row++;
}
else {
foreach (Cluster sibling in cluster.siblingClusters) {
Vector3 pos = new(150, margin + row * spacing, 0.0f);
Handles.color = Color.white;
// The selected sibling highlight ring
Handles.DrawSolidDisc(pos, Vector3.forward, size + 2);
DrawNucleus(sibling, pos, maxValue, size);
row++;
}
}
GUIStyle style = new(EditorStyles.label) {
alignment = TextAnchor.UpperCenter,
normal = { textColor = Color.white },
fontStyle = FontStyle.Bold,
};
Vector3 labelPos = new(150, yMax + size + 5, 0);
string clusterName = cluster.name;
int colonPos = clusterName.IndexOf(":");
if (colonPos > 0) {
string baseName = clusterName[..colonPos];
Handles.Label(labelPos, baseName, style);
}
else
Handles.Label(labelPos, clusterName, style);
}
else {
if (this.currentNucleus is Neuron neuron)
maxValue = neuron.outputMagnitude;
DrawNucleus(this.currentNucleus, position, maxValue, 20);
}
}
else {
float maxValue = 1;
if (this.currentNucleus is Neuron neuron)
maxValue = neuron.outputMagnitude;
else if (this.currentNucleus is Cluster cluster)
maxValue = cluster.defaultOutput.outputMagnitude;
DrawNucleus(this.currentNucleus, position, maxValue, 20);
}
}
else {
DrawAllOutputs(position, size);
DrawOutputs(position, size);
}
graphContainer.style.width = 300;
}
protected void DrawReceivers(Nucleus nucleus, Vector3 parentPos, float size) {
List<Nucleus> receivers;
if (nucleus is Neuron neuron)
receivers = neuron.receivers;
else if (nucleus is Cluster cluster)
receivers = cluster.CollectReceivers();
else
return;
// For top-level nodes, add link to previous editor and/or 'Outputs'
int nodeCount = receivers.Count();
if (nucleus == this.selectedOutput) {
// Add link to 'Outpus'
nodeCount++;
if (ClusterViewer.previousPrefab != null)
// Add link to previous editor
nodeCount++;
}
// Determine the maximum value in this layer
// This is used to 'scale' the output value colors of the nuclei
float maxValue = 0;
foreach (Nucleus receiver in receivers) {
if (receiver is Neuron neuroid) {
float value = neuroid.outputMagnitude;
if (value > maxValue)
maxValue = value;
}
}
// Determine the spacing of the nuclei in the layer
float spacing = 400f / nodeCount;
float margin = 10 + spacing / 2;
int row = 0;
List<Nucleus[]> drawnArrays = new();
foreach (Nucleus receiver in receivers) {
Nucleus receiverNucleus = receiver;
if (receiverNucleus == null)
continue;
Vector3 pos = new(50, margin + row * spacing, 0.0f);
DrawEdge(parentPos, pos);
DrawNucleus(receiverNucleus, pos, maxValue, size);
row++;
}
if (nucleus == this.selectedOutput) {
Vector3 pos = new(50, margin + row * spacing, 0);
if (ClusterViewer.previousPrefab != null) {
DrawEdge(parentPos, pos);
DrawClusterPrefab(ClusterViewer.previousPrefab, pos, size);
row++;
}
pos = new(50, margin + row * spacing, 0);
DrawEdge(parentPos, pos);
DrawAllOutputs(pos, size);
}
}
protected void DrawSynapses(Nucleus nucleus, Vector3 parentPos, float size) {
if (nucleus == null)
return;
// Determine the maximum value in this layer
// This is used to 'scale' the output value colors of the nuclei
float maxValue = 0;
int neuronCount = 0;
List<Neuron> drawnNeurons = new();
foreach (Synapse synapse in nucleus.synapses) {
if (synapse.neuron == null)
continue;
// Count multiple synapses to the same neuron only once
if (drawnNeurons.Contains(synapse.neuron))
continue;
drawnNeurons.Add(synapse.neuron);
float value = synapse.neuron.outputMagnitude * synapse.weight;
if (value > maxValue)
maxValue = value;
neuronCount++;
}
// Determine the spacing of the nuclei in the layer
float spacing = 400f / neuronCount;
float margin = 10 + spacing / 2;
int row = 0;
drawnNeurons = new();
foreach (Synapse synapse in nucleus.synapses) {
if (synapse.neuron is null)
continue;
// Draw multiple synapses to the same neuron only once
if (drawnNeurons.Contains(synapse.neuron))
continue;
drawnNeurons.Add(synapse.neuron);
Vector3 pos = new(250, margin + row * spacing, 0.0f);
DrawEdge(parentPos, pos);
// Handles.color = Color.white;
// Handles.DrawLine(parentPos, pos);
Color color = Color.black;
if (Application.isPlaying) {
if (maxValue == 0 || !float.IsFinite(maxValue))
maxValue = 1;
float brightness = synapse.neuron.outputMagnitude * synapse.weight / maxValue;
color = new Color(brightness, brightness, brightness, 1f);
}
DrawNucleus(synapse.neuron, pos, size, color);
row++;
}
}
protected void DrawOutputs(Vector2 parentPos, float size) {
// Determine the maximum value in this layer
// This is used to 'scale' the output value colors of the nuclei
float maxValue = 0;
int neuronCount = 0;
List<Nucleus> drawnNuclei = new();
foreach (Nucleus nucleus in this.currentCluster.outputs) {
if (nucleus is not Neuron neuron)
continue;
// Draw multiple synapses to the same neuron only once
if (drawnNuclei.Contains(nucleus))
continue;
drawnNuclei.Add(nucleus);
float value = neuron.outputMagnitude;
if (value > maxValue)
maxValue = value;
neuronCount++;
}
// Determine the spacing of the nuclei in the layer
float spacing = 400f / neuronCount;
float margin = 10 + spacing / 2;
int row = 0;
drawnNuclei = new();
foreach (Nucleus nucleus in this.currentCluster.outputs) {
if (nucleus is not Neuron neuron)
continue;
// Draw multiple synapses to the same neuron only once
if (drawnNuclei.Contains(nucleus))
continue;
drawnNuclei.Add(nucleus);
Vector3 pos = new(250, margin + row * spacing, 0.0f);
DrawEdge(parentPos, pos);
Color color = Color.black;
if (Application.isPlaying) {
if (maxValue == 0 || !float.IsFinite(maxValue))
maxValue = 1;
float brightness = neuron.outputMagnitude / maxValue;
color = new Color(brightness, brightness, brightness, 1f);
}
DrawNucleus(nucleus, pos, size, color);
row++;
}
}
#endregion Focus Graph
protected void DrawNucleus(Nucleus nucleus, Vector3 position, float maxValue, float size) {
Color color;
if (Application.isPlaying) {
float brightness = 0;
if (nucleus is Neuron neuron)
brightness = neuron.outputMagnitude / maxValue;
color = new Color(brightness, brightness, brightness, 1f);
}
else
color = Color.black;
DrawNucleus(nucleus, position, size, color);
}
protected void DrawNucleus(Nucleus nucleus, Vector3 position, float size, Color color) {
if (nucleus == null)
return;
if (nucleus == this.currentNucleus) {
// The selected nucleus highlight ring
Handles.color = Color.white;
Handles.DrawSolidDisc(position, Vector3.forward, size + 2);
}
if (nucleus is MemoryCell) {
Handles.color = Color.white;
Handles.DrawWireDisc(position + Vector3.right * 10, Vector3.forward, size);
}
Handles.color = color;
Handles.DrawSolidDisc(position, Vector3.forward, size);
Handles.color = Color.white;
// Position the label in front of the disc
Vector3 labelPosition = position + (Vector3.forward * 0.1f);
GUIStyle style = new(EditorStyles.label) {
alignment = TextAnchor.MiddleCenter,
normal = { textColor = Color.white },
fontStyle = FontStyle.Bold,
};
if (nucleus.parent is Cluster parentCluster && currentNucleus != null && parentCluster != currentNucleus.parent)
DrawCluster(parentCluster, position, color, size);
else if (nucleus is Cluster cluster)
DrawCluster(cluster, position, color, size);
if (expandArray == false || nucleus != currentNucleus) {
// put name below nucleus
Vector3 labelPos = position - Vector3.down * (size + 5); // below neuron
style.alignment = TextAnchor.UpperCenter;
if (nucleus.parent != null && currentNucleus != null && nucleus.parent != currentNucleus.parent && nucleus.parent is Cluster parentCluster1) {
// This neuron is part of another cluster
parentCluster1.name ??= "";
string baseName = "";
int colonPos = parentCluster1.name.IndexOf(":");
if (colonPos > 0 && colonPos < parentCluster1.name.Length - 2)
baseName = parentCluster1.name[..colonPos] + ".";
else
baseName = parentCluster1.name + ".";
// if (colonPos > 0 && colonPos < parentCluster1.name.Length - 2) {
// // if it is an array, we should not show the :0 of the first element
// //baseName = baseName[..colonPos];
// Handles.Label(labelPos, baseName + nucleus.name, style);
// }
// else
Handles.Label(labelPos, baseName + nucleus.name, style);
}
else {
nucleus.name ??= "";
int colonPos = nucleus.name.IndexOf(":");
if (colonPos > 0 && colonPos < nucleus.name.Length - 2) {
// if it is an array, we should not show the :0 of the first element
string baseName = nucleus.name[..colonPos];
Handles.Label(labelPos, baseName, style);
}
else
Handles.Label(labelPos, nucleus.name, style);
}
}
// Tooltip
Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2);
int id = GUIUtility.GetControlID(FocusType.Passive);
Event e = Event.current;
EventType et = e.GetTypeForControl(id);
if (e != null && neuronRect.Contains(e.mousePosition)) {
// Process Hover
HandleMouseHover(nucleus, neuronRect);
// Process click
if (e.type == EventType.MouseDown && e.button == 0) {
// Consume the event so the scene doesn't also handle it
e.Use();
if (nucleus is Cluster parentCluster2)
OnNeuronClick(parentCluster2);
else
OnNeuronClick(nucleus);
}
}
}
protected void DrawCluster(Cluster cluster, Vector3 position, Color color, float size) {
GUIStyle labelTextStyle = new(EditorStyles.label) {
normal = { textColor = Color.white },
fontStyle = FontStyle.Bold,
};
if (expandArray) {
// Put array indices above the discs
labelTextStyle.alignment = TextAnchor.LowerCenter;
Vector3 labelPosition = position + Vector3.down * (size + 5); // below disc
// Strip the instance number in the name
int colonPos1 = cluster.name.IndexOf(":");
if (colonPos1 > 0) {
string extName = cluster.name[(colonPos1 + 2)..];
Handles.Label(labelPosition, extName, labelTextStyle);
}
else
Handles.Label(labelPosition, "0", labelTextStyle);
}
else {
// Put instance count inside the disc
labelTextStyle.alignment = TextAnchor.MiddleCenter;
Vector3 labelPosition = position + (Vector3.forward * 0.1f);
// Adjust text color based on disc color
if (color.grayscale > 0.5f)
labelTextStyle.normal.textColor = Color.black;
else
labelTextStyle.normal.textColor = Color.white;
if (cluster.instanceCount > 1) {
Handles.Label(labelPosition, cluster.instanceCount.ToString(), labelTextStyle);
labelTextStyle.normal.textColor = Color.white;
}
else if (cluster.siblingClusters != null && cluster.siblingClusters.Length > 1) {
Handles.Label(labelPosition, cluster.siblingClusters.Length.ToString(), labelTextStyle);
labelTextStyle.normal.textColor = Color.white;
}
}
// Draw a circle around the disc to indicate this is a Cluster
Handles.color = Color.white;
Handles.DrawWireDisc(position, Vector3.forward, size + 5);
}
protected void DrawClusterPrefab(ClusterPrefab prefab, Vector2 position, float size) {
Handles.color = Color.black;
Handles.DrawSolidDisc(position, Vector3.forward, size);
// Draw a circle around the disc to indicate this is a Cluster
Handles.color = Color.white;
Handles.DrawWireDisc(position, Vector3.forward, size + 5);
// put name below nucleus
GUIStyle style = new(EditorStyles.label) {
alignment = TextAnchor.MiddleCenter,
normal = { textColor = Color.white },
fontStyle = FontStyle.Bold,
};
Vector2 labelPos = position - Vector2.down * (size + 5); // below neuron
style.alignment = TextAnchor.UpperCenter;
Handles.Label(labelPos, prefab.name, style);
Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2);
int id = GUIUtility.GetControlID(FocusType.Passive);
Event e = Event.current;
EventType et = e.GetTypeForControl(id);
if (e != null && neuronRect.Contains(e.mousePosition)) {
// Process click
if (e.type == EventType.MouseDown && e.button == 0) {
// Consume the event so the scene doesn't also handle it
e.Use();
Selection.activeObject = prefab;
EditorGUIUtility.PingObject(prefab);
ClusterViewer.previousPrefab = null;
CreateEditor(prefab);
}
}
}
protected void DrawAllOutputs(Vector2 position, float size) {
GUIStyle labelTextStyle = new(EditorStyles.label) {
normal = { textColor = Color.white },
fontStyle = FontStyle.Bold,
alignment = TextAnchor.MiddleCenter,
};
Handles.Label(position, "Outputs", labelTextStyle);
Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2);
Event e = Event.current;
if (e != null && neuronRect.Contains(e.mousePosition)) {
// Process click
if (e.type == EventType.MouseDown && e.button == 0) {
// Consume the event so the scene doesn't also handle it
e.Use();
OnAllOutputsClick();
}
}
}
protected void DrawEdge(Vector2 from, Vector2 to, float radius = 20) {
Handles.color = Color.white;
// Handles.DrawLine(from, to);
Vector2 dir = to - from;
float len = dir.magnitude;
if (len <= 2f * radius || len <= Mathf.Epsilon)
// line too short
return;
Vector2 n = dir / len; // normalized
Vector2 a = from + n * radius;
Vector2 b = to - n * radius;
Handles.DrawLine(a, b);
}
protected void HandleMouseHover(Nucleus nucleus, Rect rect) {
GUIContent tooltip;
if (nucleus is Neuron neuron) {
tooltip = new(
$"{nucleus.name}" +
$"\nValue: {neuron.outputMagnitude}");
}
else
tooltip = new($"{nucleus.name}");
Vector2 mousePosition = Event.current.mousePosition;
// Display tooltip with some offset
Vector2 tooltipSize = GUI.skin.box.CalcSize(tooltip);
Rect tooltipRect = new Rect(mousePosition.x + 10, mousePosition.y + 10, tooltipSize.x, tooltipSize.y);
GUI.Box(tooltipRect, tooltip);
}
protected void OnNeuronClick(Nucleus nucleus) {
if (nucleus == this.currentNucleus) {
if (Application.isPlaying) {
if (nucleus is Cluster)
expandArray = !expandArray;
else
expandArray = false;
}
else {
if (nucleus is Cluster cluster)
OnClusterClick(cluster);
}
}
else if (nucleus.parent != null && this.currentNucleus != null && nucleus.parent != this.currentNucleus.parent) {
// We go to a different cluster
if (Application.isPlaying) {
this.currentNucleus = nucleus;
if (this.currentNucleus is Neuron neuron && neuron.receivers.Count == 0)
this.selectedOutput = this.currentNucleus;
expandArray = false;
}
else {
// select the cluster, not the neuron in the cluster
this.currentNucleus = nucleus.parent;
expandArray = false;
}
}
else {
this.currentNucleus = nucleus;
if (this.currentNucleus is Neuron neuron && neuron.receivers.Count == 0)
this.selectedOutput = this.currentNucleus;
expandArray = false;
}
}
protected void OnClusterClick(Cluster subCluster) {
// May be used with storedPrefab...
Selection.activeObject = subCluster.prefab;
EditorGUIUtility.PingObject(subCluster.prefab);
ClusterViewer.previousPrefab = this.currentCluster.prefab;
ClusterEditor newEditor = CreateEditor(subCluster.prefab) as ClusterEditor;
}
protected void OnAllOutputsClick() {
this.currentNucleus = null;
this.selectedOutput = null;
expandArray = false;
}
#endregion Graph
void OnSceneGUI(SceneView sceneView) {
if (this.gameObject != null) {
// if (this.currentNucleus is IReceptor receptor) {
// foreach (Nucleus nucleus in receptor.nucleiArray) {
// if (nucleus is Neuron neuron) {
// Vector3 worldVector = this.gameObject.transform.TransformVector(neuron.outputValue);
// Handles.color = Color.yellow;
// Handles.DrawLine(this.gameObject.transform.position, this.gameObject.transform.position + worldVector);
// }
// }
// }
// else {
if (this.currentNucleus is Neuron currentNeuron) {
Vector3 worldVector = this.gameObject.transform.TransformVector(currentNeuron.outputValue);
Handles.color = Color.yellow;
Handles.DrawLine(this.gameObject.transform.position, this.gameObject.transform.position + worldVector);
}
// }
}
}
}
}
public class NeuroidLayer {
public int ix = 0;
public List<Nucleus> neuroids = new();
}
public class Dag {
public class Node {
public int id;
public Vector2 position;
public float radius = 20f; // circle radius
public Nucleus nucleus;
}
public class Edge {
public int fromId;
public int toId;
}
public List<Node> nodes = new();
public List<Edge> edges = new();
public Node FindNode(string name, bool justBaseName = true) {
if (justBaseName) {
int colonPos = name.IndexOf(":");
if (colonPos > 0)
name = name[..colonPos];
}
foreach (Node node in this.nodes) {
string nodeName = node.nucleus.name;
if (justBaseName) {
int colonPos = nodeName.IndexOf(":");
if (colonPos > 0)
nodeName = nodeName[..colonPos];
}
if (nodeName == name)
return node;
}
return null;
}
public static Node GetNodeById(Dag dag, int id) => dag.nodes.FirstOrDefault(x => x.id == id);
public static void ComputeLayout(Dag dag) {
Dictionary<int, List<int>> adjacency = dag.nodes.ToDictionary(n => n.id, n => new List<int>());
Dictionary<int, int> outdegree = dag.nodes.ToDictionary(node => node.id, n => 0);
foreach (Edge edge in dag.edges) {
if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId))
continue;
adjacency[edge.fromId].Add(edge.toId);
outdegree[edge.fromId]++;
}
// Kahn's algorithm to compute topological layers (horizontal layers)
// build parent list (reverse adjacency) and parentIndegree = number of children each parent has
Dictionary<int, List<int>> parents = dag.nodes.ToDictionary(n => n.id, _ => new List<int>());
Dictionary<int, int> childCount = dag.nodes.ToDictionary(n => n.id, _ => 0);
foreach (Edge edge in dag.edges) {
if (!adjacency.ContainsKey(edge.fromId) || !adjacency.ContainsKey(edge.toId)) continue;
adjacency[edge.fromId].Add(edge.toId);
parents[edge.toId].Add(edge.fromId); // parent of 'to' is 'from'
childCount[edge.fromId]++; // outdegree
}
Dictionary<int, int> layer = new();
Queue<int> queue = new(outdegree.Where(kv => kv.Value == 0).Select(kv => kv.Key));
foreach (int id in queue)
layer[id] = 0;
// process parents (reverse traversal)
while (queue.Count > 0) {
int u = queue.Dequeue();
int l = layer[u];
foreach (int p in parents[u]) {
if (!layer.ContainsKey(p) || layer[p] < l + 1)
layer[p] = l + 1;
childCount[p]--; // decrement remaining unprocessed children
if (childCount[p] == 0)
queue.Enqueue(p);
}
}
// Any unreachable nodes -> assign next layers
int maxLayer = layer.Count > 0 ? layer.Values.Max() : 0;
foreach (Node node in dag.nodes) {
if (!layer.ContainsKey(node.id)) {
maxLayer++;
layer[node.id] = maxLayer;
}
}
// Group nodes by layer (left to right)
List<List<int>> layers =
layer.
GroupBy(kv => kv.Value).
OrderBy(g => g.Key).
Select(g => g.Select(x => x.Key).ToList()).
ToList();
// Same code without using Linq
// Build layers dictionary: layerIndex -> List<int> nodeIds
// Dictionary<int, List<int>> layersDict = new();
// foreach (KeyValuePair<int, int> kv in layer) {
// int nodeId = kv.Key;
// int layerIndex = kv.Value;
// if (!layersDict.TryGetValue(layerIndex, out List<int> list)) {
// list = new List<int>();
// layersDict[layerIndex] = list;
// }
// list.Add(nodeId);
// }
// // Determine sorted layer indices
// List<int> layerIndices = new(layersDict.Keys);
// layerIndices.Sort(); // ascending order
// // Build final List<List<int>> in sorted order
// List<List<int>> layers = new();
// foreach (int idx in layerIndices) {
// layers.Add(layersDict[idx]);
// }
float hSpacing = 100f;
float totalHeight = 400f;
// Place nodes: x increases with layer index, y spaced within layer
for (int layerIx = 0; layerIx < layers.Count; layerIx++) {
List<int> nodeList = layers[layerIx];
float spacing = totalHeight / nodeList.Count;
float margin = 10 + spacing / 2;
for (int i = 0; i < nodeList.Count; i++) {
int index = nodeList[i];
Node node = GetNodeById(dag, index);
if (node == null)
continue;
float x = hSpacing + layerIx * hSpacing;
//float y = 400 - totalHeight / 2f + i * vSpacing;
float y = margin + i * spacing;
// Debug.Log($"({li}, {i}) -> {x}, {y}");
node.position = new Vector2(x, y);
}
}
//Repaint();
}
}
}

View File

@ -0,0 +1,2 @@
fileFormatVersion: 2
guid: 4fe58945c76d153edacc220597474ad2

View File

@ -1,356 +0,0 @@
using UnityEngine;
using UnityEditor;
using System.Collections.Generic;
using System.Linq;
namespace NanoBrain {
// Simple DAG data model
// [System.Serializable]
// public class DagNode
// {
// public int id;
// public string title;
// public Vector2 position;
// public float radius = 36f; // circle radius
// }
// [System.Serializable]
// public class DagEdge
// {
// public int fromId;
// public int toId;
// }
public class DAGEditorWindow : EditorWindow {
List<DagNode> nodes = new List<DagNode>();
List<DagEdge> edges = new List<DagEdge>();
Vector2 pan = Vector2.zero;
float zoom = 1.0f;
const float minZoom = 0.5f;
const float maxZoom = 2.0f;
GUIStyle labelStyle;
int selectedNodeId = -1;
Vector2 dragStart;
bool draggingNode = false;
int draggingNodeId = -1;
[MenuItem("Window/DAG Viewer (LR, Circles)")]
public static void ShowWindow() {
var w = GetWindow<DAGEditorWindow>("DAG Viewer (LR)");
w.minSize = new Vector2(500, 300);
}
void OnEnable() {
labelStyle = new GUIStyle(EditorStyles.label);
labelStyle.alignment = TextAnchor.MiddleCenter;
labelStyle.normal.textColor = Color.white;
labelStyle.fontStyle = FontStyle.Bold;
if (nodes.Count == 0)
CreateSampleGraph();
ComputeLeftToRightLayout();
}
void CreateSampleGraph() {
nodes.Clear();
edges.Clear();
nodes.Add(new DagNode() { id = 0, title = "In1" });
nodes.Add(new DagNode() { id = 1, title = "In2" });
nodes.Add(new DagNode() { id = 2, title = "A" });
nodes.Add(new DagNode() { id = 3, title = "B" });
nodes.Add(new DagNode() { id = 4, title = "C" });
nodes.Add(new DagNode() { id = 5, title = "Out1" });
nodes.Add(new DagNode() { id = 6, title = "Out2" });
edges.Add(new DagEdge() { fromId = 0, toId = 2 });
edges.Add(new DagEdge() { fromId = 1, toId = 2 });
edges.Add(new DagEdge() { fromId = 2, toId = 3 });
edges.Add(new DagEdge() { fromId = 2, toId = 4 });
edges.Add(new DagEdge() { fromId = 3, toId = 5 });
edges.Add(new DagEdge() { fromId = 4, toId = 6 });
}
void OnGUI() {
HandleInput();
Rect rect = new Rect(0, 0, position.width, position.height);
EditorGUI.DrawRect(rect, new Color(0.11f, 0.11f, 0.11f));
Matrix4x4 oldMatrix = GUI.matrix;
Vector2 origin = new Vector2(position.width / 2, position.height / 2);
GUI.matrix = Matrix4x4.TRS(origin + pan, Quaternion.identity, Vector3.one * zoom) *
Matrix4x4.TRS(-origin, Quaternion.identity, Vector3.one);
// Draw edges first
foreach (var e in edges) {
var from = GetNodeById(e.fromId);
var to = GetNodeById(e.toId);
if (from == null || to == null) continue;
DrawEdgeCircleNodes(from, to);
}
// Draw nodes (circles)
foreach (var n in nodes) {
DrawNodeCircle(n);
}
GUI.matrix = oldMatrix;
// Footer toolbar
GUILayout.FlexibleSpace();
EditorGUILayout.BeginHorizontal(EditorStyles.toolbar);
if (GUILayout.Button("Fit", EditorStyles.toolbarButton)) FitToView();
if (GUILayout.Button("Layout LR", EditorStyles.toolbarButton)) ComputeLeftToRightLayout();
if (GUILayout.Button("Add Node", EditorStyles.toolbarButton)) {
AddNode("N" + nodes.Count);
ComputeLeftToRightLayout();
}
if (GUILayout.Button("Add Edge (selected->new)", EditorStyles.toolbarButton)) {
if (selectedNodeId != -1) {
var newNode = AddNode("N" + nodes.Count);
edges.Add(new DagEdge() { fromId = selectedNodeId, toId = newNode.id });
ComputeLeftToRightLayout();
}
}
EditorGUILayout.EndHorizontal();
}
void HandleInput() {
Event e = Event.current;
// Zoom with scroll
if (e.type == EventType.ScrollWheel) {
float oldZoom = zoom;
float delta = -e.delta.y * 0.01f;
zoom = Mathf.Clamp(zoom + delta, minZoom, maxZoom);
Vector2 mouse = e.mousePosition;
pan += (mouse - new Vector2(position.width / 2, position.height / 2)) * (1 - zoom / oldZoom);
e.Use();
}
// Pan with middle or right+ctrl drag
if (e.type == EventType.MouseDrag && (e.button == 2 || (e.button == 1 && e.control))) {
pan += e.delta;
e.Use();
}
// Node dragging & selection (convert mouse to graph space)
Vector2 graphMouse = ScreenToGraph(e.mousePosition);
if (e.type == EventType.MouseDown && e.button == 0) {
int hit = HitTestNode(graphMouse);
if (hit != -1) {
selectedNodeId = hit;
draggingNode = true;
draggingNodeId = hit;
dragStart = graphMouse;
e.Use();
}
else {
selectedNodeId = -1;
}
}
if (draggingNode && draggingNodeId != -1) {
if (e.type == EventType.MouseDrag && e.button == 0) {
Vector2 graphDelta = e.delta / zoom;
var n = GetNodeById(draggingNodeId);
if (n != null) {
n.position += graphDelta;
Repaint();
e.Use();
}
}
if (e.type == EventType.MouseUp && e.button == 0) {
draggingNode = false;
draggingNodeId = -1;
e.Use();
}
}
}
DagNode AddNode(string title) {
int nextId = nodes.Count > 0 ? nodes.Max(n => n.id) + 1 : 0;
var n = new DagNode() { id = nextId, title = title, position = Vector2.zero };
nodes.Add(n);
return n;
}
DagNode GetNodeById(int id) => nodes.FirstOrDefault(x => x.id == id);
void DrawNodeCircle(DagNode n) {
Vector2 center = n.position;
float r = n.radius;
Rect nodeRect = new Rect(center.x - r, center.y - r, r * 2, r * 2);
// circle background
Color bg = (n.id == selectedNodeId) ? new Color(0.15f, 0.5f, 0.9f) : new Color(0.2f, 0.2f, 0.2f);
EditorGUI.DrawRect(nodeRect, bg);
// anti-aliased circle outline
Handles.color = Color.white * 0.9f;
Handles.DrawAAPolyLine(3f / zoom, GetCircleOutlinePoints(center, r, 48).ToArray());
// label
Vector2 labelPos = center - new Vector2(0, 8);
GUI.Label(new Rect(labelPos.x - r, labelPos.y - 8, r * 2, 18), n.title, labelStyle);
}
List<Vector3> GetCircleOutlinePoints(Vector2 center, float radius, int segments) {
var pts = new List<Vector3>(segments + 1);
for (int i = 0; i <= segments; i++) {
float a = (float)i / segments * Mathf.PI * 2f;
pts.Add(new Vector3(center.x + Mathf.Cos(a) * radius, center.y + Mathf.Sin(a) * radius, 0));
}
return pts;
}
void DrawEdgeCircleNodes(DagNode from, DagNode to) {
Vector2 a = from.position;
Vector2 b = to.position;
if (a == b) return;
// Compute edge line that starts/ends at circle circumferences
Vector2 dir = (b - a).normalized;
Vector2 start = a + dir * from.radius;
Vector2 end = b - dir * to.radius;
// Use a simple curved line: start -> control -> end (bezier)
Vector2 control = new Vector2((start.x + end.x) / 2f, (start.y + end.y) / 2f);
// Slight vertical offset to separate overlapping lines based on node ids
float offset = ((from.id * 7 + to.id * 11) % 7 - 3) * 6f / zoom;
control += new Vector2(0, offset);
Handles.color = Color.white * 0.9f;
Handles.DrawAAPolyLine(3f / zoom, 20, GetBezierPoints(start, control, end, 24).ToArray());
// Arrow at end pointing towards 'b'
DrawArrowHead(end - dir * 2f, end, 10f / zoom, 12f / zoom, Color.white);
}
List<Vector3> GetBezierPoints(Vector2 p0, Vector2 p1, Vector2 p2, int seg) {
var pts = new List<Vector3>(seg + 1);
for (int i = 0; i <= seg; i++) {
float t = (float)i / seg;
Vector2 p = (1 - t) * (1 - t) * p0 + 2 * (1 - t) * t * p1 + t * t * p2;
pts.Add(new Vector3(p.x, p.y, 0));
}
return pts;
}
void DrawArrowHead(Vector2 from, Vector2 to, float headWidth, float headLength, Color color) {
Vector2 dir = (to - from).normalized;
if (dir == Vector2.zero) return;
Vector2 right = new Vector2(-dir.y, dir.x);
Vector3 p1 = to;
Vector3 p2 = to - dir * headLength + right * headWidth * 0.5f;
Vector3 p3 = to - dir * headLength - right * headWidth * 0.5f;
Handles.color = color;
Handles.DrawAAConvexPolygon(p1, p2, p3);
}
// Left-to-right layered layout (sources on the left, sinks on the right)
void ComputeLeftToRightLayout() {
// build adjacency and indegree
var adj = nodes.ToDictionary(n => n.id, n => new List<int>());
var indeg = nodes.ToDictionary(n => n.id, n => 0);
foreach (var e in edges) {
if (!adj.ContainsKey(e.fromId) || !adj.ContainsKey(e.toId)) continue;
adj[e.fromId].Add(e.toId);
indeg[e.toId]++;
}
// Kahn's algorithm to compute topological layers (horizontal layers)
Dictionary<int, int> layer = new Dictionary<int, int>();
Queue<int> q = new Queue<int>(indeg.Where(kv => kv.Value == 0).Select(kv => kv.Key));
foreach (var id in q) layer[id] = 0;
while (q.Count > 0) {
int u = q.Dequeue();
int l = layer[u];
foreach (var v in adj[u]) {
// prefer placing v at least one layer after u
if (!layer.ContainsKey(v) || layer[v] < l + 1) layer[v] = l + 1;
indeg[v]--;
if (indeg[v] == 0) q.Enqueue(v);
}
}
// Any unreachable nodes -> assign next layers
int maxLayer = layer.Count > 0 ? layer.Values.Max() : 0;
foreach (var n in nodes) {
if (!layer.ContainsKey(n.id)) {
maxLayer++;
layer[n.id] = maxLayer;
}
}
// Group nodes by layer (left to right)
var layers = layer.GroupBy(kv => kv.Value).OrderBy(g => g.Key).Select(g => g.Select(x => x.Key).ToList()).ToList();
// Layout parameters (horizontal spacing drives left->right)
float hSpacing = 220f;
float vSpacing = 120f;
// Place nodes: x increases with layer index, y spaced within layer
for (int li = 0; li < layers.Count; li++) {
var lst = layers[li];
float totalHeight = (lst.Count - 1) * vSpacing;
for (int i = 0; i < lst.Count; i++) {
int id = lst[i];
var n = GetNodeById(id);
if (n == null) continue;
float x = li * hSpacing;
float y = -totalHeight / 2f + i * vSpacing;
n.position = new Vector2(x, y);
}
}
Repaint();
}
void FitToView() {
if (nodes.Count == 0) return;
Rect bounds = new Rect(nodes[0].position - Vector2.one * nodes[0].radius, Vector2.one * nodes[0].radius * 2f);
foreach (var n in nodes)
bounds = RectUnion(bounds, new Rect(n.position - Vector2.one * n.radius, Vector2.one * n.radius * 2f));
Vector2 center = bounds.center;
pan = -center;
zoom = 1.0f;
Repaint();
}
static Rect RectUnion(Rect a, Rect b) {
float xMin = Mathf.Min(a.xMin, b.xMin);
float xMax = Mathf.Max(a.xMax, b.xMax);
float yMin = Mathf.Min(a.yMin, b.yMin);
float yMax = Mathf.Max(a.yMax, b.yMax);
return Rect.MinMaxRect(xMin, yMin, xMax, yMax);
}
Vector2 ScreenToGraph(Vector2 screenPos) {
Vector2 origin = new Vector2(position.width / 2, position.height / 2);
// invert the GUI.matrix transform (approx for current simple transforms)
return (screenPos - (origin + pan)) / zoom + origin * (1 - 1 / zoom);
}
int HitTestNode(Vector2 graphPos) {
// returns node id under point or -1
for (int i = nodes.Count - 1; i >= 0; i--) {
var n = nodes[i];
if ((graphPos - n.position).sqrMagnitude <= n.radius * n.radius) return n.id;
}
return -1;
}
}
}

View File

@ -1,2 +0,0 @@
fileFormatVersion: 2
guid: 95393aed582b8b30d965400672aec4d8

View File

@ -1,53 +0,0 @@
using UnityEditor;
using UnityEditor.UIElements;
using UnityEngine;
using UnityEngine.UIElements;
namespace NanoBrain {
[CustomEditor(typeof(Brain))]
public class NanoBrainComponent_Editor : Editor {
protected static VisualElement mainContainer;
protected static VisualElement inspectorContainer;
protected Brain component;
private SerializedProperty brainProp;
//ClusterInspector.GraphView board;
public void OnEnable() {
component = target as Brain;
if (Application.isPlaying == false && serializedObject != null) {
string propertyName = nameof(Brain.defaultBrain);
brainProp = serializedObject.FindProperty(propertyName);
}
}
public override VisualElement CreateInspectorGUI() {
Cluster brain = component.brain;
if (Application.isPlaying == false)
serializedObject.Update();
VisualElement root = new();
if (Application.isPlaying == false) {
PropertyField brainField = new(brainProp) {
label = "Cluster Prefab"
};
root.Add(brainField);
}
if (brain != null)
ClusterInspector.CreateInspector(root, brain.prefab, brain.defaultOutput, component.gameObject);
if (Application.isPlaying == false)
serializedObject.ApplyModifiedProperties();
return root;
}
}
}

8
LinearAlgebra.meta Normal file
View File

@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: a4c7dfe43bdf504e29c5c97919d7a1c0
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

8
LinearAlgebra/src.meta Normal file
View File

@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 6a602cec2c4009925b1d19ed36a98c6a
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

8
LinearAlgebra/test.meta Normal file
View File

@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: 9b84f664459d02b90894e460de42c219
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -11,7 +11,7 @@ namespace NanoBrain {
/// <summary>
/// The Cluster prefab from which the cluster is created
/// </summary>
public ClusterPrefab defaultBrain;
public ClusterPrefab brainPrefab;
[NonSerialized]
private Cluster brainInstance;
@ -20,15 +20,24 @@ namespace NanoBrain {
/// </summary>
public Cluster brain {
get {
if (brainInstance == null && defaultBrain != null) {
brainInstance = new Cluster(defaultBrain) {
name = defaultBrain.name + " (Instance)"
if (brainInstance == null && brainPrefab != null) {
brainInstance = new Cluster(brainPrefab) {
name = brainPrefab.name
};
} else if (brainInstance != null && brainPrefab == null) {
brainInstance = null;
}
return brainInstance;
}
}
// public Cluster InitializeBrain() {
// brainInstance = new Cluster(brainPrefab) {
// name = brainPrefab.name
// };
// return brainInstance;
// }
/// <summary>
/// Update the weight for all Synapses coming from the Neuron with the given name
/// </summary>

File diff suppressed because it is too large Load Diff

View File

@ -1,277 +0,0 @@
using System;
using System.Collections.Generic;
using UnityEngine;
#if UNITY_MATHEMATICS
using Unity.Mathematics;
using static Unity.Mathematics.math;
#endif
using System.Linq;
namespace NanoBrain {
[Serializable]
public class ClusterReceptor : Cluster, IReceptor {
public ClusterReceptor(ClusterPrefab prefab, Cluster parent, string name) : base(prefab, parent) {
this.name = name;
this.array = new NucleusArray(this);
if (this.name.IndexOf(":") < 0)
this.name += ": 0";
}
public ClusterReceptor(ClusterPrefab prefab, ClusterPrefab parent, string name) : base(prefab, parent) {
this.name = name;
this.array = new NucleusArray(this);
}
public string GetName() {
return this.name;
}
public override Nucleus ShallowCloneTo(Cluster parent) {
ClusterReceptor clone = new(this.prefab, parent, this.name) {
clusterPrefab = this.clusterPrefab,
};
return clone;
}
public override Nucleus Clone(ClusterPrefab parent) {
ClusterReceptor clone = new(prefab, parent, this.name) {
array = this._array
};
foreach (Synapse synapse in this.synapses) {
Synapse clonedSynapse = clone.AddSynapse(synapse.neuron);
clonedSynapse.weight = synapse.weight;
}
this._outputs = null; // Make sure the output are regenerated
foreach (Neuron output in this.outputs) {
int ix = GetNucleusIndex(this.clusterNuclei, output);
if (ix < 0 || clone.clusterNuclei[ix] is not Neuron clonedOutput)
continue;
foreach (Nucleus receiver in output.receivers)
clonedOutput.AddReceiver(receiver);
}
return clone;
}
public override List<Nucleus> CollectReceivers() {
List<Nucleus> receivers = new();
foreach (Nucleus element in this.nucleiArray) {
if (element is not Cluster clusterElement)
continue;
foreach (Nucleus outputNucleus in clusterElement.clusterNuclei) {
if (outputNucleus is not Neuron output)
continue;
// this should be clusterElement.outputs,
// but outputs is not updated when correctly and may contain old data...
foreach (Nucleus receiver in output.receivers) {
// Only add receivers outside clusterElement cluster
if (receiver.clusterPrefab != clusterElement.prefab &&
receivers.Contains(receiver) == false)
receivers.Add(receiver);
}
}
}
return receivers;
}
[SerializeReference]
private NucleusArray _array;
public NucleusArray array {
set { _array = value; }
}
public Nucleus[] nucleiArray {
get { return _array.nuclei; }
set { _array.nuclei = value; }
}
public void AddReceptorElement(ClusterPrefab prefab) {
IReceptorHelpers.AddReceptorElement(this, prefab);
}
public void RemoveReceptorElement() {
IReceptorHelpers.RemoveReceptorElement(this);
}
public void AddArrayReceiver(Nucleus receiverToAdd, float weight = 1) {
IReceptorHelpers.AddArrayReceiver(this, receiverToAdd, weight);
}
public override void UpdateStateIsolated() {
// Clusters don't do anything,
// The nuclei in them do the work
// and should be called directly, not from the cluster
}
public override void UpdateNuclei() {
foreach (Nucleus nucleus in this.clusterNuclei)
nucleus.UpdateNuclei();
}
public override void ProcessStimulus(Vector3 inputValue, int thingId = 0, string thingName = null) {
Debug.LogError("Process Stimulus was called on clusterreceptor without a neuron specified");
}
private readonly Dictionary<int, ClusterReceptor> thingReceivers = new();
public virtual void ProcessStimulus(Neuron input, Vector3 inputValue, int thingId = 0, string thingName = null) {
CleanupReceivers();
if (!thingReceivers.TryGetValue(thingId, out ClusterReceptor selectedReceiver))
selectedReceiver = FindReceiver2(thingId, inputValue, input);
if (selectedReceiver == null)
return;
if (thingName != null) {
string baseName = selectedReceiver.name;
int colonPos = selectedReceiver.name.IndexOf(":");
if (colonPos > 0)
baseName = selectedReceiver.name[..colonPos];
selectedReceiver.name = baseName + ": " + thingName;
}
int inputIx = GetNucleusIndex(this.clusterNuclei, input);
if (inputIx < 0)
return;
if (selectedReceiver.clusterNuclei[inputIx] is Neuron selectedNeuron)
selectedNeuron.ProcessStimulusDirect(inputValue);
}
#if UNITY_MATHEMATICS
private ClusterReceptor FindReceiver2(int thingId, float3 inputValue, Neuron input) {
// No existing nucleus for this thing
ClusterReceptor selectedReceiver = null;
float selectedMagnitude = 0;
foreach (ClusterReceptor receiver in this.nucleiArray.Cast<ClusterReceptor>()) {
if (thingReceivers.ContainsValue(receiver) == false) {
// We found an unusued receiver
thingReceivers.Add(thingId, receiver);
return receiver;
}
else if (receiver.defaultOutput.isSleeping) {
// A sleeping receiver is not active and can therefore always be used
thingReceivers.Add(thingId, receiver);
receiver.bias = float3(0, 0, 0);
return receiver;
}
else if (selectedReceiver == null) {
// If we haven't found a receiver yet, just start by taking the first
selectedReceiver = receiver;
selectedMagnitude = length(selectedReceiver.defaultOutput.outputValue);
}
// Look for the receiver with the lowest output magnitude
else {
float magnitude = length(receiver.defaultOutput.outputValue);
if (length(receiver.defaultOutput.outputValue) < selectedMagnitude) {
selectedReceiver = receiver;
selectedMagnitude = length(selectedReceiver.defaultOutput.outputValue);
}
}
}
if (selectedReceiver != null) {
// To re-initialize the cluster (esp. memory cells)
// we update the cluster neuron twice.
// Bit of a hack.....
int inputIx = GetNucleusIndex(this.clusterNuclei, input);
if (inputIx >= 0) {
if (selectedReceiver.clusterNuclei[inputIx] is Neuron selectedNeuron)
selectedNeuron.ProcessStimulusDirect(inputValue);
}
// Replace the receiver
// Find the thingId current associated with the receiver
int keyToRemove = thingReceivers.FirstOrDefault(r => r.Value.Equals(selectedReceiver)).Key;
if (keyToRemove != 0 || thingReceivers.ContainsKey(keyToRemove))
thingReceivers.Remove(keyToRemove);
// And add the new association
thingReceivers.Add(thingId, selectedReceiver);
}
return selectedReceiver;
}
#else
private ClusterReceptor FindReceiver2(int thingId, Vector3 inputValue, Neuron input) {
// No existing nucleus for this thing
ClusterReceptor selectedReceiver = null;
float selectedMagnitude = 0;
foreach (ClusterReceptor receiver in this.nucleiArray.Cast<ClusterReceptor>()) {
if (thingReceivers.ContainsValue(receiver) == false) {
// We found an unusued receiver
thingReceivers.Add(thingId, receiver);
return receiver;
}
else if (receiver.defaultOutput.isSleeping) {
// A sleeping receiver is not active and can therefore always be used
thingReceivers.Add(thingId, receiver);
receiver.bias = new Vector3(0, 0, 0);
return receiver;
}
else if (selectedReceiver == null) {
// If we haven't found a receiver yet, just start by taking the first
selectedReceiver = receiver;
selectedMagnitude = selectedReceiver.defaultOutput.outputValue.magnitude;
}
// Look for the receiver with the lowest output magnitude
else {
float magnitude = receiver.defaultOutput.outputValue.magnitude;
if (receiver.defaultOutput.outputValue.magnitude < selectedMagnitude) {
selectedReceiver = receiver;
selectedMagnitude = selectedReceiver.defaultOutput.outputValue.magnitude;
}
}
}
if (selectedReceiver != null) {
// To re-initialize the cluster (esp. memory cells)
// we update the cluster neuron twice.
// Bit of a hack.....
int inputIx = GetNucleusIndex(this.clusterNuclei, input);
if (inputIx >= 0) {
if (selectedReceiver.clusterNuclei[inputIx] is Neuron selectedNeuron)
selectedNeuron.ProcessStimulusDirect(inputValue);
}
// Replace the receiver
// Find the thingId current associated with the receiver
int keyToRemove = thingReceivers.FirstOrDefault(r => r.Value.Equals(selectedReceiver)).Key;
if (keyToRemove != 0 || thingReceivers.ContainsKey(keyToRemove))
thingReceivers.Remove(keyToRemove);
// And add the new association
thingReceivers.Add(thingId, selectedReceiver);
}
return selectedReceiver;
}
#endif
private void CleanupReceivers() {
// Remove a thing-receiver connection when the nucleus is inactive
List<int> receiversToRemove = new();
foreach (KeyValuePair<int, ClusterReceptor> item in thingReceivers) {
if (item.Value != null && item.Value.defaultOutput.isSleeping)
receiversToRemove.Add(item.Key);
}
foreach (int thingId in receiversToRemove) {
Nucleus selectedReceiver = thingReceivers[thingId];
thingReceivers.Remove(thingId);
int colonPos = selectedReceiver.name.IndexOf(":");
if (colonPos > 0)
selectedReceiver.name = selectedReceiver.name[..colonPos];
}
}
}
}

View File

@ -1,2 +0,0 @@
fileFormatVersion: 2
guid: 4f64f5d72a422a7c8bb9ace598432aad

View File

@ -1,123 +0,0 @@
using UnityEngine;
namespace NanoBrain {
/// <summary>
/// A Receptor is a Nucleus which can receive input (called Stimulus) from outside the the cluster/brain
/// </summary>
/// It has the ability to distinguish stimuli from different things using an array of Nuclei
public interface IReceptor {
/// <summary>
/// Get the name of the receptor
/// </summary>
/// <returns>The name of the receptor</returns>
public string GetName();
/// <summary>
/// The array of nuclei used to track multiple things sending stimuli
/// </summary>
/// The size of the array determines the maximum number of things which can be distinguished
public Nucleus[] nucleiArray { get; set; }
/// <summary>
/// Extends the nucleiArray with an additional element
/// </summary>
/// <param name="prefab">A prefab of the nucleus to add?</param>
public void AddReceptorElement(ClusterPrefab prefab);
/// <summary>
/// Removes the last element from the nucleiArray
/// </summary>
public void RemoveReceptorElement();
/// <summary>
/// Add a receiver for this receptor array
/// </summary>
/// <param name="receiverToAdd">The receiving Nucleus</param>
/// <param name="weight">The initial weight to use for the synapses</param>
/// This function will add a synapse to the receiver for each element in the nucleiArray.
public void AddArrayReceiver(Nucleus receiverToAdd, float weight = 1);
/// <summary>
/// Process an external stimulus
/// </summary>
/// <param name="inputValue">The value of the stimulus</param>
/// <param name="thingId">The id of the thing causing the stimulus</param>
/// <param name="thingName">The name of the thing causing the stimulus</param>
public void ProcessStimulus(Vector3 inputValue, int thingId = 0, string thingName = null);
}
public static class IReceptorHelpers {
/// <summary>
/// Implementation for the NanoBrain::IReceptor::AddReceptorElement which can be used for all implementations of IReceptor
/// </summary>
/// <param name="receptor">The IReceptor which needs to extend its nucleiArray</param>
/// <param name="prefab">A prefab of the nucleus to add?</param>
public static void AddReceptorElement(IReceptor receptor, ClusterPrefab prefab) {
if (receptor.nucleiArray.Length == 0) {
Debug.LogError("Empty perceptoid array, cannot add");
}
int newLength = receptor.nucleiArray.Length + 1;
Nucleus[] newArray = new Nucleus[newLength];
string baseName = receptor.GetName();
int colonPos = baseName.IndexOf(":");
if (colonPos > 0)
baseName = baseName[..colonPos];
for (int i = 0; i < receptor.nucleiArray.Length; i++)
newArray[i] = receptor.nucleiArray[i];
if (receptor.nucleiArray[0] is Nucleus nucleus) {
newArray[newLength - 1] = nucleus.Clone(prefab);
newArray[newLength - 1].name = $"{baseName}: {newLength - 1}";
}
foreach (Nucleus element in receptor.nucleiArray) {
if (element is IReceptor receptorElement) {
receptorElement.nucleiArray = newArray;
}
}
}
/// <summary>
/// Implementation for the NanoBrain::IReceptor::RemoteReceptorElement which can be used for all implementations of IReceptor
/// </summary>
/// <param name="receptor">The IReceptor which needs to shorten its nucleiArray</param>
public static void RemoveReceptorElement(IReceptor receptor) {
int newLength = receptor.nucleiArray.Length - 1;
if (newLength == 0) {
Debug.LogWarning("Perceptoid array cannot be empty");
}
Nucleus[] newArray = new Nucleus[newLength];
for (int i = 0; i < newLength; i++)
newArray[i] = receptor.nucleiArray[i];
// Delete the last perception
if (receptor.nucleiArray[newLength] is Nucleus nucleus)
Neuron.Delete(nucleus);
foreach (Nucleus element in receptor.nucleiArray) {
if (element is IReceptor receptorElement) {
receptorElement.nucleiArray = newArray;
}
}
}
/// <summary>
/// Implementation for the NanoBreain::IRceptor::AddArrayReceiver which can be used for all implementations of IReceptor
/// </summary>
/// <param name="receptor">The IReceptor for which a receiving nuclues needs to be added</param>
/// <param name="receiverToAdd">The nucleus to receive input from the receptor</param>
/// <param name="weight">The initial weight for the synapses</param>
public static void AddArrayReceiver(IReceptor receptor, Nucleus receiverToAdd, float weight = 1) {
foreach (Nucleus element in receptor.nucleiArray) {
if (element is Cluster cluster)
cluster.defaultOutput.AddReceiver(receiverToAdd, weight);
if (element is Neuron neuron)
neuron.AddReceiver(receiverToAdd, weight);
}
}
}
}

View File

@ -1,2 +0,0 @@
fileFormatVersion: 2
guid: 73f052292ad16bb53a3c07aa1694c705

View File

@ -61,16 +61,19 @@ namespace NanoBrain {
/// <summary>
/// The type of
/// </summary>
public enum CurvePresets {
public enum ActivationType {
Linear,
Power,
Sqrt,
Reciprocal,
Tanh,
Binary,
Normalized,
Custom
}
[SerializeField]
public CurvePresets _curvePreset;
public CurvePresets curvePreset {
public ActivationType _curvePreset;
public ActivationType curvePreset {
get { return _curvePreset; }
set {
_curvePreset = value;
@ -82,18 +85,27 @@ namespace NanoBrain {
public AnimationCurve GenerateCurve() {
switch (this.curvePreset) {
case CurvePresets.Linear:
case ActivationType.Linear:
this.curveMax = 1;
return Presets.Linear(1);
case CurvePresets.Power:
case ActivationType.Power:
this.curveMax = 1;
return Presets.Power(2.0f, 1);
case CurvePresets.Sqrt:
case ActivationType.Sqrt:
this.curveMax = 1;
return Presets.Power(0.5f, 1);
case CurvePresets.Reciprocal:
case ActivationType.Reciprocal:
this.curveMax = 1 / 0.01f * 1;
return Presets.Reciprocal(1);
case ActivationType.Tanh:
this.curveMax = 1;
return Presets.Tanh(1);
case ActivationType.Binary:
this.curveMax = 1;
return Presets.Binary();
case ActivationType.Normalized:
this.curveMax = 1;
return Presets.Binary();
default:
this.curveMax = 1;
return this.curve;
@ -142,6 +154,28 @@ namespace NanoBrain {
}
return curve;
}
public static AnimationCurve Tanh(float weight) {
//int samples = 128;
float xMin = 0.001f;
float xMax = 1;
var keys = new Keyframe[samples];
for (int i = 0; i < samples; i++) {
float t = i / (float)(samples - 1);
float x = Mathf.Lerp(xMin, xMax, t);
float y = MathF.Tanh(x * weight);
keys[i] = new Keyframe(x, y);
}
var curve = new AnimationCurve(keys);
for (int i = 0; i < curve.length; i++) {
AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear);
}
return curve;
}
public static AnimationCurve Binary() {
return AnimationCurve.Linear(0, 0, 1, 1);
}
}
#endregion Serialization
@ -175,15 +209,29 @@ namespace NanoBrain {
public float outputSqrMagnitude => _outputValue.sqrMagnitude;
#endif
public bool isFiring => this.outputMagnitude > 0.5f;
public bool isFiring {
get {
SleepCheck();
return this.outputMagnitude > 0.5f;
}
}
public Action WhenFiring;
public virtual bool isSleeping => this.outputMagnitude == 0;
public virtual bool isSleeping => Time.time - this.lastUpdate > this.timeToSleep; //this.outputMagnitude == 0;
public void SleepCheck() {
if (this.isSleeping) {
#if UNITY_MATHEMATICS
this._outputValue = new float3(0, 0, 0);
#else
this._outputValue = new Vector3(0,0,0);
#endif
}
}
[NonSerialized]
public int stale = 1000;
public readonly int staleValueForSleep = 20;
public float lastUpdate = 0;
public readonly float timeToSleep = 1f;
/// \copydoc NanoBrain::Nucleus::ShallowCloneTo
public override Nucleus ShallowCloneTo(Cluster newParent) {
@ -236,9 +284,11 @@ namespace NanoBrain {
}
else if (nucleus is Cluster cluster) {
// remove all receivers for this cluster
foreach (Neuron output in cluster.outputs) {
foreach (Nucleus receiver in output.receivers) {
receiver.synapses.RemoveAll(s => s.neuron == output);
foreach (Nucleus clusterNucleus in cluster.clusterNuclei) {
if (clusterNucleus is Neuron output) {
foreach (Nucleus receiver in output.receivers) {
receiver.synapses.RemoveAll(s => s.neuron == output);
}
}
}
}
@ -252,8 +302,18 @@ namespace NanoBrain {
}
public override void UpdateStateIsolated() {
CheckSleepingSynapses();
var result = Combinator();
this.outputValue = Activator(result);
this.lastUpdate = Time.time;
}
protected void CheckSleepingSynapses() {
foreach (Synapse synapse in this.synapses) {
if (synapse.isSleeping) {
synapse.neuron.outputValue = Vector3.zero;
}
}
}
#region Combinator
@ -348,10 +408,13 @@ namespace NanoBrain {
#if UNITY_MATHEMATICS
public Func<float3, float3> Activator => this.curvePreset switch {
CurvePresets.Linear => ActivatorLinear,
CurvePresets.Sqrt => ActivatorSqrt,
CurvePresets.Power => ActivatorPower,
CurvePresets.Reciprocal => ActivatorReciprocal,
ActivationType.Linear => ActivatorLinear,
ActivationType.Sqrt => ActivatorSqrt,
ActivationType.Power => ActivatorPower,
ActivationType.Reciprocal => ActivatorReciprocal,
ActivationType.Tanh => ActivatorTanh,
ActivationType.Binary => ActivatorBinary,
ActivationType.Normalized => ActivatorNormalized,
_ => ActivatorCustom
};
@ -378,6 +441,24 @@ namespace NanoBrain {
return result;
}
protected float3 ActivatorTanh(float3 input) {
float magnitude = length(input);
float3 result = normalize(input) * MathF.Tanh(magnitude);
return result;
}
protected float3 ActivatorBinary(float3 input) {
float magnitude = length(input);
float value = Mathf.Clamp01(magnitude);
return float3(value, value, value);
}
protected float3 ActivatorNormalized(float3 input) {
if (lengthsq(input) == 0)
return input;
float3 result = normalize(input);
return result;
}
protected float3 ActivatorCustom(float3 input) {
float activatedValue = this.curve.Evaluate(length(input));
float3 result = normalize(input) * activatedValue;
@ -442,32 +523,16 @@ namespace NanoBrain {
}
public virtual void RemoveReceiver(Nucleus receiverToRemove) {
if (this is IReceptor receptor) {
foreach (Nucleus element in receptor.nucleiArray) {
if (element is Neuron neuron) {
neuron._receivers.RemoveAll(receiver => receiver == receiverToRemove);
receiverToRemove.synapses.RemoveAll(synapse => synapse.neuron == neuron);
}
}
}
else {
this._receivers.RemoveAll(receiver => receiver == receiverToRemove);
receiverToRemove.synapses.RemoveAll(synapse => synapse.neuron == this);
}
this._receivers.RemoveAll(receiver => receiver == receiverToRemove);
receiverToRemove.synapses.RemoveAll(synapse => synapse.neuron == this);
}
#endregion Receivers
public override void ProcessStimulus(Vector3 inputValue, int thingId = 0, string thingName = null) {
if (this.parent is ClusterReceptor clusterReceptor)
clusterReceptor.ProcessStimulus(this, inputValue, thingId, thingName);
else
ProcessStimulusDirect(inputValue, thingId, thingName);
}
public void ProcessStimulusDirect(Vector3 inputValue, int thingId = 0, string thingName = null) {
this.stale = 0;
public override void ProcessStimulus(Vector3 inputValue) {
;
this.lastUpdate = Time.time;
this.bias = inputValue;
this.parent.UpdateFromNucleus(this);
}

View File

@ -54,10 +54,13 @@ public abstract class Nucleus {
Neuron,
MemoryCell,
Cluster,
Receptor,
ClusterReceptor,
//Receptor,
//ClusterReceptor,
//ClusterArray,
}
public virtual void Initialize() {}
#region Synapses
/// <summary>
@ -87,6 +90,10 @@ public abstract class Nucleus {
return synapse;
}
// public Synapse AddSynapse(ClusterPrefab clusterPrefab, string neuronName, float weight = 1) {
// }
/// <summary>
/// Find a synapse
/// </summary>
@ -137,7 +144,7 @@ public abstract class Nucleus {
/// <param name="inputValue">The value of the stimulus</param>
/// <param name="thingId">The id of the thing causing the stimulus</param>
/// <param name="thingName">The name of the thing causing the stimulus</param>
public virtual void ProcessStimulus(Vector3 inputValue, int thingId = 0, string thingName = "") {
public virtual void ProcessStimulus(Vector3 inputValue) { //, int thingId = 0, string thingName = "") {
}
#endregion Update

View File

@ -1,197 +0,0 @@
using System.Linq;
using System.Collections.Generic;
using UnityEngine;
#if UNITY_MATHEMATICS
using Unity.Mathematics;
using static Unity.Mathematics.math;
#endif
namespace NanoBrain {
/// <summary>
/// Class to manage an array of nuclei for an IReceptor
/// </summary>
/// Would love to get rid of this class.
[System.Serializable]
public class NucleusArray {
/// <summary>
/// The nuclei in this array
/// </summary>
[SerializeReference]
private Nucleus[] _nuclei;
public Nucleus[] nuclei {
get {
return _nuclei;
}
set {
_nuclei = value;
}
}
/// <summary>
/// Create a new NucleusArray with the given nucleus
/// </summary>
/// <param name="nucleus">The Nucleus to put in the NucleusArray</param>
/// This results in an nucleus array of size 1
public NucleusArray(Nucleus nucleus) {
this._nuclei = new Nucleus[1];
this._nuclei[0] = nucleus;
}
/// <summary>
/// Create a new NucleusArray of the given size
/// </summary>
/// <param name="size">The size of the nucluesArray</param>
public NucleusArray(int size) {
this._nuclei = new Nucleus[size];
}
// public void AddNucleus(ClusterPrefab prefab) {
// if (this._nuclei.Length == 0) {
// Debug.LogError("Empty perceptoid array, cannot add");
// return;
// }
// int newLength = this._nuclei.Length + 1;
// Nucleus[] newArray = new Nucleus[newLength];
// for (int i = 0; i < this._nuclei.Length; i++)
// newArray[i] = this._nuclei[i];
// if (this._nuclei[0] is Nucleus nucleus) {
// newArray[newLength - 1] = nucleus.Clone(prefab);
// newArray[newLength - 1].name += $": {newLength - 1}";
// }
// this._nuclei = newArray;
// }
// public void RemoveNucleus() {
// int newLength = this._nuclei.Length - 1;
// if (newLength == 0) {
// Debug.LogWarning("Perceptoid array cannot be empty");
// return;
// }
// Nucleus[] newPerceptei = new Nucleus[newLength];
// for (int i = 0; i < newLength; i++)
// newPerceptei[i] = this._nuclei[i];
// // Delete the last perception
// if (this._nuclei[newLength] is Nucleus nucleus)
// Neuron.Delete(nucleus); //this._nuclei[newLength]);
// this._nuclei = newPerceptei;
// }
public Dictionary<int, Nucleus> thingReceivers = new();
#if UNITY_MATHEMATICS
private Nucleus FindReceiver(int thingId, float3 inputValue) {
float inputMagnitude = length(inputValue);
return FindReceiverMagnitude(thingId, inputMagnitude);
}
#else
private Nucleus FindReceiver(int thingId, Vector3 inputValue) {
float inputMagnitude = inputValue.magnitude;
return FindReceiverMagnitude(thingId, inputMagnitude);
}
#endif
private Nucleus FindReceiverMagnitude(int thingId, float inputMagnitude) {
Neuron selectedReceiver = null;
float selectedMagnitude = 0;
foreach (Nucleus nucleusReceiver in this._nuclei) {
if (nucleusReceiver is not Neuron receiver)
continue;
if (thingReceivers.ContainsValue(receiver) == false) {
// We found an unusued receiver
thingReceivers.Add(thingId, receiver);
return receiver;
}
else if (receiver.isSleeping) {
// A sleeping receiver is not active and can therefore always be used
thingReceivers.Add(thingId, receiver);
return receiver;
}
else if (selectedReceiver == null) {
// If we haven't found a receiver yet, just start by taking the first
selectedReceiver = receiver;
selectedMagnitude = selectedReceiver.outputMagnitude;
}
// Look for the receiver with the lowest magnitude
else {
float magnitude = receiver.outputMagnitude;
if (magnitude < inputMagnitude && receiver.outputMagnitude < selectedMagnitude) {
selectedReceiver = receiver;
selectedMagnitude = selectedReceiver.outputMagnitude;
}
}
}
if (selectedReceiver != null) {
// Replace the receiver
// Find the thingId current associated with the receiver
int keyToRemove = thingReceivers.FirstOrDefault(r => r.Value.Equals(selectedReceiver)).Key;
if (keyToRemove != 0 || thingReceivers.ContainsKey(keyToRemove))
thingReceivers.Remove(keyToRemove);
// And add the new association
thingReceivers.Add(thingId, selectedReceiver);
}
return selectedReceiver;
}
/// <summary>
/// Process an external stimulus
/// </summary>
/// <param name="inputValue">The value of the stimulus</param>
/// <param name="thingId">The id of the thing causing the stimulus</param>
/// <param name="thingName">The name of the thing causing the stimulus</param>
public virtual void ProcessStimulus(int thingId, Vector3 inputValue, string thingName = null) {
CleanupReceivers();
if (this._nuclei[0] is Neuron neuron)
inputValue = neuron.Activator(inputValue);
if (!thingReceivers.TryGetValue(thingId, out Nucleus selectedReceiver)) {
// No existing nucleus for this thing
selectedReceiver = FindReceiver(thingId, inputValue);
}
if (selectedReceiver == null)
return;
if (thingName != null) {
string baseName = selectedReceiver.name;
int colonPos = selectedReceiver.name.IndexOf(":");
if (colonPos > 0)
baseName = selectedReceiver.name[..colonPos];
selectedReceiver.name = baseName + ": " + thingName;
}
if (selectedReceiver is Neuron selectedNucleus)
selectedNucleus.ProcessStimulusDirect(inputValue);
}
/// <summary>
/// Remove a thing-receiver connection when the nucleus is inactive
/// </summary>
private void CleanupReceivers() {
List<int> receiversToRemove = new();
foreach (KeyValuePair<int, Nucleus> item in thingReceivers) {
if (item.Value != null && item.Value is Neuron neuron && neuron.isSleeping)
receiversToRemove.Add(item.Key);
}
foreach (int thingId in receiversToRemove) {
Nucleus selectedReceiver = thingReceivers[thingId];
thingReceivers.Remove(thingId);
int colonPos = selectedReceiver.name.IndexOf(":");
if (colonPos > 0)
selectedReceiver.name = selectedReceiver.name[..colonPos];
}
}
}
}

View File

@ -1,2 +0,0 @@
fileFormatVersion: 2
guid: f8cac60bd79854595a8571c042f77998

View File

@ -1,113 +0,0 @@
using UnityEngine;
#if UNITY_MATHEMATICS
using Unity.Mathematics;
using static Unity.Mathematics.math;
#endif
namespace NanoBrain {
/// <summary>
/// Basic IReceptor to receive external input
/// </summary>
[System.Serializable]
public class Receptor : Neuron, IReceptor {
/// <summary>
/// Create a new Receptor in a Cluster instance
/// </summary>
/// <param name="parent">The Cluster in which the Receptor is created</param>
/// <param name="name">The name of the new Receptor</param>
public Receptor(Cluster parent, string name) : base(parent, name) {
this.array = new NucleusArray(this);
if (this.name.IndexOf(":") < 0)
this.name += ": 0";
}
/// <summary>
/// Create a new Receptor in a Cluster Prefab
/// </summary>
/// <param name="prefab">The Cluster Prefab in which the Receptor is created</param>
/// <param name="name">The name of the new Receptor</param>
public Receptor(ClusterPrefab prefab, string name) : base(prefab, name) {
this.array = new NucleusArray(this);
}
public string GetName() {
return this.name;
}
/// \copydoc NanoBrain::Neuron::ShallowCloneTo
public override Nucleus ShallowCloneTo(Cluster parent) {
Receptor clone = new(parent, name) {
};
CloneFields(clone);
return clone;
}
/// \copydoc NanoBrain::Neuron::Clone
public override Nucleus Clone(ClusterPrefab prefab) {
Receptor clone = new(prefab, name) {
array = this._array
};
CloneFields(clone);
// Adding receivers will also add synapses to the receivers
foreach (Nucleus receiver in this.receivers.ToArray())
clone.AddReceiver(receiver);
return clone;
}
[SerializeReference]
private NucleusArray _array;
public NucleusArray array {
set { _array = value; }
}
public Nucleus[] nucleiArray {
get { return _array.nuclei; }
set { _array.nuclei = value; }
}
public void AddReceptorElement(ClusterPrefab prefab) {
IReceptorHelpers.AddReceptorElement(this, prefab);
}
public void RemoveReceptorElement() {
IReceptorHelpers.RemoveReceptorElement(this);
}
public virtual void AddArrayReceiver(Nucleus receiverToAdd, float weight = 1) {
IReceptorHelpers.AddArrayReceiver(this, receiverToAdd, weight);
}
public override void UpdateStateIsolated() {
this.outputValue = this.bias;
}
#if UNITY_MATHEMATICS
public override void UpdateNuclei() {
this.stale++;
if (this.stale > staleValueForSleep && lengthsq(this.bias) > 0) {
this.bias = new float3(0, 0, 0);
this.parent.UpdateFromNucleus(this);
}
}
#else
public override void UpdateNuclei() {
this.stale++;
if (this.stale > staleValueForSleep && this.bias.sqrMagnitude > 0) {
this.bias = new Vector3(0, 0, 0);
this.parent.UpdateFromNucleus(this);
}
}
#endif
public override void ProcessStimulus(Vector3 inputValue, int thingId = 0, string thingName = null) {
this._array ??= new NucleusArray(this.parent);
this._array.ProcessStimulus(thingId, inputValue, thingName);
}
}
}

View File

@ -1,2 +0,0 @@
fileFormatVersion: 2
guid: cfb9734aebc3ab85aacf87d26fb92e55

View File

@ -28,6 +28,12 @@ namespace NanoBrain {
this.neuron = nucleus;
this.weight = weight;
}
public bool isSleeping {
get {
return this.neuron.isSleeping;
}
}
}
}

View File

@ -10,6 +10,7 @@ namespace NanoBrain {
public class ClusterPrefab : ScriptableObject {
/// The nuclei in this cluster
[SerializeReference]
// This list should not include any clusters...
public List<Nucleus> nuclei = new();
/// <summary>