using System; using System.Collections.Generic; using UnityEngine; using Unity.Mathematics; using static Unity.Mathematics.math; [Serializable] public class Cluster : Nucleus { public string baseName { get { int colonPositon = this.name.IndexOf(':'); if (colonPositon < 0) return this.name; return this.name[..colonPositon]; } } #region Init public Cluster(ClusterPrefab prefab, Cluster parent) { this.prefab = prefab; this.name = prefab.name; this.parent = parent; this.parent?.clusterNuclei.Add(this); ClonePrefab(); _ = this.inputs; this.sortedNuclei = TopologicalSort(this.clusterNuclei); } public Cluster(ClusterPrefab prefab, ClusterPrefab parent = null) { this.prefab = prefab; this.name = prefab.name; this.clusterPrefab = parent; if (this.clusterPrefab != null) this.clusterPrefab.nuclei.Add(this); ClonePrefab(); _ = this.inputs; this.sortedNuclei = TopologicalSort(this.clusterNuclei); } private void ClonePrefab() { Nucleus[] prefabNuclei = this.prefab.nuclei.ToArray(); // first clone the nuclei without their connections foreach (Nucleus nucleus in this.prefab.nuclei) { nucleus.ShallowCloneTo(this); } Nucleus[] clonedNuclei = this.clusterNuclei.ToArray(); // Now clone the connections for (int nucleusIx = 0; nucleusIx < prefabNuclei.Length; nucleusIx++) { Nucleus prefabNucleus = prefabNuclei[nucleusIx]; if (prefabNucleus is not Neuron prefabNeuron) continue; Nucleus clonedNucleus = clonedNuclei[nucleusIx]; if (clonedNucleus == null || clonedNucleus is not Neuron clonedNeuron) continue; // Copy the receivers, which will also create the synapses // Clusters do not have receivers... foreach (Nucleus receiver in prefabNeuron.receivers.ToArray()) { int ix = GetNucleusIndex(prefabNuclei, receiver); if (ix < 0) continue; if (clonedNuclei[ix] is not Nucleus clonedReceiver) continue; // Find the synapse for the weight float weight = 1; foreach (Synapse synapse in receiver.synapses) { // Find the weight for this synapse if (synapse.neuron == prefabNucleus) { weight = synapse.weight; break; } } clonedNeuron.AddReceiver(clonedReceiver, weight); } } // Copy nucleus arrays for receptors for (int nucleusIx = 0; nucleusIx < prefabNuclei.Length; nucleusIx++) { Nucleus prefabNucleus = prefabNuclei[nucleusIx]; if (prefabNucleus is not IReceptor prefabReceptor) continue; if (prefabReceptor.nucleiArray == null || prefabReceptor.nucleiArray.Length == 0) continue; IReceptor clonedNucleus = clonedNuclei[nucleusIx] as IReceptor; if (prefabReceptor == prefabReceptor.nucleiArray[0]) { // We clone the array only for the first entry NucleusArray clonedArray = new(prefabReceptor.nucleiArray.Length, "array"); int arrayIx = 0; foreach (Nucleus prefabArrayNucleus in prefabReceptor.nucleiArray) { int arrayNucleusIx = GetNucleusIndex(prefabNuclei, prefabArrayNucleus); if (arrayNucleusIx >= 0) { Nucleus clonedArrayNucleus = clonedNuclei[arrayNucleusIx]; clonedArray.nuclei[arrayIx] = clonedArrayNucleus; } else { Debug.LogError($" Could not find prefab nucleus {prefabNucleus.name} in the clones"); } arrayIx++; } //clonedNucleus.array = clonedArray; clonedNucleus.nucleiArray = clonedArray.nuclei; } else { // The others will refer to the array created for the first nucleus in the array int firstNucleusIx = GetNucleusIndex(prefabNuclei, prefabReceptor.nucleiArray[0]); IReceptor clonedFirstNucleus = clonedNuclei[firstNucleusIx] as IReceptor; clonedNucleus.nucleiArray = clonedFirstNucleus.nucleiArray; } } foreach (Nucleus nucleus in this.clusterNuclei) { if (nucleus is Cluster clonedSubCluster) RestoreAllExternalReceivers(clonedSubCluster, this.prefab, this); } } // Sort the nuclei in a correct evaluation order private List TopologicalSort(List nodes) { Dictionary inDegree = new(); foreach (Nucleus node in nodes) inDegree[node] = 0; // Initialize in-degree to zero // Calculate in-degrees foreach (Nucleus node in nodes) { if (node is Cluster cluster) { foreach (Nucleus receiver in cluster.CollectReceivers()) inDegree[receiver]++; } else if (node is Neuron neuron) { foreach (Nucleus receiver in neuron.receivers) inDegree[receiver]++; } } Queue queue = new(); foreach (Nucleus node in nodes) { if (inDegree[node] == 0) // Nodes with no dependencies queue.Enqueue(node); } // The queue basically stores all input nuclei? List sortedOrder = new(); while (queue.Count > 0) { Nucleus current = queue.Dequeue(); sortedOrder.Add(current); // Process the node if (current is Neuron neuron) { foreach (Nucleus receiver in neuron.receivers) { inDegree[receiver]--; if (inDegree[receiver] == 0) // If all dependencies resolved queue.Enqueue(receiver); } } else if (current is Cluster cluster) { foreach (Nucleus receiver in cluster.CollectReceivers()) { inDegree[receiver]--; if (inDegree[receiver] == 0) // If all dependencies resolved queue.Enqueue(receiver); } } } // Check for cycles in the graph if (sortedOrder.Count != nodes.Count) throw new InvalidOperationException("Graph is not a DAG; a cycle exists."); return sortedOrder; } public override Nucleus Clone(ClusterPrefab parent) { Cluster clone = new(this.prefab, parent); foreach (Synapse synapse in this.synapses) { Synapse clonedSynapse = clone.AddSynapse(synapse.neuron); clonedSynapse.weight = synapse.weight; } foreach (Neuron output in this.outputs) { foreach (Nucleus receiver in output.receivers) { int ix = GetNucleusIndex(this.clusterNuclei.ToArray(), output); if (ix < 0) continue; if (clone.clusterNuclei[ix] is not Neuron clonedOutput) continue; clonedOutput.AddReceiver(receiver); } } return clone; } public override Nucleus ShallowCloneTo(Cluster parent) { Cluster clone = new(this.prefab, parent) { name = this.name, clusterPrefab = this.clusterPrefab, }; return clone; } private static void RestoreAllExternalReceivers(Cluster clonedCluster, ClusterPrefab prefabParent, Cluster clonedParent) { int clonedClusterIx = GetNucleusIndex(clonedParent.clusterNuclei, clonedCluster); if (prefabParent.nuclei[clonedClusterIx] is not Cluster sourceCluster) return; for (int nucleusIx = 0; nucleusIx < sourceCluster.clusterNuclei.Count; nucleusIx++) { Nucleus sourceNucleus = sourceCluster.clusterNuclei[nucleusIx]; if (sourceNucleus is not Neuron sourceNeuron) continue; if (clonedCluster.clusterNuclei[nucleusIx] is not Neuron clonedNeuron) continue; // copy the receivers (and thus synapses) from the source to the clone foreach (Nucleus receiver in sourceNeuron.receivers) { int ix = GetNucleusIndex(prefabParent.nuclei, receiver); if (ix < 0 || ix >= clonedParent.clusterNuclei.Count) continue; Nucleus clonedReceiver = clonedParent.clusterNuclei[ix]; // Find the synapse for the weight float weight = 1; foreach (Synapse synapse in receiver.synapses) { // Find the weight for this synapse if (synapse.neuron == sourceNucleus) { weight = synapse.weight; break; } } clonedNeuron.AddReceiver(clonedReceiver, weight); // Debug.Log($"external: {clonedReceiver.name} receives from {clonedNeuron.name} {clonedNeuron.GetHashCode()}"); } } } protected int GetNucleusIndex(Nucleus[] nuclei, Nucleus nucleus) { for (int i = 0; i < nuclei.Length; i++) { if (nucleus == nuclei[i]) return i; } return -1; } public static int GetNucleusIndex(List nuclei, Nucleus nucleus) { int i = 0; foreach (Nucleus nucleiElement in nuclei) { //for (int i = 0; i < nuclei.Length; i++) { if (nucleus == nucleiElement) return i; i++; } return -1; } #endregion Init public ClusterPrefab prefab; [SerializeReference] public List clusterNuclei = new(); // the nuclei sorted using topological sorting // to ensure that the cluster is computer in the right order public List sortedNuclei; //public Dictionary nucleiDict = new(); public List _inputs = null; public virtual List inputs { get { if (this._inputs == null) { this._inputs = new(); foreach (Nucleus nucleus in this.clusterNuclei) { // inputs have no synapses if (nucleus.synapses.Count == 0) this._inputs.Add(nucleus); } ComputeOrders(); } return this._inputs; } } public Dictionary> computeOrders = new(); private void ComputeOrders() { foreach (Nucleus input in this._inputs) computeOrders[input] = TopologicalSort2(input); } private List TopologicalSort2(Nucleus startNode) { Dictionary inDegree = new(); HashSet visited = new(); // Initialize in-degrees and mark all nodes as unvisited foreach (Nucleus node in this.clusterNuclei) inDegree[node] = 0; // Calculate in-degrees for all nodes reachable from the start node Queue queue = new Queue(); queue.Enqueue(startNode); visited.Add(startNode); while (queue.Count > 0) { Nucleus current = queue.Dequeue(); List receivers = null; if (current is Neuron neuron) receivers = neuron.receivers; else if (current is Cluster cluster) receivers = cluster.CollectReceivers(); // if (current is Neuron neuron) { foreach (Nucleus receiver in receivers) { if (!visited.Contains(receiver)) { visited.Add(receiver); queue.Enqueue(receiver); } inDegree[receiver]++; } // } } // Perform topological sort on all reachable nodes queue.Clear(); foreach (Nucleus node in visited) { if (inDegree[node] == 0) queue.Enqueue(node); } List sortedOrder = new List(); while (queue.Count > 0) { Nucleus current = queue.Dequeue(); sortedOrder.Add(current); // Process the node List receivers = null; if (current is Neuron neuron) receivers = neuron.receivers; else if (current is Cluster cluster) receivers = cluster.CollectReceivers(); //if (current is Neuron neuron) { foreach (Nucleus receiver in receivers) { if (visited.Contains(receiver)) { inDegree[receiver]--; if (inDegree[receiver] == 0) // If all dependencies resolved queue.Enqueue(receiver); } } //} } // Check for cycles in the graph if (sortedOrder.Count != visited.Count) throw new InvalidOperationException("Graph is not a DAG; a cycle exists."); return sortedOrder; } public virtual Neuron defaultOutput {//=> this.nuclei[0] as Nucleus; get { if (this.clusterNuclei.Count > 0) return this.clusterNuclei[0] as Neuron; return null; } } protected List _outputs = null; public List outputs { get { if (this._outputs == null) { this._outputs = new(); foreach (Nucleus nucleus in this.clusterNuclei) { if (nucleus is Neuron neuron) // && neuron.receivers.Count == 0) this._outputs.Add(neuron); } } return this._outputs; } } public bool TryGetNucleus(string nucleusName, out Nucleus foundNucleus) { foreach (Nucleus receptor in this.clusterNuclei) { if (receptor is Nucleus nucleus) if (nucleus.name == nucleusName) { foundNucleus = nucleus; return true; } } foundNucleus = null; return false; } public Nucleus GetNucleus(string nucleusName) { int dotPosition = nucleusName.IndexOf('.'); if (dotPosition >= 0) { string clusterName = nucleusName[..dotPosition]; string clusterName0 = clusterName + ": 0"; foreach (Nucleus nucleus in this.clusterNuclei) { if (nucleus is Cluster cluster) { if (cluster.name == clusterName || cluster.name == clusterName0) { string subNucleusName = nucleusName[(dotPosition + 1)..]; return cluster.GetNucleus(subNucleusName); } } } return null; } else { string nucleusName0 = nucleusName + ": 0"; foreach (Nucleus nucleus in this.clusterNuclei) { if (nucleus is IReceptor receptor) { if (nucleus.name == nucleusName | nucleus.name == nucleusName0) return nucleus; } else if (nucleus.name == nucleusName) return nucleus; } return null; } } // [Obsolete("Use GetNucleus instead")] // public IReceptor GetReceptor(string receptorName) { // return GetNucleus(receptorName) as IReceptor; // } #region Receivers public virtual List CollectReceivers() { List receivers = new(); foreach (Neuron output in this.outputs) { foreach (Nucleus receiver in output.receivers) { // Only add receivers outside this cluster if (receiver.clusterPrefab != this.prefab) receivers.Add(receiver); //receivers.AddRange(output.receivers); } } return receivers; } #endregion Receivers #region Update public void UpdateFromNucleus(Nucleus startNucleus) { // no bias+synapse input state calculation for now... if (this.computeOrders.ContainsKey(startNucleus) == false) { //Debug.LogError($"{this.name} compute orders does not contain an order for {startNucleus.name}"); return; } List computeOrder = this.computeOrders[startNucleus]; if (startNucleus.trace) Debug.Log($"Update from {startNucleus.name}"); foreach (Nucleus nucleus in computeOrder) { nucleus.UpdateStateIsolated(); if (startNucleus.trace && nucleus is Neuron neuron) Debug.Log($" {nucleus.name}[{nucleus.GetHashCode()}] = {neuron.outputValue}"); } // continue in parent this.parent?.UpdateFromNucleus(this); UpdateNuclei(); } public override void UpdateStateIsolated() { throw new Exception("Cluster should not be updated!"); // float3 sum = this.bias; // //Applying the weight factors // foreach (Synapse synapse in this.synapses) { // if (lengthsq(synapse.neuron.outputValue) > 0) { // sum += synapse.weight * synapse.neuron.outputValue; // } // } // foreach (Nucleus nucleus in this.sortedNuclei) // nucleus.UpdateStateIsolated(); // UpdateNuclei(); } public override void UpdateNuclei() { foreach (Nucleus nucleus in this.clusterNuclei) nucleus.UpdateNuclei(); } #endregion Update }