未验证 提交 ee25d701 编写于 作者: W wangxinxin08 提交者: GitHub

eval with ema weight while training (#2747)

上级 f48cab63
...@@ -26,7 +26,6 @@ import paddle ...@@ -26,7 +26,6 @@ import paddle
import paddle.distributed as dist import paddle.distributed as dist
from ppdet.utils.checkpoint import save_model from ppdet.utils.checkpoint import save_model
from ppdet.optimizer import ModelEMA
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine') logger = setup_logger('ppdet.engine')
...@@ -143,20 +142,12 @@ class Checkpointer(Callback): ...@@ -143,20 +142,12 @@ class Checkpointer(Callback):
super(Checkpointer, self).__init__(model) super(Checkpointer, self).__init__(model)
cfg = self.model.cfg cfg = self.model.cfg
self.best_ap = 0. self.best_ap = 0.
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
self.save_dir = os.path.join(self.model.cfg.save_dir, self.save_dir = os.path.join(self.model.cfg.save_dir,
self.model.cfg.filename) self.model.cfg.filename)
if hasattr(self.model.model, 'student_model'): if hasattr(self.model.model, 'student_model'):
self.weight = self.model.model.student_model self.weight = self.model.model.student_model
else: else:
self.weight = self.model.model self.weight = self.model.model
if self.use_ema:
self.ema = ModelEMA(
cfg['ema_decay'], self.weight, use_thres_step=True)
def on_step_end(self, status):
if self.use_ema:
self.ema.update(self.weight)
def on_epoch_end(self, status): def on_epoch_end(self, status):
# Checkpointer only performed during training # Checkpointer only performed during training
...@@ -170,9 +161,6 @@ class Checkpointer(Callback): ...@@ -170,9 +161,6 @@ class Checkpointer(Callback):
if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1: if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
save_name = str( save_name = str(
epoch_id) if epoch_id != end_epoch - 1 else "model_final" epoch_id) if epoch_id != end_epoch - 1 else "model_final"
if self.use_ema:
weight = self.ema.apply()
else:
weight = self.weight weight = self.weight
elif mode == 'eval': elif mode == 'eval':
if 'save_best_model' in status and status['save_best_model']: if 'save_best_model' in status and status['save_best_model']:
...@@ -187,9 +175,6 @@ class Checkpointer(Callback): ...@@ -187,9 +175,6 @@ class Checkpointer(Callback):
if map_res[key][0] > self.best_ap: if map_res[key][0] > self.best_ap:
self.best_ap = map_res[key][0] self.best_ap = map_res[key][0]
save_name = 'best_model' save_name = 'best_model'
if self.use_ema:
weight = self.ema.apply()
else:
weight = self.weight weight = self.weight
logger.info("Best test {} ap is {:0.3f}.".format( logger.info("Best test {} ap is {:0.3f}.".format(
key, self.best_ap)) key, self.best_ap))
......
...@@ -28,6 +28,7 @@ import paddle.distributed as dist ...@@ -28,6 +28,7 @@ import paddle.distributed as dist
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle import amp from paddle import amp
from paddle.static import InputSpec from paddle.static import InputSpec
from ppdet.optimizer import ModelEMA
from ppdet.core.workspace import create from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
...@@ -61,6 +62,11 @@ class Trainer(object): ...@@ -61,6 +62,11 @@ class Trainer(object):
self.model = self.cfg.model self.model = self.cfg.model
self.is_loaded_weights = True self.is_loaded_weights = True
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema:
self.ema = ModelEMA(
cfg['ema_decay'], self.model, use_thres_step=True)
# build data loader # build data loader
self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())] self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
if self.mode == 'train': if self.mode == 'train':
...@@ -281,8 +287,15 @@ class Trainer(object): ...@@ -281,8 +287,15 @@ class Trainer(object):
self.status['batch_time'].update(time.time() - iter_tic) self.status['batch_time'].update(time.time() - iter_tic)
self._compose_callback.on_step_end(self.status) self._compose_callback.on_step_end(self.status)
if self.use_ema:
self.ema.update(self.model)
iter_tic = time.time() iter_tic = time.time()
# apply ema weight on model
if self.use_ema:
weight = self.model.state_dict()
self.model.set_dict(self.ema.apply())
self._compose_callback.on_epoch_end(self.status) self._compose_callback.on_epoch_end(self.status)
if validate and (self._nranks < 2 or self._local_rank == 0) \ if validate and (self._nranks < 2 or self._local_rank == 0) \
...@@ -303,6 +316,10 @@ class Trainer(object): ...@@ -303,6 +316,10 @@ class Trainer(object):
self.status['save_best_model'] = True self.status['save_best_model'] = True
self._eval_with_loader(self._eval_loader) self._eval_with_loader(self._eval_loader)
# restore origin weight on model
if self.use_ema:
self.model.set_dict(weight)
def _eval_with_loader(self, loader): def _eval_with_loader(self, loader):
sample_num = 0 sample_num = 0
tic = time.time() tic = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册