From ee25d70189a98db4310312d5190216cb610f74b5 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Sun, 25 Apr 2021 11:27:19 +0800 Subject: [PATCH] eval with ema weight while training (#2747) --- ppdet/engine/callbacks.py | 19 ++----------------- ppdet/engine/trainer.py | 17 +++++++++++++++++ 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/ppdet/engine/callbacks.py b/ppdet/engine/callbacks.py index 9d418119f..0798b91f3 100644 --- a/ppdet/engine/callbacks.py +++ b/ppdet/engine/callbacks.py @@ -26,7 +26,6 @@ import paddle import paddle.distributed as dist from ppdet.utils.checkpoint import save_model -from ppdet.optimizer import ModelEMA from ppdet.utils.logger import setup_logger logger = setup_logger('ppdet.engine') @@ -143,20 +142,12 @@ class Checkpointer(Callback): super(Checkpointer, self).__init__(model) cfg = self.model.cfg self.best_ap = 0. - self.use_ema = ('use_ema' in cfg and cfg['use_ema']) self.save_dir = os.path.join(self.model.cfg.save_dir, self.model.cfg.filename) if hasattr(self.model.model, 'student_model'): self.weight = self.model.model.student_model else: self.weight = self.model.model - if self.use_ema: - self.ema = ModelEMA( - cfg['ema_decay'], self.weight, use_thres_step=True) - - def on_step_end(self, status): - if self.use_ema: - self.ema.update(self.weight) def on_epoch_end(self, status): # Checkpointer only performed during training @@ -170,10 +161,7 @@ class Checkpointer(Callback): if epoch_id % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1: save_name = str( epoch_id) if epoch_id != end_epoch - 1 else "model_final" - if self.use_ema: - weight = self.ema.apply() - else: - weight = self.weight + weight = self.weight elif mode == 'eval': if 'save_best_model' in status and status['save_best_model']: for metric in self.model._metrics: @@ -187,10 +175,7 @@ class Checkpointer(Callback): if map_res[key][0] > self.best_ap: self.best_ap = map_res[key][0] save_name = 'best_model' - if self.use_ema: - weight = self.ema.apply() - else: - weight = self.weight + weight = self.weight logger.info("Best test {} ap is {:0.3f}.".format( key, self.best_ap)) if weight: diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index a25554be3..2b17cdea9 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -28,6 +28,7 @@ import paddle.distributed as dist from paddle.distributed import fleet from paddle import amp from paddle.static import InputSpec +from ppdet.optimizer import ModelEMA from ppdet.core.workspace import create from ppdet.utils.checkpoint import load_weight, load_pretrain_weight @@ -61,6 +62,11 @@ class Trainer(object): self.model = self.cfg.model self.is_loaded_weights = True + self.use_ema = ('use_ema' in cfg and cfg['use_ema']) + if self.use_ema: + self.ema = ModelEMA( + cfg['ema_decay'], self.model, use_thres_step=True) + # build data loader self.dataset = cfg['{}Dataset'.format(self.mode.capitalize())] if self.mode == 'train': @@ -281,8 +287,15 @@ class Trainer(object): self.status['batch_time'].update(time.time() - iter_tic) self._compose_callback.on_step_end(self.status) + if self.use_ema: + self.ema.update(self.model) iter_tic = time.time() + # apply ema weight on model + if self.use_ema: + weight = self.model.state_dict() + self.model.set_dict(self.ema.apply()) + self._compose_callback.on_epoch_end(self.status) if validate and (self._nranks < 2 or self._local_rank == 0) \ @@ -303,6 +316,10 @@ class Trainer(object): self.status['save_best_model'] = True self._eval_with_loader(self._eval_loader) + # restore origin weight on model + if self.use_ema: + self.model.set_dict(weight) + def _eval_with_loader(self, loader): sample_num = 0 tic = time.time() -- GitLab