# 3rd party integration - RayTune, Weights & Biases

This notebook provides guideline for integration of external library functions in the model training process through `Callback` objects, a popular concept of using objects as arguments for other objects.

**[DISCLAIMER]**

We show integration of RayTune (a hyperparameter tuning framework) and Weights & Biases (ML projects experiment tracking and versioning solution) in the `pytorch_widedeep` model training process. We did not include `RayTuneReporter` and `WnBReportBest` in the library code to minimize the dependencies on other libraries that are not directly included in the model design and training.

## Initial imports

In [1]:
from typing import Optional, Dict
import os
import numpy as np
import pandas as pd
import torch
from torch.optim import SGD, lr_scheduler

from pytorch_widedeep import Trainer
from pytorch_widedeep.preprocessing import TabPreprocessor
from pytorch_widedeep.models import TabMlp, WideDeep
from torchmetrics import F1Score as F1_torchmetrics
from torchmetrics import Accuracy as Accuracy_torchmetrics
from torchmetrics import Precision as Precision_torchmetrics
from torchmetrics import Recall as Recall_torchmetrics
from pytorch_widedeep.metrics import Accuracy, Recall, Precision, F1Score, R2Score
from pytorch_widedeep.initializers import XavierNormal
from pytorch_widedeep.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    Callback,
)
from pytorch_widedeep.datasets import load_bio_kdd04

from sklearn.model_selection import train_test_split
import warnings

warnings.filterwarnings("ignore", category=DeprecationWarning)

from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler
from ray.tune import JupyterNotebookReporter
from ray.tune.integration.wandb import WandbLoggerCallback, wandb_mixin
import wandb

import tracemalloc

tracemalloc.start()

# increase displayed columns in jupyter notebook
pd.set_option("display.max_columns", 200)
pd.set_option("display.max_rows", 300)



In [29]:
class RayTuneReporter(Callback):
    r"""Callback that allows reporting history and lr_history values to RayTune
    during Hyperparameter tuning

    Callbacks are passed as input parameters to the ``Trainer`` class. See
    :class:`pytorch_widedeep.trainer.Trainer`

    For examples see the examples folder at:

        .. code-block:: bash

            /examples/12_HyperParameter_tuning_w_RayTune.ipynb
    """

    def on_epoch_end(
        self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
    ):
        report_dict = {}
        for k, v in self.trainer.history.items():
            report_dict.update({k: v[-1]})
        if hasattr(self.trainer, "lr_history"):
            for k, v in self.trainer.lr_history.items():
                report_dict.update({k: v[-1]})
        tune.report(report_dict)


class WnBReportBest(Callback):
    r"""Callback that allows reporting best performance of a run to WnB
    during Hyperparameter tuning. It is an adjusted pytorch_widedeep.callbacks.ModelCheckpoint
    with added WnB and removed checkpoint saving.

    Callbacks are passed as input parameters to the ``Trainer`` class.

    Parameters
    ----------
    wb: obj
        Weights&Biases API interface to report single best result usable for
        comparisson of multiple paramater combinations by, for example,
        `parallel coordinates
        <https://docs.wandb.ai/ref/app/features/panels/parallel-coordinates>`_.
        E.g W&B summary report `wandb.run.summary["best"]`.
    monitor: str, default="loss"
        quantity to monitor. Typically `'val_loss'` or metric name
        (e.g. `'val_acc'`)
    mode: str, default="auto"
        If ``save_best_only=True``, the decision to overwrite the current save
        file is made based on either the maximization or the minimization of
        the monitored quantity. For `'acc'`, this should be `'max'`, for
        `'loss'` this should be `'min'`, etc. In `'auto'` mode, the
        direction is automatically inferred from the name of the monitored
        quantity.

    """
    def __init__(
        self,
        wb: object,
        monitor: str = "val_loss",
        mode: str = "auto",
    ):
        super(WnBReportBest, self).__init__()

        self.monitor = monitor
        self.mode = mode
        self.wb = wb

        if self.mode not in ["auto", "min", "max"]:
            warnings.warn(
                "WnBReportBest mode %s is unknown, "
                "fallback to auto mode." % (self.mode),
                RuntimeWarning,
            )
            self.mode = "auto"
        if self.mode == "min":
            self.monitor_op = np.less
            self.best = np.Inf
        elif self.mode == "max":
            self.monitor_op = np.greater  # type: ignore[assignment]
            self.best = -np.Inf
        else:
            if self._is_metric(self.monitor):
                self.monitor_op = np.greater  # type: ignore[assignment]
                self.best = -np.Inf
            else:
                self.monitor_op = np.less
                self.best = np.Inf

    def on_epoch_end(  # noqa: C901
        self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
    ):
        logs = logs or {}
        current = logs.get(self.monitor)
        if current is not None:
            if self.monitor_op(current, self.best):
                self.wb.run.summary["best"] = current  # type: ignore[attr-defined]
                self.best = current
                self.best_epoch = epoch

    @staticmethod
    def _is_metric(monitor: str):
        "copied from pytorch_widedeep.callbacks"
        if any([s in monitor for s in ["acc", "prec", "rec", "fscore", "f1", "f2"]]):
            return True
        else:
            return False

