未验证 提交 1989b660 编写于 作者: C cuicheng01 提交者: GitHub

Merge pull request #1 from littletomatodonkey/me/add_pdemo

fix convert weight
......@@ -344,15 +344,15 @@ class Engine(object):
if self.use_dali:
self.train_dataloader.reset()
metric_msg = ", ".join([
self.output_info[key].avg_info for key in self.output_info
])
metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.config["Global"]["epochs"], metric_msg))
self.output_info.clear()
# eval model and save model if possible
start_eval_epoch = self.config["Global"].get("start_eval_epoch", 0) - 1
start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > start_eval_epoch:
......@@ -367,7 +367,8 @@ class Engine(object):
self.output_dir,
model_name=self.config["Arch"]["name"],
prefix="best_model",
loss=self.train_loss_func)
loss=self.train_loss_func,
save_student_model=True)
logger.info("[Eval][Epoch {}][best metric: {}]".format(
epoch_id, best_metric["metric"]))
logger.scaler(
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import paddle
def convert_distill_weights(distill_weights_path, student_weights_path):
assert os.path.exists(distill_weights_path), \
"Given distill_weights_path {} not exist.".format(distill_weights_path)
# Load teacher and student weights
all_params = paddle.load(distill_weights_path)
# Extract student weights
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# Save student weights
paddle.save(s_params, student_weights_path)
......@@ -42,6 +42,14 @@ def _mkdir_if_not_exist(path):
raise OSError('Failed to mkdir {}'.format(path))
def _extract_student_weights(all_params, student_prefix="Student."):
s_params = {
key[len(student_prefix):]: all_params[key]
for key in all_params if student_prefix in key
}
return s_params
def load_dygraph_pretrain(model, path=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {}.pdparams does not "
......@@ -117,7 +125,7 @@ def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
else: # common load
load_dygraph_pretrain(net, path=pretrained_model)
logger.info("Finish load pretrained model from {}".format(
pretrained_model))
pretrained_model))
def save_model(net,
......@@ -126,7 +134,8 @@ def save_model(net,
model_path,
model_name="",
prefix='ppcls',
loss: paddle.nn.Layer=None):
loss: paddle.nn.Layer=None,
save_student_model=False):
"""
save model to the target path
"""
......@@ -137,11 +146,18 @@ def save_model(net,
model_path = os.path.join(model_path, prefix)
params_state_dict = net.state_dict()
loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys())
assert len(keys_inter) == 0, \
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
params_state_dict.update(loss_state_dict)
if loss is not None:
loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys(
))
assert len(keys_inter) == 0, \
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
params_state_dict.update(loss_state_dict)
if save_student_model:
s_params = _extract_student_weights(params_state_dict)
if len(s_params) > 0:
paddle.save(s_params, model_path + "_student.pdparams")
paddle.save(params_state_dict, model_path + ".pdparams")
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册