{ "cells": [ { "cell_type": "markdown", "id": "6658e02e", "metadata": {}, "source": [ "# Custom DataLoader for Imbalanced dataset" ] }, { "cell_type": "markdown", "id": "d2528519", "metadata": {}, "source": [ "* In this notebook we will use the higly imbalanced Protein Homology Dataset from [KDD cup 2004](https://www.kdd.org/kdd-cup/view/kdd-cup-2004/Data)\n", "\n", "```\n", "* The first element of each line is a BLOCK ID that denotes to which native sequence this example belongs. There is a unique BLOCK ID for each native sequence. BLOCK IDs are integers running from 1 to 303 (one for each native sequence, i.e. for each query). BLOCK IDs were assigned before the blocks were split into the train and test sets, so they do not run consecutively in either file.\n", "* The second element of each line is an EXAMPLE ID that uniquely describes the example. You will need this EXAMPLE ID and the BLOCK ID when you submit results.\n", "* The third element is the class of the example. Proteins that are homologous to the native sequence are denoted by 1, non-homologous proteins (i.e. decoys) by 0. Test examples have a \"?\" in this position.\n", "* All following elements are feature values. There are 74 feature values in each line. The features describe the match (e.g. the score of a sequence alignment) between the native protein sequence and the sequence that is tested for homology.\n", "```" ] }, { "cell_type": "markdown", "id": "639ad78b", "metadata": {}, "source": [ "## Initial imports" ] }, { "cell_type": "code", "execution_count": 4, "id": "d40952c1", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from torch.optim import SGD, lr_scheduler\n", "\n", "from pytorch_widedeep import Trainer\n", "from pytorch_widedeep.preprocessing import TabPreprocessor\n", "from pytorch_widedeep.models import TabMlp, WideDeep\n", "from pytorch_widedeep.dataloaders import DataLoaderImbalanced, DataLoaderDefault\n", "from torchmetrics import F1 as F1_torchmetrics\n", "from torchmetrics import Accuracy as Accuracy_torchmetrics\n", "from torchmetrics import Precision as Precision_torchmetrics\n", "from torchmetrics import Recall as Recall_torchmetrics\n", "from pytorch_widedeep.metrics import Accuracy, Recall, Precision, F1Score, R2Score\n", "from pytorch_widedeep.initializers import XavierNormal\n", "from pytorch_widedeep.datasets import load_bio_kdd04\n", "\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import classification_report\n", "\n", "import time\n", "import datetime\n", "\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n", "\n", "# increase displayed columns in jupyter notebook\n", "pd.set_option(\"display.max_columns\", 200)\n", "pd.set_option(\"display.max_rows\", 300)" ] }, { "cell_type": "code", "execution_count": 5, "id": "fe5f9761", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EXAMPLE_IDBLOCK_IDtarget4567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
0279261532052.032.690.302.520.01256.8-0.890.3311.0-55.0267.20.520.05-2.3649.6252.00.431.16-2.06-33.0-123.21.60-0.49-6.0665.0296.1-0.28-0.26-3.83-22.6-170.03.06-1.05-3.2922.9286.30.122.584.08-33.0-178.91.880.53-7.0-44.01987.0-5.410.95-4.0-57.0722.9-3.26-0.55-7.5125.51547.2-0.361.129.0-37.072.50.470.74-11.0-8.01595.1-1.642.83-2.0-50.0445.2-0.350.260.76
1279261533058.033.330.0016.59.5608.10.500.0720.5-52.5521.6-1.080.58-0.02-3.2103.6-0.950.23-2.87-25.9-52.2-0.210.87-1.8110.462.0-0.28-0.041.48-17.6-198.33.432.845.87-16.972.6-0.312.792.71-33.5-11.6-1.114.015.0-57.0666.31.134.385.0-64.039.31.07-0.1632.5100.01893.7-2.80-0.222.5-28.545.00.580.41-19.0-6.0762.90.290.82-3.0-35.0140.31.160.390.73
2279261534077.027.27-0.916.058.51623.6-1.400.02-6.5-48.0621.0-1.200.14-0.2073.6609.1-0.44-0.58-0.04-23.0-27.4-0.72-1.04-1.0991.1635.6-0.880.240.59-18.7-7.2-0.60-2.82-0.7152.4504.10.89-0.67-9.30-20.8-25.7-0.77-0.850.0-20.02259.0-0.941.15-4.0-44.0-22.70.94-0.98-19.0105.01267.91.031.2711.0-39.582.30.47-0.19-10.07.01491.80.32-1.290.0-34.0658.2-0.760.260.24
3279261535041.027.91-0.353.046.01921.6-1.36-0.47-32.0-51.5560.9-0.29-0.10-1.11124.3791.60.000.39-1.85-21.7-44.9-0.210.020.89133.9797.8-0.081.06-0.26-16.4-74.10.97-0.80-0.4166.9955.3-1.901.28-6.65-28.147.5-1.911.421.0-30.01846.70.761.10-4.0-52.0-53.91.71-0.22-12.097.51969.8-1.700.16-1.0-32.5255.9-0.461.5710.06.02047.7-0.981.530.0-49.0554.2-0.830.390.73
4279261536050.028.00-1.32-9.012.0464.80.880.198.0-51.598.11.09-0.33-2.16-3.9102.70.39-1.22-3.39-15.2-42.2-1.18-1.11-3.558.9141.3-0.16-0.43-4.15-12.9-13.4-1.32-0.98-3.698.8136.1-0.304.131.89-13.0-18.7-1.37-0.930.0-1.0810.1-2.296.721.0-23.0-29.70.58-1.10-18.533.5206.81.84-0.134.0-29.030.10.80-0.245.0-14.0479.50.68-0.592.0-36.0-6.92.020.14-0.23
\n", "
" ], "text/plain": [ " EXAMPLE_ID BLOCK_ID target 4 5 6 7 8 9 10 \\\n", "0 279 261532 0 52.0 32.69 0.30 2.5 20.0 1256.8 -0.89 \n", "1 279 261533 0 58.0 33.33 0.00 16.5 9.5 608.1 0.50 \n", "2 279 261534 0 77.0 27.27 -0.91 6.0 58.5 1623.6 -1.40 \n", "3 279 261535 0 41.0 27.91 -0.35 3.0 46.0 1921.6 -1.36 \n", "4 279 261536 0 50.0 28.00 -1.32 -9.0 12.0 464.8 0.88 \n", "\n", " 11 12 13 14 15 16 17 18 19 20 21 22 \\\n", "0 0.33 11.0 -55.0 267.2 0.52 0.05 -2.36 49.6 252.0 0.43 1.16 -2.06 \n", "1 0.07 20.5 -52.5 521.6 -1.08 0.58 -0.02 -3.2 103.6 -0.95 0.23 -2.87 \n", "2 0.02 -6.5 -48.0 621.0 -1.20 0.14 -0.20 73.6 609.1 -0.44 -0.58 -0.04 \n", "3 -0.47 -32.0 -51.5 560.9 -0.29 -0.10 -1.11 124.3 791.6 0.00 0.39 -1.85 \n", "4 0.19 8.0 -51.5 98.1 1.09 -0.33 -2.16 -3.9 102.7 0.39 -1.22 -3.39 \n", "\n", " 23 24 25 26 27 28 29 30 31 32 33 34 \\\n", "0 -33.0 -123.2 1.60 -0.49 -6.06 65.0 296.1 -0.28 -0.26 -3.83 -22.6 -170.0 \n", "1 -25.9 -52.2 -0.21 0.87 -1.81 10.4 62.0 -0.28 -0.04 1.48 -17.6 -198.3 \n", "2 -23.0 -27.4 -0.72 -1.04 -1.09 91.1 635.6 -0.88 0.24 0.59 -18.7 -7.2 \n", "3 -21.7 -44.9 -0.21 0.02 0.89 133.9 797.8 -0.08 1.06 -0.26 -16.4 -74.1 \n", "4 -15.2 -42.2 -1.18 -1.11 -3.55 8.9 141.3 -0.16 -0.43 -4.15 -12.9 -13.4 \n", "\n", " 35 36 37 38 39 40 41 42 43 44 45 46 \\\n", "0 3.06 -1.05 -3.29 22.9 286.3 0.12 2.58 4.08 -33.0 -178.9 1.88 0.53 \n", "1 3.43 2.84 5.87 -16.9 72.6 -0.31 2.79 2.71 -33.5 -11.6 -1.11 4.01 \n", "2 -0.60 -2.82 -0.71 52.4 504.1 0.89 -0.67 -9.30 -20.8 -25.7 -0.77 -0.85 \n", "3 0.97 -0.80 -0.41 66.9 955.3 -1.90 1.28 -6.65 -28.1 47.5 -1.91 1.42 \n", "4 -1.32 -0.98 -3.69 8.8 136.1 -0.30 4.13 1.89 -13.0 -18.7 -1.37 -0.93 \n", "\n", " 47 48 49 50 51 52 53 54 55 56 57 58 \\\n", "0 -7.0 -44.0 1987.0 -5.41 0.95 -4.0 -57.0 722.9 -3.26 -0.55 -7.5 125.5 \n", "1 5.0 -57.0 666.3 1.13 4.38 5.0 -64.0 39.3 1.07 -0.16 32.5 100.0 \n", "2 0.0 -20.0 2259.0 -0.94 1.15 -4.0 -44.0 -22.7 0.94 -0.98 -19.0 105.0 \n", "3 1.0 -30.0 1846.7 0.76 1.10 -4.0 -52.0 -53.9 1.71 -0.22 -12.0 97.5 \n", "4 0.0 -1.0 810.1 -2.29 6.72 1.0 -23.0 -29.7 0.58 -1.10 -18.5 33.5 \n", "\n", " 59 60 61 62 63 64 65 66 67 68 69 \\\n", "0 1547.2 -0.36 1.12 9.0 -37.0 72.5 0.47 0.74 -11.0 -8.0 1595.1 \n", "1 1893.7 -2.80 -0.22 2.5 -28.5 45.0 0.58 0.41 -19.0 -6.0 762.9 \n", "2 1267.9 1.03 1.27 11.0 -39.5 82.3 0.47 -0.19 -10.0 7.0 1491.8 \n", "3 1969.8 -1.70 0.16 -1.0 -32.5 255.9 -0.46 1.57 10.0 6.0 2047.7 \n", "4 206.8 1.84 -0.13 4.0 -29.0 30.1 0.80 -0.24 5.0 -14.0 479.5 \n", "\n", " 70 71 72 73 74 75 76 77 \n", "0 -1.64 2.83 -2.0 -50.0 445.2 -0.35 0.26 0.76 \n", "1 0.29 0.82 -3.0 -35.0 140.3 1.16 0.39 0.73 \n", "2 0.32 -1.29 0.0 -34.0 658.2 -0.76 0.26 0.24 \n", "3 -0.98 1.53 0.0 -49.0 554.2 -0.83 0.39 0.73 \n", "4 0.68 -0.59 2.0 -36.0 -6.9 2.02 0.14 -0.23 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = load_bio_kdd04(as_frame=True)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 6, "id": "f60412e0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 144455\n", "1 1296\n", "Name: target, dtype: int64" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# imbalance of the classes\n", "df[\"target\"].value_counts()" ] }, { "cell_type": "code", "execution_count": 7, "id": "b4dafc90", "metadata": {}, "outputs": [], "source": [ "# drop columns we won't need in this example\n", "df.drop(columns=[\"EXAMPLE_ID\", \"BLOCK_ID\"], inplace=True)" ] }, { "cell_type": "code", "execution_count": 8, "id": "3f262cfa", "metadata": {}, "outputs": [], "source": [ "df_train, df_valid = train_test_split(\n", " df, test_size=0.2, stratify=df[\"target\"], random_state=1\n", ")\n", "df_valid, df_test = train_test_split(\n", " df_valid, test_size=0.5, stratify=df_valid[\"target\"], random_state=1\n", ")" ] }, { "cell_type": "markdown", "id": "ec00939c", "metadata": {}, "source": [ "## Preparing the data" ] }, { "cell_type": "code", "execution_count": 9, "id": "2c23b964", "metadata": {}, "outputs": [], "source": [ "continuous_cols = df.drop(columns=[\"target\"]).columns.values.tolist()" ] }, { "cell_type": "code", "execution_count": 10, "id": "713b3d11", "metadata": {}, "outputs": [], "source": [ "# deeptabular\n", "tab_preprocessor = TabPreprocessor(continuous_cols=continuous_cols, scale=True)\n", "X_tab_train = tab_preprocessor.fit_transform(df_train)\n", "X_tab_valid = tab_preprocessor.transform(df_valid)\n", "X_tab_test = tab_preprocessor.transform(df_test)\n", "\n", "# target\n", "y_train = df_train[\"target\"].values\n", "y_valid = df_valid[\"target\"].values\n", "y_test = df_test[\"target\"].values" ] }, { "cell_type": "markdown", "id": "333ee9eb", "metadata": {}, "source": [ "## Define the model" ] }, { "cell_type": "code", "execution_count": 11, "id": "eefefb2c", "metadata": {}, "outputs": [], "source": [ "input_layer = len(tab_preprocessor.continuous_cols)\n", "output_layer = 1\n", "hidden_layers = np.linspace(\n", " input_layer * 2, output_layer, 5, endpoint=False, dtype=int\n", ").tolist()" ] }, { "cell_type": "code", "execution_count": 12, "id": "3a024773", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "WideDeep(\n", " (deeptabular): Sequential(\n", " (0): TabMlp(\n", " (cat_embed_and_cont): CatEmbeddingsAndCont(\n", " (cont_norm): BatchNorm1d(74, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (tab_mlp): MLP(\n", " (mlp): Sequential(\n", " (dense_layer_0): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=74, out_features=148, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_1): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=148, out_features=118, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_2): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=118, out_features=89, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_3): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=89, out_features=59, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_4): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=59, out_features=30, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " )\n", " )\n", " )\n", " (1): Linear(in_features=30, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "deeptabular = TabMlp(\n", " mlp_hidden_dims=hidden_layers,\n", " column_idx=tab_preprocessor.column_idx,\n", " continuous_cols=tab_preprocessor.continuous_cols,\n", ")\n", "model = WideDeep(deeptabular=deeptabular, pred_dim=1)\n", "model" ] }, { "cell_type": "code", "execution_count": 13, "id": "c0e821d1", "metadata": {}, "outputs": [], "source": [ "# Metrics from torchmetrics\n", "accuracy = Accuracy_torchmetrics(average=None, num_classes=2)\n", "precision = Precision_torchmetrics(average=\"micro\", num_classes=2)\n", "f1 = F1_torchmetrics(average=None, num_classes=2)\n", "recall = Recall_torchmetrics(average=None, num_classes=2)" ] }, { "cell_type": "code", "execution_count": 14, "id": "6b6a1157", "metadata": {}, "outputs": [], "source": [ "# # Metrics from pytorch-widedeep\n", "# accuracy = Accuracy(top_k=2)\n", "# precision = Precision(average=False)\n", "# recall = Recall(average=True)\n", "# f1 = F1Score(average=False)" ] }, { "cell_type": "code", "execution_count": 15, "id": "03407003", "metadata": {}, "outputs": [], "source": [ "# Optimizers\n", "deep_opt = SGD(model.deeptabular.parameters(), lr=0.1)\n", "# LR Scheduler\n", "deep_sch = lr_scheduler.StepLR(deep_opt, step_size=3)\n", "\n", "trainer = Trainer(\n", " model,\n", " objective=\"binary\",\n", " lr_schedulers={\"deeptabular\": deep_sch},\n", " initializers={\"deeptabular\": XavierNormal},\n", " optimizers={\"deeptabular\": deep_opt},\n", " metrics=[accuracy, precision, recall, f1],\n", " verbose=1,\n", ")" ] }, { "cell_type": "code", "execution_count": 16, "id": "76914c9a", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "epoch 1: 100%|██████████| 208/208 [00:03<00:00, 65.68it/s, loss=0.225, metrics={'Accuracy': [0.9414, 0.8749], 'Precision': 0.9079, 'Recall': [0.9414, 0.8749], 'F1': [0.9103, 0.9053]}]\n", "valid: 100%|██████████| 292/292 [00:03<00:00, 78.24it/s, loss=0.113, metrics={'Accuracy': [0.9661, 0.8682], 'Precision': 0.9653, 'Recall': [0.9661, 0.8682], 'F1': [0.9822, 0.3068]}]\n", "epoch 2: 100%|██████████| 208/208 [00:03<00:00, 65.73it/s, loss=0.154, metrics={'Accuracy': [0.9485, 0.9252], 'Precision': 0.9369, 'Recall': [0.9485, 0.9252], 'F1': [0.9381, 0.9357]}]\n", "valid: 100%|██████████| 292/292 [00:03<00:00, 78.56it/s, loss=0.0866, metrics={'Accuracy': [0.974, 0.8915], 'Precision': 0.9733, 'Recall': [0.974, 0.8915], 'F1': [0.9864, 0.3716]}]\n", "epoch 3: 100%|██████████| 208/208 [00:03<00:00, 59.23it/s, loss=0.14, metrics={'Accuracy': [0.9541, 0.9356], 'Precision': 0.9448, 'Recall': [0.9541, 0.9356], 'F1': [0.9453, 0.9444]}]\n", "valid: 100%|██████████| 292/292 [00:03<00:00, 80.98it/s, loss=0.0988, metrics={'Accuracy': [0.9713, 0.8915], 'Precision': 0.9706, 'Recall': [0.9713, 0.8915], 'F1': [0.985, 0.3495]}]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training time[s]: 0:00:21\n" ] } ], "source": [ "start = time.time()\n", "trainer.fit(\n", " X_train={\"X_tab\": X_tab_train, \"target\": y_train},\n", " X_val={\"X_tab\": X_tab_valid, \"target\": y_valid},\n", " n_epochs=3,\n", " batch_size=50,\n", " custom_dataloader=DataLoaderImbalanced,\n", " oversample_mul=5,\n", ")\n", "print(\n", " \"Training time[s]: {}\".format(\n", " datetime.timedelta(seconds=round(time.time() - start))\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": 17, "id": "28e00935", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
train_losstrain_Accuracy_0train_Accuracy_1train_Precisiontrain_Recall_0train_Recall_1train_F1_0train_F1_1val_lossval_Accuracy_0val_Accuracy_1val_Precisionval_Recall_0val_Recall_1val_F1_0val_F1_1
00.2249290.9413590.8749040.9079070.9413590.8749040.9103370.9053430.1132930.9661500.8682170.9652830.9661500.8682170.9821960.306849
10.1535910.9485270.9251560.9369330.9485270.9251560.9381150.9357060.0866450.9740410.8914730.9733100.9740410.8914730.9863660.371567
20.1397390.9540540.9356450.9448410.9540540.9356450.9452950.9443800.0987890.9713420.8914730.9706350.9713420.8914730.9849780.349544
\n", "
" ], "text/plain": [ " train_loss train_Accuracy_0 train_Accuracy_1 train_Precision \\\n", "0 0.224929 0.941359 0.874904 0.907907 \n", "1 0.153591 0.948527 0.925156 0.936933 \n", "2 0.139739 0.954054 0.935645 0.944841 \n", "\n", " train_Recall_0 train_Recall_1 train_F1_0 train_F1_1 val_loss \\\n", "0 0.941359 0.874904 0.910337 0.905343 0.113293 \n", "1 0.948527 0.925156 0.938115 0.935706 0.086645 \n", "2 0.954054 0.935645 0.945295 0.944380 0.098789 \n", "\n", " val_Accuracy_0 val_Accuracy_1 val_Precision val_Recall_0 val_Recall_1 \\\n", "0 0.966150 0.868217 0.965283 0.966150 0.868217 \n", "1 0.974041 0.891473 0.973310 0.974041 0.891473 \n", "2 0.971342 0.891473 0.970635 0.971342 0.891473 \n", "\n", " val_F1_0 val_F1_1 \n", "0 0.982196 0.306849 \n", "1 0.986366 0.371567 \n", "2 0.984978 0.349544 " ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pd.DataFrame(trainer.history)" ] }, { "cell_type": "markdown", "id": "702f5ea8", "metadata": {}, "source": [ "## \"Normal\" prediction" ] }, { "cell_type": "code", "execution_count": 18, "id": "8dadf819", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "predict: 100%|██████████| 292/292 [00:01<00:00, 213.90it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 1.00 0.97 0.99 14446\n", " 1 0.23 0.95 0.37 130\n", "\n", " accuracy 0.97 14576\n", " macro avg 0.61 0.96 0.68 14576\n", "weighted avg 0.99 0.97 0.98 14576\n", "\n", "Actual predicted values:\n", "(array([0, 1]), array([14039, 537]))\n" ] } ], "source": [ "df_pred = trainer.predict(X_tab=X_tab_test)\n", "print(classification_report(df_test[\"target\"].to_list(), df_pred))\n", "print(\"Actual predicted values:\\n{}\".format(np.unique(df_pred, return_counts=True)))" ] }, { "cell_type": "markdown", "id": "e1c2d471", "metadata": {}, "source": [ "## Prediction using uncertainty" ] }, { "cell_type": "code", "execution_count": 21, "id": "5411c3a0", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "predict_UncertaintyIter: 100%|██████████| 10/10 [00:08<00:00, 1.16it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 1.00 0.97 0.99 14446\n", " 1 0.23 0.95 0.37 130\n", "\n", " accuracy 0.97 14576\n", " macro avg 0.61 0.96 0.68 14576\n", "weighted avg 0.99 0.97 0.98 14576\n", "\n", "Actual predicted values:\n", "(array([0.]), array([14576]))\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "df_pred_unc = trainer.predict_uncertainty(X_tab=X_tab_test, uncertainty_granularity=10)\n", "print(classification_report(df_test[\"target\"].to_list(), df_pred))\n", "print(\"Actual predicted values:\\n{}\".format(np.unique(df_pred_unc[:,-1], return_counts=True)))" ] }, { "cell_type": "code", "execution_count": 20, "id": "e97a1e99", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[9.83839095e-01, 1.61609277e-02, 0.00000000e+00],\n", " [9.98977661e-01, 1.02235633e-03, 0.00000000e+00],\n", " [9.90328670e-01, 9.67130624e-03, 0.00000000e+00],\n", " ...,\n", " [9.90116656e-01, 9.88335349e-03, 0.00000000e+00],\n", " [9.99370277e-01, 6.29719463e-04, 0.00000000e+00],\n", " [9.99686420e-01, 3.13554629e-04, 0.00000000e+00]])" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_pred_unc" ] } ], "metadata": { "interpreter": { "hash": "3b99005fd577fa40f3cce433b2b92303885900e634b2b5344c07c59d06c8792d" }, "kernelspec": { "display_name": "Python 3.8.5 64-bit ('base': conda)", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }