Files
ne-assignments/Abschlussprojekt/visualization.py
2023-07-16 00:40:47 +02:00

68 lines
1.9 KiB
Python

import itertools
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from genome import Genome, NodeType, mutate
def _find_layer(g: nx.DiGraph, hidden_node: int, inputs: list[int]) -> int:
paths = []
for input_node in inputs:
paths += list(nx.all_simple_paths(g, input_node, hidden_node))
path_lengths = [len(path) for path in paths]
return max(path_lengths)
def genome(genome: Genome):
graph = nx.DiGraph()
# Add nodes
for node in genome.nodes.keys():
graph.add_node(node)
# Add edges
for connection in genome.connections.values():
if connection.disabled:
continue
from_node, to_node = connection.nodes
graph.add_edge(from_node, to_node, weight=connection.weight)
inputs = [node.id for node in genome.nodes.values() if node.type == NodeType.INPUT]
hidden = [node.id for node in genome.nodes.values() if node.type == NodeType.HIDDEN]
outputs = [node.id for node in genome.nodes.values() if node.type == NodeType.OUTPUT]
for input_node in inputs:
graph.nodes[input_node]["layer"] = 0
max_layer = 1
for hidden_node in hidden:
layer = _find_layer(graph, hidden_node, inputs)
max_layer = max(layer, max_layer)
graph.nodes[hidden_node]["layer"] = layer
for output_node in outputs:
graph.nodes[output_node]["layer"] = max_layer + 1
plt.subplot()
pos = nx.multipartite_layout(graph, subset_key="layer")
nx.draw_networkx_nodes(graph, pos, nodelist=inputs, label=inputs, node_color="#ff0000")
nx.draw_networkx_nodes(graph, pos, nodelist=hidden, label=hidden, node_color="#00ff00")
nx.draw_networkx_nodes(graph, pos, nodelist=outputs, label=outputs, node_color="#0000ff")
nx.draw_networkx_edges(graph, pos)
if __name__ == "__main__":
g1 = Genome.new(3, 2)
g1.add_connection(0, 4, 0.5)
mutate(g1)
mutate(g1)
mutate(g1)
# mutate(g1)
genome(g1)
plt.show()