提交 ed820223 编写于 作者: F flytocc

add EMA code

上级 e1943f9a
...@@ -17,6 +17,11 @@ Global: ...@@ -17,6 +17,11 @@ Global:
to_static: False to_static: False
# model ema
EMA:
decay: 0.9999
# model architecture # model architecture
Arch: Arch:
name: ConvNext_tiny name: ConvNext_tiny
......
...@@ -34,6 +34,7 @@ from ppcls.arch import apply_to_static ...@@ -34,6 +34,7 @@ from ppcls.arch import apply_to_static
from ppcls.loss import build_loss from ppcls.loss import build_loss
from ppcls.metric import build_metrics from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer from ppcls.optimizer import build_optimizer
from ppcls.utils.ema import ExponentialMovingAverage
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model from ppcls.utils.save_load import init_model
from ppcls.utils import save_load from ppcls.utils import save_load
...@@ -115,6 +116,9 @@ class Engine(object): ...@@ -115,6 +116,9 @@ class Engine(object):
}) })
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
# EMA model
self.ema = "EMA" in self.config and self.mode == "train"
if "class_num" in config["Global"]: if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"] global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]: if "class_num" not in config["Arch"]:
...@@ -250,6 +254,11 @@ class Engine(object): ...@@ -250,6 +254,11 @@ class Engine(object):
level=amp_level, level=amp_level,
save_dtype='float32') save_dtype='float32')
# build EMA model
if self.ema:
self.model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
# for distributed # for distributed
world_size = dist.get_world_size() world_size = dist.get_world_size()
self.config["Global"]["distributed"] = world_size != 1 self.config["Global"]["distributed"] = world_size != 1
...@@ -278,6 +287,10 @@ class Engine(object): ...@@ -278,6 +287,10 @@ class Engine(object):
"metric": 0.0, "metric": 0.0,
"epoch": 0, "epoch": 0,
} }
ema_module = None
if self.ema:
best_metric_ema = 0.0
ema_module = self.model_ema.module
# key: # key:
# val: metrics list word # val: metrics list word
self.output_info = dict() self.output_info = dict()
...@@ -292,7 +305,8 @@ class Engine(object): ...@@ -292,7 +305,8 @@ class Engine(object):
if self.config.Global.checkpoints is not None: if self.config.Global.checkpoints is not None:
metric_info = init_model(self.config.Global, self.model, metric_info = init_model(self.config.Global, self.model,
self.optimizer, self.train_loss_func) self.optimizer, self.train_loss_func,
ema_module)
if metric_info is not None: if metric_info is not None:
best_metric.update(metric_info) best_metric.update(metric_info)
...@@ -327,6 +341,7 @@ class Engine(object): ...@@ -327,6 +341,7 @@ class Engine(object):
self.optimizer, self.optimizer,
best_metric, best_metric,
self.output_dir, self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"], model_name=self.config["Arch"]["name"],
prefix="best_model", prefix="best_model",
loss=self.train_loss_func) loss=self.train_loss_func)
...@@ -340,6 +355,32 @@ class Engine(object): ...@@ -340,6 +355,32 @@ class Engine(object):
self.model.train() self.model.train()
if self.ema:
ori_model, self.model = self.model, ema_module
acc_ema = self.eval(epoch_id)
self.model = ori_model
ema_module.eval()
if acc_ema > best_metric_ema:
best_metric_ema = acc_ema
save_load.save_model(
self.model,
self.optimizer,
{"metric": acc_ema,
"epoch": epoch_id},
self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"],
prefix="best_model_ema",
loss=self.train_loss_func)
logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
epoch_id, best_metric_ema))
logger.scaler(
name="eval_acc_ema",
value=acc_ema,
step=epoch_id,
writer=self.vdl_writer)
# save model # save model
if epoch_id % save_interval == 0: if epoch_id % save_interval == 0:
save_load.save_model( save_load.save_model(
...@@ -347,6 +388,7 @@ class Engine(object): ...@@ -347,6 +388,7 @@ class Engine(object):
self.optimizer, {"metric": acc, self.optimizer, {"metric": acc,
"epoch": epoch_id}, "epoch": epoch_id},
self.output_dir, self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"], model_name=self.config["Arch"]["name"],
prefix="epoch_{}".format(epoch_id), prefix="epoch_{}".format(epoch_id),
loss=self.train_loss_func) loss=self.train_loss_func)
...@@ -356,6 +398,7 @@ class Engine(object): ...@@ -356,6 +398,7 @@ class Engine(object):
self.optimizer, {"metric": acc, self.optimizer, {"metric": acc,
"epoch": epoch_id}, "epoch": epoch_id},
self.output_dir, self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"], model_name=self.config["Arch"]["name"],
prefix="latest", prefix="latest",
loss=self.train_loss_func) loss=self.train_loss_func)
......
...@@ -69,6 +69,9 @@ def train_epoch(engine, epoch_id, print_batch_step): ...@@ -69,6 +69,9 @@ def train_epoch(engine, epoch_id, print_batch_step):
# step lr # step lr
for i in range(len(engine.lr_sch)): for i in range(len(engine.lr_sch)):
engine.lr_sch[i].step() engine.lr_sch[i].step()
# update ema
if engine.ema:
engine.model_ema.update(engine.model)
# below code just for logging # below code just for logging
# update metric_for_logger # update metric_for_logger
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,52 +12,31 @@ ...@@ -12,52 +12,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from copy import deepcopy
import paddle import paddle
import numpy as np
class ExponentialMovingAverage(): class ExponentialMovingAverage():
""" """
Exponential Moving Average Exponential Moving Average
Code was heavily based on https://github.com/Wanger-SJTU/SegToolbox.Pytorch/blob/master/lib/utils/ema.py Code was heavily based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/model_ema.py
""" """
def __init__(self, model, decay, thres_steps=True): def __init__(self, model, decay=0.9999):
self._model = model super().__init__()
self._decay = decay # make a copy of the model for accumulating moving average of weights
self._thres_steps = thres_steps self.module = deepcopy(model)
self._shadow = {} self.module.eval()
self._backup = {} self.decay = decay
def register(self): @paddle.no_grad()
self._update_step = 0 def _update(self, model, update_fn):
for name, param in self._model.named_parameters(): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if param.stop_gradient is False: ema_v.set_value(update_fn(ema_v, model_v))
self._shadow[name] = param.numpy().copy()
def update(self, model):
def update(self): self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
decay = min(self._decay, (1 + self._update_step) / (
10 + self._update_step)) if self._thres_steps else self._decay def set(self, model):
for name, param in self._model.named_parameters(): self._update(model, update_fn=lambda e, m: m)
if param.stop_gradient is False:
assert name in self._shadow
new_val = np.array(param.numpy().copy())
old_val = np.array(self._shadow[name])
new_average = decay * old_val + (1 - decay) * new_val
self._shadow[name] = new_average
self._update_step += 1
return decay
def apply(self):
for name, param in self._model.named_parameters():
if param.stop_gradient is False:
assert name in self._shadow
self._backup[name] = np.array(param.numpy().copy())
param.set_value(np.array(self._shadow[name]))
def restore(self):
for name, param in self._model.named_parameters():
if param.stop_gradient is False:
assert name in self._backup
param.set_value(self._backup[name])
self._backup = {}
...@@ -87,7 +87,11 @@ def load_distillation_model(model, pretrained_model): ...@@ -87,7 +87,11 @@ def load_distillation_model(model, pretrained_model):
pretrained_model)) pretrained_model))
def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): def init_model(config,
net,
optimizer=None,
loss: paddle.nn.Layer=None,
ema=None):
""" """
load model from checkpoint or pretrained_model load model from checkpoint or pretrained_model
""" """
...@@ -105,7 +109,12 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None): ...@@ -105,7 +109,12 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
net.set_state_dict(para_dict) net.set_state_dict(para_dict)
loss.set_state_dict(para_dict) loss.set_state_dict(para_dict)
for i in range(len(optimizer)): for i in range(len(optimizer)):
optimizer[i].set_state_dict(opti_dict) optimizer[i].set_state_dict(opti_dict[i])
if ema is not None:
assert os.path.exists(checkpoints + ".ema.pdparams"), \
"Given dir {}.ema.pdparams not exist.".format(checkpoints)
para_ema_dict = paddle.load(checkpoints + ".ema.pdparams")
ema.set_state_dict(para_ema_dict)
logger.info("Finish load checkpoints from {}".format(checkpoints)) logger.info("Finish load checkpoints from {}".format(checkpoints))
return metric_dict return metric_dict
...@@ -125,6 +134,7 @@ def save_model(net, ...@@ -125,6 +134,7 @@ def save_model(net,
optimizer, optimizer,
metric_info, metric_info,
model_path, model_path,
ema=None,
model_name="", model_name="",
prefix='ppcls', prefix='ppcls',
loss: paddle.nn.Layer=None): loss: paddle.nn.Layer=None):
...@@ -145,6 +155,8 @@ def save_model(net, ...@@ -145,6 +155,8 @@ def save_model(net,
params_state_dict.update(loss_state_dict) params_state_dict.update(loss_state_dict)
paddle.save(params_state_dict, model_path + ".pdparams") paddle.save(params_state_dict, model_path + ".pdparams")
if ema is not None:
paddle.save(ema.state_dict(), model_path + ".ema.pdparams")
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt") paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
paddle.save(metric_info, model_path + ".pdstates") paddle.save(metric_info, model_path + ".pdstates")
logger.info("Already save model in {}".format(model_path)) logger.info("Already save model in {}".format(model_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册