test_callbacks.py 16.8 KB
Newer Older
J
jrzaurin 已提交
1
import os
2
import pickle
3
import shutil
J
jrzaurin 已提交
4
import string
5
from pathlib import Path
6 7 8
from itertools import chain

import numpy as np
J
jrzaurin 已提交
9 10
import torch
import pytest
11
from ray import tune
12
from torch.optim.lr_scheduler import StepLR, CyclicLR
J
jrzaurin 已提交
13

14
from pytorch_widedeep.optim import RAdam
15
from pytorch_widedeep.models import Wide, TabMlp, WideDeep, TabTransformer
16
from pytorch_widedeep.training import Trainer
17 18 19 20
from pytorch_widedeep.callbacks import (
    LRHistory,
    EarlyStopping,
    ModelCheckpoint,
21
    RayTuneReporter,
22
)
J
jrzaurin 已提交
23 24

# Wide array
25
X_wide = np.random.choice(50, (32, 10))
J
jrzaurin 已提交
26 27

# Deep Array
J
jrzaurin 已提交
28
colnames = list(string.ascii_lowercase)[:10]
29
embed_cols = [np.random.choice(np.arange(5), 32) for _ in range(5)]
J
jrzaurin 已提交
30
embed_input = [(u, i, j) for u, i, j in zip(colnames[:5], [5] * 5, [16] * 5)]
31
cont_cols = [np.random.rand(32) for _ in range(5)]
32
column_idx = {k: v for v, k in enumerate(colnames)}
33
X_tab = np.vstack(embed_cols + cont_cols).transpose()
J
jrzaurin 已提交
34

J
jrzaurin 已提交
35
# target
36
target = np.random.choice(2, 32)
J
jrzaurin 已提交
37

J
jrzaurin 已提交
38

J
jrzaurin 已提交
39 40 41
###############################################################################
# Test that history saves the information adequately
###############################################################################
42
wide = Wide(np.unique(X_wide).shape[0], 1)
43 44
deeptabular = TabMlp(
    mlp_hidden_dims=[32, 16],
45 46
    mlp_dropout=[0.5, 0.5],
    column_idx=column_idx,
J
jrzaurin 已提交
47 48 49
    embed_input=embed_input,
    continuous_cols=colnames[-5:],
)
50
model = WideDeep(wide=wide, deeptabular=deeptabular)
51

52 53 54
# 1. Single optimizers_1, single scheduler, not cyclic and both passed directly
optimizers_1 = RAdam(model.parameters())
lr_schedulers_1 = StepLR(optimizers_1, step_size=4)
55

56 57
# 2. Multiple optimizers, single scheduler, cyclic and pass via a 1 item
# dictionary
58
wide_opt_2 = torch.optim.Adam(model.wide.parameters())
59
deep_opt_2 = RAdam(model.deeptabular.parameters())
60 61 62
deep_sch_2 = CyclicLR(
    deep_opt_2, base_lr=0.001, max_lr=0.01, step_size_up=5, cycle_momentum=False
)
63 64
optimizers_2 = {"wide": wide_opt_2, "deeptabular": deep_opt_2}
lr_schedulers_2 = {"deeptabular": deep_sch_2}
65 66 67

# 3. Multiple schedulers no cyclic
wide_opt_3 = torch.optim.Adam(model.wide.parameters())
68
deep_opt_3 = RAdam(model.deeptabular.parameters())
69 70
wide_sch_3 = StepLR(wide_opt_3, step_size=4)
deep_sch_3 = StepLR(deep_opt_3, step_size=4)
71 72
optimizers_3 = {"wide": wide_opt_3, "deeptabular": deep_opt_3}
lr_schedulers_3 = {"wide": wide_sch_3, "deeptabular": deep_sch_3}
73 74 75

# 4. Multiple schedulers with cyclic
wide_opt_4 = torch.optim.Adam(model.wide.parameters())
76
deep_opt_4 = torch.optim.Adam(model.deeptabular.parameters())
77 78 79 80
wide_sch_4 = StepLR(wide_opt_4, step_size=4)
deep_sch_4 = CyclicLR(
    deep_opt_4, base_lr=0.001, max_lr=0.01, step_size_up=5, cycle_momentum=False
)
81 82
optimizers_4 = {"wide": wide_opt_4, "deeptabular": deep_opt_4}
lr_schedulers_4 = {"wide": wide_sch_4, "deeptabular": deep_sch_4}
83

