viz
This commit is contained in:
@ -80,6 +80,11 @@ class Genome:
|
|||||||
for _ in range(outputs):
|
for _ in range(outputs):
|
||||||
genome.add_node(node_type=NodeType.OUTPUT)
|
genome.add_node(node_type=NodeType.OUTPUT)
|
||||||
|
|
||||||
|
# Fully connect
|
||||||
|
for i in range(inputs):
|
||||||
|
for o in range(inputs, inputs + outputs):
|
||||||
|
genome.add_connection(i, o, weight=1)
|
||||||
|
|
||||||
return genome
|
return genome
|
||||||
|
|
||||||
|
|
||||||
@ -112,7 +117,7 @@ def _mutate_add_node(genome: Genome) -> None:
|
|||||||
|
|
||||||
# Find connection to split
|
# Find connection to split
|
||||||
try:
|
try:
|
||||||
connection = choice(list(genome.connections.values()))
|
connection = choice([node for node in genome.connections.values() if not node.disabled])
|
||||||
except IndexError:
|
except IndexError:
|
||||||
return
|
return
|
||||||
connection.disabled = True
|
connection.disabled = True
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
1
Abschlussprojekt/requirements.txt
Normal file
1
Abschlussprojekt/requirements.txt
Normal file
@ -0,0 +1 @@
|
|||||||
|
pygraphviz
|
||||||
@ -1,3 +1,5 @@
|
|||||||
|
import itertools
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -5,8 +7,17 @@ import numpy as np
|
|||||||
from genome import Genome, NodeType, mutate
|
from genome import Genome, NodeType, mutate
|
||||||
|
|
||||||
|
|
||||||
|
def _find_layer(g: nx.DiGraph, hidden_node: int, inputs: list[int]) -> int:
|
||||||
|
paths = []
|
||||||
|
for input_node in inputs:
|
||||||
|
paths += list(nx.all_simple_paths(g, input_node, hidden_node))
|
||||||
|
|
||||||
|
path_lengths = [len(path) for path in paths]
|
||||||
|
return max(path_lengths)
|
||||||
|
|
||||||
|
|
||||||
def genome(genome: Genome):
|
def genome(genome: Genome):
|
||||||
graph = nx.Graph()
|
graph = nx.DiGraph()
|
||||||
|
|
||||||
# Add nodes
|
# Add nodes
|
||||||
for node in genome.nodes.keys():
|
for node in genome.nodes.keys():
|
||||||
@ -20,33 +31,37 @@ def genome(genome: Genome):
|
|||||||
from_node, to_node = connection.nodes
|
from_node, to_node = connection.nodes
|
||||||
graph.add_edge(from_node, to_node, weight=connection.weight)
|
graph.add_edge(from_node, to_node, weight=connection.weight)
|
||||||
|
|
||||||
# Make sure that input and output nodes are fixed
|
inputs = [node.id for node in genome.nodes.values() if node.type == NodeType.INPUT]
|
||||||
pos = nx.spring_layout(graph)
|
hidden = [node.id for node in genome.nodes.values() if node.type == NodeType.HIDDEN]
|
||||||
x = [v[0] for v in pos.values()]
|
outputs = [node.id for node in genome.nodes.values() if node.type == NodeType.OUTPUT]
|
||||||
min_x, max_x = min(x), max(x)
|
|
||||||
y = [v[1] for v in pos.values()]
|
|
||||||
min_y, max_y = min(y), max(y)
|
|
||||||
|
|
||||||
inputs = [node for node in genome.nodes.values() if node.type == NodeType.INPUT]
|
for input_node in inputs:
|
||||||
outputs = [node for node in genome.nodes.values() if node.type == NodeType.OUTPUT]
|
graph.nodes[input_node]["layer"] = 0
|
||||||
|
|
||||||
for node, y in zip(inputs, np.linspace(min_y, max_y, len(inputs))):
|
max_layer = 1
|
||||||
pos[node.id] = np.array([min_x * 1.5, y])
|
for hidden_node in hidden:
|
||||||
|
layer = _find_layer(graph, hidden_node, inputs)
|
||||||
|
max_layer = max(layer, max_layer)
|
||||||
|
graph.nodes[hidden_node]["layer"] = layer
|
||||||
|
|
||||||
for node, y in zip(outputs, np.linspace(min_y, max_y, len(outputs))):
|
for output_node in outputs:
|
||||||
pos[node.id] = np.array([max_x * 1.5, y])
|
graph.nodes[output_node]["layer"] = max_layer + 1
|
||||||
|
|
||||||
plt.subplot()
|
plt.subplot()
|
||||||
nx.draw_networkx(graph, pos)
|
pos = nx.multipartite_layout(graph, subset_key="layer")
|
||||||
|
nx.draw_networkx_nodes(graph, pos, nodelist=inputs, label=inputs, node_color="#ff0000")
|
||||||
|
nx.draw_networkx_nodes(graph, pos, nodelist=hidden, label=hidden, node_color="#00ff00")
|
||||||
|
nx.draw_networkx_nodes(graph, pos, nodelist=outputs, label=outputs, node_color="#0000ff")
|
||||||
|
nx.draw_networkx_edges(graph, pos)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
g1 = Genome.new(3, 2)
|
g1 = Genome.new(3, 2)
|
||||||
g1.add_node()
|
|
||||||
g1.add_node()
|
|
||||||
g1.add_node()
|
|
||||||
g1.add_connection(0, 4, 0.5)
|
g1.add_connection(0, 4, 0.5)
|
||||||
|
|
||||||
mutate(g1)
|
mutate(g1)
|
||||||
|
mutate(g1)
|
||||||
|
mutate(g1)
|
||||||
|
|
||||||
|
# mutate(g1)
|
||||||
genome(g1)
|
genome(g1)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|||||||
Reference in New Issue
Block a user