Files
ne-assignments/Abschlussprojekt/genome.py
2023-07-16 19:16:43 +02:00

150 lines
4.3 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from random import choice, random
import networkx as nx
class NodeType(Enum):
INPUT = 1
HIDDEN = 2
OUTPUT = 3
class MutationType(Enum):
ADD_CONNECTION = 1
ADD_NODE = 2
@dataclass(frozen=True)
class NodeGene:
id: int
type: NodeType
@dataclass()
class ConnectionGene:
nodes: tuple[int, int]
weight: float
innovation_no: int
disabled: bool = False
_CONNECTION_GENES: dict[tuple[int, int], ConnectionGene] = dict()
class Genome:
def __init__(self):
# Initialize nodes
self.nodes: dict[int, NodeGene] = dict()
# Initialize connections
self.connections: dict[tuple[int, int], ConnectionGene] = dict()
def add_node(self, node_type: NodeType = NodeType.HIDDEN) -> int:
"""
Adds a node of the given type to the genome and returns the identification key.
"""
key = len(self.nodes)
self.nodes[key] = NodeGene(key, node_type)
return key
def add_connection(self, from_node: int, to_node: int, weight: float) -> tuple[int, int]:
"""
Adds a connection of weight between two given nodes to the genome and returns
the identification key.
"""
key = (from_node, to_node)
connection = ConnectionGene(key, weight, -1)
if key in _CONNECTION_GENES:
connection.innovation_no = _CONNECTION_GENES[key].innovation_no
else:
connection.innovation_no = len(_CONNECTION_GENES)
_CONNECTION_GENES[key] = connection
self.connections[key] = connection
return key
@staticmethod
def new(inputs: int, outputs: int) -> Genome:
genome = Genome()
# Add input nodes
for _ in range(inputs):
genome.add_node(node_type=NodeType.INPUT)
# Add output nodes
for _ in range(outputs):
genome.add_node(node_type=NodeType.OUTPUT)
# Fully connect
for i in range(inputs):
for o in range(inputs, inputs + outputs):
genome.add_connection(i, o, weight=1)
return genome
def mutate(genome: Genome) -> None:
mutation = choice([MutationType.ADD_CONNECTION])
if mutation is MutationType.ADD_CONNECTION:
_mutate_add_connection(genome)
elif mutation is MutationType.ADD_NODE:
_mutate_add_node(genome)
def _mutate_add_connection(genome: Genome) -> None:
"""
In the add_connection mutation, a single new connection gene with a random weight
is added connecting two previously unconnected nodes.
"""
from_node = choice([node for node in genome.nodes if not node.type != NodeType.OUTPUT])
inputs = [node.id for node in genome.nodes.values() if node.type == NodeType.INPUT]
distance_to_input = _distance_to_input(genome, from_node, inputs)
to_node = choice([node for node in genome.nodes if _distance_to_input(genome, node, inputs) < distance_to_input and tuple[from_node, node] not in genome.connections])
genome.add_connection(from_node, to_node, weight=random.uniform(0,1))
def _distance_to_input(g: Genome, node, inputs) -> int:
paths = []
for input_node in inputs:
paths += list(nx.all_simple_paths(g, input_node, node))
path_lengths = [len(path) for path in paths]
return max(path_lengths)
def _mutate_add_node(genome: Genome) -> None:
"""
In the add_node mutation, an existing connection is split and the new node
placed where the old connection used to be. The old connection is disabled
and two new conections are added to the genome. The new connection leading
into the new node receives a weight of 1, and the new connection leading out
receives the same weight as the old connection.
"""
# Find connection to split
try:
connection = choice([node for node in genome.connections.values() if not node.disabled])
except IndexError:
return
connection.disabled = True
# Create new node
new_node = genome.add_node()
from_node, to_node = connection.nodes
# Connect previous from_node to new_node
genome.add_connection(from_node, new_node, weight=1)
# Connection new_node to previous to_node
genome.add_connection(new_node, to_node, weight=connection.weight)