84 85 86 87 88 89
# 5. Single optimizers_5, single scheduler, cyclic and both passed directly
optimizers_5 = RAdam(model.parameters())
lr_schedulers_5 = CyclicLR(
    optimizers_5, base_lr=0.001, max_lr=0.01, step_size_up=5, cycle_momentum=False
)

J
jrzaurin 已提交
90 91

@pytest.mark.parametrize(
92 93 94 95 96 97
    "optimizers, schedulers, len_loss_output, len_lr_output, init_lr, schedulers_type",
    [
        (optimizers_1, lr_schedulers_1, 5, 5, 0.001, "step"),
        (optimizers_2, lr_schedulers_2, 5, 11, 0.001, "cyclic"),
        (optimizers_3, lr_schedulers_3, 5, 5, None, None),
        (optimizers_4, lr_schedulers_4, 5, 11, None, None),
98
        (optimizers_5, lr_schedulers_5, 5, 11, 0.001, "cyclic"),
99
    ],
J
jrzaurin 已提交
100
)
101 102 103
def test_history_callback(
    optimizers, schedulers, len_loss_output, len_lr_output, init_lr, schedulers_type
):
104 105 106 107

    trainer = Trainer(
        model=model,
        objective="binary",
J
jrzaurin 已提交
108 109 110 111 112
        optimizers=optimizers,
        lr_schedulers=schedulers,
        callbacks=[LRHistory(n_epochs=5)],
        verbose=0,
    )
113
    trainer.fit(
114
        X_wide=X_wide,
115
        X_tab=X_tab,
116 117 118 119
        target=target,
        n_epochs=5,
        batch_size=16,
    )
J
jrzaurin 已提交
120
    out = []
121
    out.append(len(trainer.history["train_loss"]) == len_loss_output)
122

J
jrzaurin 已提交
123
    try:
124
        lr_list = list(chain.from_iterable(trainer.lr_history["lr_deeptabular_0"]))
125
    except Exception:
126 127 128 129 130
        try:
            lr_list = trainer.lr_history["lr_deeptabular_0"]
        except Exception:
            lr_list = trainer.lr_history["lr_0"]

J
jrzaurin 已提交
131
    out.append(len(lr_list) == len_lr_output)
132 133 134 135
    if init_lr is not None and schedulers_type == "step":
        out.append(lr_list[-1] == init_lr / 10)
    elif init_lr is not None and schedulers_type == "cyclic":
        out.append(lr_list[-1] == init_lr)
J
jrzaurin 已提交
136 137
    assert all(out)

J
jrzaurin 已提交
138 139 140 141 142

###############################################################################
# Test that EarlyStopping stops as expected
###############################################################################
def test_early_stop():
143
    wide = Wide(np.unique(X_wide).shape[0], 1)
144 145
    deeptabular = TabMlp(
        mlp_hidden_dims=[32, 16],
146 147
        mlp_dropout=[0.5, 0.5],
        column_idx=column_idx,
J
jrzaurin 已提交
148 149 150
        embed_input=embed_input,
        continuous_cols=colnames[-5:],
    )
151
    model = WideDeep(wide=wide, deeptabular=deeptabular)
152 153 154
    trainer = Trainer(
        model=model,
        objective="binary",
J
jrzaurin 已提交
155 156
        callbacks=[
            EarlyStopping(
J
jrzaurin 已提交
157
                min_delta=5.0, patience=3, restore_best_weights=True, verbose=1
J
jrzaurin 已提交
158 159 160 161
            )
        ],
        verbose=1,
    )
162
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, val_split=0.2, n_epochs=5)
J
jrzaurin 已提交
163
    # length of history = patience+1
164
    assert len(trainer.history["train_loss"]) == 3 + 1
J
jrzaurin 已提交
165

J
jrzaurin 已提交
166 167 168 169 170

###############################################################################
# Test that ModelCheckpoint behaves as expected
###############################################################################
@pytest.mark.parametrize(
P
Pavol Mulinka 已提交
171 172 173 174 175 176 177
    "fpath, save_best_only, max_save, n_files",
    [
        ("tests/test_model_functioning/weights/test_weights", True, 2, 2),
        ("tests/test_model_functioning/weights/test_weights", False, 2, 2),
        ("tests/test_model_functioning/weights/test_weights", False, 0, 5),
        (None, False, 0, 0),
    ],
J
jrzaurin 已提交
178
)
P
Pavol Mulinka 已提交
179
def test_model_checkpoint(fpath, save_best_only, max_save, n_files):
180
    wide = Wide(np.unique(X_wide).shape[0], 1)
