From 439f995eaed3ef7c6cfaf3caeef1bec6215bbadd Mon Sep 17 00:00:00 2001 From: paumann Date: Sat, 15 Jul 2023 14:25:10 +0200 Subject: [PATCH] Initial commit --- Abschlussprojekt/genome.py | 128 ++++++++++++++++++++++++++++++ Abschlussprojekt/neat.ipynb | 71 +++++++++++++++++ Abschlussprojekt/visualization.py | 52 ++++++++++++ 3 files changed, 251 insertions(+) create mode 100644 Abschlussprojekt/genome.py create mode 100644 Abschlussprojekt/neat.ipynb create mode 100644 Abschlussprojekt/visualization.py diff --git a/Abschlussprojekt/genome.py b/Abschlussprojekt/genome.py new file mode 100644 index 0000000..645d3cd --- /dev/null +++ b/Abschlussprojekt/genome.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from random import choice + + +class NodeType(Enum): + INPUT = 1 + HIDDEN = 2 + OUTPUT = 3 + + +class MutationType(Enum): + ADD_CONNECTION = 1 + ADD_NODE = 2 + + +@dataclass(frozen=True) +class NodeGene: + id: int + type: NodeType + + +@dataclass() +class ConnectionGene: + nodes: tuple[int, int] + weight: float + innovation_no: int + disabled: bool = False + + +_CONNECTION_GENES: dict[tuple[int, int], ConnectionGene] = dict() + + +class Genome: + def __init__(self): + # Initialize nodes + self.nodes: dict[int, NodeGene] = dict() + + # Initialize connections + self.connections: dict[tuple[int, int], ConnectionGene] = dict() + + def add_node(self, node_type: NodeType = NodeType.HIDDEN) -> int: + """ + Adds a node of the given type to the genome and returns the identification key. + """ + + key = len(self.nodes) + self.nodes[key] = NodeGene(key, node_type) + return key + + def add_connection(self, from_node: int, to_node: int, weight: float) -> tuple[int, int]: + """ + Adds a connection of weight between two given nodes to the genome and returns + the identification key. + """ + + key = (from_node, to_node) + connection = ConnectionGene(key, weight, -1) + + if key in _CONNECTION_GENES: + connection.innovation_no = _CONNECTION_GENES[key].innovation_no + else: + connection.innovation_no = len(_CONNECTION_GENES) + _CONNECTION_GENES[key] = connection + + self.connections[key] = connection + return key + + @staticmethod + def new(inputs: int, outputs: int) -> Genome: + genome = Genome() + + # Add input nodes + for _ in range(inputs): + genome.add_node(node_type=NodeType.INPUT) + + # Add output nodes + for _ in range(outputs): + genome.add_node(node_type=NodeType.OUTPUT) + + return genome + + +def mutate(genome: Genome) -> None: + mutation = choice([MutationType.ADD_NODE]) + + if mutation is MutationType.ADD_CONNECTION: + _mutate_add_connection(genome) + elif mutation is MutationType.ADD_NODE: + _mutate_add_node(genome) + + +def _mutate_add_connection(genome: Genome) -> None: + """ + In the add_connection mutation, a single new connection gene with a random weight + is added connecting two previously unconnected nodes. + """ + + ... + + +def _mutate_add_node(genome: Genome) -> None: + """ + In the add_node mutation, an existing connection is split and the new node + placed where the old connection used to be. The old connection is disabled + and two new conections are added to the genome. The new connection leading + into the new node receives a weight of 1, and the new connection leading out + receives the same weight as the old connection. + """ + + # Find connection to split + try: + connection = choice(list(genome.connections.values())) + except IndexError: + return + connection.disabled = True + + # Create new node + new_node = genome.add_node() + from_node, to_node = connection.nodes + + # Connect previous from_node to new_node + genome.add_connection(from_node, new_node, weight=1) + + # Connection new_node to previous to_node + genome.add_connection(new_node, to_node, weight=connection.weight) diff --git a/Abschlussprojekt/neat.ipynb b/Abschlussprojekt/neat.ipynb new file mode 100644 index 0000000..57d1ec3 --- /dev/null +++ b/Abschlussprojekt/neat.ipynb @@ -0,0 +1,71 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from genome import Genome\n", + "import visualization" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ConnectionGene(nodes=(0, 4), weight=0.5, innovation_no=0, disabled=False)\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAV0AAADnCAYAAAC9roUQAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAARwElEQVR4nO3df0yc92HH8c9zPOc7MHecjYkhhgQ11L4sCe5sd3GaJsbJlkys3aaMbF3D1KRql4xU06p2WiVX2h8TUqttytbWljupU6skmraxaVsq2rppjZ01Szfj1IljDkp+gmNssEOOw3fH/Xj2B4GY3B3YD/A9jnu/JEtwz3Pf58kT+83D9557znIcRwAAMzzF3gEAKCdEFwAMIroAYBDRBQCDiC4AGGQvtnDLli1Oc3OzoV0BgPWhv79/wnGcunzLFo1uc3OzTpw4sTp7BeQxEUuqp39UkbGooom0gn5b4fqgHtzdqNpqX7F3D7gqlmW9WWjZotEFTDk1MqmDfcM6NjQuSUqms/PL/PaYnnh2SG076tS1r0U7m0JF2ktg+Yguiu6pF95Qd29EiXRG+d6rk3gvwEfOnNfxoQkdaA+rc2+z2Z0EVgjRRVHNBndA8VR2yXUdR4qnMuruHZAkwouSxNULKJpTI5Pq7o1cVXCvFE9l1d0b0Uujk6uzY8Aq4kwXRXOwb1iJdCbvsukzxzT5s39SJjquio2bVPtbfyZ/063zyxPpjA71Detw5x5TuwusCKKLopiIJXVsaDzvHG789Rf1Tt93Vfc7f6EN129XJnYpZx3HkY4OjutiLMlVDSgpTC+gKHr6Rwsue/e/n1bNnX8o37awLMsjO7BFdmBLznqWpJ6ThccB1iLOdFEUkbHogsvC5jjZjJLnhlXZcrvOHv68nMyMqj68V6H9n5XHu/CMNpHOKnJuytQuAyuCM10URTSRzvt4ZnpSyqZ1efBn2tr5dTU88g3NnH9N7z7/zwXGSa3iXgIrj+iiKIL+/L9kWe+dzQZ2f1J29WZVVNUo8NHfVfzV/O+MDPq9q7aPwGoguiiKcH1QPjv3r1+Fv1oVH5i/tSwr7xh+26NwQ2BV9g9YLUQXRdGxu7Hgsurbfl1T/d9XZnpSmURM0f/7D1W1fDRnPUdSx67C4wBrES+koSi2VPu0b3udfjxwPueysZo7P6VMPKqz//CoLNurjeG7VPOxP1iwjmVJ+3fUcbkYSg7RRdE83tai5345oXhq4RskrApbtfd3qfb+roLP9dsV6mprWe1dBFYc0wsomp1NIR1oD6vSe21/DSu9Hh1oD6u1MbQ6OwasIs50UVRzN61Z7C5jcyxr9gyXu4yhlBFdFF3n3ma1NoZ0qG9YRwfHZen92zlKs1cpOJqdw+1qa+EMFyWN6GJNaG0M6XDnHl2MJdVzclSRc1OKJlIK+r0KNwTUsYtPjsD6QHSxptRW+/To3TcVezeAVcMLaQBgENEFAIOILgAYRHQBwCCiCwAGEV0AMIjoAoBBRBcADCK6AGAQ0QUAg4guABhEdAHAIKILAAYRXQAwiOgCgEFEFwAMIroAYBDRBQCDiC4AGER0AcAgogsABhFdADCI6AKAQUQXAAwiugBgENEFAIOILgAYRHQBwCCiCwAGEV0AMIjoAoBBRBcADCK6AGAQ0QUAg4guABhkr+RgE7GkevpHFRmLKppIK+i3Fa4P6sHdjaqt9q3kpgCgJK1IdE+NTOpg37CODY1LkpLp7Pwyvz2mJ54dUtuOOnXta9HOptBKbBIAStKyo/vUC2+ouzeiRDojx8ldnngvwEfOnNfxoQkdaA+rc2/zcjcLACVpWdGdDe6A4qnskus6jhRPZdTdOyBJhBdAWXId3VMjk+rujeQN7tjTX1Hy7UFZngpJUkWgVtv++NuSpHgqq+7eiFobQ2ptDLndPACUJNfRPdg3rEQ6U3D55vseU2Dn/XmXJdIZHeob1uHOPW43DwAlydUlYxOxpI4Njeedw70ajiMdHRzXxVjS3QAAUKJcRbenf3TJdSb7vqeRv/+0xp78cyXefClnuSWp5+TS4wDAeuJqeiEyFl1wWdgHbdr/iLy1TbIqvJoeOK4L//ZXanjkG/JuaphfJ5HOKnJuys3mAaBkuTrTjSbSiy73Xb9DHl+VLNur6tvulW/bzYq/eiLPOCk3mweAkuUqukH/NZ4gW5ak3AngoN/rZvMAULJcRTdcH5TPzv/UbCKm+Gv9ctIzcrIZxV45quTIaVV+aPeC9fy2R+GGgJvNA0DJcjWn27G7UU88O5R3mZPNaPL4U0pdGpUsj7y1jap74Kvybt62cD1JHbsa3WweAEqWq+huqfZp3/Y6/XjgfM5lYxVVNWp4+IlFn29Z0v4dddwEB0DZcX1rx8fbWuS3K1w9129XqKutxe2mAaBkuY7uzqaQDrSHVem9tiEqvR4daA/zFmAAZWlZN7yZu2nNYncZm2NZs2e43GUMQDlb9q0dO/c2q7UxpEN9wzo6OC5L79/OUZq9SsHR7BxuV1sLZ7gAytqK3MS8tTGkw517dDGWVM/JUUXOTSmaSCno9yrcEFDHLj45AgCkFf64ntpqnx69+6aVHBIA1hU+mBIADCK6AGAQ0QUAg4guABhEdAHAIKILAAYRXQAwiOgCgEFEFwAMIroAYBDRBQCDiC4AGER0AcAgogsABhFdADCI6AKAQUQXAAwiugBgENEFAIOILgAYRHQBwCCiCwAGEV0AMIjoAoBBRBcADCK6AGAQ0QUAg4guABhEdAHAIKILAAYRXQAwiOgCgEFEFwAMIroAYBDRBQCD7GLvAACsFROxpHr6RxUZiyqaSCvotxWuD+rB3Y2qrfatyDaILoCyd2pkUgf7hnVsaFySlExn55f57TE98eyQ2nbUqWtfi3Y2hZa1LaILoKw99cIb6u6NKJHOyHFylyfeC/CRM+d1fGhCB9rD6tzb7Hp7RBdA2ZoN7oDiqeyS6zqOFE9l1N07IEmuw8sLaQDK0qmRSXX3Rq4quFeKp7Lq7o3opdFJV9vlTBdAWTrYN6xEOrPgsbf+tmPB9056RoFfbdfm+x5b8HgindGhvmEd7txzzdslugDKzkQsqWND4zlzuDd8qWf+6+xMXKPf/CNVhT+e83zHkY4OjutiLHnNVzUwvQCg7PT0jy65zuXB51VRVSNf0y15l1uSek4uPc4HEV0AZScyFl1wWVg+sZd/oo233iPLsvIuT6SzipybuuZtE10AZSeaSC+6PP3uBSVHTmvjbfcuMU7qmrdNdAGUnaB/8ZezYqd/Kl/jr8gbql9iHO81b5voAig74fqgfHbh/E2f/qmqb71n0TH8tkfhhsA1b5voAig7HbsbCy5LjA4oE7uY96qFKzmSOnYVHqcQogug7Gyp9mnf9jrle41s+vRPVLX9Y/L4qgo+37Kk/TvqXN0Eh+t0AZSlx9ta9NwvJxRPLXyDRO1vfmHJ5/rtCnW1tbjaLme6AMrSzqaQDrSHVem9tgxWej060B5Wa2PI1XY50wVQtuZuWrPYXcbmWNbsGS53GQOAZejc26zWxpAO9Q3r6OC4LL1/O0dp9ioFR7NzuF1tLa7PcOcQXQBlr7UxpMOde3QxllTPyVFFzk0pmkgp6Pcq3BBQxy4+OQIAVlxttU+P3n3Tqm6DF9IAwCCiCwAGEV0AMIjoAoBBRBcADCK6AGAQ0QUAg4guABhEdAHAIKILAAYRXQAwiOgCgEFEFwAMIroAYBDRBQCDiC4AGER0AcAgogsABhFdADCI6AKAQUQXAAwiugBgENEFAIOILgAYRHQBwCCiCwAGEV0AMIjoAoBBRBcADCK6AGAQ0QUAg4guABhEdAHAIKILAAYRXQAwiOgCgEFEFwAMIroAYBDRBQCDiC4AGER0AcAgogsABhFdADCI6AKAQUQXAAwiugBgENEFAIOILgAYRHQBwCCiCwAGEV0AMIjoAoBBRBcADCK6AGAQ0QUAg4guABhEdAHAIKILAAYRXQAwiOgCgEF2sXdgPZmIJdXTP6rIWFTRRFpBv61wfVAP7m5UbbWv2LsHYA0guivg1MikDvYN69jQuCQpmc7OL/PbY3ri2SG17ahT174W7WwKFWkvAawFRHeZnnrhDXX3RpRIZ+Q4ucsT7wX4yJnzOj40oQPtYXXubTa7kwDWDKK7DLPBHVA8lV1yXceR4qmMunsHJInwAmWK6Lp0amRS3b2RBcF10ildPHJIiTd+oWwiJjtUr037PqPKm/bMrxNPZdXdG1FrY0itjaEi7PnaxZw4ygHRdelg37AS6cyCx5xsRnZgi+o//TVV1NQp/uoJjf/n13X9Z78lO7R1fr1EOqNDfcM63Lnng8OWJebEUU6IrgsTsaSODY3nzOF6NvgVuuuh+e+rWn5Nds1WJceGF0TXcaSjg+O6GEuW/Rkcc+IoN1yn60JP/+hVrZeZfkepS2e1oe6GnGWWpJ6TVzfOevX+nHj+4F7pyjnxp154w8j+AauBM10XImPRBb8CXyna/4ymX/6JZi68Lk9lUNW33StvbVPOeol0VpFzU6u9q2tWvjnxOalLZ/X2d76gjeE7teWTX16wjDlxlDrOdF2IJtIFl9nVtQre8fuyg9fJsjza/BuPLTJOajV2ryTkmxOfc+nIYfkaPlzwuXNz4kApIrou+OzCh61y+x2KD/+vnGxGvhtuk1VR+JeJxcZZzwrNiUvS9Jlj8vg3yn/jzoLPv3JOHCg15fmvfpniM/nP0CTp0o8OKnVxRFU3f1yWp2LxcVKFx1nPCs2JZ5OXNfnc09p0z+eWHIM5cZQq5nRdqPTm/1mVfveCYr/4oWRZmnl7UJI0PXBcte1/qupb9ucZZ/Eol7pkMqmpqSnFYjFNTU3Nf33kVFzJdO5fvcnjT8ry+jX2vS8qE5+SZXs1depHCuy8P2fdcp8TR+kiui4kM/lfardrrlPVzXdLjiO7pk4z428qeXZAG677UP5xCrwYVwyO4yiRSOREMl80l/p67nvHcRQIBOb/VFdXKxAI6Hz4Aal64RUdM+dfU+LNU6r9xJe0YcsNevd//kUz51/V5PEntWHrTfLVt+TscznPiaN0EV0XCs3FZmcSujz4vK7/3EHFTv9UFVUhVbXcrulXjmpD28NXPc7VcBxHly9fdhXDQl/btj0fxg+G8sqva2pqtG3btrzrXPm9z5f/GuQ/ebpfPzg9tuCxxFsvK/3ueV3417+c/e+bScjJZiQnq/Q75/JGt1znxFHaiK4LheZ005fOSpZHdrBOei8Ydm2TkqNn8q4/OnZBzzzzjKuzyenpafl8voJhvPLr2tpaNTc3L7me1+tdzcM2L9/xq/7I/dp4892SpEt939XlV/okJyNvXfOCt1EvGKdM58RR2oiuC4XmdLOpuOSx9NbfPLDg8YpgXd71hwde0beP9uUEcOvWrWppaSkYyUAgoI0bN8q2S/N/X77j5/H6Ja9fklT3iS/qnWCdZt4elK/pFlkV+X8YrPc5caxPpfmvtsgKzel6vJWyHOmGr3x//rHoz/9dibdezrv+HXft03c+8+W8y9azQsfvSpvu7pQkXfzhtzT1Yq+Ce347d5w1NCcOXC0mxVwI+vP/rLI3b5OTzSh16ez8YzMXXpe37sYC45j5dX6tKXT88srOzunmH6c8jx9KG9F1IVwfzPsijmeDX1U77tDkc08rO5NQYvSMLg//XBvzXC7mtz0KNwRM7O6aU+j4ZaYnNX3mmLIzcTnZjOKv9Wt64Jj8zR/JWbecjx9KG9F1oWN3Y8Flm+/rkpOe0eg3H9LEf/21au/r0oY8Z7qOpI5dhcdZzwoeP8vS1Is/0OjBhzXyd5/SO0f/UZvu/byqPnx7zqrlfPxQ2pjTdWFLtU/7ttfpxwPnc97KWlEZ0HW/99VFn29Z0v4ddWV7W8dCx6+iqkb1D31tyeeX+/FDaeNM16XH21rkt929eu63K9TVlnvdaTnh+KFcEV2XdjaFdKA9XPDysUIqvR4daA+X/W0JOX4oV0wvLMPcJxgs9skHcyxr9gyNTz54H8cP5YjoLlPn3ma1NoZ0qG9YRwfHZen9j5iRZl9ldzQ7B9nV1sIZ2gdw/FBuLGeR04s9e/Y4J06cMLg7pe1iLKmek6OKnJtSNJFS0O9VuCGgjl18mu3V4PhhvbAsq99xnLzvX180upZljUt6c7V2DADWqRsdx8n7/v9FowsAWFlcvQAABhFdADCI6AKAQUQXAAwiugBg0P8DvFAVDwG+JXoAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "g1 = Genome.new(3, 2)\n", + "g1.add_node()\n", + "g1.add_node()\n", + "g1.add_node()\n", + "g1.add_connection(g1.nodes[0], g1.nodes[4], 0.5)\n", + "\n", + "visualization.genome(g1)\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.1" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Abschlussprojekt/visualization.py b/Abschlussprojekt/visualization.py new file mode 100644 index 0000000..1b25709 --- /dev/null +++ b/Abschlussprojekt/visualization.py @@ -0,0 +1,52 @@ +import matplotlib.pyplot as plt +import networkx as nx +import numpy as np + +from genome import Genome, NodeType, mutate + + +def genome(genome: Genome): + graph = nx.Graph() + + # Add nodes + for node in genome.nodes.keys(): + graph.add_node(node) + + # Add edges + for connection in genome.connections.values(): + if connection.disabled: + continue + + from_node, to_node = connection.nodes + graph.add_edge(from_node, to_node, weight=connection.weight) + + # Make sure that input and output nodes are fixed + pos = nx.spring_layout(graph) + x = [v[0] for v in pos.values()] + min_x, max_x = min(x), max(x) + y = [v[1] for v in pos.values()] + min_y, max_y = min(y), max(y) + + inputs = [node for node in genome.nodes.values() if node.type == NodeType.INPUT] + outputs = [node for node in genome.nodes.values() if node.type == NodeType.OUTPUT] + + for node, y in zip(inputs, np.linspace(min_y, max_y, len(inputs))): + pos[node.id] = np.array([min_x * 1.5, y]) + + for node, y in zip(outputs, np.linspace(min_y, max_y, len(outputs))): + pos[node.id] = np.array([max_x * 1.5, y]) + + plt.subplot() + nx.draw_networkx(graph, pos) + + +if __name__ == "__main__": + g1 = Genome.new(3, 2) + g1.add_node() + g1.add_node() + g1.add_node() + g1.add_connection(0, 4, 0.5) + + mutate(g1) + genome(g1) + plt.show()