In [16]:
df = load_bio_kdd04(as_frame=True)
df.head()

Unnamed: 0,EXAMPLE_ID,BLOCK_ID,target,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77
0,279,261532,0,52.0,32.69,0.3,2.5,20.0,1256.8,-0.89,0.33,11.0,-55.0,267.2,0.52,0.05,-2.36,49.6,252.0,0.43,1.16,-2.06,-33.0,-123.2,1.6,-0.49,-6.06,65.0,296.1,-0.28,-0.26,-3.83,-22.6,-170.0,3.06,-1.05,-3.29,22.9,286.3,0.12,2.58,4.08,-33.0,-178.9,1.88,0.53,-7.0,-44.0,1987.0,-5.41,0.95,-4.0,-57.0,722.9,-3.26,-0.55,-7.5,125.5,1547.2,-0.36,1.12,9.0,-37.0,72.5,0.47,0.74,-11.0,-8.0,1595.1,-1.64,2.83,-2.0,-50.0,445.2,-0.35,0.26,0.76
1,279,261533,0,58.0,33.33,0.0,16.5,9.5,608.1,0.5,0.07,20.5,-52.5,521.6,-1.08,0.58,-0.02,-3.2,103.6,-0.95,0.23,-2.87,-25.9,-52.2,-0.21,0.87,-1.81,10.4,62.0,-0.28,-0.04,1.48,-17.6,-198.3,3.43,2.84,5.87,-16.9,72.6,-0.31,2.79,2.71,-33.5,-11.6,-1.11,4.01,5.0,-57.0,666.3,1.13,4.38,5.0,-64.0,39.3,1.07,-0.16,32.5,100.0,1893.7,-2.8,-0.22,2.5,-28.5,45.0,0.58,0.41,-19.0,-6.0,762.9,0.29,0.82,-3.0,-35.0,140.3,1.16,0.39,0.73
2,279,261534,0,77.0,27.27,-0.91,6.0,58.5,1623.6,-1.4,0.02,-6.5,-48.0,621.0,-1.2,0.14,-0.2,73.6,609.1,-0.44,-0.58,-0.04,-23.0,-27.4,-0.72,-1.04,-1.09,91.1,635.6,-0.88,0.24,0.59,-18.7,-7.2,-0.6,-2.82,-0.71,52.4,504.1,0.89,-0.67,-9.3,-20.8,-25.7,-0.77,-0.85,0.0,-20.0,2259.0,-0.94,1.15,-4.0,-44.0,-22.7,0.94,-0.98,-19.0,105.0,1267.9,1.03,1.27,11.0,-39.5,82.3,0.47,-0.19,-10.0,7.0,1491.8,0.32,-1.29,0.0,-34.0,658.2,-0.76,0.26,0.24
3,279,261535,0,41.0,27.91,-0.35,3.0,46.0,1921.6,-1.36,-0.47,-32.0,-51.5,560.9,-0.29,-0.1,-1.11,124.3,791.6,0.0,0.39,-1.85,-21.7,-44.9,-0.21,0.02,0.89,133.9,797.8,-0.08,1.06,-0.26,-16.4,-74.1,0.97,-0.8,-0.41,66.9,955.3,-1.9,1.28,-6.65,-28.1,47.5,-1.91,1.42,1.0,-30.0,1846.7,0.76,1.1,-4.0,-52.0,-53.9,1.71,-0.22,-12.0,97.5,1969.8,-1.7,0.16,-1.0,-32.5,255.9,-0.46,1.57,10.0,6.0,2047.7,-0.98,1.53,0.0,-49.0,554.2,-0.83,0.39,0.73
4,279,261536,0,50.0,28.0,-1.32,-9.0,12.0,464.8,0.88,0.19,8.0,-51.5,98.1,1.09,-0.33,-2.16,-3.9,102.7,0.39,-1.22,-3.39,-15.2,-42.2,-1.18,-1.11,-3.55,8.9,141.3,-0.16,-0.43,-4.15,-12.9,-13.4,-1.32,-0.98,-3.69,8.8,136.1,-0.3,4.13,1.89,-13.0,-18.7,-1.37,-0.93,0.0,-1.0,810.1,-2.29,6.72,1.0,-23.0,-29.7,0.58,-1.1,-18.5,33.5,206.8,1.84,-0.13,4.0,-29.0,30.1,0.8,-0.24,5.0,-14.0,479.5,0.68,-0.59,2.0,-36.0,-6.9,2.02,0.14,-0.23


