提交 8ce56d63 编写于 作者: L LuoXueling

Fix: EarlyStopping does not store and restore the model #175

上级 9acfceca
......@@ -6,6 +6,7 @@ CREDIT TO THE TORCHSAMPLE AND KERAS TEAMS
import os
import datetime
import warnings
import copy
import numpy as np
import torch
......@@ -657,20 +658,20 @@ class EarlyStopping(Callback):
self.best = current
self.wait = 0
if self.restore_best_weights:
self.state_dict = self.model.state_dict()
self.state_dict = copy.deepcopy(self.model.state_dict())
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.trainer.early_stop = True
if self.restore_best_weights:
if self.verbose > 0:
print("Restoring model weights from the end of the best epoch")
self.model.load_state_dict(self.state_dict)
def on_train_end(self, logs: Optional[Dict] = None):
if self.stopped_epoch > 0 and self.verbose > 0:
print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))
if self.restore_best_weights and self.state_dict is not None:
if self.verbose > 0:
print("Restoring model weights from the end of the best epoch")
self.model.load_state_dict(self.state_dict)
def get_monitor_value(self, logs):
monitor_value = logs.get(self.monitor)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册