未验证 提交 9ebbd78b 编写于 作者: L littletomatodonkey 提交者: GitHub

fix save distillation model (#578)

* fix save distillation model

* add note
上级 63a10e56
......@@ -78,7 +78,7 @@ def load_distillation_model(model, pretrained_model, load_static_weights):
pretrained_model = [pretrained_model]
if not isinstance(load_static_weights, list):
load_static_weights = [load_static_weights]
load_static_weights = [load_static_weights] * len(pretrained_model)
teacher = model.teacher if hasattr(model,
"teacher") else model._layers.teacher
......@@ -114,9 +114,7 @@ def init_model(config, net, optimizer=None):
opti_dict = paddle.load(checkpoints + ".pdopt")
net.set_dict(para_dict)
optimizer.set_state_dict(opti_dict)
logger.info(
logger.coloring("Finish initing model from {}".format(checkpoints),
"HEADER"))
logger.info("Finish initing model from {}".format(checkpoints))
return
pretrained_model = config.get('pretrained_model')
......@@ -135,6 +133,19 @@ def init_model(config, net, optimizer=None):
pretrained_model), "HEADER"))
def _save_student_model(net, model_prefix):
"""
save student model if the net is the network contains student
"""
student_model_prefix = model_prefix + "_student.pdparams"
if hasattr(net, "_layers"):
net = net._layers
if hasattr(net, "student"):
paddle.save(net.student.state_dict(), student_model_prefix)
logger.info("Already save student model in {}".format(
student_model_prefix))
def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
"""
save model to the target path
......@@ -145,8 +156,8 @@ def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
_mkdir_if_not_exist(model_path)
model_prefix = os.path.join(model_path, prefix)
_save_student_model(net, model_prefix)
paddle.save(net.state_dict(), model_prefix + ".pdparams")
paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
logger.info(
logger.coloring("Already save model in {}".format(model_path),
"HEADER"))
logger.info("Already save model in {}".format(model_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册