398 lines
14 KiB
C#
398 lines
14 KiB
C#
using System;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
using UnityEngine;
|
|
using Unity.Mathematics;
|
|
using static Unity.Mathematics.math;
|
|
|
|
[Serializable]
|
|
public class Cluster : Nucleus {
|
|
|
|
#region Init
|
|
|
|
public Cluster(ClusterPrefab prefab, Cluster parent) {
|
|
this.prefab = prefab;
|
|
this.name = prefab.name;
|
|
|
|
this.parent = parent;
|
|
this.parent?.nuclei.Add(this);
|
|
|
|
ClonePrefab();
|
|
_ = this.inputs;
|
|
this.sortedNuclei = TopologicalSort(this.nuclei);
|
|
// Does not work because we have nuclei with the same names in an nucleusArray
|
|
// 'Pheromone steering'
|
|
//this.nucleiDict = nuclei.ToDictionary(nucleus => nucleus.name);
|
|
}
|
|
|
|
public Cluster(ClusterPrefab prefab, ClusterPrefab parent = null) {
|
|
this.prefab = prefab;
|
|
this.name = prefab.name;
|
|
this.cluster = parent;
|
|
|
|
if (this.cluster != null)
|
|
this.cluster.nuclei.Add(this);
|
|
|
|
ClonePrefab();
|
|
_ = this.inputs;
|
|
this.sortedNuclei = TopologicalSort(this.nuclei);
|
|
//this.nucleiDict = nuclei.ToDictionary(nucleus => nucleus.name);
|
|
}
|
|
|
|
private void ClonePrefab() {
|
|
Nucleus[] prefabNuclei = this.prefab.nuclei.ToArray();
|
|
// first clone the nuclei without their connections
|
|
foreach (Nucleus nucleus in this.prefab.nuclei)
|
|
nucleus.ShallowCloneTo(this);
|
|
Nucleus[] clonedNuclei = this.nuclei.ToArray();
|
|
|
|
// Now clone the connections
|
|
for (int nucleusIx = 0; nucleusIx < prefabNuclei.Length; nucleusIx++) {
|
|
Nucleus prefabNucleus = prefabNuclei[nucleusIx];
|
|
Nucleus clonedReceptor = clonedNuclei[nucleusIx];
|
|
if (clonedReceptor == null)
|
|
continue;
|
|
|
|
// Copy the receivers, which will also create the synapses
|
|
foreach (Nucleus receiver in prefabNucleus.receivers) {
|
|
int ix = GetNucleusIndex(prefabNuclei, receiver);
|
|
if (ix < 0)
|
|
continue;
|
|
|
|
if (clonedNuclei[ix] is not Nucleus clonedReceiver)
|
|
continue;
|
|
|
|
// Find the synapse for the weight
|
|
float weight = 1;
|
|
foreach (Synapse synapse in receiver.synapses) {
|
|
// Find the weight for this synapse
|
|
if (synapse.nucleus == prefabNucleus) {
|
|
weight = synapse.weight;
|
|
break;
|
|
}
|
|
}
|
|
|
|
clonedReceptor.AddReceiver(clonedReceiver, weight);
|
|
}
|
|
}
|
|
|
|
// Copy nucleus arrays
|
|
for (int nucleusIx = 0; nucleusIx < prefabNuclei.Length; nucleusIx++) {
|
|
Nucleus prefabReceptor = prefabNuclei[nucleusIx];
|
|
if (prefabReceptor is not Receptor prefabNucleus)
|
|
continue;
|
|
|
|
if (prefabNucleus.array == null || prefabNucleus.array.nuclei == null || prefabNucleus.array.nuclei.Length == 0)
|
|
continue;
|
|
|
|
Receptor clonedNucleus = clonedNuclei[nucleusIx] as Receptor;
|
|
if (prefabNucleus == prefabNucleus.array.nuclei[0]) {
|
|
// We clone the array only for the first entry
|
|
NucleusArray clonedArray = new(prefabNucleus.array.nuclei.Length, "array");
|
|
int arrayIx = 0;
|
|
foreach (Nucleus prefabArrayNucleus in prefabNucleus.array.nuclei) {
|
|
int arrayNucleusIx = GetNucleusIndex(prefabNuclei, prefabArrayNucleus);
|
|
if (arrayNucleusIx >= 0) {
|
|
Nucleus clonedArrayNucleus = clonedNuclei[arrayNucleusIx];
|
|
clonedArray.nuclei[arrayIx] = clonedArrayNucleus;
|
|
}
|
|
else {
|
|
Debug.LogError($" Could not find prefab nuclues {prefabNucleus.name} in the clones");
|
|
}
|
|
arrayIx++;
|
|
}
|
|
clonedNucleus.array = clonedArray;
|
|
}
|
|
else {
|
|
// The others will refer to the array created for the first nucleus in the array
|
|
int firstNucleusIx = GetNucleusIndex(prefabNuclei, prefabNucleus.array.nuclei[0]);
|
|
Receptor clonedFirstNucleus = clonedNuclei[firstNucleusIx] as Receptor;
|
|
clonedNucleus.array = clonedFirstNucleus.array;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Sort the nuclei in a correct evaluation order
|
|
private List<Nucleus> TopologicalSort(List<Nucleus> nodes) {
|
|
Dictionary<Nucleus, int> inDegree = new();
|
|
foreach (Nucleus node in nodes)
|
|
inDegree[node] = 0; // Initialize in-degree to zero
|
|
|
|
// Calculate in-degrees
|
|
foreach (Nucleus node in nodes) {
|
|
foreach (Nucleus receiver in node.receivers)
|
|
inDegree[receiver]++;
|
|
}
|
|
|
|
Queue<Nucleus> queue = new();
|
|
foreach (Nucleus node in nodes) {
|
|
if (inDegree[node] == 0) // Nodes with no dependencies
|
|
queue.Enqueue(node);
|
|
}
|
|
// The queue basically stores all input nuclei?
|
|
|
|
List<Nucleus> sortedOrder = new();
|
|
while (queue.Count > 0) {
|
|
Nucleus current = queue.Dequeue();
|
|
sortedOrder.Add(current); // Process the node
|
|
|
|
foreach (Nucleus receiver in current.receivers) {
|
|
inDegree[receiver]--;
|
|
if (inDegree[receiver] == 0) // If all dependencies resolved
|
|
queue.Enqueue(receiver);
|
|
}
|
|
}
|
|
|
|
// Check for cycles in the graph
|
|
if (sortedOrder.Count != nodes.Count)
|
|
throw new InvalidOperationException("Graph is not a DAG; a cycle exists.");
|
|
|
|
return sortedOrder;
|
|
}
|
|
|
|
public override Nucleus Clone(ClusterPrefab prefab) {
|
|
//Neuron clone = new(this.cluster, this.name) {
|
|
Neuron clone = new(prefab, this.name) {
|
|
// array = this.array,
|
|
};
|
|
|
|
foreach (Synapse synapse in this.synapses) {
|
|
Synapse clonedSynapse = clone.AddSynapse(synapse.nucleus);
|
|
clonedSynapse.weight = synapse.weight;
|
|
}
|
|
foreach (Nucleus receiver in this.receivers) {
|
|
clone.AddReceiver(receiver);
|
|
}
|
|
return clone;
|
|
}
|
|
|
|
public override Nucleus ShallowCloneTo(Cluster parent) {
|
|
Cluster clone = new(this.prefab, parent) {
|
|
name = this.name,
|
|
};
|
|
return clone;
|
|
}
|
|
|
|
private int GetNucleusIndex(Nucleus[] nucleiArray, Nucleus nucleus) {
|
|
for (int i = 0; i < nucleiArray.Length; i++) {
|
|
if (nucleus == nucleiArray[i])
|
|
return i;
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
#endregion Init
|
|
|
|
public ClusterPrefab prefab;
|
|
|
|
|
|
[SerializeReference]
|
|
public List<Nucleus> nuclei = new();
|
|
// the nuclei sorted using topological sorting
|
|
// to ensure that the cluster is computer in the right order
|
|
public List<Nucleus> sortedNuclei;
|
|
//public Dictionary<string, Nucleus> nucleiDict = new();
|
|
|
|
public List<Nucleus> _inputs = null;
|
|
public virtual List<Nucleus> inputs {
|
|
get {
|
|
if (this._inputs == null) {
|
|
this._inputs = new();
|
|
foreach (Nucleus nucleus in this.nuclei) {
|
|
// inputs have no synapses
|
|
if (nucleus.synapses.Count == 0)
|
|
this._inputs.Add(nucleus);
|
|
}
|
|
ComputeOrders();
|
|
}
|
|
return this._inputs;
|
|
}
|
|
}
|
|
|
|
public Dictionary<Nucleus, List<Nucleus>> computeOrders = new();
|
|
private void ComputeOrders() {
|
|
foreach (Nucleus input in this._inputs) {
|
|
computeOrders[input] = TopologicalSort2(input);
|
|
}
|
|
}
|
|
|
|
private List<Nucleus> TopologicalSort2(Nucleus startNode) {
|
|
Dictionary<Nucleus, int> inDegree = new Dictionary<Nucleus, int>();
|
|
HashSet<Nucleus> visited = new HashSet<Nucleus>();
|
|
|
|
// Initialize in-degrees and mark all nodes as unvisited
|
|
foreach (Nucleus node in this.nuclei) {
|
|
inDegree[node] = 0;
|
|
}
|
|
|
|
// Calculate in-degrees for all nodes reachable from the start node
|
|
Queue<Nucleus> queue = new Queue<Nucleus>();
|
|
queue.Enqueue(startNode);
|
|
visited.Add(startNode);
|
|
|
|
while (queue.Count > 0) {
|
|
Nucleus current = queue.Dequeue();
|
|
foreach (Nucleus receiver in current.receivers) {
|
|
if (!visited.Contains(receiver)) {
|
|
visited.Add(receiver);
|
|
queue.Enqueue(receiver);
|
|
}
|
|
inDegree[receiver]++;
|
|
}
|
|
}
|
|
|
|
// Perform topological sort on all reachable nodes
|
|
queue.Clear();
|
|
foreach (var node in visited) {
|
|
if (inDegree[node] == 0) {
|
|
queue.Enqueue(node);
|
|
}
|
|
}
|
|
|
|
List<Nucleus> sortedOrder = new List<Nucleus>();
|
|
while (queue.Count > 0) {
|
|
Nucleus current = queue.Dequeue();
|
|
sortedOrder.Add(current); // Process the node
|
|
|
|
foreach (Nucleus receiver in current.receivers) {
|
|
if (visited.Contains(receiver)) {
|
|
inDegree[receiver]--;
|
|
if (inDegree[receiver] == 0) // If all dependencies resolved
|
|
queue.Enqueue(receiver);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check for cycles in the graph
|
|
if (sortedOrder.Count != visited.Count)
|
|
throw new InvalidOperationException("Graph is not a DAG; a cycle exists.");
|
|
|
|
return sortedOrder;
|
|
}
|
|
|
|
private List<Nucleus> TopologicalSort3(Nucleus startNode) {
|
|
Dictionary<Nucleus, int> inDegree = new();
|
|
foreach (Nucleus node in this.nuclei)
|
|
inDegree[node] = 0; // Initialize in-degree to zero
|
|
|
|
// Calculate in-degrees
|
|
foreach (Nucleus node in this.nuclei) {
|
|
foreach (Nucleus receiver in node.receivers)
|
|
inDegree[receiver]++;
|
|
}
|
|
|
|
Queue<Nucleus> queue = new();
|
|
queue.Enqueue(startNode);
|
|
|
|
List<Nucleus> sortedOrder = new();
|
|
while (queue.Count > 0) {
|
|
Nucleus current = queue.Dequeue();
|
|
sortedOrder.Add(current); // Process the node
|
|
|
|
foreach (Nucleus receiver in current.receivers) {
|
|
inDegree[receiver]--;
|
|
if (inDegree[receiver] == 0) // If all dependencies resolved
|
|
queue.Enqueue(receiver);
|
|
}
|
|
}
|
|
|
|
Debug.Log($"Compute order for {startNode.name} length = {sortedOrder.Count}");
|
|
// Check for cycles in the graph
|
|
// if (sortedOrder.Count != this.nuclei.Count)
|
|
// throw new InvalidOperationException("Graph is not a DAG; a cycle exists.");
|
|
|
|
return sortedOrder;
|
|
}
|
|
|
|
public virtual Nucleus output {//=> this.nuclei[0] as Nucleus;
|
|
get {
|
|
if (this.nuclei.Count > 0)
|
|
return this.nuclei[0];
|
|
return null;
|
|
}
|
|
}
|
|
public List<Nucleus> _outputs = null;
|
|
public List<Nucleus> outputs {
|
|
get {
|
|
if (this._outputs == null) {
|
|
this._outputs = new();
|
|
foreach (Nucleus nucleus in this.nuclei) {
|
|
// outputs have not receivers
|
|
if (nucleus.receivers.Count == 0)
|
|
this._outputs.Add(nucleus);
|
|
}
|
|
}
|
|
return this._outputs;
|
|
}
|
|
}
|
|
|
|
public bool TryGetNucleus(string nucleusName, out Nucleus foundNucleus) {
|
|
foreach (Nucleus receptor in this.nuclei) {
|
|
if (receptor is Nucleus nucleus)
|
|
if (nucleus.name == nucleusName) {
|
|
foundNucleus = nucleus;
|
|
return true;
|
|
}
|
|
}
|
|
foundNucleus = null;
|
|
return false;
|
|
}
|
|
|
|
public Nucleus GetNucleus(string nucleusName) {
|
|
foreach (Nucleus receptor in this.nuclei) {
|
|
if (receptor is Nucleus nucleus)
|
|
if (nucleus.name == nucleusName)
|
|
return nucleus;
|
|
}
|
|
return null;
|
|
}
|
|
|
|
#region Update
|
|
|
|
public void UpdateFromNucleus(Nucleus startNucleus) {
|
|
// no bias+synapse input state calculation for now...
|
|
|
|
List<Nucleus> computeOrder = this.computeOrders[startNucleus];
|
|
if (startNucleus.trace)
|
|
Debug.Log($"Update from {startNucleus.name}");
|
|
foreach (Nucleus nucleus in computeOrder) {
|
|
nucleus.UpdateStateIsolated();
|
|
if (startNucleus.trace)
|
|
Debug.Log($" {nucleus.name} = {nucleus.outputValue}");
|
|
}
|
|
|
|
this.outputValue = this.output.outputValue;
|
|
this.stale = 0;
|
|
|
|
UpdateNuclei();
|
|
}
|
|
|
|
public override void UpdateStateIsolated() {
|
|
float3 sum = this.bias;
|
|
|
|
//Applying the weight factors
|
|
foreach (Synapse synapse in this.synapses) {
|
|
if (lengthsq(synapse.nucleus.outputValue) > 0) {
|
|
sum += synapse.weight * synapse.nucleus.outputValue;
|
|
this.stale = 0;
|
|
}
|
|
}
|
|
|
|
foreach (Nucleus nucleus in this.sortedNuclei)
|
|
nucleus.UpdateStateIsolated();
|
|
|
|
this.outputValue = this.output.outputValue;
|
|
this.stale = 0;
|
|
|
|
UpdateNuclei();
|
|
}
|
|
|
|
public override void UpdateNuclei() {
|
|
foreach (Nucleus nucleus in this.nuclei)
|
|
nucleus.UpdateNuclei();
|
|
}
|
|
|
|
#endregion Update
|
|
|
|
}
|