From e2cf20e1893ee09577f01d866f595c7e8b84991b Mon Sep 17 00:00:00 2001 From: Pavol Mulinka Date: Thu, 4 Nov 2021 01:47:36 +0100 Subject: [PATCH] added embedding rules, MonteCarlo(uncertainty) prediction and removed running of the tests on draf requests --- .github/workflows/build.yml | 3 + VERSION | 2 +- .../14_Model_Uncertainty_prediction.ipynb | 1188 +++++++++++++++++ .../preprocessing/tab_preprocessor.py | 37 +- pytorch_widedeep/training/trainer.py | 131 +- pytorch_widedeep/version.py | 2 +- 6 files changed, 1339 insertions(+), 24 deletions(-) create mode 100644 examples/14_Model_Uncertainty_prediction.ipynb diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 480655f..5754d96 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -11,6 +11,7 @@ on: jobs: codestyle: runs-on: ubuntu-latest + if: ${{ github.event_name == 'push' || !github.event.pull_request.draft }} steps: - uses: actions/checkout@v2 - name: Set up Python 3.9 @@ -32,6 +33,7 @@ jobs: test: runs-on: ubuntu-latest + if: ${{ github.event_name == 'push' || !github.event.pull_request.draft }} strategy: fail-fast: true matrix: @@ -59,6 +61,7 @@ jobs: finish: needs: test runs-on: ubuntu-latest + if: ${{ github.event_name == 'push' || !github.event.pull_request.draft }} steps: - uses: actions/checkout@v2 - name: Set up Python 3.9 diff --git a/VERSION b/VERSION index 8684498..492b167 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.0.11 \ No newline at end of file +1.0.12 \ No newline at end of file diff --git a/examples/14_Model_Uncertainty_prediction.ipynb b/examples/14_Model_Uncertainty_prediction.ipynb new file mode 100644 index 0000000..279dbb4 --- /dev/null +++ b/examples/14_Model_Uncertainty_prediction.ipynb @@ -0,0 +1,1188 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Custom DataLoader for Imbalanced dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* In this notebook we will use the higly imbalanced Protein Homology Dataset from [KDD cup 2004](https://www.kdd.org/kdd-cup/view/kdd-cup-2004/Data)\n", + "\n", + "```\n", + "* The first element of each line is a BLOCK ID that denotes to which native sequence this example belongs. There is a unique BLOCK ID for each native sequence. BLOCK IDs are integers running from 1 to 303 (one for each native sequence, i.e. for each query). BLOCK IDs were assigned before the blocks were split into the train and test sets, so they do not run consecutively in either file.\n", + "* The second element of each line is an EXAMPLE ID that uniquely describes the example. You will need this EXAMPLE ID and the BLOCK ID when you submit results.\n", + "* The third element is the class of the example. Proteins that are homologous to the native sequence are denoted by 1, non-homologous proteins (i.e. decoys) by 0. Test examples have a \"?\" in this position.\n", + "* All following elements are feature values. There are 74 feature values in each line. The features describe the match (e.g. the score of a sequence alignment) between the native protein sequence and the sequence that is tested for homology.\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initial imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "from torch.optim import SGD, lr_scheduler\n", + "\n", + "from pytorch_widedeep import Trainer\n", + "from pytorch_widedeep.preprocessing import TabPreprocessor\n", + "from pytorch_widedeep.models import TabMlp, WideDeep\n", + "from pytorch_widedeep.dataloaders import DataLoaderImbalanced, DataLoaderDefault\n", + "from torchmetrics import F1 as F1_torchmetrics\n", + "from torchmetrics import Accuracy as Accuracy_torchmetrics\n", + "from torchmetrics import Precision as Precision_torchmetrics\n", + "from torchmetrics import Recall as Recall_torchmetrics\n", + "from pytorch_widedeep.metrics import Accuracy, Recall, Precision, F1Score, R2Score\n", + "from pytorch_widedeep.initializers import XavierNormal\n", + "\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.metrics import classification_report\n", + "\n", + "import time\n", + "import datetime\n", + "\n", + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n", + "\n", + "# increase displayed columns in jupyter notebook\n", + "pd.set_option(\"display.max_columns\", 200)\n", + "pd.set_option(\"display.max_rows\", 300)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EXAMPLE_IDBLOCK_IDtarget4567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
0279261532052.032.690.302.520.01256.8-0.890.3311.0-55.0267.20.520.05-2.3649.6252.00.431.16-2.06-33.0-123.21.60-0.49-6.0665.0296.1-0.28-0.26-3.83-22.6-170.03.06-1.05-3.2922.9286.30.122.584.08-33.0-178.91.880.53-7.0-44.01987.0-5.410.95-4.0-57.0722.9-3.26-0.55-7.5125.51547.2-0.361.129.0-37.072.50.470.74-11.0-8.01595.1-1.642.83-2.0-50.0445.2-0.350.260.76
1279261533058.033.330.0016.59.5608.10.500.0720.5-52.5521.6-1.080.58-0.02-3.2103.6-0.950.23-2.87-25.9-52.2-0.210.87-1.8110.462.0-0.28-0.041.48-17.6-198.33.432.845.87-16.972.6-0.312.792.71-33.5-11.6-1.114.015.0-57.0666.31.134.385.0-64.039.31.07-0.1632.5100.01893.7-2.80-0.222.5-28.545.00.580.41-19.0-6.0762.90.290.82-3.0-35.0140.31.160.390.73
2279261534077.027.27-0.916.058.51623.6-1.400.02-6.5-48.0621.0-1.200.14-0.2073.6609.1-0.44-0.58-0.04-23.0-27.4-0.72-1.04-1.0991.1635.6-0.880.240.59-18.7-7.2-0.60-2.82-0.7152.4504.10.89-0.67-9.30-20.8-25.7-0.77-0.850.0-20.02259.0-0.941.15-4.0-44.0-22.70.94-0.98-19.0105.01267.91.031.2711.0-39.582.30.47-0.19-10.07.01491.80.32-1.290.0-34.0658.2-0.760.260.24
3279261535041.027.91-0.353.046.01921.6-1.36-0.47-32.0-51.5560.9-0.29-0.10-1.11124.3791.60.000.39-1.85-21.7-44.9-0.210.020.89133.9797.8-0.081.06-0.26-16.4-74.10.97-0.80-0.4166.9955.3-1.901.28-6.65-28.147.5-1.911.421.0-30.01846.70.761.10-4.0-52.0-53.91.71-0.22-12.097.51969.8-1.700.16-1.0-32.5255.9-0.461.5710.06.02047.7-0.981.530.0-49.0554.2-0.830.390.73
4279261536050.028.00-1.32-9.012.0464.80.880.198.0-51.598.11.09-0.33-2.16-3.9102.70.39-1.22-3.39-15.2-42.2-1.18-1.11-3.558.9141.3-0.16-0.43-4.15-12.9-13.4-1.32-0.98-3.698.8136.1-0.304.131.89-13.0-18.7-1.37-0.930.0-1.0810.1-2.296.721.0-23.0-29.70.58-1.10-18.533.5206.81.84-0.134.0-29.030.10.80-0.245.0-14.0479.50.68-0.592.0-36.0-6.92.020.14-0.23
\n", + "
" + ], + "text/plain": [ + " EXAMPLE_ID BLOCK_ID target 4 5 6 7 8 9 10 \\\n", + "0 279 261532 0 52.0 32.69 0.30 2.5 20.0 1256.8 -0.89 \n", + "1 279 261533 0 58.0 33.33 0.00 16.5 9.5 608.1 0.50 \n", + "2 279 261534 0 77.0 27.27 -0.91 6.0 58.5 1623.6 -1.40 \n", + "3 279 261535 0 41.0 27.91 -0.35 3.0 46.0 1921.6 -1.36 \n", + "4 279 261536 0 50.0 28.00 -1.32 -9.0 12.0 464.8 0.88 \n", + "\n", + " 11 12 13 14 15 16 17 18 19 20 21 22 \\\n", + "0 0.33 11.0 -55.0 267.2 0.52 0.05 -2.36 49.6 252.0 0.43 1.16 -2.06 \n", + "1 0.07 20.5 -52.5 521.6 -1.08 0.58 -0.02 -3.2 103.6 -0.95 0.23 -2.87 \n", + "2 0.02 -6.5 -48.0 621.0 -1.20 0.14 -0.20 73.6 609.1 -0.44 -0.58 -0.04 \n", + "3 -0.47 -32.0 -51.5 560.9 -0.29 -0.10 -1.11 124.3 791.6 0.00 0.39 -1.85 \n", + "4 0.19 8.0 -51.5 98.1 1.09 -0.33 -2.16 -3.9 102.7 0.39 -1.22 -3.39 \n", + "\n", + " 23 24 25 26 27 28 29 30 31 32 33 34 \\\n", + "0 -33.0 -123.2 1.60 -0.49 -6.06 65.0 296.1 -0.28 -0.26 -3.83 -22.6 -170.0 \n", + "1 -25.9 -52.2 -0.21 0.87 -1.81 10.4 62.0 -0.28 -0.04 1.48 -17.6 -198.3 \n", + "2 -23.0 -27.4 -0.72 -1.04 -1.09 91.1 635.6 -0.88 0.24 0.59 -18.7 -7.2 \n", + "3 -21.7 -44.9 -0.21 0.02 0.89 133.9 797.8 -0.08 1.06 -0.26 -16.4 -74.1 \n", + "4 -15.2 -42.2 -1.18 -1.11 -3.55 8.9 141.3 -0.16 -0.43 -4.15 -12.9 -13.4 \n", + "\n", + " 35 36 37 38 39 40 41 42 43 44 45 46 \\\n", + "0 3.06 -1.05 -3.29 22.9 286.3 0.12 2.58 4.08 -33.0 -178.9 1.88 0.53 \n", + "1 3.43 2.84 5.87 -16.9 72.6 -0.31 2.79 2.71 -33.5 -11.6 -1.11 4.01 \n", + "2 -0.60 -2.82 -0.71 52.4 504.1 0.89 -0.67 -9.30 -20.8 -25.7 -0.77 -0.85 \n", + "3 0.97 -0.80 -0.41 66.9 955.3 -1.90 1.28 -6.65 -28.1 47.5 -1.91 1.42 \n", + "4 -1.32 -0.98 -3.69 8.8 136.1 -0.30 4.13 1.89 -13.0 -18.7 -1.37 -0.93 \n", + "\n", + " 47 48 49 50 51 52 53 54 55 56 57 58 \\\n", + "0 -7.0 -44.0 1987.0 -5.41 0.95 -4.0 -57.0 722.9 -3.26 -0.55 -7.5 125.5 \n", + "1 5.0 -57.0 666.3 1.13 4.38 5.0 -64.0 39.3 1.07 -0.16 32.5 100.0 \n", + "2 0.0 -20.0 2259.0 -0.94 1.15 -4.0 -44.0 -22.7 0.94 -0.98 -19.0 105.0 \n", + "3 1.0 -30.0 1846.7 0.76 1.10 -4.0 -52.0 -53.9 1.71 -0.22 -12.0 97.5 \n", + "4 0.0 -1.0 810.1 -2.29 6.72 1.0 -23.0 -29.7 0.58 -1.10 -18.5 33.5 \n", + "\n", + " 59 60 61 62 63 64 65 66 67 68 69 \\\n", + "0 1547.2 -0.36 1.12 9.0 -37.0 72.5 0.47 0.74 -11.0 -8.0 1595.1 \n", + "1 1893.7 -2.80 -0.22 2.5 -28.5 45.0 0.58 0.41 -19.0 -6.0 762.9 \n", + "2 1267.9 1.03 1.27 11.0 -39.5 82.3 0.47 -0.19 -10.0 7.0 1491.8 \n", + "3 1969.8 -1.70 0.16 -1.0 -32.5 255.9 -0.46 1.57 10.0 6.0 2047.7 \n", + "4 206.8 1.84 -0.13 4.0 -29.0 30.1 0.80 -0.24 5.0 -14.0 479.5 \n", + "\n", + " 70 71 72 73 74 75 76 77 \n", + "0 -1.64 2.83 -2.0 -50.0 445.2 -0.35 0.26 0.76 \n", + "1 0.29 0.82 -3.0 -35.0 140.3 1.16 0.39 0.73 \n", + "2 0.32 -1.29 0.0 -34.0 658.2 -0.76 0.26 0.24 \n", + "3 -0.98 1.53 0.0 -49.0 554.2 -0.83 0.39 0.73 \n", + "4 0.68 -0.59 2.0 -36.0 -6.9 2.02 0.14 -0.23 " + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "header_list = [\"EXAMPLE_ID\", \"BLOCK_ID\", \"target\"] + [str(i) for i in range(4, 78)]\n", + "df = pd.read_csv(\"data/kddcup04/bio_train.dat\", sep=\"\\t\", names=header_list)\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "0 144455\n", + "1 1296\n", + "Name: target, dtype: int64" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# imbalance of the classes\n", + "df[\"target\"].value_counts()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# drop columns we won't need in this example\n", + "df.drop(columns=[\"EXAMPLE_ID\", \"BLOCK_ID\"], inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "df_train, df_valid = train_test_split(\n", + " df, test_size=0.2, stratify=df[\"target\"], random_state=1\n", + ")\n", + "df_valid, df_test = train_test_split(\n", + " df_valid, test_size=0.5, stratify=df_valid[\"target\"], random_state=1\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preparing the data" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "continuous_cols = df.drop(columns=[\"target\"]).columns.values.tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "# deeptabular\n", + "tab_preprocessor = TabPreprocessor(continuous_cols=continuous_cols, scale=True)\n", + "X_tab_train = tab_preprocessor.fit_transform(df_train)\n", + "X_tab_valid = tab_preprocessor.transform(df_valid)\n", + "X_tab_test = tab_preprocessor.transform(df_test)\n", + "\n", + "# target\n", + "y_train = df_train[\"target\"].values\n", + "y_valid = df_valid[\"target\"].values\n", + "y_test = df_test[\"target\"].values" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define the model" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "input_layer = len(tab_preprocessor.continuous_cols)\n", + "output_layer = 1\n", + "hidden_layers = np.linspace(\n", + " input_layer * 2, output_layer, 5, endpoint=False, dtype=int\n", + ").tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "WideDeep(\n", + " (deeptabular): Sequential(\n", + " (0): TabMlp(\n", + " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", + " (cont_norm): BatchNorm1d(74, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (tab_mlp): MLP(\n", + " (mlp): Sequential(\n", + " (dense_layer_0): Sequential(\n", + " (0): Dropout(p=0.1, inplace=False)\n", + " (1): Linear(in_features=74, out_features=148, bias=True)\n", + " (2): ReLU(inplace=True)\n", + " )\n", + " (dense_layer_1): Sequential(\n", + " (0): Dropout(p=0.1, inplace=False)\n", + " (1): Linear(in_features=148, out_features=118, bias=True)\n", + " (2): ReLU(inplace=True)\n", + " )\n", + " (dense_layer_2): Sequential(\n", + " (0): Dropout(p=0.1, inplace=False)\n", + " (1): Linear(in_features=118, out_features=89, bias=True)\n", + " (2): ReLU(inplace=True)\n", + " )\n", + " (dense_layer_3): Sequential(\n", + " (0): Dropout(p=0.1, inplace=False)\n", + " (1): Linear(in_features=89, out_features=59, bias=True)\n", + " (2): ReLU(inplace=True)\n", + " )\n", + " (dense_layer_4): Sequential(\n", + " (0): Dropout(p=0.1, inplace=False)\n", + " (1): Linear(in_features=59, out_features=30, bias=True)\n", + " (2): ReLU(inplace=True)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (1): Linear(in_features=30, out_features=1, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "deeptabular = TabMlp(\n", + " mlp_hidden_dims=hidden_layers,\n", + " column_idx=tab_preprocessor.column_idx,\n", + " continuous_cols=tab_preprocessor.continuous_cols,\n", + ")\n", + "model = WideDeep(deeptabular=deeptabular, pred_dim=1)\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "# Metrics from torchmetrics\n", + "accuracy = Accuracy_torchmetrics(average=None, num_classes=2)\n", + "precision = Precision_torchmetrics(average=\"micro\", num_classes=2)\n", + "f1 = F1_torchmetrics(average=None, num_classes=2)\n", + "recall = Recall_torchmetrics(average=None, num_classes=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# # Metrics from pytorch-widedeep\n", + "# accuracy = Accuracy(top_k=2)\n", + "# precision = Precision(average=False)\n", + "# recall = Recall(average=True)\n", + "# f1 = F1Score(average=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "# Optimizers\n", + "deep_opt = SGD(model.deeptabular.parameters(), lr=0.1)\n", + "# LR Scheduler\n", + "deep_sch = lr_scheduler.StepLR(deep_opt, step_size=3)\n", + "\n", + "trainer = Trainer(\n", + " model,\n", + " objective=\"binary\",\n", + " lr_schedulers={\"deeptabular\": deep_sch},\n", + " initializers={\"deeptabular\": XavierNormal},\n", + " optimizers={\"deeptabular\": deep_opt},\n", + " metrics=[accuracy, precision, recall, f1],\n", + " verbose=1,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "epoch 1: 100%|█| 208/208 [00:02<00:00, 101.99it/s, loss=0.22, metrics={'Accuracy': [0.937, 0.885], 'Pr\n", + "valid: 100%|█| 292/292 [00:02<00:00, 118.71it/s, loss=0.0839, metrics={'Accuracy': [0.9724, 0.876], 'P\n", + "epoch 2: 100%|█| 208/208 [00:02<00:00, 98.63it/s, loss=0.148, metrics={'Accuracy': [0.9435, 0.9386], '\n", + "valid: 100%|█| 292/292 [00:02<00:00, 113.80it/s, loss=0.0993, metrics={'Accuracy': [0.9655, 0.8915], '\n", + "epoch 3: 100%|█| 208/208 [00:02<00:00, 93.96it/s, loss=0.145, metrics={'Accuracy': [0.9463, 0.9361], '\n", + "valid: 100%|█| 292/292 [00:02<00:00, 115.44it/s, loss=0.107, metrics={'Accuracy': [0.9583, 0.938], 'Pr\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Training time[s]: 0:00:14\n" + ] + } + ], + "source": [ + "start = time.time()\n", + "trainer.fit(\n", + " X_train={\"X_tab\": X_tab_train, \"target\": y_train},\n", + " X_val={\"X_tab\": X_tab_valid, \"target\": y_valid},\n", + " n_epochs=3,\n", + " batch_size=50,\n", + " custom_dataloader=DataLoaderImbalanced,\n", + " oversample_mul=5,\n", + ")\n", + "print(\n", + " \"Training time[s]: {}\".format(\n", + " datetime.timedelta(seconds=round(time.time() - start))\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
train_losstrain_Accuracy_0train_Accuracy_1train_Precisiontrain_Recall_0train_Recall_1train_F1_0train_F1_1val_lossval_Accuracy_0val_Accuracy_1val_Precisionval_Recall_0val_Recall_1val_F1_0val_F1_1
00.2200920.9369660.8849700.9113790.9369660.8849700.9148210.9076470.0838650.9723800.8759690.9715270.9723800.8759690.9854430.352574
10.1477380.9435160.9386250.9410800.9435160.9386250.9414360.9407200.0993470.9655270.8914730.9648710.9655270.8914730.9819770.309973
20.1445070.9463430.9361000.9412730.9463430.9361000.9421160.9404050.1067790.9583280.9379840.9581480.9583280.9379840.9784440.284038
\n", + "
" + ], + "text/plain": [ + " train_loss train_Accuracy_0 train_Accuracy_1 train_Precision \\\n", + "0 0.220092 0.936966 0.884970 0.911379 \n", + "1 0.147738 0.943516 0.938625 0.941080 \n", + "2 0.144507 0.946343 0.936100 0.941273 \n", + "\n", + " train_Recall_0 train_Recall_1 train_F1_0 train_F1_1 val_loss \\\n", + "0 0.936966 0.884970 0.914821 0.907647 0.083865 \n", + "1 0.943516 0.938625 0.941436 0.940720 0.099347 \n", + "2 0.946343 0.936100 0.942116 0.940405 0.106779 \n", + "\n", + " val_Accuracy_0 val_Accuracy_1 val_Precision val_Recall_0 val_Recall_1 \\\n", + "0 0.972380 0.875969 0.971527 0.972380 0.875969 \n", + "1 0.965527 0.891473 0.964871 0.965527 0.891473 \n", + "2 0.958328 0.937984 0.958148 0.958328 0.937984 \n", + "\n", + " val_F1_0 val_F1_1 \n", + "0 0.985443 0.352574 \n", + "1 0.981977 0.309973 \n", + "2 0.978444 0.284038 " + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.DataFrame(trainer.history)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## \"Normal\" prediction" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "predict: 100%|█████████████████████████████████████████████████████| 292/292 [00:01<00:00, 274.55it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 1.00 0.96 0.98 14446\n", + " 1 0.17 0.96 0.29 130\n", + "\n", + " accuracy 0.96 14576\n", + " macro avg 0.59 0.96 0.64 14576\n", + "weighted avg 0.99 0.96 0.97 14576\n", + "\n", + "Actual predicted values:\n", + "(array([0, 1]), array([13851, 725]))\n" + ] + } + ], + "source": [ + "df_pred = trainer.predict(X_tab=X_tab_test)\n", + "print(classification_report(df_test[\"target\"].to_list(), df_pred))\n", + "print(\"Actual predicted values:\\n{}\".format(np.unique(df_pred, return_counts=True)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prediction using uncertainty" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "predict_UncertaintyIter: 100%|████████████████████████████████████████| 10/10 [00:07<00:00, 1.29it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 1.00 0.96 0.98 14446\n", + " 1 0.17 0.96 0.29 130\n", + "\n", + " accuracy 0.96 14576\n", + " macro avg 0.59 0.96 0.64 14576\n", + "weighted avg 0.99 0.96 0.97 14576\n", + "\n", + "Actual predicted values:\n", + "(array([0., 1.]), array([13852, 724]))\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "df_pred_unc = trainer.predict_uncertainty(X_tab=X_tab_test, uncertainty_granularity=10)\n", + "print(classification_report(df_test[\"target\"].to_list(), df_pred))\n", + "print(\"Actual predicted values:\\n{}\".format(np.unique(df_pred_unc[:,-1], return_counts=True)))" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[9.68546510e-01, 3.14534754e-02, 0.00000000e+00],\n", + " [9.99679565e-01, 3.20456806e-04, 0.00000000e+00],\n", + " [9.91507292e-01, 8.49271845e-03, 0.00000000e+00],\n", + " ...,\n", + " [9.98316586e-01, 1.68343820e-03, 0.00000000e+00],\n", + " [9.99173999e-01, 8.26018688e-04, 0.00000000e+00],\n", + " [9.99960661e-01, 3.93257578e-05, 0.00000000e+00]])" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_pred_unc" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pytorch_widedeep/preprocessing/tab_preprocessor.py b/pytorch_widedeep/preprocessing/tab_preprocessor.py index 0cc3089..152a26a 100644 --- a/pytorch_widedeep/preprocessing/tab_preprocessor.py +++ b/pytorch_widedeep/preprocessing/tab_preprocessor.py @@ -12,10 +12,24 @@ from pytorch_widedeep.preprocessing.base_preprocessor import ( ) -def embed_sz_rule(n_cat): - r"""Rule of thumb to pick embedding size corresponding to ``n_cat``. Taken - from fastai's Tabular API""" - return min(600, round(1.6 * n_cat ** 0.56)) +def embed_sz_rule(n_cat: int, embedding_rule: str="fastai_new") -> int: + r"""Rule of thumb to pick embedding size corresponding to ``n_cat``. Default rule is taken + from recent fastai's Tabular API. The function also includes previously used rule by fastai + and rule included in the Google's Tensorflow documentation + + Parameters + ---------- + n_cat: int + number of unique categorical values in a feature + embedding_rule: str, default = fastai_old + rule of thumb to be used for embedding vector size + """ + if embedding_rule == 'google': + return int(round(n_cat**0.25)) + elif embedding_rule == 'fastai_old': + return int(min(50, (n_cat//2) + 1)) + else: + return int(min(600, round(1.6 * n_cat ** 0.56))) class TabPreprocessor(BasePreprocessor): @@ -38,8 +52,15 @@ class TabPreprocessor(BasePreprocessor): :obj:`pytorch_widedeep.models.transformers._embedding_layers` auto_embed_dim: bool, default = True Boolean indicating whether the embedding dimensions will be - automatically defined via fastai's rule of thumb': - :math:`min(600, int(1.6 \times n_{cat}^{0.56}))` + automatically defined via rule of thumb + embedding_rule: str, default = 'fastai_new' + choice of embedding rule of thumb + 'fastai_new': + :math:`min(600, round(1.6 \times n_{cat}^{0.56}))` + 'fastai_old': + :math:`min(50, (n_{cat}//{2})+1)` + 'google': + :math:`min(600, round(n_{cat}^{0.24}))` default_embed_dim: int, default=16 Dimension for the embeddings used for the ``deeptabular`` component if the embed_dim is not provided in the ``embed_cols`` @@ -118,6 +139,7 @@ class TabPreprocessor(BasePreprocessor): continuous_cols: List[str] = None, scale: bool = True, auto_embed_dim: bool = True, + embedding_rule: str = "fastai_new", default_embed_dim: int = 16, already_standard: List[str] = None, for_transformer: bool = False, @@ -131,6 +153,7 @@ class TabPreprocessor(BasePreprocessor): self.continuous_cols = continuous_cols self.scale = scale self.auto_embed_dim = auto_embed_dim + self.embedding_rule = embedding_rule self.default_embed_dim = default_embed_dim self.already_standard = already_standard self.for_transformer = for_transformer @@ -250,7 +273,7 @@ class TabPreprocessor(BasePreprocessor): embed_colname = [emb[0] for emb in self.embed_cols] elif self.auto_embed_dim: n_cats = {col: df[col].nunique() for col in self.embed_cols} - self.embed_dim = {col: embed_sz_rule(n_cat) for col, n_cat in n_cats.items()} # type: ignore[misc] + self.embed_dim = {col: embed_sz_rule(n_cat, self.embedding_rule) for col, n_cat in n_cats.items()} # type: ignore[misc] embed_colname = self.embed_cols # type: ignore else: self.embed_dim = {e: self.default_embed_dim for e in self.embed_cols} # type: ignore diff --git a/pytorch_widedeep/training/trainer.py b/pytorch_widedeep/training/trainer.py index 1225c9d..cc57653 100644 --- a/pytorch_widedeep/training/trainer.py +++ b/pytorch_widedeep/training/trainer.py @@ -27,7 +27,7 @@ from pytorch_widedeep.callbacks import ( from pytorch_widedeep.dataloaders import DataLoaderDefault from pytorch_widedeep.initializers import Initializer, MultipleInitializer from pytorch_widedeep.training._finetune import FineTune -from pytorch_widedeep.utils.general_utils import Alias +from pytorch_widedeep.utils.general_utils import Alias, set_default_attr from pytorch_widedeep.models.tabnet._utils import create_explain_matrix from pytorch_widedeep.training._wd_dataset import WideDeepDataset from pytorch_widedeep.training._trainer_utils import ( @@ -685,8 +685,14 @@ class Trainer: If a trainer is used to predict after having trained a model, the ``batch_size`` needs to be defined as it will not be defined as the :obj:`Trainer` is instantiated + uncertainty: bool, default = False + If set to True the model activates the dropout layers and predicts + the each sample N times (uncertainty_granularity times) and returns + {max, min, mean, stdev} value for each sample + uncertainty_granularity: int default = 1000 + number of times the model does prediction for each sample if uncertainty + is set to True """ - preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test, batch_size) if self.method == "regression": return np.vstack(preds_l).squeeze(1) @@ -697,6 +703,86 @@ class Trainer: preds = np.vstack(preds_l) return np.argmax(preds, 1) # type: ignore[return-value] + def predict_uncertainty( # type: ignore[return] + self, + X_wide: Optional[np.ndarray] = None, + X_tab: Optional[np.ndarray] = None, + X_text: Optional[np.ndarray] = None, + X_img: Optional[np.ndarray] = None, + X_test: Optional[Dict[str, np.ndarray]] = None, + batch_size: int = 256, + uncertainty_granularity = 1000, + ) -> np.ndarray: + r"""Returns the predicted ucnertainty of the model for the test dataset using a + Monte Carlo method during which dropout layers are activated in the evaluation/prediction + phase and each sample is predicted N times (uncertainty_granularity times). Based on [1]. + + [1] Gal Y. & Ghahramani Z., 2016, Dropout as a Bayesian Approximation: Representing Model + Uncertainty in Deep Learning, Proceedings of the 33rd International Conference on Machine Learning + + Parameters + ---------- + X_wide: np.ndarray, Optional. default=None + Input for the ``wide`` model component. + See :class:`pytorch_widedeep.preprocessing.WidePreprocessor` + X_tab: np.ndarray, Optional. default=None + Input for the ``deeptabular`` model component. + See :class:`pytorch_widedeep.preprocessing.TabPreprocessor` + X_text: np.ndarray, Optional. default=None + Input for the ``deeptext`` model component. + See :class:`pytorch_widedeep.preprocessing.TextPreprocessor` + X_img : np.ndarray, Optional. default=None + Input for the ``deepimage`` model component. + See :class:`pytorch_widedeep.preprocessing.ImagePreprocessor` + X_test: Dict, Optional. default=None + The test dataset can also be passed in a dictionary. Keys are + `X_wide`, `'X_tab'`, `'X_text'`, `'X_img'` and `'target'`. Values + are the corresponding matrices. + batch_size: int, default = 256 + If a trainer is used to predict after having trained a model, the + ``batch_size`` needs to be defined as it will not be defined as + the :obj:`Trainer` is instantiated + uncertainty_granularity: int default = 1000 + number of times the model does prediction for each sample if uncertainty + is set to True + + Returns + ------- + method == regression : np.ndarray + {max, min, mean, stdev} values for each sample for + method == binary : np.ndarray + {mean_cls_0_prob, mean_cls_1_prob, predicted_cls} values for each sample for + method == multiclass : np.ndarray + {mean_cls_0_prob, mean_cls_1_prob, mean_cls_2_prob, ... , predicted_cls} values for each sample for + + """ + preds_l = self._predict(X_wide, X_tab, X_text, X_img, X_test, batch_size, + uncertainty_granularity, uncertainty=True) + preds = np.vstack(preds_l) + samples_num = int(preds.shape[0]/uncertainty_granularity) + if self.method == "regression": + preds = preds.squeeze(1) + preds = preds.reshape((uncertainty_granularity, samples_num)) + return np.array(( + preds.max(axis=0), + preds.min(axis=0), + preds.mean(axis=0), + preds.std(axis=0))).T + if self.method == "binary": + preds = preds.squeeze(1) + preds = preds.reshape((uncertainty_granularity, samples_num)) + preds = preds.mean(axis=0) + probs = np.zeros([preds.shape[0], 3]) + probs[:, 0] = 1 - preds + probs[:, 1] = preds + probs[:, 2] = (preds > 0.5).astype("int") + return probs + if self.method == "multiclass": + preds = preds.reshape(uncertainty_granularity, samples_num, preds.shape[1]) + preds = preds.mean(axis=0) + preds = np.hstack((preds, np.vstack(np.argmax(preds, 1)))) + return preds + def predict_proba( # type: ignore[return] self, X_wide: Optional[np.ndarray] = None, @@ -1112,6 +1198,8 @@ class Trainer: X_img: Optional[np.ndarray] = None, X_test: Optional[Dict[str, np.ndarray]] = None, batch_size: int = 256, + uncertainty_granularity = 1000, + uncertainty: bool = False, ) -> List: r"""Private method to avoid code repetition in predict and predict_proba. For parameter information, please, see the .predict() @@ -1144,20 +1232,33 @@ class Trainer: self.model.eval() preds_l = [] + + if uncertainty: + for m in self.model.modules(): + if m.__class__.__name__.startswith('Dropout'): + m.train() + prediction_iters = uncertainty_granularity + else: + prediction_iters = 1 + with torch.no_grad(): - with trange(test_steps, disable=self.verbose != 1) as t: - for i, data in zip(t, test_loader): - t.set_description("predict") - X = {k: v.cuda() for k, v in data.items()} if use_cuda else data - preds = ( - self.model(X) if not self.model.is_tabnet else self.model(X)[0] - ) - if self.method == "binary": - preds = torch.sigmoid(preds) - if self.method == "multiclass": - preds = F.softmax(preds, dim=1) - preds = preds.cpu().data.numpy() - preds_l.append(preds) + with trange(uncertainty_granularity, disable=uncertainty is False) as t: + for i, k in zip(t, range(prediction_iters)): + t.set_description("predict_UncertaintyIter") + + with trange(test_steps, disable=self.verbose != 1 or uncertainty is True) as tt: + for j, data in zip(tt, test_loader): + tt.set_description("predict") + X = {k: v.cuda() for k, v in data.items()} if use_cuda else data + preds = ( + self.model(X) if not self.model.is_tabnet else self.model(X)[0] + ) + if self.method == "binary": + preds = torch.sigmoid(preds) + if self.method == "multiclass": + preds = F.softmax(preds, dim=1) + preds = preds.cpu().data.numpy() + preds_l.append(preds) self.model.train() return preds_l diff --git a/pytorch_widedeep/version.py b/pytorch_widedeep/version.py index 9eb1ebe..bd538f7 100644 --- a/pytorch_widedeep/version.py +++ b/pytorch_widedeep/version.py @@ -1 +1 @@ -__version__ = "1.0.11" +__version__ = "1.0.12" -- GitLab