Files
ne-assignments/Abschlussprojekt/genome.py
2023-07-17 00:02:10 +02:00

281 lines
9.5 KiB
Python

from __future__ import annotations
import dataclasses
import itertools
from enum import Enum
from random import choice
import matplotlib.pyplot as plt
import numpy as np
from graphs import creates_cycle
rng = np.random.default_rng()
from connection import _CONNECTION_GENES, ConnectionGene
from node import NodeGene, NodeType
class MutationType(Enum):
ADD_CONNECTION = 1
ADD_NODE = 2
class Genome:
def __init__(self):
# Initialize nodes
self.nodes: dict[int, NodeGene] = dict()
# Initialize connections
self.connections: dict[tuple[int, int], ConnectionGene] = dict()
self.fitness = 0
def set_node(self, key: int, node: NodeGene) -> None:
self.nodes[key] = node
def set_connection(self, key: tuple[int, int], connection: ConnectionGene) -> None:
self.connections[key] = connection
def add_node(self, node_type: NodeType = NodeType.HIDDEN) -> int:
"""
Adds a node of the given type to the genome and returns the identification key.
"""
key = len(self.nodes)
self.nodes[key] = NodeGene(key, node_type)
return key
def add_connection(self, from_node: int, to_node: int, weight: float) -> tuple[int, int]:
"""
Adds a connection of weight between two given nodes to the genome and returns
the identification key.
"""
if not isinstance(from_node, int) or not isinstance(to_node, int):
raise ValueError("Nodes must be integer keys.")
if from_node not in self.nodes or to_node not in self.nodes:
raise ValueError("Nodes do not exist.")
key = (from_node, to_node)
connection = ConnectionGene(key, weight, -1)
if key in _CONNECTION_GENES:
connection.innovation_no = _CONNECTION_GENES[key].innovation_no
else:
connection.innovation_no = len(_CONNECTION_GENES)
_CONNECTION_GENES[key] = connection
self.connections[key] = connection
return key
@staticmethod
def new(inputs: int, outputs: int) -> Genome:
genome = Genome()
# Add input nodes
for _ in range(inputs):
genome.add_node(node_type=NodeType.INPUT)
# Add output nodes
for _ in range(outputs):
genome.add_node(node_type=NodeType.OUTPUT)
# Fully connect
for i in range(inputs):
for o in range(inputs, inputs + outputs):
genome.add_connection(i, o, weight=1)
return genome
@staticmethod
def copy(genome: Genome) -> Genome:
clone = Genome()
# Copy nodes
for key, node in genome.nodes.items():
clone.set_node(key, dataclasses.replace(node))
# Copy connections
for key, connection in genome.connections.items():
clone.set_connection(key, dataclasses.replace(connection))
# Set fitness
clone.fitness = genome.fitness
return clone
def mutate(genome: Genome) -> None:
mutation = choice([MutationType.ADD_NODE, MutationType.ADD_CONNECTION])
if mutation is MutationType.ADD_CONNECTION:
_mutate_add_connection(genome)
elif mutation is MutationType.ADD_NODE:
_mutate_add_node(genome)
def crossover(mother: Genome, father: Genome) -> Genome:
mother_connections = {conn.innovation_no: conn for conn in mother.connections.values()}
father_connections = {conn.innovation_no: conn for conn in father.connections.values()}
innovation_numbers = set(mother_connections.keys()) | set(father_connections.keys())
child_connections: dict[int, ConnectionGene] = {}
for i in innovation_numbers:
# Matching genes
if i in mother_connections and i in father_connections:
child_connections[i] = choice((mother_connections[i], father_connections[i]))
# Disjoint or excess
else:
# Mother has better fitness
if mother.fitness > father.fitness and i in mother_connections:
child_connections[i] = mother_connections[i]
# Father has better fitness
elif father.fitness > mother.fitness and i in father_connections:
child_connections[i] = father_connections[i]
# Equal fitness
else:
connection = choice((mother_connections.get(i, None), father_connections.get(i, None)))
if connection is not None:
child_connections[i] = connection
# Determine input/output dimensions
inputs = sum(node.type == NodeType.INPUT for node in mother.nodes.values())
outputs = sum(node.type == NodeType.OUTPUT for node in mother.nodes.values())
# Create child and set nodes & connections
child = Genome.new(inputs, outputs)
for connection in child_connections.values():
# Set connections
child.set_connection(connection.nodes, dataclasses.replace(connection))
from_node, to_node = connection.nodes
# Add nodes if required
if from_node not in child.nodes:
child.set_node(from_node, NodeGene(from_node, NodeType.HIDDEN))
if to_node not in child.nodes:
child.set_node(to_node, NodeGene(to_node, NodeType.HIDDEN))
return child
def _mutate_add_connection(genome: Genome) -> None:
"""
In the add_connection mutation, a single new connection gene with a random weight
is added connecting two previously unconnected nodes.
"""
from_node = choice([id for id, node in genome.nodes.items() if node.type != NodeType.OUTPUT])
try:
to_node = choice(
[
id
for id, node in genome.nodes.items()
if node.type != NodeType.INPUT and (from_node, id) not in genome.connections
]
)
except IndexError:
return
# Checking for cycles
if creates_cycle(genome.connections.keys(), (from_node, to_node)):
return
genome.add_connection(from_node, to_node, weight=rng.uniform(0, 1))
def _mutate_add_node(genome: Genome) -> None:
"""
In the add_node mutation, an existing connection is split and the new node
placed where the old connection used to be. The old connection is disabled
and two new conections are added to the genome. The new connection leading
into the new node receives a weight of 1, and the new connection leading out
receives the same weight as the old connection.
"""
# Find connection to split
try:
connection = choice([node for node in genome.connections.values() if not node.disabled])
except IndexError:
return
connection.disabled = True
# Create new node
new_node = genome.add_node()
from_node, to_node = connection.nodes
# Connect previous from_node to new_node
genome.add_connection(from_node, new_node, weight=1)
# Connection new_node to previous to_node
genome.add_connection(new_node, to_node, weight=connection.weight)
def _excess(g1: Genome, g2: Genome) -> list[int]:
g1_connections = {conn.innovation_no: conn for conn in g1.connections.values()}
g2_connections = {conn.innovation_no: conn for conn in g2.connections.values()}
less_connections, more_connections = sorted((g1_connections, g2_connections), key=lambda c: max(c.keys()))
return [k for k in more_connections.keys() if k > max(less_connections.keys())]
def _disjoint(g1: Genome, g2: Genome) -> list[int]:
g1_connections = {conn.innovation_no: conn for conn in g1.connections.values()}
g2_connections = {conn.innovation_no: conn for conn in g2.connections.values()}
less_connections, more_connections = sorted((g1_connections, g2_connections), key=lambda c: max(c.keys()))
return list(
{i for i in less_connections.keys() if i not in more_connections}
| {i for i in more_connections.keys() if i not in less_connections and i <= max(less_connections.keys())}
)
def _get_delta(g1: Genome, g2: Genome, c1: float, c2: float, c3: float) -> float:
n = max([len(g1.nodes), len(g2.nodes)])
g1_connections = {conn.innovation_no: conn for conn in g1.connections.values()}
g2_connections = {conn.innovation_no: conn for conn in g2.connections.values()}
innovation_numbers = set(g1_connections.keys()) | set(g2_connections.keys())
# Calculate number of excess genes
less_connections, more_connections = sorted((g1_connections, g2_connections), key=lambda c: max(c.keys()))
e = len([k for k in more_connections.keys() if k > max(less_connections.keys())])
# Calculate number of disjoint genes
d = len(
{i for i in less_connections.keys() if i not in more_connections}
| {i for i in more_connections.keys() if i not in less_connections and i <= max(less_connections.keys())}
)
# Average weight difference of matching genes
w = 0
for i in innovation_numbers:
if i in g1_connections and i in g2_connections:
w += abs(g1_connections[i].weight - g2_connections[i].weight)
delta = ((c1 * e) / n) + ((c2 * d) / n) + (c3 * w)
return delta
def specify(genomes: list, c1: float, c2: float, c3: float) -> list[list]:
THRESHOLD = 1
species = []
for genom in genomes:
done = False
if len(species) < 1:
species.append([genom])
done = True
for spicy in species:
print("Delta: ", _get_delta(genom, spicy[0], c1, c2, c3))
if _get_delta(genom, spicy[0], c1, c2, c3) < THRESHOLD and not done:
spicy.append(genom)
done = True
if not done:
species.append([genom])
return species