...
 
Commits (3)
    https://gitcode.net/greenplum/pytorch-widedeep/-/commit/890917aa310c6960e5bdcc870e77935ee60ebcf9 Fix #175 early stopping and model checkpoint restoring weights. Also fix #174... 2023-07-07T17:21:46+01:00 Javier jrzaurin@gmail.com Fix #175 early stopping and model checkpoint restoring weights. Also fix #174 the finetunner single batch run https://gitcode.net/greenplum/pytorch-widedeep/-/commit/8406813ce4742cfd17c9225fc73f467a26700982 fixed unit test to run on GPU and fix a little bug for custom data loaders 2023-07-14T16:46:47+01:00 Javier jrzaurin@gmail.com https://gitcode.net/greenplum/pytorch-widedeep/-/commit/42cfe5cc3daf2992f07893d11d4a497fdca198d6 Merge pull request #177 from jrzaurin/fix_restore_best_weights 2023-07-14T18:48:29+01:00 Javier jrzaurin@gmail.com Fix #175 early stopping and model checkpoint restoring weights.
......@@ -86,7 +86,7 @@ trainer.fit(
n_epochs=1,
batch_size=32,
custom_dataloader=DataLoaderImbalanced,
oversample_mul=5,
**{"oversample_mul": 5},
)
print(
"Training time[s]: {}".format(
......
......@@ -4,9 +4,9 @@ Code here is mostly based on the code from the torchsample and Keras packages
CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS
"""
import os
import copy
import datetime
import warnings
import copy
import numpy as np
import torch
......@@ -349,6 +349,10 @@ class ModelCheckpoint(Callback):
monitor: str, default="loss"
quantity to monitor. Typically _'val_loss'_ or metric name
(e.g. _'val_acc'_)
min_delta: float, default=0.
minimum change in the monitored quantity to qualify as an
improvement, i.e. an absolute change of less than min_delta, will
count as no improvement.
verbose:int, default=0
verbosity mode
save_best_only: bool, default=False,
......@@ -397,6 +401,7 @@ class ModelCheckpoint(Callback):
self,
filepath: Optional[str] = None,
monitor: str = "val_loss",
min_delta: float = 0.0,
verbose: int = 0,
save_best_only: bool = False,
mode: str = "auto",
......@@ -407,6 +412,7 @@ class ModelCheckpoint(Callback):
self.filepath = filepath
self.monitor = monitor
self.min_delta = min_delta
self.verbose = verbose
self.save_best_only = save_best_only
self.mode = mode
......@@ -450,6 +456,11 @@ class ModelCheckpoint(Callback):
self.monitor_op = np.less
self.best = np.Inf
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def on_epoch_end( # noqa: C901
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
......@@ -468,33 +479,20 @@ class ModelCheckpoint(Callback):
RuntimeWarning,
)
else:
if self.monitor_op(current, self.best):
if self.monitor_op(current - self.min_delta, self.best):
if self.verbose > 0:
if self.filepath:
print(
"\nEpoch %05d: %s improved from %0.5f to %0.5f,"
" saving model to %s"
% (
epoch + 1,
self.monitor,
self.best,
current,
filepath,
)
f"\nEpoch {epoch + 1}: {self.monitor} improved from {self.best:.5f} to {current:.5f} "
f"Saving model to {filepath}"
)
else:
print(
"\nEpoch %05d: %s improved from %0.5f to %0.5f"
% (
epoch + 1,
self.monitor,
self.best,
current,
)
f"\nEpoch {epoch + 1}: {self.monitor} improved from {self.best:.5f} to {current:.5f} "
)
self.best = current
self.best_epoch = epoch
self.best_state_dict = self.model.state_dict()
self.best_state_dict = copy.deepcopy(self.model.state_dict())
if self.filepath:
torch.save(self.best_state_dict, filepath)
if self.max_save > 0:
......@@ -508,8 +506,8 @@ class ModelCheckpoint(Callback):
else:
if self.verbose > 0:
print(
"\nEpoch %05d: %s did not improve from %0.5f"
% (epoch + 1, self.monitor, self.best)
f"\nEpoch {epoch + 1}: {self.monitor} did not improve from {self.best:.5f} "
f" considering a 'min_delta' improvement of {self.min_delta:.5f}"
)
if not self.save_best_only and self.filepath:
if self.verbose > 0:
......
......@@ -85,6 +85,7 @@ class DataLoaderImbalanced(DataLoader):
self.with_lds = dataset.with_lds
if "oversample_mul" in kwargs:
oversample_mul = kwargs["oversample_mul"]
del kwargs["oversample_mul"]
else:
oversample_mul = 1
weights, minor_cls_cnt, num_clss = get_class_weights(dataset)
......
import os
import sys
import warnings
from abc import ABC, abstractmethod
import numpy as np
......@@ -130,17 +131,34 @@ class BaseTrainer(ABC):
):
raise NotImplementedError("Trainer.save method not implemented")
def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
def _restore_best_weights(self): # noqa: C901
early_stopping_min_delta = None
model_checkpoint_min_delta = None
already_restored = False
for callback in self.callback_container.callbacks:
if (
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
):
early_stopping_min_delta = callback.min_delta
already_restored = True
if callback.__class__.__name__ == "ModelCheckpoint":
model_checkpoint_min_delta = callback.min_delta
if (
early_stopping_min_delta is not None
and model_checkpoint_min_delta is not None
) and (early_stopping_min_delta != model_checkpoint_min_delta):
warnings.warn(
"'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. "
"This implies a different definition of 'improvement' for these two callbacks",
UserWarning,
)
if already_restored:
# already restored via EarlyStopping
pass
else:
for callback in self.callback_container.callbacks:
......
......@@ -317,6 +317,7 @@ class FineTune:
up, down: Tuple, int
number of steps increasing/decreasing the learning rate during the cycle
"""
up = round((steps * n_epochs) * 0.1)
# up = round((steps * n_epochs) * 0.1)
up = max([round((steps * n_epochs) * 0.1), 1])
down = (steps * n_epochs) - up
return up, down
......@@ -481,18 +481,21 @@ def test_early_stopping_get_state():
assert no_trainer and no_model
def test_early_stopping_restore_state():
# min_delta is large, so the early stopping condition will never be met except for the first epoch.
# ##############################################################################
# Test the restore weights functionalities after bug fixed
# ##############################################################################
def test_early_stopping_restore_weights_with_metric():
# min_delta is large, so the early stopping condition will be met in the first epoch.
early_stopping = EarlyStopping(
restore_best_weights=True, min_delta=1000, patience=1000
)
trainer_tt = Trainer(
trainer = Trainer(
model,
objective="regression",
callbacks=[early_stopping],
verbose=0,
)
trainer_tt.fit(
trainer.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
......@@ -501,8 +504,136 @@ def test_early_stopping_restore_state():
)
assert early_stopping.wait > 0
# so early stopping is not triggered, but is over-fitting.
pred_val = trainer_tt.predict(X_test={"X_wide": X_wide_val, "X_tab": X_tab_val})
restored_metric = trainer_tt.loss_fn(
pred_val = trainer.predict(X_test={"X_wide": X_wide_val, "X_tab": X_tab_val})
restored_metric = trainer.loss_fn(
torch.tensor(pred_val), torch.tensor(target_val)
).item()
assert np.allclose(restored_metric, early_stopping.best)
def test_early_stopping_restore_weights_with_state():
# Long, perhaps too long, test to check early_stopping restore weights
# functionality
# this is repetitive, but for now I want this unit test "self-contained"
# We first define a model and train it, with early stopping that should
# set the weights back to those after the 1st epoch. We also use
# ModelCheckpoint and save all iterations
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
model = WideDeep(wide=wide, deeptabular=deeptabular)
fpath = "tests/test_model_functioning/modelcheckpoint/weights_out"
model_checkpoint = ModelCheckpoint(
filepath=fpath,
save_best_only=False,
max_save=10,
min_delta=1000, # irrelevant here
)
early_stopping = EarlyStopping(
patience=3, min_delta=1000, restore_best_weights=True
)
trainer = Trainer(
model,
objective="binary",
callbacks=[early_stopping, model_checkpoint],
verbose=0,
)
trainer.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
n_epochs=5,
batch_size=16,
)
# We now define a brand new model
new_wide = Wide(np.unique(X_wide).shape[0], 1)
new_deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
new_model = WideDeep(wide=new_wide, deeptabular=new_deeptabular)
# In general, the best epoch is equal to the (stopped_epoch - patience) + 1
full_best_epoch_path = "_".join(
[
model_checkpoint.filepath,
str((early_stopping.stopped_epoch - early_stopping.patience) + 1) + ".p",
]
)
# we load the weights for the best epoch and these should match those of
# the original model if early_stopping worked
new_model.load_state_dict(torch.load(full_best_epoch_path))
new_model.to(next(model.parameters()).device)
shutil.rmtree("tests/test_model_functioning/modelcheckpoint/")
assert torch.allclose(
new_model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
)
def test_model_checkpoint_restore_weights():
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
model = WideDeep(wide=wide, deeptabular=deeptabular)
fpath = "tests/test_model_functioning/modelcheckpoint/weights_out"
model_checkpoint = ModelCheckpoint(
filepath=fpath,
save_best_only=True,
min_delta=1000, # irrelevant here
)
trainer = Trainer(
model,
objective="binary",
callbacks=[model_checkpoint],
verbose=0,
)
trainer.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
n_epochs=5,
batch_size=16,
)
new_wide = Wide(np.unique(X_wide).shape[0], 1)
new_deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
new_model = WideDeep(wide=new_wide, deeptabular=new_deeptabular)
full_best_epoch_path = "_".join(
[model_checkpoint.filepath, str(model_checkpoint.best_epoch + 1) + ".p"]
)
new_model.load_state_dict(torch.load(full_best_epoch_path))
new_model.to(next(model.parameters()).device)
shutil.rmtree("tests/test_model_functioning/modelcheckpoint/")
assert torch.allclose(
new_model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
)