diff --git a/Editor/ClusterInspector.cs b/Editor/ClusterInspector.cs index b081cb3..d7a4008 100644 --- a/Editor/ClusterInspector.cs +++ b/Editor/ClusterInspector.cs @@ -8,8 +8,7 @@ using UnityEngine.UIElements; namespace NanoBrain { [CustomEditor(typeof(ClusterPrefab))] - public class ClusterInspector : Editor { - + public class ClusterInspector : ClusterViewer { public override VisualElement CreateInspectorGUI() { ClusterPrefab prefab = target as ClusterPrefab; if (prefab != null) @@ -37,11 +36,11 @@ namespace NanoBrain { flexDirection = FlexDirection.Row } }; - GraphView graph = new(cluster); + GraphEditor graph = new(cluster); graph.style.flexGrow = 1; - VisualElement inspectorContainer = new VisualElement { - name = "inspector", + VisualElement inspectorContainer = new() { + name = "inspector", style = { alignSelf = Align.Stretch, minHeight = 450, @@ -59,47 +58,10 @@ namespace NanoBrain { return graph; } - public class GraphView : VisualElement { - readonly ClusterPrefab prefab; - SerializedObject serializedBrain; - Nucleus currentNucleus; - GameObject gameObject; - private List layers = new(); - private readonly Dictionary neuroidPositions = new(); - private bool expandArray = false; + public class GraphEditor : GraphView { - ClusterPrefab prefabAsset; - readonly PopupField outputsField; + public GraphEditor(ClusterPrefab prefab) : base(prefab) { - public GraphView(ClusterPrefab prefab) { - this.prefab = prefab; - - name = "content"; - style.flexGrow = 1; - - IMGUIContainer graphContainer = new(OnIMGUI); - graphContainer.style.position = Position.Absolute; - graphContainer.style.left = 0; graphContainer.style.top = 0; - graphContainer.style.right = 0; graphContainer.style.bottom = 0; - graphContainer.pickingMode = PickingMode.Position; - graphContainer.focusable = true; - Add(graphContainer); - - VisualElement outputContainer = new() { - style = { - flexDirection = FlexDirection.Row, - alignItems = Align.Center, - } - }; - - List names = this.prefab.outputs.Select(output => output.name).ToList(); - if (names.Count > 0 && names.First() != null) { - outputsField = new(names, names.First()) { - style = { flexGrow = 1 } - }; - outputsField.RegisterValueChangedCallback(evt => OnOutputChanged(evt.newValue)); - outputContainer.Add(outputsField); - } Button addButton = new(() => OnAddClusterOutput()) { text = "Add" @@ -108,18 +70,6 @@ namespace NanoBrain { Add(outputContainer); - // Subscribe when added to panel (editor UI ready) - RegisterCallback(evt => Subscribe()); - RegisterCallback(evt => Unsubscribe()); - } - - void OnOutputChanged(string outputName) { - if (this.currentNucleus.parent != null) - // Get nucleus in the parent instance - this.currentNucleus = this.currentNucleus.parent.GetNucleus(outputName); - else - // Get nucleus in the prefab - this.currentNucleus = this.prefab.GetNucleus(outputName); } void OnAddClusterOutput() { @@ -131,19 +81,6 @@ namespace NanoBrain { this.currentNucleus = newOutput; } - bool subscribed = false; - void Subscribe() { - if (subscribed) return; - SceneView.duringSceneGui += OnSceneGUI; - subscribed = true; - SceneView.RepaintAll(); - } - - void Unsubscribe() { - if (!subscribed) return; - SceneView.duringSceneGui -= OnSceneGUI; - subscribed = false; - } public void SetGraph(GameObject gameObject, Nucleus nucleus, VisualElement inspectorContainer) { this.gameObject = gameObject; @@ -172,413 +109,7 @@ namespace NanoBrain { DrawInspector(inspectorContainer); } - private void BuildLayers() { - // A temporary list to track what's been added to layers - this.layers = new(); - int layerIx = 0; - - Nucleus selectedNucleus = this.currentNucleus; - if (selectedNucleus == null) - return; - NeuroidLayer currentLayer = new() { ix = layerIx }; - - if (selectedNucleus is Neuron selectedNeuron && selectedNeuron.receivers != null) { - foreach (Nucleus receiver in selectedNeuron.receivers) { - Nucleus outputNeuroid = receiver; - if (outputNeuroid != null) { - AddToLayer(currentLayer, outputNeuroid); - // Debug.Log($"layer {layerIx} nucleus {outputNeuroid.name}"); - } - } - } - if (currentLayer.neuroids.Count > 0) { - this.layers.Add(currentLayer); - layerIx++; - currentLayer = new() { ix = layerIx }; - } - - AddToLayer(currentLayer, selectedNucleus); - this.layers.Add(currentLayer); - // Debug.Log($"layer {layerIx} nucleus {selectedNucleus.name}"); - - layerIx++; - currentLayer = new() { ix = layerIx }; - - if (selectedNucleus.synapses != null) { - foreach (Synapse synapse in selectedNucleus.synapses) { - Nucleus input = synapse.neuron; - AddToLayer(currentLayer, input); - // Debug.Log($"layer {layerIx} nucleus {input.name}"); - } - } - if (currentLayer.neuroids.Count > 0) { - this.layers.Add(currentLayer); - } - } - - private void AddToLayer(NeuroidLayer layer, Nucleus nucleus) { - if (nucleus == null) - return; - layer.neuroids.Add(nucleus); - //nucleus.layerIx = layer.ix; - // Store its position - Vector2Int neuroidPosition = new(layer.ix, layer.neuroids.Count - 1); - neuroidPositions[nucleus] = neuroidPosition; - - } - - public void OnIMGUI() { - if (currentNucleus == null) - return; - - if (Application.isPlaying == false) - serializedBrain.Update(); - - Handles.BeginGUI(); - DrawGraph(); - Handles.EndGUI(); - - } - - private void DrawGraph() { - float size = 20; - Vector3 position = new(150, 210, 0); - - DrawReceivers(this.currentNucleus, position, size); - DrawSynapses(this.currentNucleus, position, size); - - // Draw selected Nucleus - if (expandArray) { - if (this.currentNucleus is IReceptor receptor1) { - float maxValue = 0; - foreach (Nucleus nucleus in receptor1.nucleiArray) { - if (nucleus is Neuron neuron) { - float value = neuron.outputMagnitude; - if (value > maxValue) - maxValue = value; - } - } - - float spacing = 400f / receptor1.nucleiArray.Count(); - float margin = 10 + spacing / 2; - float xMin = 150 - size; - float xMax = 150 + size; - float yMin = 10 + margin - size / 2; - float yMax = 400 - margin + size; - Vector3[] verts = new Vector3[4] { - new(xMin, yMin, 0), - new(xMax, yMin, 0), - new(xMax, yMax, 0), - new(xMin, yMax, 0) - }; - Handles.color = Color.black; - Handles.DrawAAConvexPolygon(verts); - int row = 0; - foreach (Nucleus nucleus in receptor1.nucleiArray) { - Vector3 pos = new(150, margin + row * spacing, 0.0f); - Handles.color = Color.white; - // The selected nucleus highlight ring - Handles.DrawSolidDisc(pos, Vector3.forward, size + 2); - DrawNucleus(nucleus, pos, maxValue, size); - row++; - } - GUIStyle style = new(EditorStyles.label) { - alignment = TextAnchor.UpperCenter, - normal = { textColor = Color.white }, - fontStyle = FontStyle.Bold, - }; - Vector3 labelPos = new(150, yMax + size + 5, 0); - string receptorName = receptor1.GetName(); - int colonPos = receptorName.IndexOf(":"); - if (colonPos > 0) { - string baseName = receptorName[..colonPos]; - Handles.Label(labelPos, baseName, style); - } - else - Handles.Label(labelPos, receptorName, style); - } - else { - Handles.color = Color.white; - // The selected nucleus highlight ring - Handles.DrawSolidDisc(position, Vector3.forward, size + 2); - float maxValue = 1; - if (this.currentNucleus is Neuron neuron) - maxValue = neuron.outputMagnitude; - else if (this.currentNucleus is Cluster cluster) - maxValue = cluster.defaultOutput.outputMagnitude; - - DrawNucleus(this.currentNucleus, position, maxValue, 20); - - } - } - else { - Handles.color = Color.white; - // The selected nucleus highlight ring - Handles.DrawSolidDisc(position, Vector3.forward, size + 2); - float maxValue = 1; - if (this.currentNucleus is Neuron neuron) - maxValue = neuron.outputMagnitude; - else if (this.currentNucleus is Cluster cluster) - maxValue = cluster.defaultOutput.outputMagnitude; - DrawNucleus(this.currentNucleus, position, maxValue, 20); - } - } - - private void DrawReceivers(Nucleus nucleus, Vector3 parentPos, float size) { - List receivers; - if (nucleus is Neuron neuron) - receivers = neuron.receivers; - else if (nucleus is Cluster cluster) - receivers = cluster.CollectReceivers(); - else - return; - - int nodeCount = receivers.Count(); //neuron != null ? neuron.receivers.Count() : 1; - - // Determine the maximum value in this layer - // This is used to 'scale' the output value colors of the nuclei - float maxValue = 0; - foreach (Nucleus receiver in receivers) { - if (receiver is Neuron neuroid) { - float value = neuroid.outputMagnitude; - if (value > maxValue) - maxValue = value; - } - } - - // Determine the spacing of the nuclei in the layer - float spacing = 400f / nodeCount; - float margin = 10 + spacing / 2; - - int row = 0; - List drawnArrays = new(); - foreach (Nucleus receiver in receivers) { - if (receiver is Receptor receptor) { - if (drawnArrays.Contains(receptor.nucleiArray)) - continue; - drawnArrays.Add(receptor.nucleiArray); - } - - Nucleus receiverNucleus = receiver; - if (receiverNucleus == null) - continue; - - Vector3 pos = new(50, margin + row * spacing, 0.0f); - Handles.color = Color.white; - Handles.DrawLine(parentPos, pos); - - DrawNucleus(receiverNucleus, pos, maxValue, size); - row++; - } - } - - private void DrawSynapses(Nucleus nucleus, Vector3 parentPos, float size) { - int nodeCount = nucleus.synapses.Count; - - // Determine the maximum value in this layer - // This is used to 'scale' the output value colors of the nuclei - float maxValue = 0; - int neuronCount = 0; - List drawnArrays = new(); - foreach (Synapse synapse in nucleus.synapses) { - if (synapse.neuron == null) - continue; - - if (synapse.neuron is Receptor receptor) { - if (drawnArrays.Contains(receptor.nucleiArray)) - continue; - drawnArrays.Add(receptor.nucleiArray); - } - else if (synapse.neuron.parent is ClusterReceptor clusterReceptor) { - if (drawnArrays.Contains(clusterReceptor.nucleiArray)) - continue; - drawnArrays.Add(clusterReceptor.nucleiArray); - } - if (synapse.neuron is Neuron synapseNeuron) { - float value = synapseNeuron.outputMagnitude * synapse.weight; - // Debug.Log($"{synapse.nucleus.name}: {value} {length(synapse.nucleus.outputValue)} {synapse.weight}"); - if (value > maxValue) - maxValue = value; - } - neuronCount++; - } - - // Determine the spacing of the nuclei in the layer - float spacing = 400f / neuronCount; - float margin = 10 + spacing / 2; - - int row = 0; - drawnArrays = new(); - foreach (Synapse synapse in nucleus.synapses) { - if (synapse.neuron is null) - continue; - - if (synapse.neuron is Receptor neuron) { - if (drawnArrays.Contains(neuron.nucleiArray)) - continue; - drawnArrays.Add(neuron.nucleiArray); - } - else if (synapse.neuron.parent is ClusterReceptor clusterReceptor) { - if (drawnArrays.Contains(clusterReceptor.nucleiArray)) - continue; - drawnArrays.Add(clusterReceptor.nucleiArray); - } - Vector3 pos = new(250, margin + row * spacing, 0.0f); - Handles.color = Color.white; - Handles.DrawLine(parentPos, pos); - Color color = Color.black; - if (Application.isPlaying) { - if (maxValue == 0 || !float.IsFinite(maxValue)) - maxValue = 1; - float brightness = 0; - if (synapse.neuron is Neuron synapseNeuron) - brightness = synapseNeuron.outputMagnitude * synapse.weight / maxValue; - color = new Color(brightness, brightness, brightness, 1f); - } - if (synapse.neuron.parent != null && synapse.neuron.parent != this.currentNucleus.parent) { - // the synapse nucleus is part of a subcluster - DrawNucleus(synapse.neuron.parent, pos, maxValue, size, color); - } - // else if (synapse.nucleus.cluster != null && synapse.nucleus.cluster != this.currentNucleus.cluster) { - // DrawNucleus(synapse.nucleus.parent, pos, maxValue, size, color); - // } - else { - DrawNucleus(synapse.neuron, pos, maxValue, size, color); - } - row++; - } - } - - private void DrawNucleus(Nucleus nucleus, Vector3 position, float maxValue, float size) { - Color color; - if (Application.isPlaying) { - float brightness = 0; - if (nucleus is Neuron neuron) - brightness = neuron.outputMagnitude / maxValue; - color = new Color(brightness, brightness, brightness, 1f); - } - else - color = Color.black; - DrawNucleus(nucleus, position, maxValue, size, color); - } - - private void DrawNucleus(Nucleus nucleus, Vector3 position, float maxValue, float size, Color color) { - if (nucleus is MemoryCell) { - Handles.color = Color.white; - Handles.DrawWireDisc(position + Vector3.right * 10, Vector3.forward, size); - } - - Handles.color = color; - Handles.DrawSolidDisc(position, Vector3.forward, size); - - Handles.color = Color.white; - // Position the label in front of the disc - Vector3 labelPosition = position + (Vector3.forward * 0.1f); - - GUIStyle style = new(EditorStyles.label) { - alignment = TextAnchor.MiddleCenter, - normal = { textColor = Color.white }, - fontStyle = FontStyle.Bold, - }; - - if (nucleus is IReceptor receptor1) { - if (expandArray) { - // Put array indices above elements - style.alignment = TextAnchor.LowerCenter; - Vector3 labelPos1 = position + Vector3.down * (size + 5); // below disc - int colonPos1 = nucleus.name.IndexOf(":"); - if (colonPos1 > 0) { - string extName = nucleus.name[(colonPos1 + 2)..]; - Handles.Label(labelPos1, extName, style); - } - } - else { - // draw the array size label - if (color.grayscale > 0.5f) - style.normal.textColor = Color.black; - else - style.normal.textColor = Color.white; - Handles.Label(labelPosition, receptor1.nucleiArray.Length.ToString(), style); - style.normal.textColor = Color.white; - } - } - - if (expandArray == false || nucleus is not IReceptor) { - // put name below nucleus - Vector3 labelPos = position - Vector3.down * (size + 5); // below neuron - style.alignment = TextAnchor.UpperCenter; - - int colonPos = nucleus.name.IndexOf(":"); - if (colonPos > 0 && colonPos < nucleus.name.Length - 2) { - // if it is an array, we should not show the :0 of the first element - string baseName = nucleus.name[..colonPos]; - Handles.Label(labelPos, baseName, style); - } - else - Handles.Label(labelPos, nucleus.name, style); - - } - - // Draw Cluster ring - if (nucleus is Cluster) { - Handles.color = Color.white; - Handles.DrawWireDisc(position, Vector3.forward, size + 5); - } - - // Tooltip - Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2); - int id = GUIUtility.GetControlID(FocusType.Passive); - Event e = Event.current; - EventType et = e.GetTypeForControl(id); - if (e != null && neuronRect.Contains(e.mousePosition)) { - // Process Hover - HandleMouseHover(nucleus, neuronRect); - // Process click - if (e.type == EventType.MouseDown && e.button == 0) { - // Consume the event so the scene doesn't also handle it - e.Use(); - HandleClicked(nucleus); - } - } - } - - private void HandleMouseHover(Nucleus nucleus, Rect rect) { - GUIContent tooltip; - if (nucleus is Neuron neuron) { - tooltip = new( - $"{nucleus.name}" + - $"\nValue: {neuron.outputMagnitude}"); - } - else - tooltip = new($"{nucleus.name}"); - - Vector2 mousePosition = Event.current.mousePosition; - - // Display tooltip with some offset - Vector2 tooltipSize = GUI.skin.box.CalcSize(tooltip); - Rect tooltipRect = new Rect(mousePosition.x + 10, mousePosition.y + 10, tooltipSize.x, tooltipSize.y); - - GUI.Box(tooltipRect, tooltip); - } - - private void HandleClicked(Nucleus nucleus) { - if (nucleus == this.currentNucleus) { - if (nucleus is Receptor || nucleus is ClusterReceptor) - expandArray = !expandArray; - else - expandArray = false; - } - // else if (nucleus is ReceptorInstance receptor) { - // this.currentNucleus = receptor.receptor; - // expandArray = false; - // BuildLayers(); - // } - else { - this.currentNucleus = nucleus; - expandArray = false; - BuildLayers(); - } - } + #region Inspector private VisualElement inspectorIMGUIContainer; private bool showSynapses = true; @@ -785,7 +316,7 @@ namespace NanoBrain { EditorGUILayout.CurveField(neuron.curve, Color.cyan, new Rect(0, 0, 1, neuron.curveMax)); else EditorGUILayout.CurveField(neuron.curve, Color.cyan, new Rect(0, neuron.curveMax, 1, -neuron.curveMax)); - Neuron.CurvePresets newPreset = (Neuron.CurvePresets)EditorGUILayout.EnumPopup(neuron.curvePreset, GUILayout.Width(100)); + Neuron.ActivationFunction newPreset = (Neuron.ActivationFunction)EditorGUILayout.EnumPopup(neuron.curvePreset, GUILayout.Width(100)); anythingChanged |= newPreset != neuron.curvePreset; neuron.curvePreset = newPreset; EditorGUILayout.EndHorizontal(); @@ -1063,14 +594,1076 @@ namespace NanoBrain { } } - #endregion Synapses + #endregion Synapses + + #endregion Inspector + } + } + /* + [CustomEditor(typeof(ClusterPrefab))] + public class ClusterInspector : Editor { + + public override VisualElement CreateInspectorGUI() { + ClusterPrefab prefab = target as ClusterPrefab; + if (prefab != null) + prefab.EnsureInitialization(); + + serializedObject.Update(); + + VisualElement root = new(); + CreateInspector(root, prefab, prefab.output, null); + + serializedObject.ApplyModifiedProperties(); + return root; + } + + public static GraphView CreateInspector(VisualElement root, ClusterPrefab cluster, Nucleus output, GameObject gameObject) { + root.style.paddingLeft = 0; + root.style.paddingRight = 0; + root.style.paddingTop = 0; + root.style.paddingBottom = 0; + + root.styleSheets.Add(Resources.Load("GraphStyles")); + + VisualElement mainContainer = new() { + style = { + flexDirection = FlexDirection.Row + } + }; + GraphView graph = new(cluster); + graph.style.flexGrow = 1; + + VisualElement inspectorContainer = new VisualElement { + name = "inspector", + style = { + alignSelf = Align.Stretch, + minHeight = 450, + width = 300, + flexGrow = 0 + } + }; + + mainContainer.Add(graph); + mainContainer.Add(inspectorContainer); + root.Add(mainContainer); + + graph.SetGraph(gameObject, output, inspectorContainer); + + return graph; + } + + public class GraphView : VisualElement { + readonly ClusterPrefab prefab; + SerializedObject serializedBrain; + Nucleus currentNucleus; + GameObject gameObject; + private List layers = new(); + private readonly Dictionary neuroidPositions = new(); + private bool expandArray = false; + + ClusterPrefab prefabAsset; + readonly PopupField outputsField; + + public GraphView(ClusterPrefab prefab) { + this.prefab = prefab; + + name = "content"; + style.flexGrow = 1; + + IMGUIContainer graphContainer = new(OnIMGUI); + graphContainer.style.position = Position.Absolute; + graphContainer.style.left = 0; graphContainer.style.top = 0; + graphContainer.style.right = 0; graphContainer.style.bottom = 0; + graphContainer.pickingMode = PickingMode.Position; + graphContainer.focusable = true; + Add(graphContainer); + + VisualElement outputContainer = new() { + style = { + flexDirection = FlexDirection.Row, + alignItems = Align.Center, + } + }; + + List names = this.prefab.outputs.Select(output => output.name).ToList(); + if (names.Count > 0 && names.First() != null) { + outputsField = new(names, names.First()) { + style = { flexGrow = 1 } + }; + outputsField.RegisterValueChangedCallback(evt => OnOutputChanged(evt.newValue)); + outputContainer.Add(outputsField); + } + + Button addButton = new(() => OnAddClusterOutput()) { + text = "Add" + }; + outputContainer.Add(addButton); + + Add(outputContainer); + + // Subscribe when added to panel (editor UI ready) + RegisterCallback(evt => Subscribe()); + RegisterCallback(evt => Unsubscribe()); + } + + void OnOutputChanged(string outputName) { + if (this.currentNucleus.parent != null) + // Get nucleus in the parent instance + this.currentNucleus = this.currentNucleus.parent.GetNucleus(outputName); + else + // Get nucleus in the prefab + this.currentNucleus = this.prefab.GetNucleus(outputName); + } + + void OnAddClusterOutput() { + Nucleus newOutput = new Neuron(this.prefab, "New Output"); + this.prefab.RefreshOutputs(); + outputsField.choices = this.prefab.outputs.Select(output => output.name).ToList(); + outputsField.value = newOutput.name; + + this.currentNucleus = newOutput; + } + + bool subscribed = false; + void Subscribe() { + if (subscribed) return; + SceneView.duringSceneGui += OnSceneGUI; + subscribed = true; + SceneView.RepaintAll(); + } + + void Unsubscribe() { + if (!subscribed) return; + SceneView.duringSceneGui -= OnSceneGUI; + subscribed = false; + } + + public void SetGraph(GameObject gameObject, Nucleus nucleus, VisualElement inspectorContainer) { + this.gameObject = gameObject; + //this.cluster = brain; + if (Application.isPlaying == false) + this.serializedBrain = new SerializedObject(this.prefab); + this.currentNucleus = nucleus; + Rebuild(inspectorContainer); + } + + void Rebuild(VisualElement inspectorContainer) { + BuildLayers(); + + if (this.currentNucleus == null) { + inspectorContainer.Clear(); + return; + } + + string path = AssetDatabase.GetAssetPath(this.prefab); // or known path + this.prefabAsset = AssetDatabase.LoadAssetAtPath(path); + if (this.prefabAsset == null) { + // create in memory save if it doesn't exist + this.prefabAsset = CreateInstance(); + //Debug.LogError("Cluster Prefab is not found on disk"); + } + DrawInspector(inspectorContainer); + } + + private void BuildLayers() { + // A temporary list to track what's been added to layers + this.layers = new(); + int layerIx = 0; + + Nucleus selectedNucleus = this.currentNucleus; + if (selectedNucleus == null) + return; + NeuroidLayer currentLayer = new() { ix = layerIx }; + + if (selectedNucleus is Neuron selectedNeuron && selectedNeuron.receivers != null) { + foreach (Nucleus receiver in selectedNeuron.receivers) { + Nucleus outputNeuroid = receiver; + if (outputNeuroid != null) { + AddToLayer(currentLayer, outputNeuroid); + // Debug.Log($"layer {layerIx} nucleus {outputNeuroid.name}"); + } + } + } + if (currentLayer.neuroids.Count > 0) { + this.layers.Add(currentLayer); + layerIx++; + currentLayer = new() { ix = layerIx }; + } + + AddToLayer(currentLayer, selectedNucleus); + this.layers.Add(currentLayer); + // Debug.Log($"layer {layerIx} nucleus {selectedNucleus.name}"); + + layerIx++; + currentLayer = new() { ix = layerIx }; + + if (selectedNucleus.synapses != null) { + foreach (Synapse synapse in selectedNucleus.synapses) { + Nucleus input = synapse.neuron; + AddToLayer(currentLayer, input); + // Debug.Log($"layer {layerIx} nucleus {input.name}"); + } + } + if (currentLayer.neuroids.Count > 0) { + this.layers.Add(currentLayer); + } + } + + private void AddToLayer(NeuroidLayer layer, Nucleus nucleus) { + if (nucleus == null) + return; + layer.neuroids.Add(nucleus); + //nucleus.layerIx = layer.ix; + // Store its position + Vector2Int neuroidPosition = new(layer.ix, layer.neuroids.Count - 1); + neuroidPositions[nucleus] = neuroidPosition; + + } + + public void OnIMGUI() { + if (currentNucleus == null) + return; + + if (Application.isPlaying == false) + serializedBrain.Update(); + + Handles.BeginGUI(); + DrawGraph(); + Handles.EndGUI(); + + } + + private void DrawGraph() { + float size = 20; + Vector3 position = new(150, 210, 0); + + DrawReceivers(this.currentNucleus, position, size); + DrawSynapses(this.currentNucleus, position, size); + + // Draw selected Nucleus + if (expandArray) { + if (this.currentNucleus is IReceptor receptor1) { + float maxValue = 0; + foreach (Nucleus nucleus in receptor1.nucleiArray) { + if (nucleus is Neuron neuron) { + float value = neuron.outputMagnitude; + if (value > maxValue) + maxValue = value; + } + } + + float spacing = 400f / receptor1.nucleiArray.Count(); + float margin = 10 + spacing / 2; + float xMin = 150 - size; + float xMax = 150 + size; + float yMin = 10 + margin - size / 2; + float yMax = 400 - margin + size; + Vector3[] verts = new Vector3[4] { + new(xMin, yMin, 0), + new(xMax, yMin, 0), + new(xMax, yMax, 0), + new(xMin, yMax, 0) + }; + Handles.color = Color.black; + Handles.DrawAAConvexPolygon(verts); + int row = 0; + foreach (Nucleus nucleus in receptor1.nucleiArray) { + Vector3 pos = new(150, margin + row * spacing, 0.0f); + Handles.color = Color.white; + // The selected nucleus highlight ring + Handles.DrawSolidDisc(pos, Vector3.forward, size + 2); + DrawNucleus(nucleus, pos, maxValue, size); + row++; + } + GUIStyle style = new(EditorStyles.label) { + alignment = TextAnchor.UpperCenter, + normal = { textColor = Color.white }, + fontStyle = FontStyle.Bold, + }; + Vector3 labelPos = new(150, yMax + size + 5, 0); + string receptorName = receptor1.GetName(); + int colonPos = receptorName.IndexOf(":"); + if (colonPos > 0) { + string baseName = receptorName[..colonPos]; + Handles.Label(labelPos, baseName, style); + } + else + Handles.Label(labelPos, receptorName, style); + } + else { + Handles.color = Color.white; + // The selected nucleus highlight ring + Handles.DrawSolidDisc(position, Vector3.forward, size + 2); + float maxValue = 1; + if (this.currentNucleus is Neuron neuron) + maxValue = neuron.outputMagnitude; + else if (this.currentNucleus is Cluster cluster) + maxValue = cluster.defaultOutput.outputMagnitude; + + DrawNucleus(this.currentNucleus, position, maxValue, 20); + + } + } + else { + Handles.color = Color.white; + // The selected nucleus highlight ring + Handles.DrawSolidDisc(position, Vector3.forward, size + 2); + float maxValue = 1; + if (this.currentNucleus is Neuron neuron) + maxValue = neuron.outputMagnitude; + else if (this.currentNucleus is Cluster cluster) + maxValue = cluster.defaultOutput.outputMagnitude; + DrawNucleus(this.currentNucleus, position, maxValue, 20); + } + } + + private void DrawReceivers(Nucleus nucleus, Vector3 parentPos, float size) { + List receivers; + if (nucleus is Neuron neuron) + receivers = neuron.receivers; + else if (nucleus is Cluster cluster) + receivers = cluster.CollectReceivers(); + else + return; + + int nodeCount = receivers.Count(); //neuron != null ? neuron.receivers.Count() : 1; + + // Determine the maximum value in this layer + // This is used to 'scale' the output value colors of the nuclei + float maxValue = 0; + foreach (Nucleus receiver in receivers) { + if (receiver is Neuron neuroid) { + float value = neuroid.outputMagnitude; + if (value > maxValue) + maxValue = value; + } + } + + // Determine the spacing of the nuclei in the layer + float spacing = 400f / nodeCount; + float margin = 10 + spacing / 2; + + int row = 0; + List drawnArrays = new(); + foreach (Nucleus receiver in receivers) { + if (receiver is Receptor receptor) { + if (drawnArrays.Contains(receptor.nucleiArray)) + continue; + drawnArrays.Add(receptor.nucleiArray); + } + + Nucleus receiverNucleus = receiver; + if (receiverNucleus == null) + continue; + + Vector3 pos = new(50, margin + row * spacing, 0.0f); + Handles.color = Color.white; + Handles.DrawLine(parentPos, pos); + + DrawNucleus(receiverNucleus, pos, maxValue, size); + row++; + } + } + + private void DrawSynapses(Nucleus nucleus, Vector3 parentPos, float size) { + int nodeCount = nucleus.synapses.Count; + + // Determine the maximum value in this layer + // This is used to 'scale' the output value colors of the nuclei + float maxValue = 0; + int neuronCount = 0; + List drawnArrays = new(); + foreach (Synapse synapse in nucleus.synapses) { + if (synapse.neuron == null) + continue; + + if (synapse.neuron is Receptor receptor) { + if (drawnArrays.Contains(receptor.nucleiArray)) + continue; + drawnArrays.Add(receptor.nucleiArray); + } + else if (synapse.neuron.parent is ClusterReceptor clusterReceptor) { + if (drawnArrays.Contains(clusterReceptor.nucleiArray)) + continue; + drawnArrays.Add(clusterReceptor.nucleiArray); + } + if (synapse.neuron is Neuron synapseNeuron) { + float value = synapseNeuron.outputMagnitude * synapse.weight; + // Debug.Log($"{synapse.nucleus.name}: {value} {length(synapse.nucleus.outputValue)} {synapse.weight}"); + if (value > maxValue) + maxValue = value; + } + neuronCount++; + } + + // Determine the spacing of the nuclei in the layer + float spacing = 400f / neuronCount; + float margin = 10 + spacing / 2; + + int row = 0; + drawnArrays = new(); + foreach (Synapse synapse in nucleus.synapses) { + if (synapse.neuron is null) + continue; + + if (synapse.neuron is Receptor neuron) { + if (drawnArrays.Contains(neuron.nucleiArray)) + continue; + drawnArrays.Add(neuron.nucleiArray); + } + else if (synapse.neuron.parent is ClusterReceptor clusterReceptor) { + if (drawnArrays.Contains(clusterReceptor.nucleiArray)) + continue; + drawnArrays.Add(clusterReceptor.nucleiArray); + } + Vector3 pos = new(250, margin + row * spacing, 0.0f); + Handles.color = Color.white; + Handles.DrawLine(parentPos, pos); + Color color = Color.black; + if (Application.isPlaying) { + if (maxValue == 0 || !float.IsFinite(maxValue)) + maxValue = 1; + float brightness = 0; + if (synapse.neuron is Neuron synapseNeuron) + brightness = synapseNeuron.outputMagnitude * synapse.weight / maxValue; + color = new Color(brightness, brightness, brightness, 1f); + } + if (synapse.neuron.parent != null && synapse.neuron.parent != this.currentNucleus.parent) { + // the synapse nucleus is part of a subcluster + DrawNucleus(synapse.neuron.parent, pos, maxValue, size, color); + } + // else if (synapse.nucleus.cluster != null && synapse.nucleus.cluster != this.currentNucleus.cluster) { + // DrawNucleus(synapse.nucleus.parent, pos, maxValue, size, color); + // } + else { + DrawNucleus(synapse.neuron, pos, maxValue, size, color); + } + row++; + } + } + + private void DrawNucleus(Nucleus nucleus, Vector3 position, float maxValue, float size) { + Color color; + if (Application.isPlaying) { + float brightness = 0; + if (nucleus is Neuron neuron) + brightness = neuron.outputMagnitude / maxValue; + color = new Color(brightness, brightness, brightness, 1f); + } + else + color = Color.black; + DrawNucleus(nucleus, position, maxValue, size, color); + } + + private void DrawNucleus(Nucleus nucleus, Vector3 position, float maxValue, float size, Color color) { + if (nucleus is MemoryCell) { + Handles.color = Color.white; + Handles.DrawWireDisc(position + Vector3.right * 10, Vector3.forward, size); + } + + Handles.color = color; + Handles.DrawSolidDisc(position, Vector3.forward, size); + + Handles.color = Color.white; + // Position the label in front of the disc + Vector3 labelPosition = position + (Vector3.forward * 0.1f); + + GUIStyle style = new(EditorStyles.label) { + alignment = TextAnchor.MiddleCenter, + normal = { textColor = Color.white }, + fontStyle = FontStyle.Bold, + }; + + if (nucleus is IReceptor receptor1) { + if (expandArray) { + // Put array indices above elements + style.alignment = TextAnchor.LowerCenter; + Vector3 labelPos1 = position + Vector3.down * (size + 5); // below disc + int colonPos1 = nucleus.name.IndexOf(":"); + if (colonPos1 > 0) { + string extName = nucleus.name[(colonPos1 + 2)..]; + Handles.Label(labelPos1, extName, style); + } + } + else { + // draw the array size label + if (color.grayscale > 0.5f) + style.normal.textColor = Color.black; + else + style.normal.textColor = Color.white; + Handles.Label(labelPosition, receptor1.nucleiArray.Length.ToString(), style); + style.normal.textColor = Color.white; + } + } + + if (expandArray == false || nucleus is not IReceptor) { + // put name below nucleus + Vector3 labelPos = position - Vector3.down * (size + 5); // below neuron + style.alignment = TextAnchor.UpperCenter; + + int colonPos = nucleus.name.IndexOf(":"); + if (colonPos > 0 && colonPos < nucleus.name.Length - 2) { + // if it is an array, we should not show the :0 of the first element + string baseName = nucleus.name[..colonPos]; + Handles.Label(labelPos, baseName, style); + } + else + Handles.Label(labelPos, nucleus.name, style); + + } + + // Draw Cluster ring + if (nucleus is Cluster) { + Handles.color = Color.white; + Handles.DrawWireDisc(position, Vector3.forward, size + 5); + } + + // Tooltip + Rect neuronRect = new(position.x - size, position.y - size, size * 2, size * 2); + int id = GUIUtility.GetControlID(FocusType.Passive); + Event e = Event.current; + EventType et = e.GetTypeForControl(id); + if (e != null && neuronRect.Contains(e.mousePosition)) { + // Process Hover + HandleMouseHover(nucleus, neuronRect); + // Process click + if (e.type == EventType.MouseDown && e.button == 0) { + // Consume the event so the scene doesn't also handle it + e.Use(); + HandleClicked(nucleus); + } + } + } + + private void HandleMouseHover(Nucleus nucleus, Rect rect) { + GUIContent tooltip; + if (nucleus is Neuron neuron) { + tooltip = new( + $"{nucleus.name}" + + $"\nValue: {neuron.outputMagnitude}"); + } + else + tooltip = new($"{nucleus.name}"); + + Vector2 mousePosition = Event.current.mousePosition; + + // Display tooltip with some offset + Vector2 tooltipSize = GUI.skin.box.CalcSize(tooltip); + Rect tooltipRect = new Rect(mousePosition.x + 10, mousePosition.y + 10, tooltipSize.x, tooltipSize.y); + + GUI.Box(tooltipRect, tooltip); + } + + private void HandleClicked(Nucleus nucleus) { + if (nucleus == this.currentNucleus) { + if (nucleus is Receptor || nucleus is ClusterReceptor) + expandArray = !expandArray; + else + expandArray = false; + } + // else if (nucleus is ReceptorInstance receptor) { + // this.currentNucleus = receptor.receptor; + // expandArray = false; + // BuildLayers(); + // } + else { + this.currentNucleus = nucleus; + expandArray = false; + BuildLayers(); + } + } + + private VisualElement inspectorIMGUIContainer; + private bool showSynapses = true; + private bool showActivation = true; + protected bool breakOnWake = false; + protected bool trace = false; + void DrawInspector(VisualElement inspectorContainer) { + if (inspectorContainer == null) + return; + + inspectorContainer.Clear(); + if (this.currentNucleus == null) + return; + + // create a SerializedObject wrapper so Unity inspector controls work (and Undo) + SerializedObject so = new(prefabAsset); + this.inspectorIMGUIContainer = new IMGUIContainer(() => InspectorHandler(so)); + + inspectorContainer.Add(inspectorIMGUIContainer); + } + + void InspectorHandler(SerializedObject serializedObject) { + bool anythingChanged = false; + + if (serializedObject == null || serializedObject.targetObject == null) + return; + + if (this.currentNucleus == null) + return; + + serializedObject.Update(); + + GUIStyle headerStyle = new(EditorStyles.boldLabel) { + alignment = TextAnchor.MiddleLeft, + margin = new RectOffset(10, 0, 4, 4) + }; + GUIStyle boldTextFieldStyle = new(EditorStyles.textField) { + fontStyle = FontStyle.Bold + }; + + GUILayout.Label(this.currentNucleus.GetType().ToString(), headerStyle); + string newName = EditorGUILayout.TextField(this.currentNucleus.name, boldTextFieldStyle); + if (newName != this.currentNucleus.name) { + this.currentNucleus.name = newName; + this.prefab.RefreshOutputs(); + outputsField.choices = this.prefab.outputs.Select(output => output.name).ToList(); + anythingChanged = true; + } + + if (Application.isPlaying) { + if (currentNucleus is Neuron currentNeuron1) { + GUIContent nameLabel = new("Output", currentNeuron1.outputValue.ToString()); + EditorGUILayout.FloatField(nameLabel, currentNeuron1.outputMagnitude); + } + else + EditorGUILayout.LabelField(" "); + } + else + EditorGUILayout.LabelField(" "); + + if (this.currentNucleus is MemoryCell memory) { + memory.staticMemory = EditorGUILayout.Toggle("Static Memory", memory.staticMemory); + } + + if (this.currentNucleus is IReceptor receptor1) { + EditorGUILayout.BeginHorizontal(); + EditorGUILayout.IntField("Array size", receptor1.nucleiArray.Count()); + if (GUILayout.Button("Add")) { + Undo.RecordObject(prefabAsset, "Array add " + prefabAsset.name); + receptor1.AddReceptorElement(this.prefab); + anythingChanged = true; + } + if (GUILayout.Button("Del")) { + Undo.RecordObject(prefabAsset, "Array delete " + prefabAsset.name); + receptor1.RemoveReceptorElement(); + anythingChanged = true; + } + EditorGUILayout.EndHorizontal(); + } + + // Synapses + + if (this.currentNucleus is not Receptor && this.currentNucleus is not ClusterReceptor) { + showSynapses = EditorGUILayout.BeginFoldoutHeaderGroup(showSynapses, "Synapses"); + if (showSynapses) { + if (this.currentNucleus is Neuron neuron2) { + Neuron.CombinatorType newCombinator = (Neuron.CombinatorType)EditorGUILayout.EnumPopup("Combinator", neuron2.combinator); + anythingChanged |= newCombinator != neuron2.combinator; + neuron2.combinator = newCombinator; + } + + EditorGUIUtility.wideMode = true; + EditorGUIUtility.labelWidth = 100; + Vector3 newBias = EditorGUILayout.Vector3Field("Bias", this.currentNucleus.bias); + anythingChanged |= newBias != this.currentNucleus.bias; + this.currentNucleus.bias = newBias; + + Nucleus[] array = null; + int elementIx = -1; + if (this.currentNucleus.synapses.Count > 0) { + Synapse[] synapses = this.currentNucleus.synapses.ToArray(); + foreach (Synapse synapse in synapses) { + if (synapse.neuron == null) + continue; + + if (array != null) { + if (synapse.neuron.parent is Cluster iCluster && elementIx > 0) { + int thisElementIx = Cluster.GetNucleusIndex(iCluster.clusterNuclei, synapse.neuron); + if (thisElementIx == elementIx) + continue; + else + elementIx = thisElementIx; + } + // if (array.Contains(synapse.nucleus)) + // continue; + else if (array.Contains(synapse.neuron.parent)) + continue; + } + else { + if (synapse.neuron.parent is IReceptor iReceptor) { + array = iReceptor.nucleiArray; + if (iReceptor is Cluster iCluster) + elementIx = Cluster.GetNucleusIndex(iCluster.clusterNuclei, synapse.neuron); + } + // else if (synapse.nucleus is Receptor receptor2) // && receptor2.array != null && receptor2.array.nuclei.Length > 1) + // array = receptor2.nucleiArray; + } + + EditorGUILayout.Space(); + + if (Application.isPlaying) { + if (synapse.neuron is Neuron synapseNeuron) { + Vector3 value = synapseNeuron.outputValue * synapse.weight; + GUIContent synapseValueLabel = new(synapse.neuron.name, synapseNeuron.outputValue.ToString()); + EditorGUILayout.FloatField(synapseValueLabel, synapseNeuron.outputMagnitude); + } + } + else { + EditorGUILayout.BeginHorizontal(); + + if (synapse.neuron.parent != null && synapse.neuron.parent != this.currentNucleus) { + // If it is a cluster + GUIStyle labelStyle = new(GUI.skin.label); + float labelWidth = 200; + if (synapse.neuron.clusterPrefab != null) { + labelWidth = labelStyle.CalcSize(new GUIContent($"{synapse.neuron.parent.baseName}.")).x; + GUILayout.Label($"{synapse.neuron.parent.baseName}", GUILayout.Width(labelWidth)); + } + string[] options = synapse.neuron.parent.clusterNuclei.Select(n => n.name).ToArray(); + int selectedIndex = System.Array.IndexOf(options, synapse.neuron.name); + int newIndex = EditorGUILayout.Popup(selectedIndex, options); + if (newIndex != selectedIndex && synapse.neuron.parent.clusterNuclei[newIndex] is Neuron newNeuron) + ChangeSynapse(synapse, newNeuron); + } + else + GUILayout.Label(synapse.neuron.name); + + bool disconnecting = GUILayout.Button("Disconnect", GUILayout.Width(80)); + if (disconnecting && synapse.neuron is Neuron synapseNeuron) { + synapseNeuron.RemoveReceiver(this.currentNucleus); + this.prefab.GarbageCollection(); + anythingChanged = true; + } + EditorGUILayout.EndHorizontal(); + + } + + EditorGUI.indentLevel++; + float newWeight = EditorGUILayout.FloatField("Weight", synapse.weight); + if (newWeight != synapse.weight) { + if (synapse.neuron.parent is IReceptor receptor) { + Nucleus[] receptorArray = receptor.nucleiArray; + foreach (Synapse s in this.currentNucleus.synapses) { + if (s.neuron.parent is IReceptor r && r.nucleiArray == receptorArray) + s.weight = newWeight; + } + } + else + synapse.weight = newWeight; + anythingChanged = true; + } + EditorGUI.indentLevel--; + } + } + + EditorGUILayout.Space(); + anythingChanged |= ConnectNucleus(this.prefab, this.currentNucleus); + anythingChanged |= AddSynapse(this.prefab, this.currentNucleus); + } + EditorGUILayout.EndFoldoutHeaderGroup(); + } + + // Activation + + if (this.currentNucleus is not Cluster) { + EditorGUILayout.Space(); + showActivation = EditorGUILayout.BeginFoldoutHeaderGroup(showActivation, "Activation"); + if (showActivation) { + if (this.currentNucleus is Neuron neuron) { + if (this.currentNucleus is not MemoryCell) { + EditorGUILayout.BeginHorizontal(); + EditorGUILayout.LabelField("Activation Curve", GUILayout.Width(150)); + if (neuron.curveMax > 0) + EditorGUILayout.CurveField(neuron.curve, Color.cyan, new Rect(0, 0, 1, neuron.curveMax)); + else + EditorGUILayout.CurveField(neuron.curve, Color.cyan, new Rect(0, neuron.curveMax, 1, -neuron.curveMax)); + Neuron.CurvePresets newPreset = (Neuron.CurvePresets)EditorGUILayout.EnumPopup(neuron.curvePreset, GUILayout.Width(100)); + anythingChanged |= newPreset != neuron.curvePreset; + neuron.curvePreset = newPreset; + EditorGUILayout.EndHorizontal(); + } + if (neuron is Receptor receptor2) { + if (receptor2.nucleiArray == null || receptor2.nucleiArray.Count() == 0) + receptor2.array = new NucleusArray(neuron); + } + } + + EditorGUILayout.Space(); + } + EditorGUILayout.EndFoldoutHeaderGroup(); + } + + if (GUILayout.Button("Delete this neuron")) + DeleteNucleus(this.currentNucleus); + + if (this.currentNucleus is Cluster subCluster) { + if (GUILayout.Button("Edit Cluster")) + EditCluster(subCluster); + } + + EditorGUILayout.Space(); + breakOnWake = EditorGUILayout.Toggle("Break on wake", breakOnWake); + if (breakOnWake && this.currentNucleus is Neuron currentNeuron) { + if (currentNeuron.isSleeping == false) + Debug.Break(); + } + trace = EditorGUILayout.Toggle("Trace", trace); + this.currentNucleus.trace = trace; + + serializedObject.ApplyModifiedProperties(); + if (anythingChanged) { + EditorUtility.SetDirty(prefabAsset); + AssetDatabase.SaveAssets(); + } + } + + void OnSceneGUI(SceneView sceneView) { + if (this.gameObject != null) { + if (this.currentNucleus is IReceptor receptor) { + foreach (Nucleus nucleus in receptor.nucleiArray) { + if (nucleus is Neuron neuron) { + Vector3 worldVector = this.gameObject.transform.TransformVector(neuron.outputValue); + Handles.color = Color.yellow; + Handles.DrawLine(this.gameObject.transform.position, this.gameObject.transform.position + worldVector); + } + } + } + else { + if (this.currentNucleus is Neuron currentNeuron) { + Vector3 worldVector = this.gameObject.transform.TransformVector(currentNeuron.outputValue); + Handles.color = Color.yellow; + Handles.DrawLine(this.gameObject.transform.position, this.gameObject.transform.position + worldVector); + } + } + } + } + + #region Synapses + + protected virtual void AddInput(Nucleus.Type selectedType, Nucleus nucleus) { + switch (selectedType) { + case Nucleus.Type.Neuron: + AddNeuronInput(nucleus); + break; + case Nucleus.Type.MemoryCell: + AddMemoryCellInput(nucleus); + break; + // case Nucleus.Type.Selector: + // AddSelectorInput(nucleus); + // break; + case Nucleus.Type.Cluster: + AddClusterInput(nucleus); + break; + // case Nucleus.Type.Pulsar: + // AddPulsarInput(nucleus); + // break; + case Nucleus.Type.Receptor: + AddReceptorInput(nucleus); + break; + // case Nucleus.Type.ReceptorArray: + // AddReceptorArrayInput(nucleus); + // break; + case Nucleus.Type.ClusterReceptor: + AddClusterReceptorInput(nucleus); + break; + default: + break; + } + } + + protected virtual void AddNeuronInput(Nucleus nucleus) { + Neuron newNeuroid = new(this.prefab, "New neuron"); + newNeuroid.AddReceiver(nucleus); + this.currentNucleus = newNeuroid; + BuildLayers(); + } + + protected virtual void AddMemoryCellInput(Nucleus nucleus) { + MemoryCell newMemory = new(this.prefab, "New memory cell"); + newMemory.AddReceiver(nucleus); + this.currentNucleus = newMemory; + BuildLayers(); + } + + protected virtual void AddClusterInput(Nucleus nucleus) { + ClusterPickerWindow.ShowPicker(brain => OnClusterPicked(nucleus, brain), "Select Cluster"); + } + private void OnClusterPicked(Nucleus nucleus, ClusterPrefab prefab) { + Cluster subclusterInstance = new(prefab, this.prefab); + subclusterInstance.defaultOutput.AddReceiver(nucleus); + } + + protected virtual void AddReceptorInput(Nucleus nucleus) { + Receptor newReceptor = new(this.prefab, "New Receptor"); + newReceptor.AddReceiver(nucleus); + this.currentNucleus = newReceptor; + BuildLayers(); + } + + protected virtual void AddClusterReceptorInput(Nucleus nucleus) { + ClusterPickerWindow.ShowPicker(prefab => OnClusterReceptorPicked(nucleus, prefab), "Select Cluster"); + } + private void OnClusterReceptorPicked(Nucleus nucleus, ClusterPrefab selectedPrefab) { + ClusterReceptor clusterInstance = new(selectedPrefab, this.prefab, "New " + selectedPrefab.name); + clusterInstance.defaultOutput.AddReceiver(nucleus); + this.currentNucleus = clusterInstance; + BuildLayers(); + } + + private void EditCluster(Cluster subCluster) { + // May be used with storedPrefab... + Selection.activeObject = subCluster.prefab; + EditorGUIUtility.PingObject(subCluster.prefab); + var editor = Editor.CreateEditor(subCluster.prefab); + } + + int selectedConnectNucleus = -1; + // Connect to another nucleus in the same cluster + protected virtual bool ConnectNucleus(ClusterPrefab cluster, Nucleus nucleusToConnect) { + if (cluster == null) + return false; + + IEnumerable synapseNuclei = this.currentNucleus.synapses + .Where(synapse => synapse.neuron != null) + .Select(synapse => synapse.neuron); + + IEnumerable nuclei = cluster.nuclei + .Except(synapseNuclei); + IEnumerable nucleiNames = nuclei + .Select(n => { + int idx = n.name.IndexOf(':'); + return idx < 0 ? n.name : n.name[..idx]; + }) + .Distinct(); + + string[] names = nucleiNames.ToArray(); + EditorGUILayout.BeginHorizontal(); + selectedConnectNucleus = EditorGUILayout.Popup(selectedConnectNucleus, names); + bool connecting = GUILayout.Button("Connect", GUILayout.Width(80)); + EditorGUILayout.EndHorizontal(); + if (connecting) { + Nucleus nucleus = nuclei.ElementAt(selectedConnectNucleus); + if (nucleus is IReceptor receptor) + receptor.AddArrayReceiver(this.currentNucleus); + else if (nucleus is Neuron neuron) + neuron.AddReceiver(this.currentNucleus); + else if (nucleus is Cluster subCluster) + subCluster.defaultOutput.AddReceiver(this.currentNucleus); + + } + return connecting; + } + + protected virtual void DeleteNucleus(Nucleus nucleus) { + if (nucleus == null) + return; + + if (nucleus is Neuron neuron) { + foreach (Nucleus receiver in neuron.receivers) { + if (receiver != null) { + this.currentNucleus = receiver; + break; + } + } + } + this.prefab.nuclei.Remove(nucleus); + + if (outputsField.value == nucleus.name) { + this.prefab.RefreshOutputs(); + outputsField.choices = this.prefab.outputs.Select(output => output.name).ToList(); + outputsField.index = 0; + } + + Neuron.Delete(nucleus); + + this.currentNucleus = this.prefab.output; + BuildLayers(); + } + + Nucleus.Type selectedType = Nucleus.Type.None; + protected virtual bool AddSynapse(ClusterPrefab cluster, Nucleus nucleus) { + if (cluster == null) + return false; + + EditorGUILayout.BeginHorizontal(); + selectedType = (Nucleus.Type)EditorGUILayout.EnumPopup(selectedType); + bool connecting = GUILayout.Button("Add", GUILayout.Width(80)); + EditorGUILayout.EndHorizontal(); + + if (connecting) { + AddInput(selectedType, this.currentNucleus); + } + return connecting; + // if (selectedType == Nucleus.Type.None) + // return false; + + // AddInput(selectedType, this.currentNucleus); + // return true; + } + + protected virtual void ChangeSynapse(Synapse synapse, Neuron newNucleus) { + Neuron synapseNeuron = synapse.neuron as Neuron; + if (synapse.neuron.parent is Cluster subCluster && subCluster.prefab != this.prefab) { + if (synapse.neuron.parent is ClusterReceptor receptor) { + // the new nucleus is part of a (cluster) receptor, + // so we have to change all synapses to this nucleus array elements + int oldNucleusIx = Cluster.GetNucleusIndex(subCluster.clusterNuclei, synapse.neuron); + int newNucleusIx = Cluster.GetNucleusIndex(subCluster.clusterNuclei, newNucleus); + foreach (Nucleus element in receptor.nucleiArray) { + if (element is not ClusterReceptor clusterReceptor) + continue; + // Get the same neuron as the synapse.nucleus in a different element + // of the ClusterReceptor array + Nucleus oldElementNucleus = clusterReceptor.clusterNuclei[oldNucleusIx]; + if (oldElementNucleus is not Neuron oldElementNeuron) + continue; + // Get the same neuron as newNucleus in a different element + // of the ClusterReceptor array + Nucleus newElementNucleus = clusterReceptor.clusterNuclei[newNucleusIx]; + if (newElementNucleus is not Neuron newElementNeuron) + continue; + + oldElementNeuron.RemoveReceiver(this.currentNucleus); + newElementNeuron.AddReceiver(this.currentNucleus); + // Now find the synapse which pointed to the old Neuron + // Synapse synapseForUpdate = this.currentNucleus.GetSynapse(oldElementNeuron); + // synapseForUpdate.nucleus = newElementNeuron; + } + } + else { + // it is a neuron in a subcluster + synapseNeuron.RemoveReceiver(this.currentNucleus); + newNucleus.AddReceiver(this.currentNucleus); + } + } + else { + synapseNeuron.RemoveReceiver(this.currentNucleus); + newNucleus.AddReceiver(this.currentNucleus); + } + } + + protected virtual void DisconnectNucleus(Neuron nucleus) { + if (this.currentNucleus.clusterPrefab == null) + return; + string[] names = this.currentNucleus.synapses.Select(synapse => synapse.neuron.name).ToArray(); + int selectedIndex = -1; + selectedIndex = EditorGUILayout.Popup("Disconnect from", selectedIndex, names); + if (selectedIndex >= 0 && selectedIndex < this.currentNucleus.clusterPrefab.nuclei.Count) { + Synapse synapse = this.currentNucleus.synapses[selectedIndex]; + Neuron synapseNeuron = synapse.neuron as Neuron; + synapseNeuron.RemoveReceiver(this.currentNucleus); + } + } + + #endregion Synapses + } + } - } - - public class NeuroidLayer { - public int ix = 0; - public List neuroids = new(); - } - + public class NeuroidLayer { + public int ix = 0; + public List neuroids = new(); + } + */ } \ No newline at end of file diff --git a/Editor/ClusterViewer.cs b/Editor/ClusterViewer.cs index 7edadff..510de71 100644 --- a/Editor/ClusterViewer.cs +++ b/Editor/ClusterViewer.cs @@ -10,16 +10,17 @@ namespace NanoBrain { public class ClusterViewer : Editor { public class GraphView : VisualElement { - readonly ClusterPrefab prefab; - SerializedObject serializedBrain; - Nucleus currentNucleus; - GameObject gameObject; + protected readonly ClusterPrefab prefab; + protected SerializedObject serializedBrain; + protected Nucleus currentNucleus; + protected GameObject gameObject; private List layers = new(); private readonly Dictionary neuroidPositions = new(); private bool expandArray = false; - ClusterPrefab prefabAsset; - readonly PopupField outputsField; + protected ClusterPrefab prefabAsset; + protected VisualElement outputContainer; + protected readonly PopupField outputsField; public GraphView(ClusterPrefab prefab) { this.prefab = prefab; @@ -35,7 +36,7 @@ namespace NanoBrain { graphContainer.focusable = true; Add(graphContainer); - VisualElement outputContainer = new() { + outputContainer = new() { style = { flexDirection = FlexDirection.Row, alignItems = Align.Center, @@ -108,7 +109,7 @@ namespace NanoBrain { //DrawInspector(inspectorContainer); } - private void BuildLayers() { + protected void BuildLayers() { // A temporary list to track what's been added to layers this.layers = new(); int layerIx = 0; @@ -534,4 +535,9 @@ namespace NanoBrain { } } + + public class NeuroidLayer { + public int ix = 0; + public List neuroids = new(); + } } \ No newline at end of file diff --git a/Runtime/Scripts/Core/Neuron.cs b/Runtime/Scripts/Core/Neuron.cs index fcadfea..9cc021f 100644 --- a/Runtime/Scripts/Core/Neuron.cs +++ b/Runtime/Scripts/Core/Neuron.cs @@ -61,16 +61,17 @@ namespace NanoBrain { /// /// The type of /// - public enum CurvePresets { + public enum ActivationFunction { Linear, Power, Sqrt, Reciprocal, + Tanh, Custom } [SerializeField] - public CurvePresets _curvePreset; - public CurvePresets curvePreset { + public ActivationFunction _curvePreset; + public ActivationFunction curvePreset { get { return _curvePreset; } set { _curvePreset = value; @@ -82,18 +83,21 @@ namespace NanoBrain { public AnimationCurve GenerateCurve() { switch (this.curvePreset) { - case CurvePresets.Linear: + case ActivationFunction.Linear: this.curveMax = 1; return Presets.Linear(1); - case CurvePresets.Power: + case ActivationFunction.Power: this.curveMax = 1; return Presets.Power(2.0f, 1); - case CurvePresets.Sqrt: + case ActivationFunction.Sqrt: this.curveMax = 1; return Presets.Power(0.5f, 1); - case CurvePresets.Reciprocal: + case ActivationFunction.Reciprocal: this.curveMax = 1 / 0.01f * 1; return Presets.Reciprocal(1); + case ActivationFunction.Tanh: + this.curveMax = 1; + return Presets.Tanh(1); default: this.curveMax = 1; return this.curve; @@ -142,6 +146,25 @@ namespace NanoBrain { } return curve; } + public static AnimationCurve Tanh(float weight) { + //int samples = 128; + float xMin = 0.001f; + float xMax = 1; + var keys = new Keyframe[samples]; + for (int i = 0; i < samples; i++) { + float t = i / (float)(samples - 1); + float x = Mathf.Lerp(xMin, xMax, t); + float y = MathF.Tanh(x * weight); + keys[i] = new Keyframe(x, y); + } + var curve = new AnimationCurve(keys); + for (int i = 0; i < curve.length; i++) { + AnimationUtility.SetKeyLeftTangentMode(curve, i, AnimationUtility.TangentMode.Linear); + AnimationUtility.SetKeyRightTangentMode(curve, i, AnimationUtility.TangentMode.Linear); + } + return curve; + + } } #endregion Serialization @@ -348,10 +371,11 @@ namespace NanoBrain { #if UNITY_MATHEMATICS public Func Activator => this.curvePreset switch { - CurvePresets.Linear => ActivatorLinear, - CurvePresets.Sqrt => ActivatorSqrt, - CurvePresets.Power => ActivatorPower, - CurvePresets.Reciprocal => ActivatorReciprocal, + ActivationFunction.Linear => ActivatorLinear, + ActivationFunction.Sqrt => ActivatorSqrt, + ActivationFunction.Power => ActivatorPower, + ActivationFunction.Reciprocal => ActivatorReciprocal, + ActivationFunction.Tanh => ActivatorTanh, _ => ActivatorCustom }; @@ -378,6 +402,12 @@ namespace NanoBrain { return result; } + protected float3 ActivatorTanh(float3 input) { + float magnitude = length(input); + float3 result = normalize(input) * MathF.Tanh(magnitude); + return result; + } + protected float3 ActivatorCustom(float3 input) { float activatedValue = this.curve.Evaluate(length(input)); float3 result = normalize(input) * activatedValue;