提交 afc6636c 编写于 作者: J Javier

Makint it consistent with master so I don't have to rebase or anything...

Makint it consistent with master so I don't have to rebase or anything similar. Also fixing bugs related to the other forms of training where I did not fix the restore weights functionality
上级 76360359
......@@ -86,7 +86,7 @@ trainer.fit(
n_epochs=1,
batch_size=32,
custom_dataloader=DataLoaderImbalanced,
oversample_mul=5,
**{"oversample_mul": 5},
)
print(
"Training time[s]: {}".format(
......
......@@ -4,6 +4,7 @@ Code here is mostly based on the code from the torchsample and Keras packages
CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS
"""
import os
import copy
import datetime
import warnings
......@@ -348,6 +349,10 @@ class ModelCheckpoint(Callback):
monitor: str, default="loss"
quantity to monitor. Typically _'val_loss'_ or metric name
(e.g. _'val_acc'_)
min_delta: float, default=0.
minimum change in the monitored quantity to qualify as an
improvement, i.e. an absolute change of less than min_delta, will
count as no improvement.
verbose:int, default=0
verbosity mode
save_best_only: bool, default=False,
......@@ -396,6 +401,7 @@ class ModelCheckpoint(Callback):
self,
filepath: Optional[str] = None,
monitor: str = "val_loss",
min_delta: float = 0.0,
verbose: int = 0,
save_best_only: bool = False,
mode: str = "auto",
......@@ -406,6 +412,7 @@ class ModelCheckpoint(Callback):
self.filepath = filepath
self.monitor = monitor
self.min_delta = min_delta
self.verbose = verbose
self.save_best_only = save_best_only
self.mode = mode
......@@ -449,6 +456,11 @@ class ModelCheckpoint(Callback):
self.monitor_op = np.less
self.best = np.Inf
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
def on_epoch_end( # noqa: C901
self, epoch: int, logs: Optional[Dict] = None, metric: Optional[float] = None
):
......@@ -467,33 +479,20 @@ class ModelCheckpoint(Callback):
RuntimeWarning,
)
else:
if self.monitor_op(current, self.best):
if self.monitor_op(current - self.min_delta, self.best):
if self.verbose > 0:
if self.filepath:
print(
"\nEpoch %05d: %s improved from %0.5f to %0.5f,"
" saving model to %s"
% (
epoch + 1,
self.monitor,
self.best,
current,
filepath,
)
f"\nEpoch {epoch + 1}: {self.monitor} improved from {self.best:.5f} to {current:.5f} "
f"Saving model to {filepath}"
)
else:
print(
"\nEpoch %05d: %s improved from %0.5f to %0.5f"
% (
epoch + 1,
self.monitor,
self.best,
current,
)
f"\nEpoch {epoch + 1}: {self.monitor} improved from {self.best:.5f} to {current:.5f} "
)
self.best = current
self.best_epoch = epoch
self.best_state_dict = self.model.state_dict()
self.best_state_dict = copy.deepcopy(self.model.state_dict())
if self.filepath:
torch.save(self.best_state_dict, filepath)
if self.max_save > 0:
......@@ -507,8 +506,8 @@ class ModelCheckpoint(Callback):
else:
if self.verbose > 0:
print(
"\nEpoch %05d: %s did not improve from %0.5f"
% (epoch + 1, self.monitor, self.best)
f"\nEpoch {epoch + 1}: {self.monitor} did not improve from {self.best:.5f} "
f" considering a 'min_delta' improvement of {self.min_delta:.5f}"
)
if not self.save_best_only and self.filepath:
if self.verbose > 0:
......@@ -657,20 +656,20 @@ class EarlyStopping(Callback):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.state_dict = self.model.state_dict()
self.state_dict = copy.deepcopy(self.model.state_dict())
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.trainer.early_stop = True
if self.restore_best_weights:
if self.verbose > 0:
print("Restoring model weights from the end of the best epoch")
self.model.load_state_dict(self.state_dict)
def on_train_end(self, logs: Optional[Dict] = None):
if self.stopped_epoch > 0 and self.verbose > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
if self.restore_best_weights and self.state_dict is not None:
if self.verbose > 0:
print("Restoring model weights from the end of the best epoch")
self.model.load_state_dict(self.state_dict)
def get_monitor_value(self, logs):
monitor_value = logs.get(self.monitor)
......
......@@ -85,6 +85,7 @@ class DataLoaderImbalanced(DataLoader):
self.with_lds = dataset.with_lds
if "oversample_mul" in kwargs:
oversample_mul = kwargs["oversample_mul"]
del kwargs["oversample_mul"]
else:
oversample_mul = 1
weights, minor_cls_cnt, num_clss = get_class_weights(dataset)
......
import os
import sys
import warnings
from abc import ABC, abstractmethod
import numpy as np
......@@ -174,17 +175,34 @@ class BaseContrastiveDenoisingTrainer(ABC):
self.callback_container.set_model(self.cd_model)
self.callback_container.set_trainer(self)
def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
def _restore_best_weights(self): # noqa: C901
early_stopping_min_delta = None
model_checkpoint_min_delta = None
already_restored = False
for callback in self.callback_container.callbacks:
if (
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
):
early_stopping_min_delta = callback.min_delta
already_restored = True
if callback.__class__.__name__ == "ModelCheckpoint":
model_checkpoint_min_delta = callback.min_delta
if (
early_stopping_min_delta is not None
and model_checkpoint_min_delta is not None
) and (early_stopping_min_delta != model_checkpoint_min_delta):
warnings.warn(
"'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. "
"This implies a different definition of 'improvement' for these two callbacks",
UserWarning,
)
if already_restored:
# already restored via EarlyStopping
pass
else:
for callback in self.callback_container.callbacks:
......
import os
import sys
import warnings
from abc import ABC, abstractmethod
import numpy as np
......@@ -121,17 +122,34 @@ class BaseEncoderDecoderTrainer(ABC):
self.callback_container.set_model(self.ed_model)
self.callback_container.set_trainer(self)
def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
def _restore_best_weights(self): # noqa: C901
early_stopping_min_delta = None
model_checkpoint_min_delta = None
already_restored = False
for callback in self.callback_container.callbacks:
if (
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
):
early_stopping_min_delta = callback.min_delta
already_restored = True
if callback.__class__.__name__ == "ModelCheckpoint":
model_checkpoint_min_delta = callback.min_delta
if (
early_stopping_min_delta is not None
and model_checkpoint_min_delta is not None
) and (early_stopping_min_delta != model_checkpoint_min_delta):
warnings.warn(
"'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. "
"This implies a different definition of 'improvement' for these two callbacks",
UserWarning,
)
if already_restored:
# already restored via EarlyStopping
pass
else:
for callback in self.callback_container.callbacks:
......
import os
import sys
import warnings
from abc import ABC, abstractmethod
import numpy as np
......@@ -120,17 +121,34 @@ class BaseBayesianTrainer(ABC):
):
raise NotImplementedError("Trainer.save method not implemented")
def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
def _restore_best_weights(self): # noqa: C901
early_stopping_min_delta = None
model_checkpoint_min_delta = None
already_restored = False
for callback in self.callback_container.callbacks:
if (
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
):
early_stopping_min_delta = callback.min_delta
already_restored = True
if callback.__class__.__name__ == "ModelCheckpoint":
model_checkpoint_min_delta = callback.min_delta
if (
early_stopping_min_delta is not None
and model_checkpoint_min_delta is not None
) and (early_stopping_min_delta != model_checkpoint_min_delta):
warnings.warn(
"'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. "
"This implies a different definition of 'improvement' for these two callbacks",
UserWarning,
)
if already_restored:
# already restored via EarlyStopping
pass
else:
for callback in self.callback_container.callbacks:
......
import os
import sys
import warnings
from abc import ABC, abstractmethod
import numpy as np
......@@ -128,17 +129,34 @@ class BaseTrainer(ABC):
):
raise NotImplementedError("Trainer.save method not implemented")
def _restore_best_weights(self):
already_restored = any(
[
(
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
)
for callback in self.callback_container.callbacks
]
)
def _restore_best_weights(self): # noqa: C901
early_stopping_min_delta = None
model_checkpoint_min_delta = None
already_restored = False
for callback in self.callback_container.callbacks:
if (
callback.__class__.__name__ == "EarlyStopping"
and callback.restore_best_weights
):
early_stopping_min_delta = callback.min_delta
already_restored = True
if callback.__class__.__name__ == "ModelCheckpoint":
model_checkpoint_min_delta = callback.min_delta
if (
early_stopping_min_delta is not None
and model_checkpoint_min_delta is not None
) and (early_stopping_min_delta != model_checkpoint_min_delta):
warnings.warn(
"'min_delta' is different in the 'EarlyStopping' and 'ModelCheckpoint' callbacks. "
"This implies a different definition of 'improvement' for these two callbacks",
UserWarning,
)
if already_restored:
# already restored via EarlyStopping
pass
else:
for callback in self.callback_container.callbacks:
......
......@@ -317,6 +317,7 @@ class FineTune:
up, down: Tuple, int
number of steps increasing/decreasing the learning rate during the cycle
"""
up = round((steps * n_epochs) * 0.1)
# up = round((steps * n_epochs) * 0.1)
up = max([round((steps * n_epochs) * 0.1), 1])
down = (steps * n_epochs) - up
return up, down
import numpy as np
import pandas as pd
import pytest
from pytorch_widedeep import Trainer
from pytorch_widedeep.models import (
SAINT,
......@@ -152,7 +152,6 @@ def test_fttransformer_valueerror():
def test_feature_importances_tabnet():
tab_preprocessor = TabPreprocessor(
cat_embed_cols=cat_cols,
continuous_cols=cont_cols,
......@@ -164,7 +163,7 @@ def test_feature_importances_tabnet():
column_idx=tab_preprocessor.column_idx,
cat_embed_input=tab_preprocessor.cat_embed_input,
continuous_cols=tab_preprocessor.continuous_cols,
embed_continuous=True
embed_continuous=True,
)
model = WideDeep(deeptabular=tabnet)
......
......@@ -479,3 +479,161 @@ def test_early_stopping_get_state():
shutil.rmtree("tests/test_model_functioning/early_stopping/")
assert no_trainer and no_model
# ##############################################################################
# Test the restore weights functionalities after bug fixed
# ##############################################################################
def test_early_stopping_restore_weights_with_metric():
# min_delta is large, so the early stopping condition will be met in the first epoch.
early_stopping = EarlyStopping(
restore_best_weights=True, min_delta=1000, patience=1000
)
trainer = Trainer(
model,
objective="regression",
callbacks=[early_stopping],
verbose=0,
)
trainer.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
n_epochs=2,
batch_size=16,
)
assert early_stopping.wait > 0
# so early stopping is not triggered, but is over-fitting.
pred_val = trainer.predict(X_test={"X_wide": X_wide_val, "X_tab": X_tab_val})
restored_metric = trainer.loss_fn(
torch.tensor(pred_val), torch.tensor(target_val)
).item()
assert np.allclose(restored_metric, early_stopping.best)
def test_early_stopping_restore_weights_with_state():
# Long, perhaps too long, test to check early_stopping restore weights
# functionality
# this is repetitive, but for now I want this unit test "self-contained"
# We first define a model and train it, with early stopping that should
# set the weights back to those after the 1st epoch. We also use
# ModelCheckpoint and save all iterations
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
model = WideDeep(wide=wide, deeptabular=deeptabular)
fpath = "tests/test_model_functioning/modelcheckpoint/weights_out"
model_checkpoint = ModelCheckpoint(
filepath=fpath,
save_best_only=False,
max_save=10,
min_delta=1000, # irrelevant here
)
early_stopping = EarlyStopping(
patience=3, min_delta=1000, restore_best_weights=True
)
trainer = Trainer(
model,
objective="binary",
callbacks=[early_stopping, model_checkpoint],
verbose=0,
)
trainer.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
n_epochs=5,
batch_size=16,
)
# We now define a brand new model
new_wide = Wide(np.unique(X_wide).shape[0], 1)
new_deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
new_model = WideDeep(wide=new_wide, deeptabular=new_deeptabular)
# In general, the best epoch is equal to the (stopped_epoch - patience) + 1
full_best_epoch_path = "_".join(
[
model_checkpoint.filepath,
str((early_stopping.stopped_epoch - early_stopping.patience) + 1) + ".p",
]
)
# we load the weights for the best epoch and these should match those of
# the original model if early_stopping worked
new_model.load_state_dict(torch.load(full_best_epoch_path))
new_model.to(next(model.parameters()).device)
shutil.rmtree("tests/test_model_functioning/modelcheckpoint/")
assert torch.allclose(
new_model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
)
def test_model_checkpoint_restore_weights():
wide = Wide(np.unique(X_wide).shape[0], 1)
deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
model = WideDeep(wide=wide, deeptabular=deeptabular)
fpath = "tests/test_model_functioning/modelcheckpoint/weights_out"
model_checkpoint = ModelCheckpoint(
filepath=fpath,
save_best_only=True,
min_delta=1000, # irrelevant here
)
trainer = Trainer(
model,
objective="binary",
callbacks=[model_checkpoint],
verbose=0,
)
trainer.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
n_epochs=5,
batch_size=16,
)
new_wide = Wide(np.unique(X_wide).shape[0], 1)
new_deeptabular = TabMlp(
column_idx=column_idx,
cat_embed_input=embed_input,
continuous_cols=colnames[-5:],
mlp_hidden_dims=[16, 8],
)
new_model = WideDeep(wide=new_wide, deeptabular=new_deeptabular)
full_best_epoch_path = "_".join(
[model_checkpoint.filepath, str(model_checkpoint.best_epoch + 1) + ".p"]
)
new_model.load_state_dict(torch.load(full_best_epoch_path))
new_model.to(next(model.parameters()).device)
shutil.rmtree("tests/test_model_functioning/modelcheckpoint/")
assert torch.allclose(
new_model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
model.state_dict()["deeptabular.0.encoder.mlp.dense_layer_1.1.weight"],
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册