743 lines
30 KiB
C#

using System;
using System.Collections.Generic;
using UnityEngine;
#if UNITY_MATHEMATICS
using Unity.Mathematics;
using static Unity.Mathematics.math;
#endif
namespace NanoBrain {
/// <summary>
/// A Cluster combines a collection of Nuclei to implement reusable behaviour
/// </summary>
/// A Cluster is an instantiation of a ClusterPrefab.
/// Clusters can be nested inside other clusters.
[Serializable]
public class Cluster : Nucleus {
// It may be that clusters will not be nuclei anymore in the future....
/// <summary>
/// The base name of the cluster. I don't think this is actively used at this moment
/// </summary>
public string baseName {
get {
int colonPositon = this.name.IndexOf(':');
if (colonPositon < 0)
return this.name;
return this.name[..colonPositon];
}
}
// This should not be serialized
[SerializeReference]
public Cluster[] siblingClusters;
// This serialization should be enough
public int instanceCount = 1;
public Dictionary<int, Cluster> thingClusters = new();
#region Init
/// <summary>
/// Instantiate a new copy of a ClusterPrefab in the given parent
/// </summary>
/// <param name="prefab">The prefab to use</param>
/// <param name="parent">The cluster in which this new cluster will be placed</param>
public Cluster(ClusterPrefab prefab, Cluster parent) {
this.prefab = prefab;
this.name = prefab.name;
this.parent = parent;
this.parent?.clusterNuclei.Add(this);
ClonePrefab();
_ = this.inputs;
this.sortedNuclei = TopologicalSort(this.clusterNuclei);
}
/// <summary>
/// Add a new cluster to a ClusterPrefab
/// </summary>
/// <param name="prefab">The prefab to copy</param>
/// <param name="parent">The prefab in which the new copy is placed</param>
public Cluster(ClusterPrefab prefab, ClusterPrefab parent = null) {
this.prefab = prefab;
this.name = prefab.name;
this.clusterPrefab = parent;
if (this.clusterPrefab != null)
this.clusterPrefab.nuclei.Add(this);
ClonePrefab();
_ = this.inputs;
this.sortedNuclei = TopologicalSort(this.clusterNuclei);
}
/// <summary>
/// Clone a prefab.
/// </summary>
/// Strange that this does not take any parameters or return values.
/// Where which the clone be found???
private void ClonePrefab() {
Nucleus[] prefabNuclei = this.prefab.nuclei.ToArray();
// first clone the nuclei without their connections
foreach (Nucleus nucleus in this.prefab.nuclei) {
nucleus.ShallowCloneTo(this);
}
Nucleus[] clonedNuclei = this.clusterNuclei.ToArray();
// Now clone the connections
for (int nucleusIx = 0; nucleusIx < prefabNuclei.Length; nucleusIx++) {
Nucleus prefabNucleus = prefabNuclei[nucleusIx];
if (prefabNucleus is not Neuron prefabNeuron)
continue;
Nucleus clonedNucleus = clonedNuclei[nucleusIx];
if (clonedNucleus == null || clonedNucleus is not Neuron clonedNeuron)
continue;
// Copy the receivers, which will also create the synapses
// Clusters do not have receivers...
foreach (Nucleus receiver in prefabNeuron.receivers.ToArray()) {
int ix = GetNucleusIndex(prefabNuclei, receiver);
if (ix < 0)
continue;
if (clonedNuclei[ix] is not Nucleus clonedReceiver)
continue;
// Find the synapse for the weight
float weight = 1;
foreach (Synapse synapse in receiver.synapses) {
// Find the weight for this synapse
if (synapse.neuron == prefabNucleus) {
weight = synapse.weight;
break;
}
}
clonedNeuron.AddReceiver(clonedReceiver, weight);
}
}
// Copy the siblings for clusters
for (int nucleusIx = 0; nucleusIx < prefabNuclei.Length; nucleusIx++) {
Nucleus prefabNucleus = prefabNuclei[nucleusIx];
if (prefabNucleus is not Cluster prefabCluster)
continue;
if (prefabCluster.siblingClusters == null || prefabCluster.siblingClusters.Length == 0)
continue;
Cluster clonedNucleus = clonedNuclei[nucleusIx] as Cluster;
if (prefabCluster == prefabCluster.siblingClusters[0]) {
// We clone the array only for the first entry
//NucleusArray clonedArray = new(prefabReceptor.nucleiArray.Length);
Cluster[] clonedArray = new Cluster[prefabCluster.siblingClusters.Length];
int arrayIx = 0;
foreach (Cluster prefabArrayNucleus in prefabCluster.siblingClusters) {
int arrayNucleusIx = GetNucleusIndex(prefabNuclei, prefabArrayNucleus);
if (arrayNucleusIx >= 0) {
Cluster clonedArrayNucleus = clonedNuclei[arrayNucleusIx] as Cluster;
clonedArray[arrayIx] = clonedArrayNucleus;
}
else {
Debug.LogError($" Could not find prefab nucleus {prefabNucleus.name} in the clones");
}
arrayIx++;
}
clonedNucleus.siblingClusters = clonedArray;
}
else {
// The others will refer to the array created for the first nucleus in the array
int firstNucleusIx = GetNucleusIndex(prefabNuclei, prefabCluster.siblingClusters[0]);
Cluster clonedFirstNucleus = clonedNuclei[firstNucleusIx] as Cluster;
clonedNucleus.siblingClusters = clonedFirstNucleus.siblingClusters;
}
}
// Collect the subclusters
List<Cluster> subClusters = new();
foreach (Nucleus nucleus in prefabNuclei) {
foreach (Synapse synapse in nucleus.synapses) {
Nucleus synapseNucleus = synapse.neuron;
if (synapseNucleus is not Cluster subCluster)
continue;
if (subClusters.Contains(subCluster))
continue;
subClusters.Add(subCluster);
}
}
// Create the subcluster instances
foreach (Cluster subCluster in subClusters) {
for (int ix = 0; ix < subCluster.instanceCount; ix++) {
// create the new instance
Cluster clusterInstance = new(subCluster.prefab);
// connect it
foreach ((Neuron sender, Nucleus receiver) in subCluster.CollectConnections()) {
int receiverIx = GetNucleusIndex(prefabNuclei, receiver);
if (receiverIx < 0)
continue;
if (clonedNuclei[receiverIx] is not Nucleus clonedReceiver)
continue;
// Find the synapse for the weight
float weight = 1;
foreach (Synapse synapse in receiver.synapses) {
// Find the weight for this synapse
if (synapse.neuron == sender) {
weight = synapse.weight;
break;
}
}
if (clusterInstance.GetNucleus(sender.name) is not Neuron clonedSender)
continue;
clonedSender.AddReceiver(clonedReceiver, weight);
}
}
}
foreach (Nucleus nucleus in this.clusterNuclei) {
if (nucleus is Cluster clonedSubCluster)
RestoreAllExternalReceivers(clonedSubCluster, this.prefab, this);
}
}
/// <summary>
/// Sort the nuclei in a correct evaluation order
/// </summary>
/// <param name="nodes"></param>
/// <returns></returns>
/// <exception cref="InvalidOperationException"></exception>
private List<Nucleus> TopologicalSort(List<Nucleus> nodes) {
Dictionary<Nucleus, int> inDegree = new();
foreach (Nucleus node in nodes)
inDegree[node] = 0; // Initialize in-degree to zero
// Calculate in-degrees
foreach (Nucleus node in nodes) {
if (node is Cluster cluster) {
foreach (Nucleus receiver in cluster.CollectReceivers())
inDegree[receiver]++;
}
else if (node is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers)
inDegree[receiver]++;
}
}
Queue<Nucleus> queue = new();
foreach (Nucleus node in nodes) {
if (inDegree[node] == 0) // Nodes with no dependencies
queue.Enqueue(node);
}
// The queue basically stores all input nuclei?
List<Nucleus> sortedOrder = new();
while (queue.Count > 0) {
Nucleus current = queue.Dequeue();
sortedOrder.Add(current); // Process the node
if (current is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers) {
inDegree[receiver]--;
if (inDegree[receiver] == 0) // If all dependencies resolved
queue.Enqueue(receiver);
}
}
else if (current is Cluster cluster) {
foreach (Nucleus receiver in cluster.CollectReceivers()) {
inDegree[receiver]--;
if (inDegree[receiver] == 0) // If all dependencies resolved
queue.Enqueue(receiver);
}
}
}
// Check for cycles in the graph
if (sortedOrder.Count != nodes.Count)
throw new InvalidOperationException("Graph is not a DAG; a cycle exists.");
return sortedOrder;
}
public override Nucleus Clone(ClusterPrefab parent) {
Cluster clone = new(this.prefab, parent);
foreach (Synapse synapse in this.synapses) {
Synapse clonedSynapse = clone.AddSynapse(synapse.neuron);
clonedSynapse.weight = synapse.weight;
}
foreach (Nucleus nucleus in this.clusterNuclei) {
if (nucleus is Neuron output) {
foreach (Nucleus receiver in output.receivers) {
int ix = GetNucleusIndex(this.clusterNuclei, output);
Debug.Log($"{output.name} -> {receiver.name}: {ix}");
if (ix < 0)
continue;
if (clone.clusterNuclei[ix] is not Neuron clonedOutput)
continue;
clonedOutput.AddReceiver(receiver);
}
}
}
return clone;
}
public override Nucleus ShallowCloneTo(Cluster parent) {
// Clusters should not be cloned, but instantiated from the prefab....
Cluster clone = new(this.prefab, parent) {
name = this.name,
clusterPrefab = this.clusterPrefab,
};
// Somehow siblingClusters should be cloned too. Believe I do this in ClonePrefab right now.
return clone;
}
private static void RestoreAllExternalReceivers(Cluster clonedCluster, ClusterPrefab prefabParent, Cluster clonedParent) {
int clonedClusterIx = GetNucleusIndex(clonedParent.clusterNuclei, clonedCluster);
if (prefabParent.nuclei[clonedClusterIx] is not Cluster sourceCluster)
return;
for (int nucleusIx = 0; nucleusIx < sourceCluster.clusterNuclei.Count; nucleusIx++) {
Nucleus sourceNucleus = sourceCluster.clusterNuclei[nucleusIx];
if (sourceNucleus is not Neuron sourceNeuron)
continue;
if (clonedCluster.clusterNuclei[nucleusIx] is not Neuron clonedNeuron)
continue;
// copy the receivers (and thus synapses) from the source to the clone
foreach (Nucleus receiver in sourceNeuron.receivers) {
int ix = GetNucleusIndex(prefabParent.nuclei, receiver);
if (ix < 0 || ix >= clonedParent.clusterNuclei.Count)
continue;
Nucleus clonedReceiver = clonedParent.clusterNuclei[ix];
// Find the synapse for the weight
float weight = 1;
foreach (Synapse synapse in receiver.synapses) {
// Find the weight for this synapse
if (synapse.neuron == sourceNucleus) {
weight = synapse.weight;
break;
}
}
clonedNeuron.AddReceiver(clonedReceiver, weight);
// Debug.Log($"external: {clonedReceiver.name} receives from {clonedNeuron.name} {clonedNeuron.GetHashCode()}");
}
}
}
protected int GetNucleusIndex(Nucleus[] nuclei, Nucleus nucleus) {
for (int i = 0; i < nuclei.Length; i++) {
if (nucleus == nuclei[i])
return i;
}
return -1;
}
public static int GetNucleusIndex(List<Nucleus> nuclei, Nucleus nucleus) {
int i = 0;
foreach (Nucleus nucleiElement in nuclei) {
//for (int i = 0; i < nuclei.Length; i++) {
if (nucleus == nucleiElement)
return i;
i++;
}
return -1;
}
#endregion Init
#region Cluster Array
public void AddInstance() {
this.instanceCount++;
}
public void AddInstance(ClusterPrefab prefab) {
// Ensure siblingClusters exists
if (this.siblingClusters == null || this.siblingClusters.Length == 0)
this.siblingClusters = new Cluster[1] { this };
// Prepare the new array
int newLength = this.siblingClusters.Length + 1;
Cluster[] newSiblings = new Cluster[newLength];
for (int i = 0; i < newSiblings.Length - 1; i++)
newSiblings[i] = this.siblingClusters[i];
Cluster newCluster = this.Clone(prefab) as Cluster;
string baseName = this.name;
int colonPos = baseName.IndexOf(":");
if (colonPos > 0)
baseName = baseName[..colonPos];
newCluster.name = $"{baseName}: {newLength - 1}";
newSiblings[newLength - 1] = newCluster;
// All siblingClusters need to user this array!
foreach (Cluster sibling in newSiblings)
sibling.siblingClusters = newSiblings;
}
public void RemoveInstance() {
if (instanceCount > 1)
instanceCount--;
else {
if (this.siblingClusters == null || this.siblingClusters.Length <= 1)
return;
// Prepare the new array
int newLength = this.siblingClusters.Length - 1;
Cluster[] newClusters = new Cluster[newLength];
for (int i = 0; i < newLength; i++)
newClusters[i] = this.siblingClusters[i];
Neuron.Delete(this.siblingClusters[^1]);
this.siblingClusters = newClusters;
}
}
public virtual Cluster GetThingCluster() {
Cluster selectedCluster = SelectCluster();
return selectedCluster;
}
public virtual Cluster GetThingCluster(int thingId, string thingName = null) {
if (thingClusters.TryGetValue(thingId, out Cluster cluster))
return cluster;
Cluster selectedCluster = SelectCluster();
thingClusters[thingId] = selectedCluster;
return selectedCluster;
}
private Cluster SelectCluster() {
if (this.siblingClusters == null)
return this;
// Find a sleeping cluster
foreach (Cluster cluster in this.siblingClusters) {
if (cluster.defaultOutput.isSleeping) {
RemoveThingCluster(cluster);
return cluster;
}
}
// Otherwise find longest unused cluster
Cluster unusedCluster = this.siblingClusters[0];
for (int ix = 1; ix < this.siblingClusters.Length; ix++) {
if (this.siblingClusters[ix].defaultOutput.lastUpdate < unusedCluster.defaultOutput.lastUpdate)
unusedCluster = this.siblingClusters[ix];
}
RemoveThingCluster(unusedCluster);
return unusedCluster;
}
private void RemoveThingCluster(Cluster cluster) {
List<int> keysToRemove = new();
foreach (KeyValuePair<int, Cluster> kvp in thingClusters) {
if (kvp.Value == cluster)
keysToRemove.Add(kvp.Key);
}
foreach (int thingId in keysToRemove)
thingClusters.Remove(thingId);
}
public bool SameSiblingsAs(Cluster[] otherSiblingClusters) {
if (this.siblingClusters == null)
return false;
for (int ix = 0; ix < this.siblingClusters.Length; ix++) {
if (this.siblingClusters[ix] != otherSiblingClusters[ix])
return false;
}
return true;
}
public void AddArrayReceiver(Nucleus receiverToAdd, float weight = 1) {
foreach (Cluster cluster in this.siblingClusters) {
cluster.defaultOutput.AddReceiver(receiverToAdd, weight);
}
}
#endregion ClusterArray
public ClusterPrefab prefab;
[SerializeReference]
public List<Nucleus> clusterNuclei = new();
// the nuclei sorted using topological sorting
// to ensure that the cluster is computer in the right order
public List<Nucleus> sortedNuclei;
//public Dictionary<string, Nucleus> nucleiDict = new();
public List<Nucleus> _inputs = null;
public virtual List<Nucleus> inputs {
get {
if (this._inputs == null) {
this._inputs = new();
foreach (Nucleus nucleus in this.clusterNuclei) {
// inputs have no synapses
if (nucleus.synapses.Count == 0)
this._inputs.Add(nucleus);
}
ComputeOrders();
}
return this._inputs;
}
}
public Dictionary<Nucleus, List<Nucleus>> computeOrders = new();
private void ComputeOrders() {
foreach (Nucleus input in this._inputs)
computeOrders[input] = TopologicalSort2(input);
}
private List<Nucleus> TopologicalSort2(Nucleus startNode) {
Dictionary<Nucleus, int> inDegree = new();
HashSet<Nucleus> visited = new();
// Initialize in-degrees and mark all nodes as unvisited
foreach (Nucleus node in this.clusterNuclei)
inDegree[node] = 0;
// Calculate in-degrees for all nodes reachable from the start node
Queue<Nucleus> queue = new Queue<Nucleus>();
queue.Enqueue(startNode);
visited.Add(startNode);
while (queue.Count > 0) {
Nucleus current = queue.Dequeue();
List<Nucleus> receivers = null;
if (current is Neuron neuron)
receivers = neuron.receivers;
else if (current is Cluster cluster)
receivers = cluster.CollectReceivers();
// if (current is Neuron neuron) {
foreach (Nucleus receiver in receivers) {
if (!visited.Contains(receiver)) {
visited.Add(receiver);
queue.Enqueue(receiver);
}
inDegree[receiver]++;
}
// }
}
// Perform topological sort on all reachable nodes
queue.Clear();
foreach (Nucleus node in visited) {
if (inDegree[node] == 0)
queue.Enqueue(node);
}
List<Nucleus> sortedOrder = new List<Nucleus>();
while (queue.Count > 0) {
Nucleus current = queue.Dequeue();
sortedOrder.Add(current); // Process the node
List<Nucleus> receivers = null;
if (current is Neuron neuron)
receivers = neuron.receivers;
else if (current is Cluster cluster)
receivers = cluster.CollectReceivers();
//if (current is Neuron neuron) {
foreach (Nucleus receiver in receivers) {
if (visited.Contains(receiver)) {
inDegree[receiver]--;
if (inDegree[receiver] == 0) // If all dependencies resolved
queue.Enqueue(receiver);
}
}
//}
}
// Check for cycles in the graph
if (sortedOrder.Count != visited.Count)
throw new InvalidOperationException("Graph is not a DAG; a cycle exists.");
return sortedOrder;
}
public virtual Neuron defaultOutput {//=> this.nuclei[0] as Nucleus;
get {
if (this.clusterNuclei.Count > 0)
return this.clusterNuclei[0] as Neuron;
return null;
}
}
protected List<Neuron> _outputs = null;
public List<Neuron> outputs {
get {
if (this._outputs == null) {
this._outputs = new();
foreach (Nucleus nucleus in this.clusterNuclei) {
if (nucleus is Neuron neuron) // && neuron.receivers.Count == 0)
this._outputs.Add(neuron);
}
}
return this._outputs;
}
}
public bool TryGetNucleus(string nucleusName, out Nucleus foundNucleus) {
foreach (Nucleus receptor in this.clusterNuclei) {
if (receptor is Nucleus nucleus)
if (nucleus.name == nucleusName) {
foundNucleus = nucleus;
return true;
}
}
foundNucleus = null;
return false;
}
public Nucleus GetNucleus(string nucleusName) {
int dotPosition = nucleusName.IndexOf('.');
if (dotPosition >= 0) {
string clusterName = nucleusName[..dotPosition];
string clusterName0 = clusterName + ": 0";
foreach (Nucleus nucleus in this.clusterNuclei) {
if (nucleus is Cluster cluster) {
if (cluster.name == clusterName || cluster.name == clusterName0) {
string subNucleusName = nucleusName[(dotPosition + 1)..];
return cluster.GetNucleus(subNucleusName);
}
}
}
return null;
}
else {
string nucleusName0 = nucleusName + ": 0";
foreach (Nucleus nucleus in this.clusterNuclei) {
if (nucleus is Cluster) { //IReceptor receptor) {
if (nucleus.name == nucleusName | nucleus.name == nucleusName0)
return nucleus;
}
else if (nucleus.name == nucleusName)
return nucleus;
}
return null;
}
}
#region Receivers
public virtual List<Nucleus> CollectReceivers() {
List<Nucleus> receivers = new();
foreach (Nucleus outputNucleus in this.clusterNuclei) {
if (outputNucleus is not Neuron output)
continue;
foreach (Nucleus receiver in output.receivers) {
// Only add receivers outside this cluster
if (receiver.clusterPrefab != this.prefab)
receivers.Add(receiver);
}
}
return receivers;
}
public List<(Neuron, Nucleus)> CollectConnections() {
List<(Neuron, Nucleus)> connections = new();
foreach (Nucleus outputNucleus in this.clusterNuclei) {
if (outputNucleus is not Neuron output)
continue;
foreach (Nucleus receiver in output.receivers) {
// Only add receivers outside this cluster
if (receiver.clusterPrefab != this.prefab)
connections.Add((output, receiver));
}
}
return connections;
}
public void MoveReceivers(Cluster newCluster) {
Debug.Log($"Move receivers for {this.name} to {newCluster.name}");
foreach (Nucleus outputNucleus in this.clusterNuclei) {
if (outputNucleus is not Neuron output)
continue;
// Find the existing output in the new cluster
if (newCluster.GetNucleus(output.name) is not Neuron newOutput) {
Debug.LogWarning($"Could not find output {this.name}.{output.name} in {newCluster.name}");
continue;
}
Debug.Log($"Check {this.name}.{output.name} receivers");
Nucleus[] receivers = output.receivers.ToArray();
foreach (Nucleus receiver in receivers) {
if (receiver.clusterPrefab != this.prefab) {
// Replace synapse with new synapse
// to the new cluster
Debug.Log($"move {receiver.name} from {this.name}.{output.name} to {newCluster.name}.{newOutput.name}");
Synapse synapse = receiver.GetSynapse(output);
newOutput.AddReceiver(receiver, synapse.weight);
output.RemoveReceiver(receiver);
}
}
}
}
#endregion Receivers
#region Update
public void UpdateFromNucleus(Nucleus startNucleus) {
// no bias+synapse input state calculation for now...
if (this.computeOrders.ContainsKey(startNucleus) == false) {
//Debug.LogError($"{this.name} compute orders does not contain an order for {startNucleus.name}");
return;
}
List<Nucleus> computeOrder = this.computeOrders[startNucleus];
if (startNucleus.trace)
Debug.Log($"Update from {startNucleus.name}");
foreach (Nucleus nucleus in computeOrder) {
if (nucleus is not Cluster) {
nucleus.UpdateStateIsolated();
if (startNucleus.trace && nucleus is Neuron neuron)
Debug.Log($" {nucleus.name}[{nucleus.GetHashCode()}]"); // = {neuron.outputValue}");
}
}
// continue in parent
this.parent?.UpdateFromNucleus(this);
UpdateNuclei();
}
public override void UpdateStateIsolated() {
throw new Exception("Cluster should not be updated!");
}
public override void UpdateNuclei() {
foreach (Nucleus nucleus in this.clusterNuclei)
nucleus.UpdateNuclei();
}
#endregion Update
}
}