424 lines
14 KiB
C#

using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using Unity.Mathematics;
using static Unity.Mathematics.math;
[Serializable]
public class Cluster : Nucleus {
#region Init
public Cluster(ClusterPrefab prefab, Cluster parent) {
this.prefab = prefab;
this.name = prefab.name;
this.parent = parent;
this.parent?.clusterNuclei.Add(this);
ClonePrefab();
_ = this.inputs;
this.sortedNuclei = TopologicalSort(this.clusterNuclei);
}
public Cluster(ClusterPrefab prefab, ClusterPrefab parent = null) {
this.prefab = prefab;
this.name = prefab.name;
this.clusterPrefab = parent;
if (this.clusterPrefab != null)
this.clusterPrefab.nuclei.Add(this);
ClonePrefab();
_ = this.inputs;
this.sortedNuclei = TopologicalSort(this.clusterNuclei);
}
private void ClonePrefab() {
Nucleus[] prefabNuclei = this.prefab.nuclei.ToArray();
// first clone the nuclei without their connections
foreach (Nucleus nucleus in this.prefab.nuclei)
nucleus.ShallowCloneTo(this);
Nucleus[] clonedNuclei = this.clusterNuclei.ToArray();
// Now clone the connections
for (int nucleusIx = 0; nucleusIx < prefabNuclei.Length; nucleusIx++) {
Nucleus prefabNucleus = prefabNuclei[nucleusIx];
Nucleus clonedReceptor = clonedNuclei[nucleusIx];
if (clonedReceptor == null)
continue;
// Copy the receivers, which will also create the synapses
// Clusters do not have receivers...
// foreach (Nucleus receiver in prefabNucleus.receivers) {
// int ix = GetNucleusIndex(prefabNuclei, receiver);
// if (ix < 0)
// continue;
// if (clonedNuclei[ix] is not Nucleus clonedReceiver)
// continue;
// // Find the synapse for the weight
// float weight = 1;
// foreach (Synapse synapse in receiver.synapses) {
// // Find the weight for this synapse
// if (synapse.nucleus == prefabNucleus) {
// weight = synapse.weight;
// break;
// }
// }
// clonedReceptor.AddReceiver(clonedReceiver, weight);
// }
}
// Copy nucleus arrays
for (int nucleusIx = 0; nucleusIx < prefabNuclei.Length; nucleusIx++) {
Nucleus prefabReceptor = prefabNuclei[nucleusIx];
if (prefabReceptor is not Receptor prefabNucleus)
continue;
if (prefabNucleus.array == null || prefabNucleus.array.nuclei == null || prefabNucleus.array.nuclei.Length == 0)
continue;
Receptor clonedNucleus = clonedNuclei[nucleusIx] as Receptor;
if (prefabNucleus == prefabNucleus.array.nuclei[0]) {
// We clone the array only for the first entry
NucleusArray clonedArray = new(prefabNucleus.array.nuclei.Length, "array");
int arrayIx = 0;
foreach (Nucleus prefabArrayNucleus in prefabNucleus.array.nuclei) {
int arrayNucleusIx = GetNucleusIndex(prefabNuclei, prefabArrayNucleus);
if (arrayNucleusIx >= 0) {
Nucleus clonedArrayNucleus = clonedNuclei[arrayNucleusIx];
clonedArray.nuclei[arrayIx] = clonedArrayNucleus;
}
else {
Debug.LogError($" Could not find prefab nuclues {prefabNucleus.name} in the clones");
}
arrayIx++;
}
clonedNucleus.array = clonedArray;
}
else {
// The others will refer to the array created for the first nucleus in the array
int firstNucleusIx = GetNucleusIndex(prefabNuclei, prefabNucleus.array.nuclei[0]);
Receptor clonedFirstNucleus = clonedNuclei[firstNucleusIx] as Receptor;
clonedNucleus.array = clonedFirstNucleus.array;
}
}
}
// Sort the nuclei in a correct evaluation order
private List<Nucleus> TopologicalSort(List<Nucleus> nodes) {
Dictionary<Nucleus, int> inDegree = new();
foreach (Nucleus node in nodes)
inDegree[node] = 0; // Initialize in-degree to zero
// Calculate in-degrees
foreach (Nucleus node in nodes) {
if (node is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers)
inDegree[receiver]++;
}
}
Queue<Nucleus> queue = new();
foreach (Nucleus node in nodes) {
if (inDegree[node] == 0) // Nodes with no dependencies
queue.Enqueue(node);
}
// The queue basically stores all input nuclei?
List<Nucleus> sortedOrder = new();
while (queue.Count > 0) {
Nucleus current = queue.Dequeue();
sortedOrder.Add(current); // Process the node
if (current is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers) {
inDegree[receiver]--;
if (inDegree[receiver] == 0) // If all dependencies resolved
queue.Enqueue(receiver);
}
}
}
// Check for cycles in the graph
if (sortedOrder.Count != nodes.Count)
throw new InvalidOperationException("Graph is not a DAG; a cycle exists.");
return sortedOrder;
}
public override Nucleus Clone(ClusterPrefab prefab) {
Neuron clone = new(prefab, this.name);
foreach (Synapse synapse in this.synapses) {
Synapse clonedSynapse = clone.AddSynapse(synapse.nucleus);
clonedSynapse.weight = synapse.weight;
}
// foreach (Nucleus receiver in this.receivers) {
// clone.AddReceiver(receiver);
// }
return clone;
}
public override Nucleus ShallowCloneTo(Cluster parent) {
Cluster clone = new(this.prefab, parent) {
name = this.name,
clusterPrefab = this.clusterPrefab,
};
return clone;
}
private int GetNucleusIndex(Nucleus[] nucleiArray, Nucleus nucleus) {
for (int i = 0; i < nucleiArray.Length; i++) {
if (nucleus == nucleiArray[i])
return i;
}
return -1;
}
#endregion Init
public ClusterPrefab prefab;
[SerializeReference]
public List<Nucleus> clusterNuclei = new();
// the nuclei sorted using topological sorting
// to ensure that the cluster is computer in the right order
public List<Nucleus> sortedNuclei;
//public Dictionary<string, Nucleus> nucleiDict = new();
public List<Nucleus> _inputs = null;
public virtual List<Nucleus> inputs {
get {
if (this._inputs == null) {
this._inputs = new();
foreach (Nucleus nucleus in this.clusterNuclei) {
// inputs have no synapses
if (nucleus.synapses.Count == 0)
this._inputs.Add(nucleus);
}
ComputeOrders();
}
return this._inputs;
}
}
public Dictionary<Nucleus, List<Nucleus>> computeOrders = new();
private void ComputeOrders() {
foreach (Nucleus input in this._inputs) {
computeOrders[input] = TopologicalSort2(input);
}
}
private List<Nucleus> TopologicalSort2(Nucleus startNode) {
Dictionary<Nucleus, int> inDegree = new Dictionary<Nucleus, int>();
HashSet<Nucleus> visited = new HashSet<Nucleus>();
// Initialize in-degrees and mark all nodes as unvisited
foreach (Nucleus node in this.clusterNuclei) {
inDegree[node] = 0;
}
// Calculate in-degrees for all nodes reachable from the start node
Queue<Nucleus> queue = new Queue<Nucleus>();
queue.Enqueue(startNode);
visited.Add(startNode);
while (queue.Count > 0) {
Nucleus current = queue.Dequeue();
if (current is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers) {
if (!visited.Contains(receiver)) {
visited.Add(receiver);
queue.Enqueue(receiver);
}
inDegree[receiver]++;
}
}
}
// Perform topological sort on all reachable nodes
queue.Clear();
foreach (var node in visited) {
if (inDegree[node] == 0) {
queue.Enqueue(node);
}
}
List<Nucleus> sortedOrder = new List<Nucleus>();
while (queue.Count > 0) {
Nucleus current = queue.Dequeue();
sortedOrder.Add(current); // Process the node
if (current is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers) {
if (visited.Contains(receiver)) {
inDegree[receiver]--;
if (inDegree[receiver] == 0) // If all dependencies resolved
queue.Enqueue(receiver);
}
}
}
}
// Check for cycles in the graph
if (sortedOrder.Count != visited.Count)
throw new InvalidOperationException("Graph is not a DAG; a cycle exists.");
return sortedOrder;
}
private List<Nucleus> TopologicalSort3(Nucleus startNode) {
Dictionary<Nucleus, int> inDegree = new();
foreach (Nucleus node in this.clusterNuclei)
inDegree[node] = 0; // Initialize in-degree to zero
// Calculate in-degrees
foreach (Nucleus node in this.clusterNuclei) {
if (node is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers)
inDegree[receiver]++;
}
}
Queue<Nucleus> queue = new();
queue.Enqueue(startNode);
List<Nucleus> sortedOrder = new();
while (queue.Count > 0) {
Nucleus current = queue.Dequeue();
sortedOrder.Add(current); // Process the node
if (current is Neuron neuron) {
foreach (Nucleus receiver in neuron.receivers) {
inDegree[receiver]--;
if (inDegree[receiver] == 0) // If all dependencies resolved
queue.Enqueue(receiver);
}
}
}
Debug.Log($"Compute order for {startNode.name} length = {sortedOrder.Count}");
// Check for cycles in the graph
// if (sortedOrder.Count != this.nuclei.Count)
// throw new InvalidOperationException("Graph is not a DAG; a cycle exists.");
return sortedOrder;
}
public virtual Neuron defaultOutput {//=> this.nuclei[0] as Nucleus;
get {
if (this.clusterNuclei.Count > 0)
return this.clusterNuclei[0] as Neuron;
return null;
}
}
private List<Neuron> _outputs = null;
public List<Neuron> outputs {
get {
if (this._outputs == null) {
this._outputs = new();
foreach (Nucleus nucleus in this.clusterNuclei) {
if (nucleus is Neuron neuron) // && neuron.receivers.Count == 0)
this._outputs.Add(neuron);
}
}
return this._outputs;
}
}
public bool TryGetNucleus(string nucleusName, out Nucleus foundNucleus) {
foreach (Nucleus receptor in this.clusterNuclei) {
if (receptor is Nucleus nucleus)
if (nucleus.name == nucleusName) {
foundNucleus = nucleus;
return true;
}
}
foundNucleus = null;
return false;
}
public Nucleus GetNucleus(string nucleusName) {
foreach (Nucleus nucleus in this.clusterNuclei) {
if (nucleus.name == nucleusName)
return nucleus;
}
return null;
}
public Receptor GetReceptor(string receptorName) {
foreach (Nucleus nucleus in this.clusterNuclei) {
if (nucleus is Receptor receptor)
if (receptor.name == receptorName)
return receptor;
}
return null;
}
#region Receivers
public virtual List<Nucleus> CollectReceivers() {
List<Nucleus> receivers = new();
foreach (Neuron output in this.outputs) {
receivers.AddRange(output.receivers);
}
return receivers;
}
#endregion Receivers
#region Update
public void UpdateFromNucleus(Nucleus startNucleus) {
// no bias+synapse input state calculation for now...
List<Nucleus> computeOrder = this.computeOrders[startNucleus];
if (startNucleus.trace)
Debug.Log($"Update from {startNucleus.name}");
foreach (Nucleus nucleus in computeOrder) {
nucleus.UpdateStateIsolated();
if (startNucleus.trace)
Debug.Log($" {nucleus.name} = {nucleus.outputValue}");
}
this.outputValue = this.defaultOutput.outputValue;
this.stale = 0;
UpdateNuclei();
}
public override void UpdateStateIsolated() {
float3 sum = this.bias;
//Applying the weight factors
foreach (Synapse synapse in this.synapses) {
if (lengthsq(synapse.nucleus.outputValue) > 0) {
sum += synapse.weight * synapse.nucleus.outputValue;
this.stale = 0;
}
}
foreach (Nucleus nucleus in this.sortedNuclei)
nucleus.UpdateStateIsolated();
this.outputValue = this.defaultOutput.outputValue;
this.stale = 0;
UpdateNuclei();
}
public override void UpdateNuclei() {
foreach (Nucleus nucleus in this.clusterNuclei)
nucleus.UpdateNuclei();
}
#endregion Update
}