Initial commit
This commit is contained in:
128
Abschlussprojekt/genome.py
Normal file
128
Abschlussprojekt/genome.py
Normal file
@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from random import choice
|
||||
|
||||
|
||||
class NodeType(Enum):
|
||||
INPUT = 1
|
||||
HIDDEN = 2
|
||||
OUTPUT = 3
|
||||
|
||||
|
||||
class MutationType(Enum):
|
||||
ADD_CONNECTION = 1
|
||||
ADD_NODE = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NodeGene:
|
||||
id: int
|
||||
type: NodeType
|
||||
|
||||
|
||||
@dataclass()
|
||||
class ConnectionGene:
|
||||
nodes: tuple[int, int]
|
||||
weight: float
|
||||
innovation_no: int
|
||||
disabled: bool = False
|
||||
|
||||
|
||||
_CONNECTION_GENES: dict[tuple[int, int], ConnectionGene] = dict()
|
||||
|
||||
|
||||
class Genome:
|
||||
def __init__(self):
|
||||
# Initialize nodes
|
||||
self.nodes: dict[int, NodeGene] = dict()
|
||||
|
||||
# Initialize connections
|
||||
self.connections: dict[tuple[int, int], ConnectionGene] = dict()
|
||||
|
||||
def add_node(self, node_type: NodeType = NodeType.HIDDEN) -> int:
|
||||
"""
|
||||
Adds a node of the given type to the genome and returns the identification key.
|
||||
"""
|
||||
|
||||
key = len(self.nodes)
|
||||
self.nodes[key] = NodeGene(key, node_type)
|
||||
return key
|
||||
|
||||
def add_connection(self, from_node: int, to_node: int, weight: float) -> tuple[int, int]:
|
||||
"""
|
||||
Adds a connection of weight between two given nodes to the genome and returns
|
||||
the identification key.
|
||||
"""
|
||||
|
||||
key = (from_node, to_node)
|
||||
connection = ConnectionGene(key, weight, -1)
|
||||
|
||||
if key in _CONNECTION_GENES:
|
||||
connection.innovation_no = _CONNECTION_GENES[key].innovation_no
|
||||
else:
|
||||
connection.innovation_no = len(_CONNECTION_GENES)
|
||||
_CONNECTION_GENES[key] = connection
|
||||
|
||||
self.connections[key] = connection
|
||||
return key
|
||||
|
||||
@staticmethod
|
||||
def new(inputs: int, outputs: int) -> Genome:
|
||||
genome = Genome()
|
||||
|
||||
# Add input nodes
|
||||
for _ in range(inputs):
|
||||
genome.add_node(node_type=NodeType.INPUT)
|
||||
|
||||
# Add output nodes
|
||||
for _ in range(outputs):
|
||||
genome.add_node(node_type=NodeType.OUTPUT)
|
||||
|
||||
return genome
|
||||
|
||||
|
||||
def mutate(genome: Genome) -> None:
|
||||
mutation = choice([MutationType.ADD_NODE])
|
||||
|
||||
if mutation is MutationType.ADD_CONNECTION:
|
||||
_mutate_add_connection(genome)
|
||||
elif mutation is MutationType.ADD_NODE:
|
||||
_mutate_add_node(genome)
|
||||
|
||||
|
||||
def _mutate_add_connection(genome: Genome) -> None:
|
||||
"""
|
||||
In the add_connection mutation, a single new connection gene with a random weight
|
||||
is added connecting two previously unconnected nodes.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
|
||||
def _mutate_add_node(genome: Genome) -> None:
|
||||
"""
|
||||
In the add_node mutation, an existing connection is split and the new node
|
||||
placed where the old connection used to be. The old connection is disabled
|
||||
and two new conections are added to the genome. The new connection leading
|
||||
into the new node receives a weight of 1, and the new connection leading out
|
||||
receives the same weight as the old connection.
|
||||
"""
|
||||
|
||||
# Find connection to split
|
||||
try:
|
||||
connection = choice(list(genome.connections.values()))
|
||||
except IndexError:
|
||||
return
|
||||
connection.disabled = True
|
||||
|
||||
# Create new node
|
||||
new_node = genome.add_node()
|
||||
from_node, to_node = connection.nodes
|
||||
|
||||
# Connect previous from_node to new_node
|
||||
genome.add_connection(from_node, new_node, weight=1)
|
||||
|
||||
# Connection new_node to previous to_node
|
||||
genome.add_connection(new_node, to_node, weight=connection.weight)
|
||||
71
Abschlussprojekt/neat.ipynb
Normal file
71
Abschlussprojekt/neat.ipynb
Normal file
File diff suppressed because one or more lines are too long
52
Abschlussprojekt/visualization.py
Normal file
52
Abschlussprojekt/visualization.py
Normal file
@ -0,0 +1,52 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from genome import Genome, NodeType, mutate
|
||||
|
||||
|
||||
def genome(genome: Genome):
|
||||
graph = nx.Graph()
|
||||
|
||||
# Add nodes
|
||||
for node in genome.nodes.keys():
|
||||
graph.add_node(node)
|
||||
|
||||
# Add edges
|
||||
for connection in genome.connections.values():
|
||||
if connection.disabled:
|
||||
continue
|
||||
|
||||
from_node, to_node = connection.nodes
|
||||
graph.add_edge(from_node, to_node, weight=connection.weight)
|
||||
|
||||
# Make sure that input and output nodes are fixed
|
||||
pos = nx.spring_layout(graph)
|
||||
x = [v[0] for v in pos.values()]
|
||||
min_x, max_x = min(x), max(x)
|
||||
y = [v[1] for v in pos.values()]
|
||||
min_y, max_y = min(y), max(y)
|
||||
|
||||
inputs = [node for node in genome.nodes.values() if node.type == NodeType.INPUT]
|
||||
outputs = [node for node in genome.nodes.values() if node.type == NodeType.OUTPUT]
|
||||
|
||||
for node, y in zip(inputs, np.linspace(min_y, max_y, len(inputs))):
|
||||
pos[node.id] = np.array([min_x * 1.5, y])
|
||||
|
||||
for node, y in zip(outputs, np.linspace(min_y, max_y, len(outputs))):
|
||||
pos[node.id] = np.array([max_x * 1.5, y])
|
||||
|
||||
plt.subplot()
|
||||
nx.draw_networkx(graph, pos)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
g1 = Genome.new(3, 2)
|
||||
g1.add_node()
|
||||
g1.add_node()
|
||||
g1.add_node()
|
||||
g1.add_connection(0, 4, 0.5)
|
||||
|
||||
mutate(g1)
|
||||
genome(g1)
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user