181 182
    deeptabular = TabMlp(
        mlp_hidden_dims=[32, 16],
183 184
        mlp_dropout=[0.5, 0.5],
        column_idx=column_idx,
J
jrzaurin 已提交
185 186 187
        embed_input=embed_input,
        continuous_cols=colnames[-5:],
    )
188
    model = WideDeep(wide=wide, deeptabular=deeptabular)
189 190 191
    trainer = Trainer(
        model=model,
        objective="binary",
J
jrzaurin 已提交
192 193
        callbacks=[
            ModelCheckpoint(
P
Pavol Mulinka 已提交
194
                filepath=fpath,
195 196
                save_best_only=save_best_only,
                max_save=max_save,
J
jrzaurin 已提交
197 198 199 200
            )
        ],
        verbose=0,
    )
201
    trainer.fit(X_wide=X_wide, X_tab=X_tab, target=target, n_epochs=5, val_split=0.2)
P
Pavol Mulinka 已提交
202 203 204 205 206
    if fpath:
        n_saved = len(os.listdir("tests/test_model_functioning/weights/"))
        shutil.rmtree("tests/test_model_functioning/weights/")
    else:
        n_saved = 0
J
jrzaurin 已提交
207
    assert n_saved <= n_files
208 209 210 211


def test_filepath_error():
    wide = Wide(np.unique(X_wide).shape[0], 1)
212 213
    deeptabular = TabMlp(
        mlp_hidden_dims=[16, 4],
214
        column_idx=column_idx,
215 216 217
        embed_input=embed_input,
        continuous_cols=colnames[-5:],
    )
218
    model = WideDeep(wide=wide, deeptabular=deeptabular)
219
    with pytest.raises(ValueError):
220 221 222
        trainer = Trainer(  # noqa: F841
            model=model,
            objective="binary",
223 224 225
            callbacks=[ModelCheckpoint(filepath="wrong_file_path")],
            verbose=0,
        )
226 227 228 229 230 231 232 233 234 235 236 237 238 239


###############################################################################
# Repeat 1st set of tests for TabTransormer
###############################################################################

# Wide array
X_wide = np.random.choice(50, (32, 10))

# Tab Array
colnames = list(string.ascii_lowercase)[:10]
embed_cols = [np.random.choice(np.arange(5), 32) for _ in range(5)]
embeds_input = [(i, j) for i, j in zip(colnames[:5], [5] * 5)]  # type: ignore[misc]
cont_cols = [np.random.rand(32) for _ in range(5)]
240
column_idx = {k: v for v, k in enumerate(colnames)}
241 242 243 244 245 246 247
X_tab = np.vstack(embed_cols + cont_cols).transpose()

# target
target = np.random.choice(2, 32)

wide = Wide(np.unique(X_wide).shape[0], 1)
tab_transformer = TabTransformer(
248
    column_idx={k: v for v, k in enumerate(colnames)},
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
    embed_input=embeds_input,
    continuous_cols=colnames[5:],
)
model_tt = WideDeep(wide=wide, deeptabular=tab_transformer)

# 1. Single optimizers_1, single scheduler, not cyclic and both passed directly
optimizers_1 = RAdam(model_tt.parameters())
lr_schedulers_1 = StepLR(optimizers_1, step_size=4)

# 2. Multiple optimizers, single scheduler, cyclic and pass via a 1 item
# dictionary
wide_opt_2 = torch.optim.Adam(model_tt.wide.parameters())
deep_opt_2 = RAdam(model_tt.deeptabular.parameters())
deep_sch_2 = CyclicLR(
    deep_opt_2, base_lr=0.001, max_lr=0.01, step_size_up=5, cycle_momentum=False
)
optimizers_2 = {"wide": wide_opt_2, "deeptabular": deep_opt_2}
lr_schedulers_2 = {"deeptabular": deep_sch_2}

# 3. Multiple schedulers no cyclic
wide_opt_3 = torch.optim.Adam(model_tt.wide.parameters())
deep_opt_3 = RAdam(model_tt.deeptabular.parameters())
wide_sch_3 = StepLR(wide_opt_3, step_size=4)
deep_sch_3 = StepLR(deep_opt_3, step_size=4)
optimizers_3 = {"wide": wide_opt_3, "deeptabular": deep_opt_3}
lr_schedulers_3 = {"wide": wide_sch_3, "deeptabular": deep_sch_3}

