from __future__ import annotations import dataclasses import itertools from enum import Enum from random import choice import matplotlib.pyplot as plt import numpy as np from graphs import creates_cycle rng = np.random.default_rng() from connection import _CONNECTION_GENES, ConnectionGene from node import NodeGene, NodeType class MutationType(Enum): ADD_CONNECTION = 1 ADD_NODE = 2 class Genome: def __init__(self): # Initialize nodes self.nodes: dict[int, NodeGene] = dict() # Initialize connections self.connections: dict[tuple[int, int], ConnectionGene] = dict() self.fitness = 0 def set_node(self, key: int, node: NodeGene) -> None: self.nodes[key] = node def set_connection(self, key: tuple[int, int], connection: ConnectionGene) -> None: self.connections[key] = connection def add_node(self, node_type: NodeType = NodeType.HIDDEN) -> int: """ Adds a node of the given type to the genome and returns the identification key. """ key = len(self.nodes) self.nodes[key] = NodeGene(key, node_type) return key def add_connection(self, from_node: int, to_node: int, weight: float) -> tuple[int, int]: """ Adds a connection of weight between two given nodes to the genome and returns the identification key. """ if not isinstance(from_node, int) or not isinstance(to_node, int): raise ValueError("Nodes must be integer keys.") if from_node not in self.nodes or to_node not in self.nodes: raise ValueError("Nodes do not exist.") key = (from_node, to_node) connection = ConnectionGene(key, weight, -1) if key in _CONNECTION_GENES: connection.innovation_no = _CONNECTION_GENES[key].innovation_no else: connection.innovation_no = len(_CONNECTION_GENES) _CONNECTION_GENES[key] = connection self.connections[key] = connection return key @staticmethod def new(inputs: int, outputs: int) -> Genome: genome = Genome() # Add input nodes for _ in range(inputs): genome.add_node(node_type=NodeType.INPUT) # Add output nodes for _ in range(outputs): genome.add_node(node_type=NodeType.OUTPUT) # Fully connect for i in range(inputs): for o in range(inputs, inputs + outputs): genome.add_connection(i, o, weight=1) return genome @staticmethod def copy(genome: Genome) -> Genome: clone = Genome() # Copy nodes for key, node in genome.nodes.items(): clone.set_node(key, dataclasses.replace(node)) # Copy connections for key, connection in genome.connections.items(): clone.set_connection(key, dataclasses.replace(connection)) # Set fitness clone.fitness = genome.fitness return clone def mutate(genome: Genome) -> None: mutation = choice([MutationType.ADD_NODE, MutationType.ADD_CONNECTION]) if mutation is MutationType.ADD_CONNECTION: _mutate_add_connection(genome) elif mutation is MutationType.ADD_NODE: _mutate_add_node(genome) def crossover(mother: Genome, father: Genome) -> Genome: mother_connections = {conn.innovation_no: conn for conn in mother.connections.values()} father_connections = {conn.innovation_no: conn for conn in father.connections.values()} innovation_numbers = set(mother_connections.keys()) | set(father_connections.keys()) child_connections: dict[int, ConnectionGene] = {} for i in innovation_numbers: # Matching genes if i in mother_connections and i in father_connections: child_connections[i] = choice((mother_connections[i], father_connections[i])) # Disjoint or excess else: # Mother has better fitness if mother.fitness > father.fitness and i in mother_connections: child_connections[i] = mother_connections[i] # Father has better fitness elif father.fitness > mother.fitness and i in father_connections: child_connections[i] = father_connections[i] # Equal fitness else: connection = choice((mother_connections.get(i, None), father_connections.get(i, None))) if connection is not None: child_connections[i] = connection # Determine input/output dimensions inputs = sum(node.type == NodeType.INPUT for node in mother.nodes.values()) outputs = sum(node.type == NodeType.OUTPUT for node in mother.nodes.values()) # Create child and set nodes & connections child = Genome.new(inputs, outputs) for connection in child_connections.values(): # Set connections child.set_connection(connection.nodes, dataclasses.replace(connection)) from_node, to_node = connection.nodes # Add nodes if required if from_node not in child.nodes: child.set_node(from_node, NodeGene(from_node, NodeType.HIDDEN)) if to_node not in child.nodes: child.set_node(to_node, NodeGene(to_node, NodeType.HIDDEN)) return child def _mutate_add_connection(genome: Genome) -> None: """ In the add_connection mutation, a single new connection gene with a random weight is added connecting two previously unconnected nodes. """ from_node = choice([id for id, node in genome.nodes.items() if node.type != NodeType.OUTPUT]) try: to_node = choice( [ id for id, node in genome.nodes.items() if node.type != NodeType.INPUT and (from_node, id) not in genome.connections ] ) except IndexError: return # Checking for cycles if creates_cycle(genome.connections.keys(), (from_node, to_node)): return genome.add_connection(from_node, to_node, weight=rng.uniform(0, 1)) def _mutate_add_node(genome: Genome) -> None: """ In the add_node mutation, an existing connection is split and the new node placed where the old connection used to be. The old connection is disabled and two new conections are added to the genome. The new connection leading into the new node receives a weight of 1, and the new connection leading out receives the same weight as the old connection. """ # Find connection to split try: connection = choice([node for node in genome.connections.values() if not node.disabled]) except IndexError: return connection.disabled = True # Create new node new_node = genome.add_node() from_node, to_node = connection.nodes # Connect previous from_node to new_node genome.add_connection(from_node, new_node, weight=1) # Connection new_node to previous to_node genome.add_connection(new_node, to_node, weight=connection.weight) def _excess(g1: Genome, g2: Genome) -> list[int]: g1_connections = {conn.innovation_no: conn for conn in g1.connections.values()} g2_connections = {conn.innovation_no: conn for conn in g2.connections.values()} less_connections, more_connections = sorted((g1_connections, g2_connections), key=lambda c: max(c.keys())) return [k for k in more_connections.keys() if k > max(less_connections.keys())] def _disjoint(g1: Genome, g2: Genome) -> list[int]: g1_connections = {conn.innovation_no: conn for conn in g1.connections.values()} g2_connections = {conn.innovation_no: conn for conn in g2.connections.values()} less_connections, more_connections = sorted((g1_connections, g2_connections), key=lambda c: max(c.keys())) return list( {i for i in less_connections.keys() if i not in more_connections} | {i for i in more_connections.keys() if i not in less_connections and i <= max(less_connections.keys())} ) def _get_delta(g1: Genome, g2: Genome, c1: float, c2: float, c3: float) -> float: n = max([len(g1.nodes), len(g2.nodes)]) g1_connections = {conn.innovation_no: conn for conn in g1.connections.values()} g2_connections = {conn.innovation_no: conn for conn in g2.connections.values()} innovation_numbers = set(g1_connections.keys()) | set(g2_connections.keys()) # Calculate number of excess genes less_connections, more_connections = sorted((g1_connections, g2_connections), key=lambda c: max(c.keys())) e = len([k for k in more_connections.keys() if k > max(less_connections.keys())]) # Calculate number of disjoint genes d = len( {i for i in less_connections.keys() if i not in more_connections} | {i for i in more_connections.keys() if i not in less_connections and i <= max(less_connections.keys())} ) # Average weight difference of matching genes w = 0 for i in innovation_numbers: if i in g1_connections and i in g2_connections: w += abs(g1_connections[i].weight - g2_connections[i].weight) delta = ((c1 * e) / n) + ((c2 * d) / n) + (c3 * w) return delta def specify(genomes: list, c1: float, c2: float, c3: float) -> list[list]: THRESHOLD = 1 species = [] for genom in genomes: done = False if len(species) < 1: species.append([genom]) done = True for spicy in species: print("Delta: ", _get_delta(genom, spicy[0], c1, c2, c3)) if _get_delta(genom, spicy[0], c1, c2, c3) < THRESHOLD and not done: spicy.append(genom) done = True if not done: species.append([genom]) return species