未验证 提交 34a74b39 编写于 作者: S shangliang Xu 提交者: GitHub

refine ema_model save (#5314)

上级 c27233d3
......@@ -182,7 +182,7 @@ class Checkpointer(Callback):
) % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
save_name = str(
epoch_id) if epoch_id != end_epoch - 1 else "model_final"
weight = self.weight
weight = self.weight.state_dict()
elif mode == 'eval':
if 'save_best_model' in status and status['save_best_model']:
for metric in self.model._metrics:
......@@ -201,18 +201,22 @@ class Checkpointer(Callback):
if map_res[key][0] > self.best_ap:
self.best_ap = map_res[key][0]
save_name = 'best_model'
weight = self.weight
weight = self.weight.state_dict()
logger.info("Best test {} ap is {:0.3f}.".format(
key, self.best_ap))
if weight:
if self.model.use_ema:
save_model(status['weight'], self.save_dir, save_name,
epoch_id + 1, self.model.optimizer)
save_model(weight, self.save_dir,
'{}_ema'.format(save_name), epoch_id + 1)
# save model and ema_model
save_model(
status['weight'],
self.model.optimizer,
self.save_dir,
save_name,
epoch_id + 1,
ema_model=weight)
else:
save_model(weight, self.save_dir, save_name, epoch_id + 1,
self.model.optimizer)
save_model(weight, self.model.optimizer, self.save_dir,
save_name, epoch_id + 1)
class WiferFaceEval(Callback):
......
......@@ -332,7 +332,7 @@ class ModelEMA(object):
for k, v in self.state_dict.items():
self.state_dict[k] = paddle.zeros_like(v)
def resume(self, state_dict, step):
def resume(self, state_dict, step=0):
for k, v in state_dict.items():
self.state_dict[k] = v
self.step = step
......
......@@ -72,7 +72,14 @@ def load_weight(model, weight, optimizer=None, ema=None):
raise ValueError("Model pretrain path {} does not "
"exists.".format(pdparam_path))
if ema is not None and os.path.exists(path + '.pdema'):
# Exchange model and ema_model to load
ema_state_dict = paddle.load(pdparam_path)
param_state_dict = paddle.load(path + '.pdema')
else:
ema_state_dict = None
param_state_dict = paddle.load(pdparam_path)
model_dict = model.state_dict()
model_weight = {}
incorrect_keys = 0
......@@ -102,10 +109,11 @@ def load_weight(model, weight, optimizer=None, ema=None):
last_epoch = optim_state_dict.pop('last_epoch')
optimizer.set_state_dict(optim_state_dict)
if ema is not None and os.path.exists(path + '_ema.pdparams'):
ema_state_dict = paddle.load(path + '_ema.pdparams')
if ema_state_dict is not None:
ema.resume(ema_state_dict,
optim_state_dict['LR_Scheduler']['last_epoch'])
elif ema_state_dict is not None:
ema.resume(ema_state_dict)
return last_epoch
......@@ -205,30 +213,42 @@ def load_pretrain_weight(model, pretrain_weight):
logger.info('Finish loading model weights: {}'.format(weights_path))
def save_model(model, save_dir, save_name, last_epoch, optimizer=None):
def save_model(model,
optimizer,
save_dir,
save_name,
last_epoch,
ema_model=None):
"""
save model into disk.
Args:
model (paddle.nn.Layer): the Layer instalce to save parameters.
model (dict): the model state_dict to save parameters.
optimizer (paddle.optimizer.Optimizer): the Optimizer instance to
save optimizer states.
save_dir (str): the directory to be saved.
save_name (str): the path to be saved.
last_epoch (int): the epoch index.
ema_model (dict|None): the ema_model state_dict to save parameters.
"""
if paddle.distributed.get_rank() != 0:
return
assert isinstance(model, dict), ("model is not a instance of dict, "
"please call model.state_dict() to get.")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
save_path = os.path.join(save_dir, save_name)
if isinstance(model, nn.Layer):
paddle.save(model.state_dict(), save_path + ".pdparams")
else:
assert isinstance(model,
dict), 'model is not a instance of nn.layer or dict'
# save model
if ema_model is None:
paddle.save(model, save_path + ".pdparams")
if optimizer is not None:
else:
assert isinstance(ema_model,
dict), ("ema_model is not a instance of dict, "
"please call model.state_dict() to get.")
# Exchange model and ema_model to save
paddle.save(ema_model, save_path + ".pdparams")
paddle.save(model, save_path + ".pdema")
# save optimizer
state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch
paddle.save(state_dict, save_path + ".pdopt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册