提交 ed820223 编写于 作者: F flytocc

add EMA code

上级 e1943f9a
......@@ -17,6 +17,11 @@ Global:
to_static: False
# model ema
EMA:
decay: 0.9999
# model architecture
Arch:
name: ConvNext_tiny
......
......@@ -34,6 +34,7 @@ from ppcls.arch import apply_to_static
from ppcls.loss import build_loss
from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.ema import ExponentialMovingAverage
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model
from ppcls.utils import save_load
......@@ -115,6 +116,9 @@ class Engine(object):
})
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
# EMA model
self.ema = "EMA" in self.config and self.mode == "train"
if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
......@@ -250,6 +254,11 @@ class Engine(object):
level=amp_level,
save_dtype='float32')
# build EMA model
if self.ema:
self.model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
# for distributed
world_size = dist.get_world_size()
self.config["Global"]["distributed"] = world_size != 1
......@@ -278,6 +287,10 @@ class Engine(object):
"metric": 0.0,
"epoch": 0,
}
ema_module = None
if self.ema:
best_metric_ema = 0.0
ema_module = self.model_ema.module
# key:
# val: metrics list word
self.output_info = dict()
......@@ -292,7 +305,8 @@ class Engine(object):
if self.config.Global.checkpoints is not None:
metric_info = init_model(self.config.Global, self.model,
self.optimizer, self.train_loss_func)
self.optimizer, self.train_loss_func,
ema_module)
if metric_info is not None:
best_metric.update(metric_info)
......@@ -327,6 +341,7 @@ class Engine(object):
self.optimizer,
best_metric,
self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"],
prefix="best_model",
loss=self.train_loss_func)
......@@ -340,6 +355,32 @@ class Engine(object):
self.model.train()
if self.ema:
ori_model, self.model = self.model, ema_module
acc_ema = self.eval(epoch_id)
self.model = ori_model
ema_module.eval()
if acc_ema > best_metric_ema:
best_metric_ema = acc_ema
save_load.save_model(
self.model,
self.optimizer,
{"metric": acc_ema,
"epoch": epoch_id},
self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"],
prefix="best_model_ema",
loss=self.train_loss_func)
logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
epoch_id, best_metric_ema))
logger.scaler(
name="eval_acc_ema",
value=acc_ema,
step=epoch_id,
writer=self.vdl_writer)
# save model
if epoch_id % save_interval == 0:
save_load.save_model(
......@@ -347,6 +388,7 @@ class Engine(object):
self.optimizer, {"metric": acc,
"epoch": epoch_id},
self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"],
prefix="epoch_{}".format(epoch_id),
loss=self.train_loss_func)
......@@ -356,6 +398,7 @@ class Engine(object):
self.optimizer, {"metric": acc,
"epoch": epoch_id},
self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"],
prefix="latest",
loss=self.train_loss_func)
......
......@@ -69,6 +69,9 @@ def train_epoch(engine, epoch_id, print_batch_step):
# step lr
for i in range(len(engine.lr_sch)):
engine.lr_sch[i].step()
# update ema
if engine.ema:
engine.model_ema.update(engine.model)
# below code just for logging
# update metric_for_logger
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -12,52 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import paddle
import numpy as np
class ExponentialMovingAverage():
"""
Exponential Moving Average
Code was heavily based on https://github.com/Wanger-SJTU/SegToolbox.Pytorch/blob/master/lib/utils/ema.py
Code was heavily based on https://github.com/rwightman/pytorch-image-models/blob/master/timm/utils/model_ema.py
"""
def __init__(self, model, decay, thres_steps=True):
self._model = model
self._decay = decay
self._thres_steps = thres_steps
self._shadow = {}
self._backup = {}
def register(self):
self._update_step = 0
for name, param in self._model.named_parameters():
if param.stop_gradient is False:
self._shadow[name] = param.numpy().copy()
def update(self):
decay = min(self._decay, (1 + self._update_step) / (
10 + self._update_step)) if self._thres_steps else self._decay
for name, param in self._model.named_parameters():
if param.stop_gradient is False:
assert name in self._shadow
new_val = np.array(param.numpy().copy())
old_val = np.array(self._shadow[name])
new_average = decay * old_val + (1 - decay) * new_val
self._shadow[name] = new_average
self._update_step += 1
return decay
def apply(self):
for name, param in self._model.named_parameters():
if param.stop_gradient is False:
assert name in self._shadow
self._backup[name] = np.array(param.numpy().copy())
param.set_value(np.array(self._shadow[name]))
def restore(self):
for name, param in self._model.named_parameters():
if param.stop_gradient is False:
assert name in self._backup
param.set_value(self._backup[name])
self._backup = {}
def __init__(self, model, decay=0.9999):
super().__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
@paddle.no_grad()
def _update(self, model, update_fn):
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
ema_v.set_value(update_fn(ema_v, model_v))
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
......@@ -87,7 +87,11 @@ def load_distillation_model(model, pretrained_model):
pretrained_model))
def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
def init_model(config,
net,
optimizer=None,
loss: paddle.nn.Layer=None,
ema=None):
"""
load model from checkpoint or pretrained_model
"""
......@@ -105,7 +109,12 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
net.set_state_dict(para_dict)
loss.set_state_dict(para_dict)
for i in range(len(optimizer)):
optimizer[i].set_state_dict(opti_dict)
optimizer[i].set_state_dict(opti_dict[i])
if ema is not None:
assert os.path.exists(checkpoints + ".ema.pdparams"), \
"Given dir {}.ema.pdparams not exist.".format(checkpoints)
para_ema_dict = paddle.load(checkpoints + ".ema.pdparams")
ema.set_state_dict(para_ema_dict)
logger.info("Finish load checkpoints from {}".format(checkpoints))
return metric_dict
......@@ -125,6 +134,7 @@ def save_model(net,
optimizer,
metric_info,
model_path,
ema=None,
model_name="",
prefix='ppcls',
loss: paddle.nn.Layer=None):
......@@ -145,6 +155,8 @@ def save_model(net,
params_state_dict.update(loss_state_dict)
paddle.save(params_state_dict, model_path + ".pdparams")
if ema is not None:
paddle.save(ema.state_dict(), model_path + ".ema.pdparams")
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
paddle.save(metric_info, model_path + ".pdstates")
logger.info("Already save model in {}".format(model_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册