提交 fad5c8e3 编写于 作者: G gaotingquan 提交者: Wei Shengyu

refactor: simpfy engine.train()

1. ModelSaver();
2. _build_ema_model();
3. _init_checkpoints();
4. others.
上级 a38e42f6
......@@ -34,7 +34,7 @@ from ppcls.optimizer import build_optimizer
from ppcls.utils.ema import ExponentialMovingAverage
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model
from ppcls.utils import save_load
from ppcls.utils.model_saver import ModelSaver
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
......@@ -56,7 +56,10 @@ class Engine(object):
self._init_seed()
# init logger
init_logger(self.config, mode=mode)
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
f"{mode}.log")
init_logger(log_file=log_file)
# for visualdl
self.vdl_writer = self._init_vdl()
......@@ -103,21 +106,14 @@ class Engine(object):
assert self.mode == "train"
print_batch_step = self.config['Global']['print_batch_step']
save_interval = self.config["Global"]["save_interval"]
start_eval_epoch = self.config["Global"].get("start_eval_epoch", 0) - 1
epochs = self.config["Global"]["epochs"]
best_metric = {
"metric": -1.0,
"epoch": 0,
}
# build EMA model
self.ema = "EMA" in self.config and self.mode == "train"
if self.ema:
self.model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
best_metric_ema = 0.0
ema_module = self.model_ema.module
else:
ema_module = None
# key:
# val: metrics list word
self.output_info = dict()
......@@ -127,31 +123,35 @@ class Engine(object):
"reader_cost": AverageMeter(
"reader_cost", ".5f", postfix=" s,"),
}
# global iter counter
self.global_step = 0
if self.config.Global.checkpoints is not None:
metric_info = init_model(self.config.Global, self.model,
self.optimizer, self.train_loss_func,
ema_module)
if metric_info is not None:
best_metric.update(metric_info)
# build EMA model
self.model_ema = self._build_ema_model()
# TODO: mv best_metric_ema to best_metric dict
best_metric_ema = 0
for epoch_id in range(best_metric["epoch"] + 1,
self.config["Global"]["epochs"] + 1):
acc = 0.0
# build model saver
model_saver = ModelSaver(
self,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema")
self._init_checkpoints(best_metric)
# global iter counter
self.global_step = 0
for epoch_id in range(best_metric["epoch"] + 1, epochs + 1):
# for one epoch train
self.train_epoch_func(self, epoch_id, print_batch_step)
metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.config["Global"]["epochs"], metric_msg))
logger.info("[Train][Epoch {}/{}][Avg]{}".format(epoch_id, epochs,
metric_msg))
self.output_info.clear()
# eval model and save model if possible
start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
acc = 0.0
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > start_eval_epoch:
......@@ -166,16 +166,11 @@ class Engine(object):
if acc > best_metric["metric"]:
best_metric["metric"] = acc
best_metric["epoch"] = epoch_id
save_load.save_model(
self.model,
self.optimizer,
model_saver.save(
best_metric,
self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"],
prefix="best_model",
loss=self.train_loss_func,
save_student_model=True)
logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, best_metric["metric"]))
logger.scaler(
......@@ -186,24 +181,20 @@ class Engine(object):
self.model.train()
if self.ema:
ori_model, self.model = self.model, ema_module
if self.model_ema:
ori_model, self.model = self.model, self.model_ema.module
acc_ema = self.eval(epoch_id)
self.model = ori_model
ema_module.eval()
self.model_ema.module.eval()
if acc_ema > best_metric_ema:
best_metric_ema = acc_ema
save_load.save_model(
self.model,
self.optimizer,
{"metric": acc_ema,
"epoch": epoch_id},
self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"],
prefix="best_model_ema",
loss=self.train_loss_func)
model_saver.save(
{
"metric": acc_ema,
"epoch": epoch_id
},
prefix="best_model_ema")
logger.info("[Eval][Epoch {}][best metric ema: {}]".format(
epoch_id, best_metric_ema))
logger.scaler(
......@@ -214,25 +205,19 @@ class Engine(object):
# save model
if save_interval > 0 and epoch_id % save_interval == 0:
save_load.save_model(
self.model,
self.optimizer, {"metric": acc,
"epoch": epoch_id},
self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"],
prefix="epoch_{}".format(epoch_id),
loss=self.train_loss_func)
model_saver.save(
{
"metric": acc,
"epoch": epoch_id
},
prefix=f"epoch_{epoch_id}")
# save the latest model
save_load.save_model(
self.model,
self.optimizer, {"metric": acc,
"epoch": epoch_id},
self.output_dir,
ema=ema_module,
model_name=self.config["Arch"]["name"],
prefix="latest",
loss=self.train_loss_func)
model_saver.save(
{
"metric": acc,
"epoch": epoch_id
}, prefix="latest")
if self.vdl_writer is not None:
self.vdl_writer.close()
......@@ -483,6 +468,23 @@ class Engine(object):
self.train_loss_func = paddle.DataParallel(
self.train_loss_func)
def _build_ema_model(self):
if "EMA" in self.config and self.mode == "train":
model_ema = ExponentialMovingAverage(
self.model, self.config['EMA'].get("decay", 0.9999))
return model_ema
else:
return None
def _init_checkpoints(self, best_metric):
if self.config["Global"].get("checkpoints", None) is not None:
metric_info = init_model(self.config.Global, self.model,
self.optimizer, self.train_loss_func,
self.model_ema)
if metric_info is not None:
best_metric.update(metric_info)
return best_metric
class ExportModel(TheseusLayer):
"""
......
......@@ -75,7 +75,7 @@ def regular_train_epoch(engine, epoch_id, print_batch_step):
if not getattr(engine.lr_sch[i], "by_epoch", False):
engine.lr_sch[i].step()
# update ema
if engine.ema:
if engine.model_ema:
engine.model_ema.update(engine.model)
# below code just for logging
......
......@@ -25,4 +25,4 @@ from .metrics import mean_average_precision
from .metrics import multi_hot_encode
from .metrics import precision_recall_fscore
from .misc import AverageMeter
from .save_load import init_model, save_model
from .save_load import init_model
......@@ -22,15 +22,16 @@ import paddle.distributed as dist
_logger = None
def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO):
def init_logger(name='ppcls', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added.
added. If `log_file` is specified a FileHandler will also be added.
Args:
config(dict): Training config.
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
......@@ -62,8 +63,6 @@ def init_logger(config, mode="train", name='ppcls', log_level=logging.INFO):
if init_flag:
_logger.addHandler(stream_handler)
log_file = os.path.join(config['Global']['output_dir'],
config["Arch"]["name"], f"{mode}.log")
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
......
......@@ -26,30 +26,6 @@ from .download import get_weights_path_from_url
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
def _mkdir_if_not_exist(path):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
def _extract_student_weights(all_params, student_prefix="Student."):
s_params = {
key[len(student_prefix):]: all_params[key]
for key in all_params if student_prefix in key
}
return s_params
def load_dygraph_pretrain(model, path=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {}.pdparams does not "
......@@ -110,7 +86,7 @@ def init_model(config,
net,
optimizer=None,
loss: paddle.nn.Layer=None,
ema=None):
model_ema=None):
"""
load model from checkpoint or pretrained_model
"""
......@@ -130,11 +106,11 @@ def init_model(config,
for i in range(len(optimizer)):
optimizer[i].set_state_dict(opti_dict[i] if isinstance(
opti_dict, list) else opti_dict)
if ema is not None:
if model_ema is not None:
assert os.path.exists(checkpoints + ".ema.pdparams"), \
"Given dir {}.ema.pdparams not exist.".format(checkpoints)
para_ema_dict = paddle.load(checkpoints + ".ema.pdparams")
ema.set_state_dict(para_ema_dict)
model_ema.module.set_state_dict(para_ema_dict)
logger.info("Finish load checkpoints from {}".format(checkpoints))
return metric_dict
......@@ -147,43 +123,3 @@ def init_model(config,
load_dygraph_pretrain(net, path=pretrained_model)
logger.info("Finish load pretrained model from {}".format(
pretrained_model))
def save_model(net,
optimizer,
metric_info,
model_path,
ema=None,
model_name="",
prefix='ppcls',
loss: paddle.nn.Layer=None,
save_student_model=False):
"""
save model to the target path
"""
if paddle.distributed.get_rank() != 0:
return
model_path = os.path.join(model_path, model_name)
_mkdir_if_not_exist(model_path)
model_path = os.path.join(model_path, prefix)
params_state_dict = net.state_dict()
if loss is not None:
loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys(
))
assert len(keys_inter) == 0, \
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
params_state_dict.update(loss_state_dict)
if save_student_model:
s_params = _extract_student_weights(params_state_dict)
if len(s_params) > 0:
paddle.save(s_params, model_path + "_student.pdparams")
paddle.save(params_state_dict, model_path + ".pdparams")
if ema is not None:
paddle.save(ema.state_dict(), model_path + ".ema.pdparams")
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
paddle.save(metric_info, model_path + ".pdstates")
logger.info("Already save model in {}".format(model_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册