提交 6f631e43 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix convert weight

上级 5a816278
......@@ -16,16 +16,27 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
__all__ == ["extract_subnet_weights"]
import os
import paddle
def convert_distill_weights(distill_weights_path, student_weights_path):
def extract_subnet_weights(distill_weights_path,
student_weights_path,
student_name="Student"):
assert os.path.exists(distill_weights_path), \
"Given distill_weights_path {} not exist.".format(distill_weights_path)
# Load teacher and student weights
all_params = paddle.load(distill_weights_path)
# Extract student weights
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
# Save student weights
student_prefix = student_name + "."
s_params = {
key[len(student_prefix):]: all_params[key]
for key in all_params if student_prefix in key
}
assert len(
s_params
) > 0, f"extracted params length must be > 0 but got {len(s_params)}"
# Save subnet weights
paddle.save(s_params, student_weights_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册