提交 46e68b6b 编写于 作者: L LuoXueling

The unit test for the restored state of EarlyStopping.

上级 8ce56d63
......@@ -479,3 +479,30 @@ def test_early_stopping_get_state():
shutil.rmtree("tests/test_model_functioning/early_stopping/")
assert no_trainer and no_model
def test_early_stopping_restore_state():
# min_delta is large, so the early stopping condition will never be met except for the first epoch.
early_stopping = EarlyStopping(
restore_best_weights=True, min_delta=1000, patience=1000
)
trainer_tt = Trainer(
model,
objective="regression",
callbacks=[early_stopping],
verbose=0,
)
trainer_tt.fit(
X_train={"X_wide": X_wide, "X_tab": X_tab, "target": target},
X_val={"X_wide": X_wide_val, "X_tab": X_tab_val, "target": target_val},
target=target,
n_epochs=2,
batch_size=16,
)
assert early_stopping.wait > 0
# so early stopping is not triggered, but is over-fitting.
pred_val = trainer_tt.predict(X_test={"X_wide": X_wide_val, "X_tab": X_tab_val})
restored_metric = trainer_tt.loss_fn(
torch.tensor(pred_val), torch.tensor(target_val)
).item()
assert np.allclose(restored_metric, early_stopping.best)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册