In [4]:
# imbalance of the classes
df["target"].value_counts()

0    144455
1      1296
Name: target, dtype: int64

In [5]:
# drop columns we won't need in this example
df.drop(columns=["EXAMPLE_ID", "BLOCK_ID"], inplace=True)

In [6]:
df_train, df_valid = train_test_split(
    df, test_size=0.2, stratify=df["target"], random_state=1
)
df_valid, df_test = train_test_split(
    df_valid, test_size=0.5, stratify=df_valid["target"], random_state=1
)

## Preparing the data

In [7]:
continuous_cols = df.drop(columns=["target"]).columns.values.tolist()

In [8]:
# deeptabular
tab_preprocessor = TabPreprocessor(continuous_cols=continuous_cols, scale=True)
X_tab_train = tab_preprocessor.fit_transform(df_train)
X_tab_valid = tab_preprocessor.transform(df_valid)
X_tab_test = tab_preprocessor.transform(df_test)

# target
y_train = df_train["target"].values
y_valid = df_valid["target"].values
y_test = df_test["target"].values



## Define the model

In [9]:
input_layer = len(tab_preprocessor.continuous_cols)
output_layer = 1
hidden_layers = np.linspace(
    input_layer * 2, output_layer, 5, endpoint=False, dtype=int
).tolist()

In [10]:
deeptabular = TabMlp(
    mlp_hidden_dims=hidden_layers,
    column_idx=tab_preprocessor.column_idx,
    continuous_cols=tab_preprocessor.continuous_cols,
)
model = WideDeep(deeptabular=deeptabular)
model

