未验证 提交 ee25d701 编写于 作者: W wangxinxin08 提交者: GitHub

eval with ema weight while training (#2747)

上级 f48cab63
......@@ -26,7 +26,6 @@ import paddle
import paddle.distributed as dist
from ppdet.utils.checkpoint import save_model
from ppdet.optimizer import ModelEMA
from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet.engine')
......@@ -143,20 +142,12 @@ class Checkpointer(Callback):
super(Checkpointer, self).__init__(model)
cfg = self.model.cfg
self.best_ap = 0.
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
self.save_dir = os.path.join(self.model.cfg.save_dir,
self.model.cfg.filename)
if hasattr(self.model.model, 'student_model'):
self.weight = self.model.model.student_model
else:
self.weight = self.model.model
if self.use_ema:
self.ema = ModelEMA(
cfg['ema_decay'], self.weight, use_thres_step=True)
def on_step_end(self, status):
if self.use_ema:
self.ema.update(self.weight)
def on_epoch_end(self, status):
# Checkpointer only performed during training
......@@ -170,10 +161,7 @@ class Checkpointer(Callback):
if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
save_name = str(
epoch_id) if epoch_id != end_epoch - 1 else "model_final"
if self.use_ema:
weight = self.ema.apply()
else:
weight = self.weight
weight = self.weight
elif mode == 'eval':
if 'save_best_model' in status and status['save_best_model']:
for metric in self.model._metrics:
......@@ -187,10 +175,7 @@ class Checkpointer(Callback):
if map_res[key][0] > self.best_ap:
self.best_ap = map_res[key][0]
save_name = 'best_model'
if self.use_ema:
weight = self.ema.apply()
else:
weight = self.weight
weight = self.weight
logger.info("Best test {} ap is {:0.3f}.".format(
key, self.best_ap))
if weight:
......
......@@ -28,6 +28,7 @@ import paddle.distributed as dist
from paddle.distributed import fleet
from paddle import amp
from paddle.static import InputSpec
from ppdet.optimizer import ModelEMA
from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
......@@ -61,6 +62,11 @@ class Trainer(object):
self.model = self.cfg.model
self.is_loaded_weights = True
self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
if self.use_ema:
self.ema = ModelEMA(
cfg['ema_decay'], self.model, use_thres_step=True)
# build data loader
self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())]
if self.mode == 'train':
......@@ -281,8 +287,15 @@ class Trainer(object):
self.status['batch_time'].update(time.time() - iter_tic)
self._compose_callback.on_step_end(self.status)
if self.use_ema:
self.ema.update(self.model)
iter_tic = time.time()
# apply ema weight on model
if self.use_ema:
weight = self.model.state_dict()
self.model.set_dict(self.ema.apply())
self._compose_callback.on_epoch_end(self.status)
if validate and (self._nranks < 2 or self._local_rank == 0) \
......@@ -303,6 +316,10 @@ class Trainer(object):
self.status['save_best_model'] = True
self._eval_with_loader(self._eval_loader)
# restore origin weight on model
if self.use_ema:
self.model.set_dict(weight)
def _eval_with_loader(self, loader):
sample_num = 0
tic = time.time()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册