未验证 提交 b28426ca 编写于 作者: J Javier 提交者: GitHub

Merge pull request #176 from LuoXueling/master

Fix: EarlyStopping does not store and restore the model #175
......@@ -6,6 +6,7 @@ CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS
import os
import datetime
import warnings
import copy
import numpy as np
import torch
......@@ -657,20 +658,20 @@ class EarlyStopping(Callback):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.state_dict = self.model.state_dict()
self.state_dict = copy.deepcopy(self.model.state_dict())
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.trainer.early_stop = True
if self.restore_best_weights:
if self.verbose > 0:
print("Restoring model weights from the end of the best epoch")
self.model.load_state_dict(self.state_dict)
def on_train_end(self, logs: Optional[Dict] = None):
if self.stopped_epoch > 0 and self.verbose > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
if self.restore_best_weights and self.state_dict is not None:
if self.verbose > 0:
print("Restoring model weights from the end of the best epoch")
self.model.load_state_dict(self.state_dict)
def get_monitor_value(self, logs):
monitor_value = logs.get(self.monitor)
......
......@@ -479,3 +479,30 @@ def test_early_stopping_get_state():
shutil.rmtree("tests/test_model_functioning/early_stopping/")
assert no_trainer and no_model
def test_early_stopping_restore_state():
# min_delta is large, so the early stopping condition will never be met except for the first epoch.
early_stopping = EarlyStopping(
restore_best_weights=True, min_delta=1000, patience=1000
)
trainer_tt = Trainer(
model,
objective="regression",
callbacks=[early_stopping],
verbose=0,
)
trainer_tt.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
n_epochs=2,
batch_size=16,
)
assert early_stopping.wait > 0
# so early stopping is not triggered, but is over-fitting.
pred_val = trainer_tt.predict(X_test={"X_wide": X_wide_val, "X_tab": X_tab_val})
restored_metric = trainer_tt.loss_fn(
torch.tensor(pred_val), torch.tensor(target_val)
).item()
assert np.allclose(restored_metric, early_stopping.best)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册