# 4. Multiple schedulers with cyclic
wide_opt_4 = torch.optim.Adam(model_tt.wide.parameters())
deep_opt_4 = torch.optim.Adam(model_tt.deeptabular.parameters())
wide_sch_4 = StepLR(wide_opt_4, step_size=4)
deep_sch_4 = CyclicLR(
    deep_opt_4, base_lr=0.001, max_lr=0.01, step_size_up=5, cycle_momentum=False
)
optimizers_4 = {"wide": wide_opt_4, "deeptabular": deep_opt_4}
lr_schedulers_4 = {"wide": wide_sch_4, "deeptabular": deep_sch_4}


@pytest.mark.parametrize(
    "optimizers, schedulers, len_loss_output, len_lr_output, init_lr, schedulers_type",
    [
        (optimizers_1, lr_schedulers_1, 5, 5, 0.001, "step"),
        (optimizers_2, lr_schedulers_2, 5, 11, 0.001, "cyclic"),
        (optimizers_3, lr_schedulers_3, 5, 5, None, None),
        (optimizers_4, lr_schedulers_4, 5, 11, None, None),
    ],
)
def test_history_callback_w_tabtransformer(
    optimizers, schedulers, len_loss_output, len_lr_output, init_lr, schedulers_type
):
299 300 301
    trainer_tt = Trainer(
        model_tt,
        objective="binary",
302 303 304 305 306
        optimizers=optimizers,
        lr_schedulers=schedulers,
        callbacks=[LRHistory(n_epochs=5)],
        verbose=0,
    )
307
    trainer_tt.fit(
308 309 310 311 312 313 314
        X_wide=X_wide,
        X_tab=X_tab,
        target=target,
        n_epochs=5,
        batch_size=16,
    )
    out = []
315
    out.append(len(trainer_tt.history["train_loss"]) == len_loss_output)
316
    try:
317
        lr_list = list(chain.from_iterable(trainer_tt.lr_history["lr_deeptabular_0"]))
318
    except TypeError:
319
        lr_list = trainer_tt.lr_history["lr_deeptabular_0"]
320
    except Exception:
321
        lr_list = trainer_tt.lr_history["lr_0"]
322 323 324 325 326 327
    out.append(len(lr_list) == len_lr_output)
    if init_lr is not None and schedulers_type == "step":
        out.append(lr_list[-1] == init_lr / 10)
    elif init_lr is not None and schedulers_type == "cyclic":
        out.append(lr_list[-1] == init_lr)
    assert all(out)
328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349


def test_modelcheckpoint_mode_warning():

    fpath = "tests/test_model_functioning/modelcheckpoint/weights_out"

    with pytest.warns(RuntimeWarning):
        model_checkpoint = ModelCheckpoint(  # noqa: F841
            filepath=fpath, monitor="val_loss", mode="unknown"
        )

    shutil.rmtree("tests/test_model_functioning/modelcheckpoint/")


def test_modelcheckpoint_mode_options():

    fpath = "tests/test_model_functioning/modelcheckpoint/weights_out"

    model_checkpoint_1 = ModelCheckpoint(filepath=fpath, monitor="val_loss", mode="min")
    model_checkpoint_2 = ModelCheckpoint(filepath=fpath, monitor="val_loss")
    model_checkpoint_3 = ModelCheckpoint(filepath=fpath, monitor="acc", mode="max")
    model_checkpoint_4 = ModelCheckpoint(filepath=fpath, monitor="acc")
P
Pavol Mulinka 已提交
350
    model_checkpoint_5 = ModelCheckpoint(filepath=None, monitor="acc")
351 352 353 354 355 356 357 358 359

    is_min = model_checkpoint_1.monitor_op is np.less
    best_inf = model_checkpoint_1.best is np.Inf
    auto_is_min = model_checkpoint_2.monitor_op is np.less
    auto_best_inf = model_checkpoint_2.best is np.Inf
    is_max = model_checkpoint_3.monitor_op is np.greater
    best_minus_inf = -model_checkpoint_3.best == np.Inf
    auto_is_max = model_checkpoint_4.monitor_op is np.greater
    auto_best_minus_inf = -model_checkpoint_4.best == np.Inf
P
Pavol Mulinka 已提交
360 361
    auto_is_max = model_checkpoint_5.monitor_op is np.greater
    auto_best_minus_inf = -model_checkpoint_5.best == np.Inf
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481

    shutil.rmtree("tests/test_model_functioning/modelcheckpoint/")

    assert all(
        [
            is_min,
            best_inf,
            is_max,
            best_minus_inf,
            auto_is_min,
            auto_best_inf,
            auto_is_max,
            auto_best_minus_inf,
        ]
    )


