未验证 提交 34a74b39 编写于 作者: S shangliang Xu 提交者: GitHub

refine ema_model save (#5314)

上级 c27233d3
...@@ -182,7 +182,7 @@ class Checkpointer(Callback): ...@@ -182,7 +182,7 @@ class Checkpointer(Callback):
) % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1: ) % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
save_name = str( save_name = str(
epoch_id) if epoch_id != end_epoch - 1 else "model_final" epoch_id) if epoch_id != end_epoch - 1 else "model_final"
weight = self.weight weight = self.weight.state_dict()
elif mode == 'eval': elif mode == 'eval':
if 'save_best_model' in status and status['save_best_model']: if 'save_best_model' in status and status['save_best_model']:
for metric in self.model._metrics: for metric in self.model._metrics:
...@@ -201,18 +201,22 @@ class Checkpointer(Callback): ...@@ -201,18 +201,22 @@ class Checkpointer(Callback):
if map_res[key][0] > self.best_ap: if map_res[key][0] > self.best_ap:
self.best_ap = map_res[key][0] self.best_ap = map_res[key][0]
save_name = 'best_model' save_name = 'best_model'
weight = self.weight weight = self.weight.state_dict()
logger.info("Best test {} ap is {:0.3f}.".format( logger.info("Best test {} ap is {:0.3f}.".format(
key, self.best_ap)) key, self.best_ap))
if weight: if weight:
if self.model.use_ema: if self.model.use_ema:
save_model(status['weight'], self.save_dir, save_name, # save model and ema_model
epoch_id + 1, self.model.optimizer) save_model(
save_model(weight, self.save_dir, status['weight'],
'{}_ema'.format(save_name), epoch_id + 1) self.model.optimizer,
self.save_dir,
save_name,
epoch_id + 1,
ema_model=weight)
else: else:
save_model(weight, self.save_dir, save_name, epoch_id + 1, save_model(weight, self.model.optimizer, self.save_dir,
self.model.optimizer) save_name, epoch_id + 1)
class WiferFaceEval(Callback): class WiferFaceEval(Callback):
......
...@@ -332,7 +332,7 @@ class ModelEMA(object): ...@@ -332,7 +332,7 @@ class ModelEMA(object):
for k, v in self.state_dict.items(): for k, v in self.state_dict.items():
self.state_dict[k] = paddle.zeros_like(v) self.state_dict[k] = paddle.zeros_like(v)
def resume(self, state_dict, step): def resume(self, state_dict, step=0):
for k, v in state_dict.items(): for k, v in state_dict.items():
self.state_dict[k] = v self.state_dict[k] = v
self.step = step self.step = step
......
...@@ -72,7 +72,14 @@ def load_weight(model, weight, optimizer=None, ema=None): ...@@ -72,7 +72,14 @@ def load_weight(model, weight, optimizer=None, ema=None):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path {} does not "
"exists.".format(pdparam_path)) "exists.".format(pdparam_path))
if ema is not None and os.path.exists(path + '.pdema'):
# Exchange model and ema_model to load
ema_state_dict = paddle.load(pdparam_path)
param_state_dict = paddle.load(path + '.pdema')
else:
ema_state_dict = None
param_state_dict = paddle.load(pdparam_path) param_state_dict = paddle.load(pdparam_path)
model_dict = model.state_dict() model_dict = model.state_dict()
model_weight = {} model_weight = {}
incorrect_keys = 0 incorrect_keys = 0
...@@ -102,10 +109,11 @@ def load_weight(model, weight, optimizer=None, ema=None): ...@@ -102,10 +109,11 @@ def load_weight(model, weight, optimizer=None, ema=None):
last_epoch = optim_state_dict.pop('last_epoch') last_epoch = optim_state_dict.pop('last_epoch')
optimizer.set_state_dict(optim_state_dict) optimizer.set_state_dict(optim_state_dict)
if ema is not None and os.path.exists(path + '_ema.pdparams'): if ema_state_dict is not None:
ema_state_dict = paddle.load(path + '_ema.pdparams')
ema.resume(ema_state_dict, ema.resume(ema_state_dict,
optim_state_dict['LR_Scheduler']['last_epoch']) optim_state_dict['LR_Scheduler']['last_epoch'])
elif ema_state_dict is not None:
ema.resume(ema_state_dict)
return last_epoch return last_epoch
...@@ -205,30 +213,42 @@ def load_pretrain_weight(model, pretrain_weight): ...@@ -205,30 +213,42 @@ def load_pretrain_weight(model, pretrain_weight):
logger.info('Finish loading model weights: {}'.format(weights_path)) logger.info('Finish loading model weights: {}'.format(weights_path))
def save_model(model, save_dir, save_name, last_epoch, optimizer=None): def save_model(model,
optimizer,
save_dir,
save_name,
last_epoch,
ema_model=None):
""" """
save model into disk. save model into disk.
Args: Args:
model (paddle.nn.Layer): the Layer instalce to save parameters. model (dict): the model state_dict to save parameters.
optimizer (paddle.optimizer.Optimizer): the Optimizer instance to optimizer (paddle.optimizer.Optimizer): the Optimizer instance to
save optimizer states. save optimizer states.
save_dir (str): the directory to be saved. save_dir (str): the directory to be saved.
save_name (str): the path to be saved. save_name (str): the path to be saved.
last_epoch (int): the epoch index. last_epoch (int): the epoch index.
ema_model (dict|None): the ema_model state_dict to save parameters.
""" """
if paddle.distributed.get_rank() != 0: if paddle.distributed.get_rank() != 0:
return return
assert isinstance(model, dict), ("model is not a instance of dict, "
"please call model.state_dict() to get.")
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
os.makedirs(save_dir) os.makedirs(save_dir)
save_path = os.path.join(save_dir, save_name) save_path = os.path.join(save_dir, save_name)
if isinstance(model, nn.Layer): # save model
paddle.save(model.state_dict(), save_path + ".pdparams") if ema_model is None:
else:
assert isinstance(model,
dict), 'model is not a instance of nn.layer or dict'
paddle.save(model, save_path + ".pdparams") paddle.save(model, save_path + ".pdparams")
if optimizer is not None: else:
assert isinstance(ema_model,
dict), ("ema_model is not a instance of dict, "
"please call model.state_dict() to get.")
# Exchange model and ema_model to save
paddle.save(ema_model, save_path + ".pdparams")
paddle.save(model, save_path + ".pdema")
# save optimizer
state_dict = optimizer.state_dict() state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch state_dict['last_epoch'] = last_epoch
paddle.save(state_dict, save_path + ".pdopt") paddle.save(state_dict, save_path + ".pdopt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册