提交 0d7e595f 编写于 作者: G gaotingquan 提交者: Wei Shengyu

mv model_saver to __init__()

上级 6e77bd6c
......@@ -33,8 +33,7 @@ from ppcls.metric import build_metrics
from ppcls.optimizer import build_optimizer
from ppcls.utils.ema import ExponentialMovingAverage
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
from ppcls.utils.save_load import init_model
from ppcls.utils.model_saver import ModelSaver
from ppcls.utils.save_load import init_model, ModelSaver
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
......@@ -100,6 +99,14 @@ class Engine(object):
# for distributed
self._init_dist()
# build model saver
self.model_saver = ModelSaver(
self,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema")
print_config(config)
def train(self):
......@@ -129,14 +136,6 @@ class Engine(object):
# TODO: mv best_metric_ema to best_metric dict
best_metric_ema = 0
# build model saver
model_saver = ModelSaver(
self,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema")
self._init_checkpoints(best_metric)
# global iter counter
......@@ -166,7 +165,7 @@ class Engine(object):
if acc > best_metric["metric"]:
best_metric["metric"] = acc
best_metric["epoch"] = epoch_id
model_saver.save(
self.model_saver.save(
best_metric,
prefix="best_model",
save_student_model=True)
......@@ -189,7 +188,7 @@ class Engine(object):
if acc_ema > best_metric_ema:
best_metric_ema = acc_ema
model_saver.save(
self.model_saver.save(
{
"metric": acc_ema,
"epoch": epoch_id
......@@ -205,7 +204,7 @@ class Engine(object):
# save model
if save_interval > 0 and epoch_id % save_interval == 0:
model_saver.save(
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
......@@ -213,7 +212,7 @@ class Engine(object):
prefix=f"epoch_{epoch_id}")
# save the latest model
model_saver.save(
self.model_saver.save(
{
"metric": acc,
"epoch": epoch_id
......
import os
import paddle
from . import logger
def _mkdir_if_not_exist(path):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
def _extract_student_weights(all_params, student_prefix="Student."):
s_params = {
key[len(student_prefix):]: all_params[key]
for key in all_params if student_prefix in key
}
return s_params
class ModelSaver(object):
def __init__(self,
engine,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema"):
# net, loss, opt, model_ema, output_dir,
self.engine = engine
self.net_name = net_name
self.loss_name = loss_name
self.opt_name = opt_name
self.model_ema_name = model_ema_name
arch_name = engine.config["Arch"]["name"]
self.output_dir = os.path.join(engine.output_dir, arch_name)
_mkdir_if_not_exist(self.output_dir)
def save(self, metric_info, prefix='ppcls', save_student_model=False):
if paddle.distributed.get_rank() != 0:
return
save_dir = os.path.join(self.output_dir, prefix)
params_state_dict = getattr(self.engine, self.net_name).state_dict()
loss = getattr(self.engine, self.loss_name)
if loss is not None:
loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set(
loss_state_dict.keys())
assert len(keys_inter) == 0, \
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
params_state_dict.update(loss_state_dict)
if save_student_model:
s_params = _extract_student_weights(params_state_dict)
if len(s_params) > 0:
paddle.save(s_params, save_dir + "_student.pdparams")
paddle.save(params_state_dict, save_dir + ".pdparams")
model_ema = getattr(self.engine, self.model_ema_name)
if model_ema is not None:
paddle.save(model_ema.module.state_dict(),
save_dir + ".ema.pdparams")
optimizer = getattr(self.engine, self.opt_name)
paddle.save([opt.state_dict() for opt in optimizer],
save_dir + ".pdopt")
paddle.save(metric_info, save_dir + ".pdstates")
logger.info("Already save model in {}".format(save_dir))
......@@ -123,3 +123,79 @@ def init_model(config,
load_dygraph_pretrain(net, path=pretrained_model)
logger.info("Finish load pretrained model from {}".format(
pretrained_model))
def _mkdir_if_not_exist(path):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
def _extract_student_weights(all_params, student_prefix="Student."):
s_params = {
key[len(student_prefix):]: all_params[key]
for key in all_params if student_prefix in key
}
return s_params
class ModelSaver(object):
def __init__(self,
engine,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema"):
# net, loss, opt, model_ema, output_dir,
self.engine = engine
self.net_name = net_name
self.loss_name = loss_name
self.opt_name = opt_name
self.model_ema_name = model_ema_name
arch_name = engine.config["Arch"]["name"]
self.output_dir = os.path.join(engine.output_dir, arch_name)
_mkdir_if_not_exist(self.output_dir)
def save(self, metric_info, prefix='ppcls', save_student_model=False):
if paddle.distributed.get_rank() != 0:
return
save_dir = os.path.join(self.output_dir, prefix)
params_state_dict = getattr(self.engine, self.net_name).state_dict()
loss = getattr(self.engine, self.loss_name)
if loss is not None:
loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set(
loss_state_dict.keys())
assert len(keys_inter) == 0, \
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
params_state_dict.update(loss_state_dict)
if save_student_model:
s_params = _extract_student_weights(params_state_dict)
if len(s_params) > 0:
paddle.save(s_params, save_dir + "_student.pdparams")
paddle.save(params_state_dict, save_dir + ".pdparams")
model_ema = getattr(self.engine, self.model_ema_name)
if model_ema is not None:
paddle.save(model_ema.module.state_dict(),
save_dir + ".ema.pdparams")
optimizer = getattr(self.engine, self.opt_name)
paddle.save([opt.state_dict() for opt in optimizer],
save_dir + ".pdopt")
paddle.save(metric_info, save_dir + ".pdstates")
logger.info("Already save model in {}".format(save_dir))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册