def test_modelcheckpoint_get_state():

    fpath = "tests/test_model_functioning/modelcheckpoint/"

    model_checkpoint = ModelCheckpoint(
        filepath="/".join([fpath, "weights_out"]), monitor="val_loss"
    )

    trainer = Trainer(
        model,
        objective="binary",
        callbacks=[model_checkpoint],
        verbose=0,
    )
    trainer.fit(
        X_wide=X_wide,
        X_tab=X_tab,
        target=target,
        n_epochs=1,
        batch_size=16,
    )

    with open("/".join([fpath, "checkpoint.p"]), "wb") as f:
        pickle.dump(model_checkpoint, f)

    with open("/".join([fpath, "checkpoint.p"]), "rb") as f:
        model_checkpoint = pickle.load(f)

    self_dict_keys = model_checkpoint.__dict__.keys()

    no_trainer = "trainer" not in self_dict_keys
    no_model = "model" not in self_dict_keys

    shutil.rmtree("tests/test_model_functioning/modelcheckpoint/")

    assert no_trainer and no_model


def test_early_stop_mode_warning():

    with pytest.warns(RuntimeWarning):
        model_checkpoint = EarlyStopping(  # noqa: F841
            monitor="val_loss", mode="unknown"
        )


def test_early_stop_mode_options():

    early_stopping_1 = EarlyStopping(monitor="val_loss", mode="min")
    early_stopping_2 = EarlyStopping(monitor="val_loss")
    early_stopping_3 = EarlyStopping(monitor="acc", mode="max")
    early_stopping_4 = EarlyStopping(monitor="acc")

    is_min = early_stopping_1.monitor_op is np.less
    auto_is_min = early_stopping_2.monitor_op is np.less
    is_max = early_stopping_3.monitor_op is np.greater
    auto_is_max = early_stopping_4.monitor_op is np.greater

    assert all(
        [
            is_min,
            is_max,
            auto_is_min,
            auto_is_max,
        ]
    )


def test_early_stopping_get_state():

    early_stopping_path = Path("tests/test_model_functioning/early_stopping")
    early_stopping_path.mkdir()

    early_stopping = EarlyStopping()

    trainer_tt = Trainer(
        model,
        objective="binary",
        callbacks=[early_stopping],
        verbose=0,
    )
    trainer_tt.fit(
        X_wide=X_wide,
        X_tab=X_tab,
        target=target,
        n_epochs=1,
        batch_size=16,
    )

    with open(early_stopping_path / "early_stopping.p", "wb") as f:
        pickle.dump(early_stopping, f)

    with open(early_stopping_path / "early_stopping.p", "rb") as f:
        early_stopping = pickle.load(f)

    self_dict_keys = early_stopping.__dict__.keys()

    no_trainer = "trainer" not in self_dict_keys
    no_model = "model" not in self_dict_keys

    shutil.rmtree("tests/test_model_functioning/early_stopping/")

    assert no_trainer and no_model
482 483 484 485 486 487 488 489 490


###############################################################################
# Test RayTuneReporter
###############################################################################


def test_ray_tune_reporter():

491 492 493 494 495 496 497 498 499 500
    rt_wide = Wide(np.unique(X_wide).shape[0], 1)
    rt_deeptabular = TabMlp(
        mlp_hidden_dims=[32, 16],
        mlp_dropout=[0.5, 0.5],
        column_idx=column_idx,
        embed_input=embed_input,
        continuous_cols=colnames[-5:],
    )
    rt_model = WideDeep(wide=rt_wide, deeptabular=rt_deeptabular)

501 502 503 504 505 506 507 508
    config = {
        "batch_size": tune.grid_search([8, 16]),
    }

    def training_function(config):
        batch_size = config["batch_size"]

        trainer = Trainer(
509
            rt_model,
510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
            objective="binary",
            callbacks=[RayTuneReporter],
            verbose=0,
        )

        trainer.fit(
            X_wide=X_wide,
            X_tab=X_tab,
            target=target,
            n_epochs=1,
            batch_size=batch_size,
        )

    analysis = tune.run(
        tune.with_parameters(training_function),
        config=config,
526 527 528
        resources_per_trial={"cpu": 1, "gpu": 0}
        if not torch.cuda.is_available()
        else {"cpu": 0, "gpu": 1},
529 530 531 532
        verbose=0,
    )

    assert any(["train_loss" in el for el in analysis.results_df.keys()])