...
 
Commits (6)
    https://gitcode.net/greenplum/pytorch-widedeep/-/commit/8ce56d63f6a8a0405dac410c51b75e701cd20f20 Fix: EarlyStopping does not store and restore the model #175 2023-07-05T23:41:09+08:00 LuoXueling sharrin@sjtu.edu.cn https://gitcode.net/greenplum/pytorch-widedeep/-/commit/46e68b6b87ac5ff895fee19a990643af98a58bf3 The unit test for the restored state of EarlyStopping. 2023-07-06T01:47:06+08:00 LuoXueling sharrin@sjtu.edu.cn https://gitcode.net/greenplum/pytorch-widedeep/-/commit/29972a970cbeb1018e0c3e1a6a69f2e3630d3913 Update build.yml 2023-07-07T09:59:20+01:00 Krishan Davda krishandavda92@gmail.com Get fork PR tests running https://gitcode.net/greenplum/pytorch-widedeep/-/commit/30e8033e017f325a6ff168a803b04c3306516f88 Update build.yml 2023-07-07T10:01:22+01:00 Krishan Davda krishandavda92@gmail.com manual trigger on fork PR https://gitcode.net/greenplum/pytorch-widedeep/-/commit/b28426ca59a7b3d18b0f0d2ee643cc91cfdf134d Merge pull request #176 from LuoXueling/master 2023-07-07T10:20:18+01:00 Javier jrzaurin@gmail.com Fix: EarlyStopping does not store and restore the model #175 https://gitcode.net/greenplum/pytorch-widedeep/-/commit/26e1985921e63c44a9cf47338a5fccf7f1d81ec1 Update build.yml 2023-07-07T10:26:37+01:00 Krishan Davda krishandavda92@gmail.com revert
......@@ -83,4 +83,4 @@ jobs:
- name: upload coverage to Codecov
uses: codecov/codecov-action@v1
with:
fail_ci_if_error: true
\ No newline at end of file
fail_ci_if_error: true
......@@ -6,6 +6,7 @@ CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS
import os
import datetime
import warnings
import copy
import numpy as np
import torch
......@@ -657,20 +658,20 @@ class EarlyStopping(Callback):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.state_dict = self.model.state_dict()
self.state_dict = copy.deepcopy(self.model.state_dict())
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.trainer.early_stop = True
if self.restore_best_weights:
if self.verbose > 0:
print("Restoring model weights from the end of the best epoch")
self.model.load_state_dict(self.state_dict)
def on_train_end(self, logs: Optional[Dict] = None):
if self.stopped_epoch > 0 and self.verbose > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
if self.restore_best_weights and self.state_dict is not None:
if self.verbose > 0:
print("Restoring model weights from the end of the best epoch")
self.model.load_state_dict(self.state_dict)
def get_monitor_value(self, logs):
monitor_value = logs.get(self.monitor)
......
......@@ -479,3 +479,30 @@ def test_early_stopping_get_state():
shutil.rmtree("tests/test_model_functioning/early_stopping/")
assert no_trainer and no_model
def test_early_stopping_restore_state():
# min_delta is large, so the early stopping condition will never be met except for the first epoch.
early_stopping = EarlyStopping(
restore_best_weights=True, min_delta=1000, patience=1000
)
trainer_tt = Trainer(
model,
objective="regression",
callbacks=[early_stopping],
verbose=0,
)
trainer_tt.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
n_epochs=2,
batch_size=16,
)
assert early_stopping.wait > 0
# so early stopping is not triggered, but is over-fitting.
pred_val = trainer_tt.predict(X_test={"X_wide": X_wide_val, "X_tab": X_tab_val})
restored_metric = trainer_tt.loss_fn(
torch.tensor(pred_val), torch.tensor(target_val)
).item()
assert np.allclose(restored_metric, early_stopping.best)