{ "cells": [ { "cell_type": "markdown", "id": "60bf2cbb", "metadata": {}, "source": [ "# 3rd party integration - RayTune, Weights & Biases\n", "\n", "This notebook provides guideline for integration of external library functions in the model training process through `Callback` objects, a popular concept of using objects as arguments for other objects.\n", "\n", "**[DISCLAIMER]**\n", "\n", "We show integration of RayTune (a hyperparameter tuning framework) and Weights & Biases (ML projects experiment tracking and versioning solution) in the `pytorch_widedeep` model training process. We did not include `RayTuneReporter` and `WnBReportBest` in the library code to minimize the dependencies on other libraries that are not directly included in the model design and training." ] }, { "cell_type": "markdown", "id": "9822d48a", "metadata": {}, "source": [ "## Initial imports" ] }, { "cell_type": "code", "execution_count": 1, "id": "d073f793", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.9/site-packages/scipy/__init__.py:138: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.1)\n", " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion} is required for this version of \"\n", "/opt/conda/lib/python3.9/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n", " warnings.warn(msg)\n" ] } ], "source": [ "from typing import Optional, Dict\n", "import os\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from torch.optim import SGD, lr_scheduler\n", "\n", "from pytorch_widedeep import Trainer\n", "from pytorch_widedeep.preprocessing import TabPreprocessor\n", "from pytorch_widedeep.models import TabMlp, WideDeep\n", "from torchmetrics import F1Score as F1_torchmetrics\n", "from torchmetrics import Accuracy as Accuracy_torchmetrics\n", "from torchmetrics import Precision as Precision_torchmetrics\n", "from torchmetrics import Recall as Recall_torchmetrics\n", "from pytorch_widedeep.metrics import Accuracy, Recall, Precision, F1Score, R2Score\n", "from pytorch_widedeep.initializers import XavierNormal\n", "from pytorch_widedeep.callbacks import (\n", " EarlyStopping,\n", " ModelCheckpoint,\n", " Callback,\n", ")\n", "from pytorch_widedeep.datasets import load_bio_kdd04\n", "\n", "from sklearn.model_selection import train_test_split\n", "import warnings\n", "\n", "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n", "\n", "from ray import tune\n", "from ray.tune.schedulers import AsyncHyperBandScheduler\n", "from ray.tune import JupyterNotebookReporter\n", "from ray.tune.integration.wandb import WandbLoggerCallback, wandb_mixin\n", "import wandb\n", "\n", "import tracemalloc\n", "\n", "tracemalloc.start()\n", "\n", "# increase displayed columns in jupyter notebook\n", "pd.set_option(\"display.max_columns\", 200)\n", "pd.set_option(\"display.max_rows\", 300)" ] }, { "cell_type": "code", "execution_count": 29, "id": "3157bb9a", "metadata": {}, "outputs": [], "source": [ "class RayTuneReporter(Callback):\n", " r\"\"\"Callback that allows reporting history and lr_history values to RayTune\n", " during Hyperparameter tuning\n", "\n", " Callbacks are passed as input parameters to the ``Trainer`` class. See\n", " :class:`pytorch_widedeep.trainer.Trainer`\n", "\n", " For examples see the examples folder at:\n", "\n", " .. code-block:: bash\n", "\n", " /examples/12_HyperParameter_tuning_w_RayTune.ipynb\n", " \"\"\"\n", "\n", " def on_epoch_end(\n", " self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None\n", " ):\n", " report_dict = {}\n", " for k, v in self.trainer.history.items():\n", " report_dict.update({k: v[-1]})\n", " if hasattr(self.trainer, \"lr_history\"):\n", " for k, v in self.trainer.lr_history.items():\n", " report_dict.update({k: v[-1]})\n", " tune.report(report_dict)\n", "\n", "\n", "class WnBReportBest(Callback):\n", " r\"\"\"Callback that allows reporting best performance of a run to WnB\n", " during Hyperparameter tuning. It is an adjusted pytorch_widedeep.callbacks.ModelCheckpoint\n", " with added WnB and removed checkpoint saving.\n", "\n", " Callbacks are passed as input parameters to the ``Trainer`` class.\n", "\n", " Parameters\n", " ----------\n", " wb: obj\n", " Weights&Biases API interface to report single best result usable for\n", " comparisson of multiple paramater combinations by, for example,\n", " `parallel coordinates\n", " `_.\n", " E.g W&B summary report `wandb.run.summary[\"best\"]`.\n", " monitor: str, default=\"loss\"\n", " quantity to monitor. Typically `'val_loss'` or metric name\n", " (e.g. `'val_acc'`)\n", " mode: str, default=\"auto\"\n", " If ``save_best_only=True``, the decision to overwrite the current save\n", " file is made based on either the maximization or the minimization of\n", " the monitored quantity. For `'acc'`, this should be `'max'`, for\n", " `'loss'` this should be `'min'`, etc. In `'auto'` mode, the\n", " direction is automatically inferred from the name of the monitored\n", " quantity.\n", "\n", " \"\"\"\n", " def __init__(\n", " self,\n", " wb: object,\n", " monitor: str = \"val_loss\",\n", " mode: str = \"auto\",\n", " ):\n", " super(WnBReportBest, self).__init__()\n", "\n", " self.monitor = monitor\n", " self.mode = mode\n", " self.wb = wb\n", "\n", " if self.mode not in [\"auto\", \"min\", \"max\"]:\n", " warnings.warn(\n", " \"WnBReportBest mode %s is unknown, \"\n", " \"fallback to auto mode.\" % (self.mode),\n", " RuntimeWarning,\n", " )\n", " self.mode = \"auto\"\n", " if self.mode == \"min\":\n", " self.monitor_op = np.less\n", " self.best = np.Inf\n", " elif self.mode == \"max\":\n", " self.monitor_op = np.greater # type: ignore[assignment]\n", " self.best = -np.Inf\n", " else:\n", " if self._is_metric(self.monitor):\n", " self.monitor_op = np.greater # type: ignore[assignment]\n", " self.best = -np.Inf\n", " else:\n", " self.monitor_op = np.less\n", " self.best = np.Inf\n", "\n", " def on_epoch_end( # noqa: C901\n", " self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None\n", " ):\n", " logs = logs or {}\n", " current = logs.get(self.monitor)\n", " if current is not None:\n", " if self.monitor_op(current, self.best):\n", " self.wb.run.summary[\"best\"] = current # type: ignore[attr-defined]\n", " self.best = current\n", " self.best_epoch = epoch\n", "\n", " @staticmethod\n", " def _is_metric(monitor: str):\n", " \"copied from pytorch_widedeep.callbacks\"\n", " if any([s in monitor for s in [\"acc\", \"prec\", \"rec\", \"fscore\", \"f1\", \"f2\"]]):\n", " return True\n", " else:\n", " return False" ] }, { "cell_type": "code", "execution_count": 16, "id": "6f0ee187", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EXAMPLE_IDBLOCK_IDtarget4567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
0279261532052.032.690.302.520.01256.8-0.890.3311.0-55.0267.20.520.05-2.3649.6252.00.431.16-2.06-33.0-123.21.60-0.49-6.0665.0296.1-0.28-0.26-3.83-22.6-170.03.06-1.05-3.2922.9286.30.122.584.08-33.0-178.91.880.53-7.0-44.01987.0-5.410.95-4.0-57.0722.9-3.26-0.55-7.5125.51547.2-0.361.129.0-37.072.50.470.74-11.0-8.01595.1-1.642.83-2.0-50.0445.2-0.350.260.76
1279261533058.033.330.0016.59.5608.10.500.0720.5-52.5521.6-1.080.58-0.02-3.2103.6-0.950.23-2.87-25.9-52.2-0.210.87-1.8110.462.0-0.28-0.041.48-17.6-198.33.432.845.87-16.972.6-0.312.792.71-33.5-11.6-1.114.015.0-57.0666.31.134.385.0-64.039.31.07-0.1632.5100.01893.7-2.80-0.222.5-28.545.00.580.41-19.0-6.0762.90.290.82-3.0-35.0140.31.160.390.73
2279261534077.027.27-0.916.058.51623.6-1.400.02-6.5-48.0621.0-1.200.14-0.2073.6609.1-0.44-0.58-0.04-23.0-27.4-0.72-1.04-1.0991.1635.6-0.880.240.59-18.7-7.2-0.60-2.82-0.7152.4504.10.89-0.67-9.30-20.8-25.7-0.77-0.850.0-20.02259.0-0.941.15-4.0-44.0-22.70.94-0.98-19.0105.01267.91.031.2711.0-39.582.30.47-0.19-10.07.01491.80.32-1.290.0-34.0658.2-0.760.260.24
3279261535041.027.91-0.353.046.01921.6-1.36-0.47-32.0-51.5560.9-0.29-0.10-1.11124.3791.60.000.39-1.85-21.7-44.9-0.210.020.89133.9797.8-0.081.06-0.26-16.4-74.10.97-0.80-0.4166.9955.3-1.901.28-6.65-28.147.5-1.911.421.0-30.01846.70.761.10-4.0-52.0-53.91.71-0.22-12.097.51969.8-1.700.16-1.0-32.5255.9-0.461.5710.06.02047.7-0.981.530.0-49.0554.2-0.830.390.73
4279261536050.028.00-1.32-9.012.0464.80.880.198.0-51.598.11.09-0.33-2.16-3.9102.70.39-1.22-3.39-15.2-42.2-1.18-1.11-3.558.9141.3-0.16-0.43-4.15-12.9-13.4-1.32-0.98-3.698.8136.1-0.304.131.89-13.0-18.7-1.37-0.930.0-1.0810.1-2.296.721.0-23.0-29.70.58-1.10-18.533.5206.81.84-0.134.0-29.030.10.80-0.245.0-14.0479.50.68-0.592.0-36.0-6.92.020.14-0.23
\n", "
" ], "text/plain": [ " EXAMPLE_ID BLOCK_ID target 4 5 6 7 8 9 10 \\\n", "0 279 261532 0 52.0 32.69 0.30 2.5 20.0 1256.8 -0.89 \n", "1 279 261533 0 58.0 33.33 0.00 16.5 9.5 608.1 0.50 \n", "2 279 261534 0 77.0 27.27 -0.91 6.0 58.5 1623.6 -1.40 \n", "3 279 261535 0 41.0 27.91 -0.35 3.0 46.0 1921.6 -1.36 \n", "4 279 261536 0 50.0 28.00 -1.32 -9.0 12.0 464.8 0.88 \n", "\n", " 11 12 13 14 15 16 17 18 19 20 21 22 \\\n", "0 0.33 11.0 -55.0 267.2 0.52 0.05 -2.36 49.6 252.0 0.43 1.16 -2.06 \n", "1 0.07 20.5 -52.5 521.6 -1.08 0.58 -0.02 -3.2 103.6 -0.95 0.23 -2.87 \n", "2 0.02 -6.5 -48.0 621.0 -1.20 0.14 -0.20 73.6 609.1 -0.44 -0.58 -0.04 \n", "3 -0.47 -32.0 -51.5 560.9 -0.29 -0.10 -1.11 124.3 791.6 0.00 0.39 -1.85 \n", "4 0.19 8.0 -51.5 98.1 1.09 -0.33 -2.16 -3.9 102.7 0.39 -1.22 -3.39 \n", "\n", " 23 24 25 26 27 28 29 30 31 32 33 34 \\\n", "0 -33.0 -123.2 1.60 -0.49 -6.06 65.0 296.1 -0.28 -0.26 -3.83 -22.6 -170.0 \n", "1 -25.9 -52.2 -0.21 0.87 -1.81 10.4 62.0 -0.28 -0.04 1.48 -17.6 -198.3 \n", "2 -23.0 -27.4 -0.72 -1.04 -1.09 91.1 635.6 -0.88 0.24 0.59 -18.7 -7.2 \n", "3 -21.7 -44.9 -0.21 0.02 0.89 133.9 797.8 -0.08 1.06 -0.26 -16.4 -74.1 \n", "4 -15.2 -42.2 -1.18 -1.11 -3.55 8.9 141.3 -0.16 -0.43 -4.15 -12.9 -13.4 \n", "\n", " 35 36 37 38 39 40 41 42 43 44 45 46 \\\n", "0 3.06 -1.05 -3.29 22.9 286.3 0.12 2.58 4.08 -33.0 -178.9 1.88 0.53 \n", "1 3.43 2.84 5.87 -16.9 72.6 -0.31 2.79 2.71 -33.5 -11.6 -1.11 4.01 \n", "2 -0.60 -2.82 -0.71 52.4 504.1 0.89 -0.67 -9.30 -20.8 -25.7 -0.77 -0.85 \n", "3 0.97 -0.80 -0.41 66.9 955.3 -1.90 1.28 -6.65 -28.1 47.5 -1.91 1.42 \n", "4 -1.32 -0.98 -3.69 8.8 136.1 -0.30 4.13 1.89 -13.0 -18.7 -1.37 -0.93 \n", "\n", " 47 48 49 50 51 52 53 54 55 56 57 58 \\\n", "0 -7.0 -44.0 1987.0 -5.41 0.95 -4.0 -57.0 722.9 -3.26 -0.55 -7.5 125.5 \n", "1 5.0 -57.0 666.3 1.13 4.38 5.0 -64.0 39.3 1.07 -0.16 32.5 100.0 \n", "2 0.0 -20.0 2259.0 -0.94 1.15 -4.0 -44.0 -22.7 0.94 -0.98 -19.0 105.0 \n", "3 1.0 -30.0 1846.7 0.76 1.10 -4.0 -52.0 -53.9 1.71 -0.22 -12.0 97.5 \n", "4 0.0 -1.0 810.1 -2.29 6.72 1.0 -23.0 -29.7 0.58 -1.10 -18.5 33.5 \n", "\n", " 59 60 61 62 63 64 65 66 67 68 69 \\\n", "0 1547.2 -0.36 1.12 9.0 -37.0 72.5 0.47 0.74 -11.0 -8.0 1595.1 \n", "1 1893.7 -2.80 -0.22 2.5 -28.5 45.0 0.58 0.41 -19.0 -6.0 762.9 \n", "2 1267.9 1.03 1.27 11.0 -39.5 82.3 0.47 -0.19 -10.0 7.0 1491.8 \n", "3 1969.8 -1.70 0.16 -1.0 -32.5 255.9 -0.46 1.57 10.0 6.0 2047.7 \n", "4 206.8 1.84 -0.13 4.0 -29.0 30.1 0.80 -0.24 5.0 -14.0 479.5 \n", "\n", " 70 71 72 73 74 75 76 77 \n", "0 -1.64 2.83 -2.0 -50.0 445.2 -0.35 0.26 0.76 \n", "1 0.29 0.82 -3.0 -35.0 140.3 1.16 0.39 0.73 \n", "2 0.32 -1.29 0.0 -34.0 658.2 -0.76 0.26 0.24 \n", "3 -0.98 1.53 0.0 -49.0 554.2 -0.83 0.39 0.73 \n", "4 0.68 -0.59 2.0 -36.0 -6.9 2.02 0.14 -0.23 " ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = load_bio_kdd04(as_frame=True)\n", "df.head()" ] }, { "cell_type": "code", "execution_count": 4, "id": "eb1768e7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 144455\n", "1 1296\n", "Name: target, dtype: int64" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# imbalance of the classes\n", "df[\"target\"].value_counts()" ] }, { "cell_type": "code", "execution_count": 5, "id": "141f0e67", "metadata": {}, "outputs": [], "source": [ "# drop columns we won't need in this example\n", "df.drop(columns=[\"EXAMPLE_ID\", \"BLOCK_ID\"], inplace=True)" ] }, { "cell_type": "code", "execution_count": 6, "id": "8159efde", "metadata": {}, "outputs": [], "source": [ "df_train, df_valid = train_test_split(\n", " df, test_size=0.2, stratify=df[\"target\"], random_state=1\n", ")\n", "df_valid, df_test = train_test_split(\n", " df_valid, test_size=0.5, stratify=df_valid[\"target\"], random_state=1\n", ")" ] }, { "cell_type": "markdown", "id": "f8b6d9f0", "metadata": {}, "source": [ "## Preparing the data" ] }, { "cell_type": "code", "execution_count": 7, "id": "d9bcf02a", "metadata": {}, "outputs": [], "source": [ "continuous_cols = df.drop(columns=[\"target\"]).columns.values.tolist()" ] }, { "cell_type": "code", "execution_count": 8, "id": "3618ceb5", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.9/site-packages/pytorch_widedeep/preprocessing/tab_preprocessor.py:394: UserWarning: WARNING: Scaling will be applied to quantized continuous features!\n", " warnings.warn(\n" ] } ], "source": [ "# deeptabular\n", "tab_preprocessor = TabPreprocessor(continuous_cols=continuous_cols, scale=True)\n", "X_tab_train = tab_preprocessor.fit_transform(df_train)\n", "X_tab_valid = tab_preprocessor.transform(df_valid)\n", "X_tab_test = tab_preprocessor.transform(df_test)\n", "\n", "# target\n", "y_train = df_train[\"target\"].values\n", "y_valid = df_valid[\"target\"].values\n", "y_test = df_test[\"target\"].values" ] }, { "cell_type": "markdown", "id": "4f270bf8", "metadata": {}, "source": [ "## Define the model" ] }, { "cell_type": "code", "execution_count": 9, "id": "5781ea4d", "metadata": {}, "outputs": [], "source": [ "input_layer = len(tab_preprocessor.continuous_cols)\n", "output_layer = 1\n", "hidden_layers = np.linspace(\n", " input_layer * 2, output_layer, 5, endpoint=False, dtype=int\n", ").tolist()" ] }, { "cell_type": "code", "execution_count": 10, "id": "9d7dcc60", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "WideDeep(\n", " (deeptabular): Sequential(\n", " (0): TabMlp(\n", " (cat_and_cont_embed): DiffSizeCatAndContEmbeddings(\n", " (cont_norm): BatchNorm1d(74, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " )\n", " (tab_mlp): MLP(\n", " (mlp): Sequential(\n", " (dense_layer_0): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=74, out_features=148, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_1): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=148, out_features=118, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_2): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=118, out_features=89, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_3): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=89, out_features=59, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " (dense_layer_4): Sequential(\n", " (0): Dropout(p=0.1, inplace=False)\n", " (1): Linear(in_features=59, out_features=30, bias=True)\n", " (2): ReLU(inplace=True)\n", " )\n", " )\n", " )\n", " )\n", " (1): Linear(in_features=30, out_features=1, bias=True)\n", " )\n", ")" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "deeptabular = TabMlp(\n", " mlp_hidden_dims=hidden_layers,\n", " column_idx=tab_preprocessor.column_idx,\n", " continuous_cols=tab_preprocessor.continuous_cols,\n", ")\n", "model = WideDeep(deeptabular=deeptabular)\n", "model" ] }, { "cell_type": "code", "execution_count": 11, "id": "e8293c96", "metadata": {}, "outputs": [], "source": [ "# Metrics from torchmetrics\n", "accuracy = Accuracy_torchmetrics(average=None, num_classes=1)\n", "precision = Precision_torchmetrics(average=\"micro\", num_classes=1)\n", "f1 = F1_torchmetrics(average=None, num_classes=1)\n", "recall = Recall_torchmetrics(average=None, num_classes=1)" ] }, { "cell_type": "markdown", "id": "e1bd7147-31f6-47da-bfae-f6194c4f25bd", "metadata": {}, "source": [ "**Note**:\n", "\n", "Following cells includes usage of both `RayTuneReporter` and `WnBReportBest` callbacks.\n", "In case you want to use just `RayTuneReporter`, remove following:\n", "* wandb from config\n", "* `WandbLoggerCallback`\n", "* `WnBReportBest`\n", "* `@wandb_mixin` decorator\n", "\n", "We do not see strong reason to use WnB without RayTune for a single paramater combination run, but it is possible:\n", "* **option01**: define paramaters in config only for a single value `tune.grid_search([1000])` (single value RayTune run)\n", "* **option02**: define WnB callback that reports currnet validation/training loss, metrics, etc. at the end of batch, ie. do not report to WnB at `epoch_end` as in `WnBReportBest` but at the `on_batch_end`, see `pytorch_widedeep.callbacks.Callback`\n" ] }, { "cell_type": "code", "execution_count": 30, "id": "b841d015", "metadata": {}, "outputs": [ { "data": { "text/html": [ "== Status ==
Current time: 2022-07-31 14:42:12 (running for 00:00:24.76)
Memory usage on this node: 4.5/30.6 GiB
Using AsyncHyperBand: num_stopped=0\n", "Bracket: Iter 90.000: None | Iter 30.000: None | Iter 10.000: None
Resources requested: 0/8 CPUs, 0/0 GPUs, 0.0/16.97 GiB heap, 0.0/8.49 GiB objects
Result logdir: /home/jovyan/ray_results/training_function_2022-07-31_14-41-47
Number of trials: 2/2 (2 TERMINATED)
\n", "\n", "\n", "\n", "\n", "\n", "\n", "\n", "
Trial name status loc batch_size iter total time (s)
training_function_e7fce_00000TERMINATED10.32.44.172:6759 1000 5 12.2567
training_function_e7fce_00001TERMINATED10.32.44.172:6924 5000 5 12.7518


" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m /opt/conda/lib/python3.9/tempfile.py:817: ResourceWarning: Implicitly cleaning up \n", "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m _warnings.warn(warn_message, ResourceWarning)\n", "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m /opt/conda/lib/python3.9/site-packages/wandb/sdk/internal/sender.py:1123: ResourceWarning: unclosed \n", "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m self._pusher = None\n", "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m ResourceWarning: Enable tracemalloc to get the object allocation traceback\n", "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m wandb: \n", "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m wandb: Run summary:\n", "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m wandb: best 0.00505\n", "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m wandb: \n", "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m wandb: Synced training_function_e7fce_00000: https://wandb.ai/palo/test/runs/e7fce_00000\n", "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m wandb: Synced 3 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)\n", "\u001b[2m\u001b[36m(training_function pid=6759)\u001b[0m wandb: Find logs at: ./wandb/run-20220731_144151-e7fce_00000/logs\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: Waiting for W&B process to finish... (success).\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: - 0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: \\ 0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: | 0.000 MB of 0.032 MB uploaded (0.000 MB deduped)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: / 0.032 MB of 0.032 MB uploaded (0.000 MB deduped)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: - 0.032 MB of 0.032 MB uploaded (0.000 MB deduped)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: \\ 0.032 MB of 0.032 MB uploaded (0.000 MB deduped)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: | 0.032 MB of 0.032 MB uploaded (0.000 MB deduped)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: / 0.032 MB of 0.032 MB uploaded (0.000 MB deduped)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: - 0.032 MB of 0.032 MB uploaded (0.000 MB deduped)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: \\ 0.032 MB of 0.032 MB uploaded (0.000 MB deduped)\n", "wandb: 0 MB deduped)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m /opt/conda/lib/python3.9/tempfile.py:817: ResourceWarning: Implicitly cleaning up \n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m _warnings.warn(warn_message, ResourceWarning)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m /opt/conda/lib/python3.9/site-packages/wandb/sdk/internal/sender.py:1123: ResourceWarning: unclosed \n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m self._pusher = None\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m ResourceWarning: Enable tracemalloc to get the object allocation traceback\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: \n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: Run summary:\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: best 0.01509\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: \n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: Synced training_function_e7fce_00001: https://wandb.ai/palo/test/runs/e7fce_00001\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: Synced 3 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)\n", "\u001b[2m\u001b[36m(training_function pid=6924)\u001b[0m wandb: Find logs at: ./wandb/run-20220731_144155-e7fce_00001/logs\n", "2022-07-31 14:42:17,787\tINFO tune.py:747 -- Total run time: 29.89 seconds (24.71 seconds for the tuning loop).\n" ] } ], "source": [ "config = {\n", " \"batch_size\": tune.grid_search([1000, 5000]),\n", " \"wandb\": {\n", " \"project\": \"test\",\n", " # \"api_key_file\": os.getcwd() + \"/wandb_api.key\",\n", " \"api_key\": \"WNB_API_KEY\", \n", " },\n", "}\n", "\n", "# Optimizers\n", "deep_opt = SGD(model.deeptabular.parameters(), lr=0.1)\n", "# LR Scheduler\n", "deep_sch = lr_scheduler.StepLR(deep_opt, step_size=3)\n", "\n", "\n", "@wandb_mixin\n", "def training_function(config, X_train, X_val):\n", " early_stopping = EarlyStopping()\n", " model_checkpoint = ModelCheckpoint(save_best_only=True)\n", " # Hyperparameters\n", " batch_size = config[\"batch_size\"]\n", " trainer = Trainer(\n", " model,\n", " objective=\"binary_focal_loss\",\n", " callbacks=[RayTuneReporter, WnBReportBest(wb=wandb), early_stopping, model_checkpoint],\n", " lr_schedulers={\"deeptabular\": deep_sch},\n", " initializers={\"deeptabular\": XavierNormal},\n", " optimizers={\"deeptabular\": deep_opt},\n", " metrics=[accuracy, precision, recall, f1],\n", " verbose=0,\n", " )\n", "\n", " trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=batch_size)\n", "\n", "\n", "X_train = {\"X_tab\": X_tab_train, \"target\": y_train}\n", "X_val = {\"X_tab\": X_tab_valid, \"target\": y_valid}\n", "\n", "asha_scheduler = AsyncHyperBandScheduler(\n", " time_attr=\"training_iteration\",\n", " metric=\"_metric/val_loss\",\n", " mode=\"min\",\n", " max_t=100,\n", " grace_period=10,\n", " reduction_factor=3,\n", " brackets=1,\n", ")\n", "\n", "analysis = tune.run(\n", " tune.with_parameters(training_function, X_train=X_train, X_val=X_val),\n", " resources_per_trial={\"cpu\": 1, \"gpu\": 0},\n", " progress_reporter=JupyterNotebookReporter(overwrite=True),\n", " scheduler=asha_scheduler,\n", " config=config,\n", " callbacks=[\n", " WandbLoggerCallback(\n", " project=config[\"wandb\"][\"project\"],\n", " # api_key_file=config[\"wandb\"][\"api_key_file\"],\n", " api_key=config[\"wandb\"][\"api_key\"],\n", " log_config=True,\n", " )\n", " ],\n", ")" ] }, { "cell_type": "code", "execution_count": 14, "id": "81d581da", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'fc9a8_00000': {'_metric': {'train_loss': 0.006297602537127896,\n", " 'train_Accuracy': 0.9925042986869812,\n", " 'train_Precision': 0.9939393997192383,\n", " 'train_Recall': 0.15814851224422455,\n", " 'train_F1Score': 0.2728785574436188,\n", " 'val_loss': 0.005045663565397263,\n", " 'val_Accuracy': 0.9946483969688416,\n", " 'val_Precision': 1.0,\n", " 'val_Recall': 0.39534884691238403,\n", " 'val_F1Score': 0.5666667222976685},\n", " 'time_this_iter_s': 2.388202428817749,\n", " 'done': True,\n", " 'timesteps_total': None,\n", " 'episodes_total': None,\n", " 'training_iteration': 5,\n", " 'trial_id': 'fc9a8_00000',\n", " 'experiment_id': 'baad1d4f3d924b48b9ece1b9f26c80cc',\n", " 'date': '2022-07-31_14-06-51',\n", " 'timestamp': 1659276411,\n", " 'time_total_s': 12.656474113464355,\n", " 'pid': 1813,\n", " 'hostname': 'jupyter-5uperpalo',\n", " 'node_ip': '10.32.44.172',\n", " 'config': {'batch_size': 1000},\n", " 'time_since_restore': 12.656474113464355,\n", " 'timesteps_since_restore': 0,\n", " 'iterations_since_restore': 5,\n", " 'warmup_time': 0.8006253242492676,\n", " 'experiment_tag': '0_batch_size=1000'},\n", " 'fc9a8_00001': {'_metric': {'train_loss': 0.02519632239515583,\n", " 'train_Accuracy': 0.9910891652107239,\n", " 'train_Precision': 0.25,\n", " 'train_Recall': 0.0009643201483413577,\n", " 'train_F1Score': 0.0019212296465411782,\n", " 'val_loss': 0.02578434906899929,\n", " 'val_Accuracy': 0.9911492466926575,\n", " 'val_Precision': 0.0,\n", " 'val_Recall': 0.0,\n", " 'val_F1Score': 0.0},\n", " 'time_this_iter_s': 4.113586902618408,\n", " 'done': True,\n", " 'timesteps_total': None,\n", " 'episodes_total': None,\n", " 'training_iteration': 5,\n", " 'trial_id': 'fc9a8_00001',\n", " 'experiment_id': 'f2e54a6a5780429fbf0db0746853347e',\n", " 'date': '2022-07-31_14-06-56',\n", " 'timestamp': 1659276416,\n", " 'time_total_s': 12.926990509033203,\n", " 'pid': 1962,\n", " 'hostname': 'jupyter-5uperpalo',\n", " 'node_ip': '10.32.44.172',\n", " 'config': {'batch_size': 5000},\n", " 'time_since_restore': 12.926990509033203,\n", " 'timesteps_since_restore': 0,\n", " 'iterations_since_restore': 5,\n", " 'warmup_time': 0.9253025054931641,\n", " 'experiment_tag': '1_batch_size=5000'}}" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "analysis.results" ] }, { "cell_type": "markdown", "id": "56311243", "metadata": {}, "source": [ "Using Weights and Biases logging you can create [parallel coordinates graphs](https://docs.wandb.ai/ref/app/features/panels/parallel-coordinates) that map parametr combinations to the best(lowest) loss achieved during the training of the networks\n", "\n", "![WNB](figures/wnb.png \"parallel coordinates\")" ] }, { "cell_type": "markdown", "id": "9b469bc2", "metadata": {}, "source": [ "local visualization of raytune reults using tensorboard" ] }, { "cell_type": "code", "execution_count": 23, "id": "57b44a55", "metadata": { "scrolled": true }, "outputs": [], "source": [ "%load_ext tensorboard\n", "%tensorboard --logdir ~/ray_results" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false }, "vscode": { "interpreter": { "hash": "bee110fa72fc220f84be99700c69baf478c6696e63cfda5b1944123ebc470d26" } } }, "nbformat": 4, "nbformat_minor": 5 }