From 4f41b02e056ccca73cc6b5efcf532240c525fb2b Mon Sep 17 00:00:00 2001 From: Kai Koellemann Date: Sat, 10 Jun 2023 14:27:48 +0200 Subject: [PATCH] Inital commit Aufgabe 4 --- Aufgabe 4/aufgabe04.ipynb | 301 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 301 insertions(+) create mode 100644 Aufgabe 4/aufgabe04.ipynb diff --git a/Aufgabe 4/aufgabe04.ipynb b/Aufgabe 4/aufgabe04.ipynb new file mode 100644 index 0000000..f2b3956 --- /dev/null +++ b/Aufgabe 4/aufgabe04.ipynb @@ -0,0 +1,301 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from tqdm import tqdm\n", + "\n", + "rng = np.random.default_rng()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Aufgabe 1\n", + "\n", + "relu = np.vectorize(lambda x : max(0, x))\n", + "\n", + "sigmoid = np.vectorize(lambda x : 1 / (1 + np.exp(-x)))\n", + "\n", + "def samples(n: int) -> list:\n", + " inputs = rng.integers(0, 1, size=(n, 2), endpoint=True)\n", + " outputs = inputs.T[0] ^ inputs.T[1]\n", + " return (inputs, outputs)\n", + "\n", + "# Aufgabe 3\n", + "\n", + "# Binary Cross Entropy Loss\n", + "bcel = np.vectorize(lambda y, ŷ : -(y * np.log(ŷ) + (1 - y) * np.log(1 - ŷ)))\n", + "\n", + "# Derivatives\n", + "derivatives = {\n", + " bcel : np.vectorize(lambda y, ŷ: (1 / (1 - ŷ)) if y == 0 else -(1 / ŷ)),\n", + " sigmoid : np.vectorize(lambda x: (1 / (1 + np.exp(-x))) * (1 - (1 / (1 + np.exp(-x))))),\n", + " relu : np.vectorize(lambda x: 0 if x < 0 else 1)\n", + "}\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Aufgabe 2\n", + "from typing import Callable\n", + "\n", + "\n", + "class NeuralNet:\n", + " def __init__(self, inputs: int = 2, hidden_layers: list[tuple[int, Callable]] = None):\n", + " \"\"\"\n", + " Initializes the neural network.\n", + " Hidden layers can be specified with the 'hidden_layers' parameter,\n", + " which takes list of the format [(layer1_size, layer1_activation_function), (layer2...), ...].\n", + " The output layer will always consist of a single neuron and use the sigmoid\n", + " activation function.\n", + " \"\"\"\n", + "\n", + " self.input_shape = (-1, inputs)\n", + " self.layers = [] if hidden_layers is None else hidden_layers\n", + " self.layers.append((1, sigmoid)) # Add output layer\n", + "\n", + " # Construct weights for hidden layer\n", + " self.activation_functions = []\n", + " self.weights = []\n", + " for index, (num_neurons, activation_function) in enumerate(self.layers):\n", + " self.activation_functions.append(activation_function)\n", + "\n", + " num_layer_inputs = inputs if index == 0 else self.layers[index - 1][0]\n", + " self.weights.append(rng.uniform(low=-1.0, high=1.0, size=(num_layer_inputs, num_neurons)))\n", + "\n", + " def forward_pass(self, x: np.array) -> tuple:\n", + " \"\"\"\n", + " Do a forward pass through the neural net.\n", + " Returns the linear and activation function results for each layer.\n", + " For the final output, see the last input in the F list.\n", + " \"\"\"\n", + "\n", + " x = np.reshape(x, self.input_shape)\n", + "\n", + " Z = [] # linear values for each layer\n", + " F = [x] # activation function values for each layer\n", + " for weights, activation_function in zip(self.weights, self.activation_functions):\n", + " Z.append(np.matmul(F[-1], weights)) # linear\n", + " F.append(activation_function(Z[-1]))\n", + "\n", + " return (Z, F)\n", + "\n", + " def classify(self, x: np.array) -> float:\n", + " \"\"\"\n", + " Executes a forward pass,\n", + " and returns the classification result ŷ.\n", + " \"\"\"\n", + " \n", + " _, F = self.forward_pass(x)\n", + " ŷ = F[-1].reshape(x.shape[0])\n", + " return ŷ\n", + "\n", + " # Aufgabe 4\n", + " def backward_pass(self, x, y: float):\n", + " \"\"\"\n", + " Do a backward pass through the neural net.\n", + " Returns the calculated weight difference.\n", + " \"\"\"\n", + " x = np.reshape(x, self.input_shape)\n", + " batch_size = x.shape[0]\n", + "\n", + " Z, F = self.forward_pass(x)\n", + "\n", + " layer_errors = [None for _ in range(len(self.layers))]\n", + " layer_errors[-1] = F[-1] - y.reshape((-1, 1)) # ŷ - y\n", + "\n", + " # Backpropagation\n", + " for i in reversed(range(len(self.layers) - 1)):\n", + " layer_errors[i] = np.multiply(\n", + " self.weights[i + 1].T,\n", + " np.multiply(layer_errors[i + 1], derivatives[self.activation_functions[i]](Z[i])),\n", + " )\n", + "\n", + " # Final weight updates, normalized\n", + " Δweights = [np.dot(F[i].T, error) / batch_size for i, error in enumerate(layer_errors)]\n", + " return Δweights\n", + "\n", + " # Aufgabe 5\n", + " def train(self, data: list, batch_size: int = 10, epochs: int = 50, learning_rate: float = 0.5):\n", + " \"\"\"\n", + " Train the neural network with the given input data.\n", + " \"\"\"\n", + "\n", + " x, y = data\n", + " for epoch in tqdm(range(epochs)):\n", + " for i in range(0, len(data), batch_size):\n", + " x_batch = x[i : i + batch_size]\n", + " y_batch = y[i : i + batch_size]\n", + " Δweights = self.backward_pass(x_batch, y_batch)\n", + "\n", + " self.weights = [w - learning_rate * Δw for w, Δw in zip(self.weights, Δweights, strict=True)]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Unit test to make sure neural net behaves as expected\n", + "# Uses example values from https://towardsdatascience.com/how-does-back-propagation-work-in-neural-networks-with-worked-example-bc59dfb97f48.\n", + "\n", + "nn = NeuralNet(inputs=3, hidden_layers=[(4, sigmoid)])\n", + "\n", + "nn.weights = [\n", + " np.array([[0.179, -0.186, -0.008, -0.048], [0.044, -0.028, -0.063, -0.131], [0.01, -0.035, -0.004, 0.088]]),\n", + " np.array([[0.088], [0.171], [0.005], [-0.04]]),\n", + "]\n", + "\n", + "# Make sure forward pass is correct\n", + "Z, F = nn.forward_pass(np.array([[7, 8, 10]]))\n", + "assert np.allclose(F[1], np.array([[0.845, 0.132, 0.354, 0.377]]), atol=0.01)\n", + "assert np.allclose(F[2], np.array([[0.521]]), atol=0.01)\n", + "\n", + "# Make sure backward pass is correct\n", + "Δweights = nn.backward_pass(np.array([[7, 8, 10], [7,8,10]]), np.array([[1], [1]]))\n", + "assert np.allclose(\n", + " Δweights[0],\n", + " np.array([[-0.039, -0.066, -0.004, 0.032], [-0.044, -0.075, -0.004, 0.036], [-0.055, -0.094, -0.005, 0.046]]),\n", + " atol=0.01,\n", + ")\n", + "assert np.allclose(Δweights[1], np.array([[-0.405], [-0.063], [-0.169], [-0.181]]), atol=0.01)\n", + "\n", + "# No output means success" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Correct classifications before training: 0.474\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 80/80 [00:00<00:00, 3998.96it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Correct classifications after training: 1.0\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Aufgabe 4\n", + "# To update the weights after each backward pass, set the batch size to 1 (Aufgabe 4)\n", + "\n", + "test_data = samples(1_000)\n", + "training_data = samples(100)\n", + "\n", + "# Updating weights after every backward pass\n", + "nn = NeuralNet(hidden_layers=[(8, relu)])\n", + "\n", + "correct = sum(np.around(nn.classify(test_data[0])) == test_data[1])\n", + "print(f\"Correct classifications before training: {correct/test_data[0].shape[0]}\")\n", + "\n", + "nn.train(training_data, batch_size=20, learning_rate=1, epochs=80)\n", + "\n", + "correct = sum(np.around(nn.classify(test_data[0])) == test_data[1])\n", + "print(f\"Correct classifications after training: {correct/test_data[0].shape[0]}\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAW3ElEQVR4nO3df6zfV13H8eeb2911LWNd1hJIf7ASi1DBZOOyIShMAelqXBNB0hKUyUIBGSGAhOEMYDEaRCESKqzGBUayjUGMXkNxRhhMBx3tsjFoYeRS0LWgE1insrBSePvH97vw7fuc23N6+vn+al+PZMk9n+/5fL7n3rXvfr6ve875mLsjIjLoMeMegIhMHhUGEUmoMIhIQoVBRBIqDCKSUGEQkUSxMJjZ9Wb2gJl9dZHXzcw+YGYLZnavmV3c/TBFZJRq7hg+Amw6weuXAxv6/20HPnTqwxKRcSoWBne/HfjBCbpsAW7wnj3ACjN7YlcDFJHRW9LBNVYD9w+0D/WPfTd2NLPt9O4qWD7DM5963sCLuRIVj5XaADMVfeKxlnNqxtJyjlmhUxxs7YVL32TuunEssU98PXfd0vdT26d03Zax5NSMpXROzdhqxnKy75O6664vf8/dV53sO3VRGKq5+y5gF8DcBeb7Xjzw4rLMCeeG9vLQbjknd17sE69Rc87Sjs6ZPSsceGyhXdsnvnnNOXGAsU/uG4jHzm44p6ZP/KPbck6uGJaum/srE4+V2osd6+K6xzNb9e8Vb5To4rcSh4G1A+01/WMiMqW6KAzzwO/2fzvxbOAhd08+RojI9Cjei5jZTcBlwEozOwS8EzgLwN0/DOwGNgMLwMPA7zWN5D8zx/43tGs+JpTOyZ0Xb/l/2HBOzUeWhwuvAyw9Gq4bct/ZI5mTVoR2zceN+E3+X8U5sU/Lx49HMuf8qHCN3LHYPtZw3dwf/9J1az4WdPHxI6dlLG2KV3H3bYXXHXh9J6MRkYmgmY8iklBhEJGECoOIJEY6j+GEHpc5Fn/pGYPFmnkM8ZxcnxgctpyTG8tQAsufpucsDQHlOUcyFy7NW2iZHzGqwBLKQWJNYFkTcpYCvtxfmThXo4vAsua9c4FrN3+ldccgIgkVBhFJqDCISGJyMoac1aH99dBuneBU6pPLC1rOKa3BiPlB7jpVk6Jin0wOsfx/wjmhbbmQpyWXiBOn4g+hJpeIWQCUM4X4Ob/mnJaJVLm/MqVMITe2lrygqzUYZbpjEJGECoOIJFQYRCQx2RlD9NTQ3pvp0zInoSaXiJ/tR3VOF/MloJxVxMwB0lzC4km5jGFFoU/unJq5DqW9IHLnxAVbXSzWyp1Tum5LfgBpNqGMQUTGSIVBRBIqDCKSUGEQkcR0hY/RszLHPhvaLZOVcoFlvE7LYq1SGFlzTs3Cq1yoGcPHZFJUxTnLwoGlmRlaS0pBYsvuUlAOKGsmTo1qsVYMG2smbA1r4lQb3TGISEKFQUQSKgwikpjujCHn10L7HzJ9RjVZaVyLtWoWZ3UyKSo3lnBSzCWWHMmcFD/rr6joU1qsBcPJJbLfdKFPzW7ULZvGaKMWERkhFQYRSagwiEhChUFEEqdf+BhtSQ/95Kbj2zMx8BvWZKW4gDG3aVIXweiwHrHXRWC5PDyCD9JdrmdrQsGWiVM1qzhLE5xazmnZXaqmT83EqTa6YxCRhAqDiCRUGEQkcfpnDBkz4fndD4bMYXnmSe2zw5islMslSjtLZzZaSrKKYT1Jq+Z7LmUMuclXyYKuTA6xPOYQR0KHlt2kahZr1WQMw9hdKnedlt2k2uiOQUQSKgwiklBhEJHEGZkxROeHzOHwTWmfZT8+vp185K3Z3GVci7VyfeL71MypqFm3NKpdrs8OT9uKO1oDnNPFpjHj2kQmd6y0WGuxYydPdwwikqgqDGa2yczuM7MFM7sm8/o6M7vNzO42s3vNbHP3QxWRUSkWBjObAXYClwMbgW1mtjF0+yPgFne/CNgK/HXXAxWR0am5Y7gEWHD3g+5+FLiZdAWC87NPqecB3+luiCIyajXh42rg/oH2IeDS0OddwD+b2RvoJSQvzF3IzLYD2wHW5QKmCbF6W3rs6yGQTDK2EE4CLHvo+PY5MQCsmazUslirZYLTsB7l18VEqtx710ycWhYCyvgYvlxgmTyGb0Vot+xy3bK7FJQXTeWCxtyCrZPXVfi4DfiIu68BNgMfM7Pk2u6+y93n3H1uVTfhqYgMQU1hOAysHWiv6R8bdBVwC4C7f5FeKVvZxQBFZPRqCsNeYIOZrTezWXrh4nzo8x/ACwDM7Gn0CsN/dzlQERmdYsbg7sfM7GrgVmAGuN7d95vZDmCfu88DbwH+xszeRC+IvNLdfZgDH7WnhtzhrpA5VK0NinNyMvNe4ibLVRuq1GwAEz+jt3yu7+LpWzWTolo2mmnZGTt7TuFpW0tyIckwNpHJ9anJGLr5jF4189HddwO7w7F3DHx9AHhuJyMSkbHTzEcRSagwiEhCi6gaPTNkDv+aWXhVeqB07iPuDws5xLLMR9HkV+8t8xhaNoBpmW+QO2dYT9Iaxga4uSd8F5+2tSLzRjU5RAxbWuY+tNEdg4gkVBhEJKHCICIJFQYRSSh87MivZBZe7S4svKqaFBXa2bwsXCiZJEXmaVulMBK6meDUxWKtXJ+WIDFeoybkjH1y2V7paVtLH0jPqdrlumXiVO4HcfJ0xyAiCRUGEUmoMIhIQhnDEG0OucMnWhZehXbLx2KAZeHBR/FpW0kGAeV8oKsJTl3shF0zwam0QC13TkuWUfM+S0MOEbcdh0wOsSK0a3KJNrpjEJGECoOIJFQYRCShjGGEfjtkDh+rWHgVl9Hk8oO4/WfTA6IyT/iOOUTytK1xzmOIi75y3/S4NrPtIpeAzGa2IYc450jmJGUMIjIkKgwiklBhEJGECoOIJBQ+jtHvZBZeXRcCyRhGNk9wKvSpWhsUnrYVn7QFcE7LLtfjmjjV8iSwlidpdfU/Lb7P2SGchPzTtRrojkFEEioMIpJQYRCRhDKGCfOakDu8L2QOuY/oNc8nKk2UqskYqtYgFXKI5InfuQuN82ndLRvalJ6k1TLjrKZP7n9ALqtooDsGEUmoMIhIQoVBRBLKGCbcm0Pm8KcVC69qnoFckxeUconcOaWHasd1QVD5tK2WjVq62My29A3VnFPz9K2W+RHKGERklFQYRCShwiAiCRUGEUkofJwyf5hZePXOEEjGCU9QDihrzqna/LjhnKqJU6FTfPqW5U4qTZzKhYKlnaFaJkW1TNiCbiZONaq6YzCzTWZ2n5ktmNk1i/R5mZkdMLP9ZnZjN8MTkXEo3jGY2QywE3gRcAjYa2bz7n5goM8G4O3Ac939QTN7/LAGLCLDV3PHcAmw4O4H3f0ocDOwJfR5NbDT3R8EcPfMUzxFZFrUZAyrgfsH2oeAS0OfpwCY2R3ADPAud/+neCEz2w5sB1jX0WchgT8OucPbMpOg4ibQNYuo4rG4G3X8yAttD3+qmUNUXE+UmdgTc4nkaVstT9Vu2cE6bv6Su07NJKiWcxp1FT4uATYAlwFrgNvN7BnufmSwk7vvAnYBzF1g3tF7i0jHaj5KHAbWDrTX9I8NOgTMu/uP3f1bwDfoFQoRmUI1hWEvsMHM1pvZLLAVmA99/p7e3QJmtpLeR4uD3Q1TREapWBjc/RhwNXAr8DXgFnffb2Y7zOyKfrdbge+b2QHgNuCt7v79YQ1aRIbL3MfzUX/uAvN9Lx7LW5+R3hgCybPC6zXhY80qzhhQ1pwT87KqHasL75vrU7Np0mzpeX9drYJsmeBU2sE6c127lrvcfS7T84Q0JVpEEioMIpJQYRCRhBZRnSH+KkyCek3IHMLmzkB59+mYU+T61GQMXewmlcsYStfNPvwpzARbHtqzLbtc1yy86mpnqNz4GuiOQUQSKgwiklBhEJGEMoYz1HUhc3hVZuFVzBBqnnhV6lMzX6KrTWNKfXIfx4tzHzJhTPFpWy1P1YZuNoBppDsGEUmoMIhIQoVBRBIqDCKSGF/4uP5pcOMNP2u//FljG4rAxZljXwjtUhgJ5SCxJnxsmThVs5tUzSbLLZOi4ljiY/jOzTxy75yqrbFDu2VnqEa6YxCRhAqDiCRUGEQkMTkTnG7cmx5T7jA0H8xMaIqeE9qfD+1cFhDn/8Q+cbdqSHOI2UyfeF4Xk62yi6hCu7R4K3fdqsVaIYdYnskhloVjydO2NMFJREZJhUFEEioMIpKYnIwhJ+YOyhya1WQKJc8P7VszfWKmULPp7NHQzmUMpT4t2cUk5xJQftpWfNIWwEzuzRrojkFEEioMIpJQYRCRhAqDiCQmO3yMFEZW6SJorJF7kNg/hnYM/HK7UZcCy9x1WnZ9qglCS+Fjbjfq0uKslsAyd17VxCmFjyIyLCoMIpJQYRCRxHRlDJEWXgGjyxRq/GZo/11o1+QHNYuzaiZFlfKCmlyi5UlaNQu8YqbQ8oSumoditdIdg4gkVBhEJKHCICKJ6c4Yck7zuQ6TlCfU+K3Q/mSmT2lzFyjnELlFVHHOQc37tDxJq+WJXTXXjdepySWUMYjI0FQVBjPbZGb3mdmCmV1zgn4vMTM3s7nuhigio1YsDGY2A+wELgc2AtvMbGOm37nAG4E7ux6kiIxWzR3DJcCCux9096PAzcCWTL93A+8h/+thEZkiNeHjauD+gfYh4NLBDmZ2MbDW3T9lZm9d7EJmth3YDrBu3RNOfrQtpmwS1LSFiyfrpZljN4Z2bqFVKTicyZxTmgSVCx/jsfi+w3r6VsuCrpqJU61OOXw0s8cA7wPeUurr7rvcfc7d51atOv9U31pEhqSmMBwG1g601/SPPepc4OnA58zs28CzgXkFkCLTq6Yw7AU2mNl6M5sFtgLzj77o7g+5+0p3v9DdLwT2AFe4+76hjFhEhq6YMbj7MTO7mt6mwDPA9e6+38x2APvcff7EV5hAEzIJ6nTPE2q9PLQ/munzk9CumawU+7RMpOpiUlRNLpFbBBYnaNVkDLljLapmPrr7bmB3OPaORfpedurDEpFx0sxHEUmoMIhIQoVBRBKn3+rKFiMKIxU21nll5tj1oR0nNNXsPt0SWMY+uZCw5X3iatDc7tOlx/DVhJqtdMcgIgkVBhFJqDCISEIZQ05HC6+UKXTnVaF9XWjnFlHFSVGxT3wdyvlALsuI1615slbNU7FixlCzOCu3k1UL3TGISEKFQUQSKgwiklDGUKsw10F5wmi9JrQ/lOnTkjHEPjVPvIp5Qcs5NfMwap4c3tX2abpjEJGECoOIJFQYRCShwiAiCYWPrW583vHtm24fzzgEgNdljn0wtH8a2jXhY/yXM14DymFjS7CY6xOvE993seu00B2DiCRUGEQkocIgIgllDNVO/Dydq/15ybEPmnKHcbo6tN8f2rl/FUsLonIZQ7xOFztY586L+UFu/MoYRGRoVBhEJKHCICIJZQxZxefzVom5gzKH8XpTaP9Fpk9prkNuQ5hSLlEzX6Jm05jYzo0ll1W00B2DiCRUGEQkocIgIgkVBhFJKHwEugobSxRGTpY/yBx7T2jHgC83wSkGhzWBZfwXuSagrJk4lbtOC90xiEhChUFEEioMIpI4QzOG0WQKJVp4NXneFtp/Fto1WUDLhjA1i7NqJk6NdIKTmW0ys/vMbMHMrsm8/mYzO2Bm95rZZ8zsSR2NT0TGoFgYzGwG2AlcDmwEtpnZxtDtbmDO3X8R+CTw510PVERGp+aO4RJgwd0PuvtR4GZgy2AHd7/N3R/uN/cAa7odpoiMUk3GsBq4f6B9CLj0BP2vAj6de8HMtgPbAdate0LlEE/VZOQJtTTXYbK8PbT/JNOnZeFVFzlEzUYzrTr9rYSZvQKYA96be93dd7n7nLvPrVp1fpdvLSIdqrljOAysHWiv6R87jpm9ELgWeL67P9LN8ERkHGruGPYCG8xsvZnNAluB+cEOZnYRcB1whbs/0P0wRWSUioXB3Y/R21fzVuBrwC3uvt/MdpjZFf1u7wUeC3zCzO4xs/lFLiciU8DcfSxvPDe30fftu+EUrzJdwWIXFEZOnh2hHf+1rVlEVbMzVMs5B+Aud5/LdD0hTYkWkYQKg4gkVBhEJDFli6jOvEwh0sKryfOO0I6ZQ80CqZoNYEobt+T6tNIdg4gkVBhEJKHCICKJCc8YlCnU0MKryRInDdxZcU5LDpH7V12bwYrI0KgwiEhChUFEEioMIpKYoPBRQWNXNAlqtHYXXs8FgjW7PMU+U7uDk4icHlQYRCShwiAiiTFmDPejXGF0NAmqG6U8Iec5mWNfCO2WiUm5c3ITpVrojkFEEioMIpJQYRCRhAqDiCQmaIKTjJLCyLyWcLFFDCT/bUTvW0t3DCKSUGEQkYQKg4gkxpgx/Dzw+YH288c1EOHMXHg1qjyhxi9njsXcIe4KnVswpQlOIjI0KgwiklBhEJHEBM1j+HzmmHKHcTrd5jpMUqZQ46zQjoumutoROkd3DCKSUGEQkYQKg4gkVBhEJDFB4WNODCQVRo7TtIWRkxw21jy2Lloa2j/qYiCLqLpjMLNNZnafmS2Y2TWZ1882s4/3X7/TzC7sfKQiMjLFwmBmM8BO4HJgI7DNzDaGblcBD7r7zwHvB97T9UBFZHRq7hguARbc/aC7HwVuBraEPluAj/a//iTwAjOz7oYpIqNUkzGsprel86MOAZcu1sfdj5nZQ8AFwPcGO5nZdmB7v/mImX21ZdBjspLw/UywaRorTNd4p2ms0FuteNJGGj66+y5gF4CZ7XP3uVG+/6mYpvFO01hhusY7TWOF3nhbzqv5KHEYWDvQXtM/lu1jZkuA84DvtwxIRMavpjDsBTaY2XozmwW2AvOhzzzwyv7XLwU+6+7e3TBFZJSKHyX6mcHVwK309oa43t33m9kOYJ+7zwN/C3zMzBaAH9ArHiW7TmHc4zBN452mscJ0jXeaxgqN4zX9wy4ikaZEi0hChUFEEkMvDNM0nbpirG82swNmdq+ZfcbMnjSOcQ6M54TjHej3EjNzMxvbr9lqxmpmL+v/fPeb2Y2jHmMYS+nPwjozu83M7u7/edg8jnH2x3K9mT2w2Lwg6/lA/3u518wuLl7U3Yf2H72w8pvAk4FZ4MvAxtDn94EP97/eCnx8mGM6xbH+KrCs//XrxjXW2vH2+50L3A7sAeYmdazABuBu4Px++/GT/LOlF+q9rv/1RuDbYxzv84CLga8u8vpm4NOAAc8G7ixdc9h3DNM0nbo4Vne/zd0f7jf30JvTMS41P1uAd9NbuzLMxXglNWN9NbDT3R8EcPcHRjzGQTXjdeBx/a/PA74zwvEdPxD32+n9NnAxW4AbvGcPsMLMnniiaw67MOSmU69erI+7HwMenU49ajVjHXQVvSo8LsXx9m8Z17r7p0Y5sIyan+1TgKeY2R1mtsfMNo1sdKma8b4LeIWZHaK3wvsNoxlak5P9sz3p+zFMJjN7BTDHBG8QYWaPAd4HXDnmodRaQu/jxGX07sRuN7NnuPuRcQ7qBLYBH3H3vzSzX6I3j+fp7t7VM1/Gath3DNM0nbpmrJjZC4FrgSvc/ZERjS2nNN5zgacDnzOzb9P7bDk/pgCy5md7CJh39x+7+7eAb9ArFONQM96rgFsA3P2L9PZRWTmS0Z28qj/bxxlyKLIEOAis52chzi+EPq/n+PDxljEFODVjvYheKLVhHGM82fGG/p9jfOFjzc92E/DR/tcr6d36XjDB4/00cGX/66fRyxhsjH8eLmTx8PE3OD58/FLxeiMY8GZ61f+bwLX9Yzvo/YsLvUr7CWAB+BLw5DH+cEtj/Rfgv4B7+v/Nj2usNeMNfcdWGCp/tkbvo88B4CvA1kn+2dL7TcQd/aJxD/DrYxzrTcB36T3e8hC9u5nXAq8d+Nnu7H8vX6n5c6Ap0SKS0MxHEUmoMIhIQoVBRBIqDCKSUGEQkYQKg4gkVBhEJPH/C1K5/psvhc4AAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# Visualize learned separation\n", + "\n", + "def visualize(nn: NeuralNet):\n", + " # left, right, bottom, top \n", + " dim = (0, 1, 0, 1)\n", + " resolution = 50\n", + "\n", + " data = []\n", + " for px in np.linspace(dim[0], dim[1], resolution):\n", + " col = []\n", + " for py in np.linspace(dim[2], dim[3], resolution):\n", + " col.append(nn.classify(np.array([[px, py]])))\n", + " data.append(col)\n", + " data = np.array(data)\n", + "\n", + " import matplotlib.pyplot as plt\n", + "\n", + " _, ax = plt.subplots()\n", + " ax.imshow(data, cmap=\"hot\", interpolation=\"nearest\", extent=dim)\n", + "visualize(nn)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9 (main, Dec 15 2022, 17:11:09) [Clang 14.0.0 (clang-1400.0.29.202)]" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}