Initial commit

This commit is contained in:
2023-07-15 14:25:10 +02:00
parent 58f1e7b1ad
commit 439f995eae
3 changed files with 251 additions and 0 deletions

128
Abschlussprojekt/genome.py Normal file
View File

@ -0,0 +1,128 @@
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from random import choice
class NodeType(Enum):
INPUT = 1
HIDDEN = 2
OUTPUT = 3
class MutationType(Enum):
ADD_CONNECTION = 1
ADD_NODE = 2
@dataclass(frozen=True)
class NodeGene:
id: int
type: NodeType
@dataclass()
class ConnectionGene:
nodes: tuple[int, int]
weight: float
innovation_no: int
disabled: bool = False
_CONNECTION_GENES: dict[tuple[int, int], ConnectionGene] = dict()
class Genome:
def __init__(self):
# Initialize nodes
self.nodes: dict[int, NodeGene] = dict()
# Initialize connections
self.connections: dict[tuple[int, int], ConnectionGene] = dict()
def add_node(self, node_type: NodeType = NodeType.HIDDEN) -> int:
"""
Adds a node of the given type to the genome and returns the identification key.
"""
key = len(self.nodes)
self.nodes[key] = NodeGene(key, node_type)
return key
def add_connection(self, from_node: int, to_node: int, weight: float) -> tuple[int, int]:
"""
Adds a connection of weight between two given nodes to the genome and returns
the identification key.
"""
key = (from_node, to_node)
connection = ConnectionGene(key, weight, -1)
if key in _CONNECTION_GENES:
connection.innovation_no = _CONNECTION_GENES[key].innovation_no
else:
connection.innovation_no = len(_CONNECTION_GENES)
_CONNECTION_GENES[key] = connection
self.connections[key] = connection
return key
@staticmethod
def new(inputs: int, outputs: int) -> Genome:
genome = Genome()
# Add input nodes
for _ in range(inputs):
genome.add_node(node_type=NodeType.INPUT)
# Add output nodes
for _ in range(outputs):
genome.add_node(node_type=NodeType.OUTPUT)
return genome
def mutate(genome: Genome) -> None:
mutation = choice([MutationType.ADD_NODE])
if mutation is MutationType.ADD_CONNECTION:
_mutate_add_connection(genome)
elif mutation is MutationType.ADD_NODE:
_mutate_add_node(genome)
def _mutate_add_connection(genome: Genome) -> None:
"""
In the add_connection mutation, a single new connection gene with a random weight
is added connecting two previously unconnected nodes.
"""
...
def _mutate_add_node(genome: Genome) -> None:
"""
In the add_node mutation, an existing connection is split and the new node
placed where the old connection used to be. The old connection is disabled
and two new conections are added to the genome. The new connection leading
into the new node receives a weight of 1, and the new connection leading out
receives the same weight as the old connection.
"""
# Find connection to split
try:
connection = choice(list(genome.connections.values()))
except IndexError:
return
connection.disabled = True
# Create new node
new_node = genome.add_node()
from_node, to_node = connection.nodes
# Connect previous from_node to new_node
genome.add_connection(from_node, new_node, weight=1)
# Connection new_node to previous to_node
genome.add_connection(new_node, to_node, weight=connection.weight)

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,52 @@
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from genome import Genome, NodeType, mutate
def genome(genome: Genome):
graph = nx.Graph()
# Add nodes
for node in genome.nodes.keys():
graph.add_node(node)
# Add edges
for connection in genome.connections.values():
if connection.disabled:
continue
from_node, to_node = connection.nodes
graph.add_edge(from_node, to_node, weight=connection.weight)
# Make sure that input and output nodes are fixed
pos = nx.spring_layout(graph)
x = [v[0] for v in pos.values()]
min_x, max_x = min(x), max(x)
y = [v[1] for v in pos.values()]
min_y, max_y = min(y), max(y)
inputs = [node for node in genome.nodes.values() if node.type == NodeType.INPUT]
outputs = [node for node in genome.nodes.values() if node.type == NodeType.OUTPUT]
for node, y in zip(inputs, np.linspace(min_y, max_y, len(inputs))):
pos[node.id] = np.array([min_x * 1.5, y])
for node, y in zip(outputs, np.linspace(min_y, max_y, len(outputs))):
pos[node.id] = np.array([max_x * 1.5, y])
plt.subplot()
nx.draw_networkx(graph, pos)
if __name__ == "__main__":
g1 = Genome.new(3, 2)
g1.add_node()
g1.add_node()
g1.add_node()
g1.add_connection(0, 4, 0.5)
mutate(g1)
genome(g1)
plt.show()