From 9132801de3bd5a808acaedc5e4efaa8a120ae42f Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sun, 17 Jan 2021 09:14:41 +0530 Subject: [PATCH] capsule net notebook --- labml_nn/capsule_networks/__init__.py | 5 + labml_nn/capsule_networks/mnist.ipynb | 328 ++++++++++++++++++++++++++ 2 files changed, 333 insertions(+) create mode 100644 labml_nn/capsule_networks/mnist.ipynb diff --git a/labml_nn/capsule_networks/__init__.py b/labml_nn/capsule_networks/__init__.py index 95ff1ca4..8a0650f4 100644 --- a/labml_nn/capsule_networks/__init__.py +++ b/labml_nn/capsule_networks/__init__.py @@ -22,6 +22,11 @@ This file holds the implementations of the core modules of Capsule Networks. I used [jindongwang/Pytorch-CapsuleNet](https://github.com/jindongwang/Pytorch-CapsuleNet) to clarify some confusions I had with the paper. + +Here's a notebook for training a Capsule Networks on MNIST dataset. + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/capsule_networks/mnist.ipynb) +[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://web.lab-ml.com/run?uuid=e7c08e08586711ebb3e30242ac1c0002) """ import torch.nn as nn diff --git a/labml_nn/capsule_networks/mnist.ipynb b/labml_nn/capsule_networks/mnist.ipynb new file mode 100644 index 00000000..c0e24938 --- /dev/null +++ b/labml_nn/capsule_networks/mnist.ipynb @@ -0,0 +1,328 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "Capsule Networks", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "AYV_dMVDxyc2" + }, + "source": [ + "[![Github](https://img.shields.io/github/stars/lab-ml/nn?style=social)](https://github.com/lab-ml/nn)\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/capsule_networks/mnist.ipynb) \n", + "\n", + "## Training a Capsule Network to classify MNIST digits\n", + "\n", + "This is an experiment to train a Capsule Network to classify MNIST digits using PyTorch." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AahG_i2y5tY9" + }, + "source": [ + "Install the `labml-nn` package" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ZCzmCrAIVg0L", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "7ab15f72-c99f-4097-ecd2-5740ee9ed61c" + }, + "source": [ + "!pip install labml-nn" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Collecting labml-nn\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/f5/92/c454c38d613449e9cfee59809b83589bfc5463ebcf39a72126c268e31a77/labml_nn-0.4.78-py3-none-any.whl (111kB)\n", + "\r\u001b[K |███ | 10kB 23.8MB/s eta 0:00:01\r\u001b[K |██████ | 20kB 27.8MB/s eta 0:00:01\r\u001b[K |████████▉ | 30kB 22.4MB/s eta 0:00:01\r\u001b[K |███████████▉ | 40kB 18.7MB/s eta 0:00:01\r\u001b[K |██████████████▊ | 51kB 17.4MB/s eta 0:00:01\r\u001b[K |█████████████████▊ | 61kB 13.9MB/s eta 0:00:01\r\u001b[K |████████████████████▋ | 71kB 14.2MB/s eta 0:00:01\r\u001b[K |███████████████████████▋ | 81kB 14.1MB/s eta 0:00:01\r\u001b[K |██████████████████████████▋ | 92kB 14.3MB/s eta 0:00:01\r\u001b[K |█████████████████████████████▌ | 102kB 14.3MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 112kB 14.3MB/s \n", + "\u001b[?25hCollecting einops\n", + " Downloading https://files.pythonhosted.org/packages/5d/a0/9935e030634bf60ecd572c775f64ace82ceddf2f504a5fd3902438f07090/einops-0.3.0-py2.py3-none-any.whl\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.7.0+cu101)\n", + "Collecting labml>=0.4.86\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/87/4c/30c05318a66f4297babef5c3d11a34700ba2c79afc261a0c632eb8225871/labml-0.4.91-py3-none-any.whl (98kB)\n", + "\u001b[K |████████████████████████████████| 102kB 11.2MB/s \n", + "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from labml-nn) (1.19.5)\n", + "Collecting labml-helpers>=0.4.72\n", + " Downloading https://files.pythonhosted.org/packages/ec/58/2b7dcfde4565134ad97cdfe96ad7070fef95c37be2cbc066b608c9ae5c1d/labml_helpers-0.4.72-py3-none-any.whl\n", + "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.16.0)\n", + "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (0.8)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->labml-nn) (3.7.4.3)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from labml>=0.4.86->labml-nn) (3.13)\n", + "Collecting gitpython\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/d7/cb/ec98155c501b68dcb11314c7992cd3df6dce193fd763084338a117967d53/GitPython-3.1.12-py3-none-any.whl (159kB)\n", + "\u001b[K |████████████████████████████████| 163kB 46.5MB/s \n", + "\u001b[?25hCollecting gitdb<5,>=4.0.1\n", + "\u001b[?25l Downloading https://files.pythonhosted.org/packages/48/11/d1800bca0a3bae820b84b7d813ad1eff15a48a64caea9c823fc8c1b119e8/gitdb-4.0.5-py3-none-any.whl (63kB)\n", + "\u001b[K |████████████████████████████████| 71kB 10.3MB/s \n", + "\u001b[?25hCollecting smmap<4,>=3.0.1\n", + " Downloading https://files.pythonhosted.org/packages/b0/9a/4d409a6234eb940e6a78dfdfc66156e7522262f5f2fecca07dc55915952d/smmap-3.0.4-py2.py3-none-any.whl\n", + "Installing collected packages: einops, smmap, gitdb, gitpython, labml, labml-helpers, labml-nn\n", + "Successfully installed einops-0.3.0 gitdb-4.0.5 gitpython-3.1.12 labml-0.4.91 labml-helpers-0.4.72 labml-nn-0.4.78 smmap-3.0.4\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SE2VUQ6L5zxI" + }, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0hJXx_g0wS2C" + }, + "source": [ + "import torch\n", + "\n", + "from labml import experiment\n", + "from labml_nn.capsule_networks.mnist import Configs" + ], + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Lpggo0wM6qb-" + }, + "source": [ + "Create an experiment" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bFcr9k-l4cAg" + }, + "source": [ + "experiment.create(name=\"capsule_networks\")" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-OnHLi626tJt" + }, + "source": [ + "Initialize [Capsule Network configurations](https://lab-ml.com/labml_nn/capsule_networks/mnist.html)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Piz0c5f44hRo" + }, + "source": [ + "conf = Configs()" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wwMzCqpD6vkL" + }, + "source": [ + "Set experiment configurations and assign a configurations dictionary to override configurations" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 17 + }, + "id": "e6hmQhTw4nks", + "outputId": "ebefa8fa-93d2-4131-db95-e27f15aa3aa0" + }, + "source": [ + "experiment.configs(conf, {'optimizer.optimizer': 'Adam',\n", + " 'optimizer.learning_rate': 1e-3,\n", + " 'inner_iterations': 5})" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/html": [ + "
"
+            ],
+            "text/plain": [
+              ""
+            ]
+          },
+          "metadata": {
+            "tags": []
+          }
+        }
+      ]
+    },
+    {
+      "cell_type": "markdown",
+      "metadata": {
+        "id": "EvI7MtgJ61w5"
+      },
+      "source": [
+        "Set PyTorch models for loading and saving"
+      ]
+    },
+    {
+      "cell_type": "code",
+      "metadata": {
+        "colab": {
+          "base_uri": "https://localhost:8080/",
+          "height": 102
+        },
+        "id": "GDlt7dp-5ALt",
+        "outputId": "9701092b-c88a-4687-c90e-b193c369e59e"
+      },
+      "source": [
+        "experiment.add_pytorch_models({'model': conf.model})"
+      ],
+      "execution_count": 5,
+      "outputs": [
+        {
+          "output_type": "display_data",
+          "data": {
+            "text/html": [
+              "
Prepare model...\n",
+              "  Prepare device...\n",
+              "    Prepare device_info...[DONE]\t47.85ms\n",
+              "  Prepare device...[DONE]\t52.83ms\n",
+              "Prepare model...[DONE]\t4,606.99ms\n",
+              "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KJZRf8527GxL" + }, + "source": [ + "Start the experiment and run the training loop." + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 646 + }, + "id": "aIAWo7Fw5DR8", + "outputId": "5ddbfce3-91f8-4506-e483-1640cb5a14b3" + }, + "source": [ + "with experiment.start():\n", + " conf.run()" + ], + "execution_count": 6, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/html": [ + "
\n",
+              "capsule_networks: e7c08e08586711ebb3e30242ac1c0002\n",
+              "\t[dirty]: \"\"\n",
+              "Initialize...[DONE]\t27.73ms\n",
+              "Prepare validator...\n",
+              "  Prepare mode...[DONE]\t6.00ms\n",
+              "  Prepare valid_loader...\n",
+              "    Prepare valid_dataset...\n",
+              "      Prepare dataset_transforms...[DONE]\t4.04ms\n",
+              "    Prepare valid_dataset...[DONE]\t42.26ms\n",
+              "\n",
+              "--------------------------------------------------\n",
+              "LABML WARNING\n",
+              "LabML App Warning: empty_token: Please create a valid token at https://web.lab-ml.com.\n",
+              "Click on the experiment link to monitor the experiment and add it to your experiments list.\n",
+              "--------------------------------------------------\n",
+              "Monitor experiment at https://web.lab-ml.com/run?uuid=e7c08e08586711ebb3e30242ac1c0002\n",
+              "  Prepare valid_loader...[DONE]\t127.68ms\n",
+              "Prepare validator...[DONE]\t295.57ms\n",
+              "Prepare trainer...\n",
+              "  Prepare train_loader...\n",
+              "    Prepare train_dataset...[DONE]\t36.64ms\n",
+              "  Prepare train_loader...[DONE]\t126.53ms\n",
+              "Prepare trainer...[DONE]\t159.96ms\n",
+              "Prepare training_loop...\n",
+              "  Prepare loop_count...[DONE]\t34.24ms\n",
+              "Prepare training_loop...[DONE]\t214.47ms\n",
+              "  60,000:  Train: 100%    67,954ms  Valid: 100% 7,768ms   loss.train: 0.036759 loss.valid: 0.018877 accuracy.train: 0.962317 accuracy.valid: 0.979800  78,355ms  0:01m/  0:11m  \n",
+              " 120,000:  Train: 100%    67,571ms  Valid: 100% 8,267ms   loss.train: 0.016659 loss.valid: 0.017786 accuracy.train: 0.989217 accuracy.valid: 0.987000  77,000ms  0:02m/  0:10m  \n",
+              " 180,000:  Train: 100%    67,421ms  Valid: 100% 8,458ms   loss.train: 0.010699 loss.valid: 0.011324 accuracy.train: 0.993017 accuracy.valid: 0.990400  76,496ms  0:03m/  0:08m  \n",
+              " 240,000:  Train: 100%    67,333ms  Valid: 100% 8,544ms   loss.train: 0.001724 loss.valid: 0.010312 accuracy.train: 0.995183 accuracy.valid: 0.992500  76,241ms  0:05m/  0:07m  \n",
+              " 300,000:  Train: 100%    67,393ms  Valid: 100% 8,584ms   loss.train: 0.025503 loss.valid: 0.009328 accuracy.train: 0.996467 accuracy.valid: 0.992300  76,131ms  0:06m/  0:06m  \n",
+              " 360,000:  Train: 100%    67,243ms  Valid: 100% 8,614ms   loss.train: 0.002150 loss.valid: 0.009803 accuracy.train: 0.997183 accuracy.valid: 0.992100  76,030ms  0:07m/  0:05m  \n",
+              " 420,000:  Train: 100%    67,368ms  Valid: 100% 8,655ms   loss.train: 0.000345 loss.valid: 0.011668 accuracy.train: 0.997750 accuracy.valid: 0.992400  75,969ms  0:08m/  0:03m  \n",
+              " 480,000:  Train: 100%    67,265ms  Valid: 100% 8,646ms   loss.train: 0.008524 loss.valid: 0.009893 accuracy.train: 0.998200 accuracy.valid: 0.992000  75,889ms  0:10m/  0:02m  \n",
+              " 540,000:  Train: 100%    67,440ms  Valid: 100% 8,660ms   loss.train: 0.000430 loss.valid: 0.010111 accuracy.train: 0.998383 accuracy.valid: 0.991400  75,870ms  0:11m/  0:01m  \n",
+              " 600,000:   loss.train: 0.000784 loss.valid: 0.009602 accuracy.train: 0.998817 accuracy.valid: 0.992500\n",
+              "Still updating LabML App, please wait for it to complete...
" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "oBXXlP2b7XZO" + }, + "source": [ + "" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file -- GitLab