53 lines
1.4 KiB
Python
53 lines
1.4 KiB
Python
import matplotlib.pyplot as plt
|
|
import networkx as nx
|
|
import numpy as np
|
|
|
|
from genome import Genome, NodeType, mutate
|
|
|
|
|
|
def genome(genome: Genome):
|
|
graph = nx.Graph()
|
|
|
|
# Add nodes
|
|
for node in genome.nodes.keys():
|
|
graph.add_node(node)
|
|
|
|
# Add edges
|
|
for connection in genome.connections.values():
|
|
if connection.disabled:
|
|
continue
|
|
|
|
from_node, to_node = connection.nodes
|
|
graph.add_edge(from_node, to_node, weight=connection.weight)
|
|
|
|
# Make sure that input and output nodes are fixed
|
|
pos = nx.spring_layout(graph)
|
|
x = [v[0] for v in pos.values()]
|
|
min_x, max_x = min(x), max(x)
|
|
y = [v[1] for v in pos.values()]
|
|
min_y, max_y = min(y), max(y)
|
|
|
|
inputs = [node for node in genome.nodes.values() if node.type == NodeType.INPUT]
|
|
outputs = [node for node in genome.nodes.values() if node.type == NodeType.OUTPUT]
|
|
|
|
for node, y in zip(inputs, np.linspace(min_y, max_y, len(inputs))):
|
|
pos[node.id] = np.array([min_x * 1.5, y])
|
|
|
|
for node, y in zip(outputs, np.linspace(min_y, max_y, len(outputs))):
|
|
pos[node.id] = np.array([max_x * 1.5, y])
|
|
|
|
plt.subplot()
|
|
nx.draw_networkx(graph, pos)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
g1 = Genome.new(3, 2)
|
|
g1.add_node()
|
|
g1.add_node()
|
|
g1.add_node()
|
|
g1.add_connection(0, 4, 0.5)
|
|
|
|
mutate(g1)
|
|
genome(g1)
|
|
plt.show()
|