diff --git a/ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml b/ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml index e5f55cb894f406addf5c6ea43aa1576d95676f86..b8c865dd32e2ea213c53f7902cc7eb830705ddf6 100644 --- a/ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml +++ b/ppcls/configs/ImageNet/ConvNeXt/convnext_tiny.yaml @@ -17,6 +17,11 @@ Global: to_static: False +# model ema +EMA: + decay: 0.9999 + + # model architecture Arch: name: ConvNext_tiny diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index b36aeb70cf5ceb1917e50a7c51d4abcc9c8d1a65..772bf8ed8f95de19b1282cd4686c06337fed32a8 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -34,6 +34,7 @@ from ppcls.arch import apply_to_static from ppcls.loss import build_loss from ppcls.metric import build_metrics from ppcls.optimizer import build_optimizer +from ppcls.utils.ema import ExponentialMovingAverage from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ppcls.utils.save_load import init_model from ppcls.utils import save_load @@ -115,6 +116,9 @@ class Engine(object): }) paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) + # EMA model + self.ema = "EMA" in self.config and self.mode == "train" + if "class_num" in config["Global"]: global_class_num = config["Global"]["class_num"] if "class_num" not in config["Arch"]: @@ -250,6 +254,11 @@ class Engine(object): level=amp_level, save_dtype='float32') + # build EMA model + if self.ema: + self.model_ema = ExponentialMovingAverage( + self.model, self.config['EMA'].get("decay", 0.9999)) + # for distributed world_size = dist.get_world_size() self.config["Global"]["distributed"] = world_size != 1 @@ -278,6 +287,10 @@ class Engine(object): "metric": 0.0, "epoch": 0, } + ema_module = None + if self.ema: + best_metric_ema = 0.0 + ema_module = self.model_ema.module # key: # val: metrics list word self.output_info = dict() @@ -292,7 +305,8 @@ class Engine(object): if self.config.Global.checkpoints is not None: metric_info = init_model(self.config.Global, self.model, - self.optimizer, self.train_loss_func) + self.optimizer, self.train_loss_func, + ema_module) if metric_info is not None: best_metric.update(metric_info) @@ -327,6 +341,7 @@ class Engine(object): self.optimizer, best_metric, self.output_dir, + ema=ema_module, model_name=self.config["Arch"]["name"], prefix="best_model", loss=self.train_loss_func) @@ -340,6 +355,32 @@ class Engine(object): self.model.train() + if self.ema: + ori_model, self.model = self.model, ema_module + acc_ema = self.eval(epoch_id) + self.model = ori_model + ema_module.eval() + + if acc_ema > best_metric_ema: + best_metric_ema = acc_ema + save_load.save_model( + self.model, + self.optimizer, + {"metric": acc_ema, + "epoch": epoch_id}, + self.output_dir, + ema=ema_module, + model_name=self.config["Arch"]["name"], + prefix="best_model_ema", + loss=self.train_loss_func) + logger.info("[Eval][Epoch {}][best metric ema: {}]".format( + epoch_id, best_metric_ema)) + logger.scaler( + name="eval_acc_ema", + value=acc_ema, + step=epoch_id, + writer=self.vdl_writer) + # save model if epoch_id % save_interval == 0: save_load.save_model( @@ -347,6 +388,7 @@ class Engine(object): self.optimizer, {"metric": acc, "epoch": epoch_id}, self.output_dir, + ema=ema_module, model_name=self.config["Arch"]["name"], prefix="epoch_{}".format(epoch_id), loss=self.train_loss_func) @@ -356,6 +398,7 @@ class Engine(object): self.optimizer, {"metric": acc, "epoch": epoch_id}, self.output_dir, + ema=ema_module, model_name=self.config["Arch"]["name"], prefix="latest", loss=self.train_loss_func) diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index 1e944a609d066a6a193c5af55ce56bc931c82eeb..c46650a0973407ca4feadea34e619ff55f4c83fb 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -69,6 +69,9 @@ def train_epoch(engine, epoch_id, print_batch_step): # step lr for i in range(len(engine.lr_sch)): engine.lr_sch[i].step() + # update ema + if engine.ema: + engine.model_ema.update(engine.model) # below code just for logging # update metric_for_logger diff --git a/ppcls/utils/ema.py b/ppcls/utils/ema.py index b54cdb1b2030dc0a70394816a433e7e715e12996..8292781955210d68cea119b2fd887b534b3a6c04 100644 --- a/ppcls/utils/ema.py +++ b/ppcls/utils/ema.py @@ -1,10 +1,10 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,52 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy + import paddle -import numpy as np class ExponentialMovingAverage(): """ Exponential Moving Average - Code was heavily based on https://github.com/Wanger-SJTU/SegToolbox.Pytorch/blob/master/lib/utils/ema.py + Code was heavily based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/model_ema.py """ - def __init__(self, model, decay, thres_steps=True): - self._model = model - self._decay = decay - self._thres_steps = thres_steps - self._shadow = {} - self._backup = {} - - def register(self): - self._update_step = 0 - for name, param in self._model.named_parameters(): - if param.stop_gradient is False: - self._shadow[name] = param.numpy().copy() - - def update(self): - decay = min(self._decay, (1 + self._update_step) / ( - 10 + self._update_step)) if self._thres_steps else self._decay - for name, param in self._model.named_parameters(): - if param.stop_gradient is False: - assert name in self._shadow - new_val = np.array(param.numpy().copy()) - old_val = np.array(self._shadow[name]) - new_average = decay * old_val + (1 - decay) * new_val - self._shadow[name] = new_average - self._update_step += 1 - return decay - - def apply(self): - for name, param in self._model.named_parameters(): - if param.stop_gradient is False: - assert name in self._shadow - self._backup[name] = np.array(param.numpy().copy()) - param.set_value(np.array(self._shadow[name])) - - def restore(self): - for name, param in self._model.named_parameters(): - if param.stop_gradient is False: - assert name in self._backup - param.set_value(self._backup[name]) - self._backup = {} + def __init__(self, model, decay=0.9999): + super().__init__() + # make a copy of the model for accumulating moving average of weights + self.module = deepcopy(model) + self.module.eval() + self.decay = decay + + @paddle.no_grad() + def _update(self, model, update_fn): + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + ema_v.set_value(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m) diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 093255379cd35875fbaf06282e391017bf7f14a3..d588cbdc185de257404b2d871b87e86b847f0019 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -87,7 +87,11 @@ def load_distillation_model(model, pretrained_model): pretrained_model)) -def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): +def init_model(config, + net, + optimizer=None, + loss: paddle.nn.Layer=None, + ema=None): """ load model from checkpoint or pretrained_model """ @@ -105,7 +109,12 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): net.set_state_dict(para_dict) loss.set_state_dict(para_dict) for i in range(len(optimizer)): - optimizer[i].set_state_dict(opti_dict) + optimizer[i].set_state_dict(opti_dict[i]) + if ema is not None: + assert os.path.exists(checkpoints + ".ema.pdparams"), \ + "Given dir {}.ema.pdparams not exist.".format(checkpoints) + para_ema_dict = paddle.load(checkpoints + ".ema.pdparams") + ema.set_state_dict(para_ema_dict) logger.info("Finish load checkpoints from {}".format(checkpoints)) return metric_dict @@ -125,6 +134,7 @@ def save_model(net, optimizer, metric_info, model_path, + ema=None, model_name="", prefix='ppcls', loss: paddle.nn.Layer=None): @@ -145,6 +155,8 @@ def save_model(net, params_state_dict.update(loss_state_dict) paddle.save(params_state_dict, model_path + ".pdparams") + if ema is not None: + paddle.save(ema.state_dict(), model_path + ".ema.pdparams") paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt") paddle.save(metric_info, model_path + ".pdstates") logger.info("Already save model in {}".format(model_path))