WideDeep(
  (deeptabular): Sequential(
    (0): TabMlp(
      (cat_and_cont_embed): DiffSizeCatAndContEmbeddings(
        (cont_norm): BatchNorm1d(74, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (tab_mlp): MLP(
        (mlp): Sequential(
          (dense_layer_0): Sequential(
            (0): Dropout(p=0.1, inplace=False)
            (1): Linear(in_features=74, out_features=148, bias=True)
            (2): ReLU(inplace=True)
          )
          (dense_layer_1): Sequential(
            (0): Dropout(p=0.1, inplace=False)
            (1): Linear(in_features=148, out_features=118, bias=True)
            (2): ReLU(inplace=True)
          )
          (dense_layer_2): Sequential(
            (0): Dropout(p=0.1, inplace=False)
            (1): Linear(in_features=118, out_features=89, bias=True)
            (2): ReLU(inplace=True)
          )
          (dense_layer_3): Sequential(
            (0): Dropout(p=0.1, inplace=False)
            (1): Linear(in_featu

In [11]:
# Metrics from torchmetrics
accuracy = Accuracy_torchmetrics(average=None, num_classes=1)
precision = Precision_torchmetrics(average="micro", num_classes=1)
f1 = F1_torchmetrics(average=None, num_classes=1)
recall = Recall_torchmetrics(average=None, num_classes=1)

**Note**:

Following cells includes usage of both `RayTuneReporter` and `WnBReportBest` callbacks.
In case you want to use just `RayTuneReporter`, remove following:
* wandb from config
* `WandbLoggerCallback`
* `WnBReportBest`
* `@wandb_mixin` decorator

We do not see strong reason to use WnB without RayTune for a single paramater combination run, but it is possible:
* **option01**: define paramaters in config only for a single value `tune.grid_search([1000])` (single value RayTune run)
* **option02**: define WnB callback that reports currnet validation/training loss, metrics, etc. at the end of batch, ie. do not report to WnB at `epoch_end` as in `WnBReportBest` but at the `on_batch_end`, see `pytorch_widedeep.callbacks.Callback`


In [30]:
config = {
    "batch_size": tune.grid_search([1000, 5000]),
    "wandb": {
        "project": "test",
        # "api_key_file": os.getcwd() + "/wandb_api.key",
        "api_key": "WNB_API_KEY", 
    },
}

# Optimizers
deep_opt = SGD(model.deeptabular.parameters(), lr=0.1)
# LR Scheduler
deep_sch = lr_scheduler.StepLR(deep_opt, step_size=3)


@wandb_mixin
def training_function(config, X_train, X_val):
    early_stopping = EarlyStopping()
    model_checkpoint = ModelCheckpoint(save_best_only=True)
    # Hyperparameters
    batch_size = config["batch_size"]
    trainer = Trainer(
        model,
        objective="binary_focal_loss",
        callbacks=[RayTuneReporter, WnBReportBest(wb=wandb), early_stopping, model_checkpoint],
        lr_schedulers={"deeptabular": deep_sch},
        initializers={"deeptabular": XavierNormal},
        optimizers={"deeptabular": deep_opt},
        metrics=[accuracy, precision, recall, f1],
        verbose=0,
    )

    trainer.fit(X_train=X_train, X_val=X_val, n_epochs=5, batch_size=batch_size)


X_train = {"X_tab": X_tab_train, "target": y_train}
X_val = {"X_tab": X_tab_valid, "target": y_valid}

asha_scheduler = AsyncHyperBandScheduler(
    time_attr="training_iteration",
    metric="_metric/val_loss",
    mode="min",
    max_t=100,
    grace_period=10,
    reduction_factor=3,
    brackets=1,
)

analysis = tune.run(
    tune.with_parameters(training_function, X_train=X_train, X_val=X_val),
    resources_per_trial={"cpu": 1, "gpu": 0},
    progress_reporter=JupyterNotebookReporter(overwrite=True),
    scheduler=asha_scheduler,
    config=config,
    callbacks=[
        WandbLoggerCallback(
            project=config["wandb"]["project"],
            # api_key_file=config["wandb"]["api_key_file"],
            api_key=config["wandb"]["api_key"],
            log_config=True,
        )
    ],
)

Trial name,status,loc,batch_size,iter,total time (s)
training_function_e7fce_00000,TERMINATED,10.32.44.172:6759,1000,5,12.2567
training_function_e7fce_00001,TERMINATED,10.32.44.172:6924,5000,5,12.7518


[2m[36m(training_function pid=6759)[0m   self._pusher = None
[2m[36m(training_function pid=6759)[0m wandb: 
[2m[36m(training_function pid=6759)[0m wandb: Run summary:
[2m[36m(training_function pid=6759)[0m wandb: best 0.00505
[2m[36m(training_function pid=6759)[0m wandb: 
[2m[36m(training_function pid=6759)[0m wandb: Synced training_function_e7fce_00000: https://wandb.ai/palo/test/runs/e7fce_00000
[2m[36m(training_function pid=6759)[0m wandb: Synced 3 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)
[2m[36m(training_function pid=6759)[0m wandb: Find logs at: ./wandb/run-20220731_144151-e7fce_00000/logs
[2m[36m(training_function pid=6924)[0m wandb: Waiting for W&B process to finish... (success).
[2m[36m(training_function pid=6924)[0m wandb: - 0.000 MB of 0.000 MB uploaded (0.000 MB deduped)
[2m[36m(training_function pid=6924)[0m wandb: \ 0.000 MB of 0.000 MB uploaded (0.000 MB deduped)
[2m[36m(training_function pid=6924)[0m wandb: 

In [14]:
analysis.results

{'fc9a8_00000': {'_metric': {'train_loss': 0.006297602537127896,
   'train_Accuracy': 0.9925042986869812,
   'train_Precision': 0.9939393997192383,
   'train_Recall': 0.15814851224422455,
   'train_F1Score': 0.2728785574436188,
   'val_loss': 0.005045663565397263,
   'val_Accuracy': 0.9946483969688416,
   'val_Precision': 1.0,
   'val_Recall': 0.39534884691238403,
   'val_F1Score': 0.5666667222976685},
  'time_this_iter_s': 2.388202428817749,
  'done': True,
  'timesteps_total': None,
  'episodes_total': None,
  'training_iteration': 5,
  'trial_id': 'fc9a8_00000',
  'experiment_id': 'baad1d4f3d924b48b9ece1b9f26c80cc',
  'date': '2022-07-31_14-06-51',
  'timestamp': 1659276411,
  'time_total_s': 12.656474113464355,
  'pid': 1813,
  'hostname': 'jupyter-5uperpalo',
  'node_ip': '10.32.44.172',
  'config': {'batch_size': 1000},
  'time_since_restore': 12.656474113464355,
  'timesteps_since_restore': 0,
  'iterations_since_restore': 5,
  'warmup_time': 0.8006253242492676,
  'experiment_ta

Using Weights and Biases logging you can create [parallel coordinates graphs](https://docs.wandb.ai/ref/app/features/panels/parallel-coordinates) that map parametr combinations to the best(lowest) loss achieved during the training of the networks

![WNB](figures/wnb.png "parallel coordinates")

local visualization of raytune reults using tensorboard

In [23]:
%load_ext tensorboard
%tensorboard --logdir ~/ray_results