提交 6f631e43 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix convert weight

上级 5a816278
...@@ -16,16 +16,27 @@ from __future__ import absolute_import ...@@ -16,16 +16,27 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
__all__ == ["extract_subnet_weights"]
import os import os
import paddle import paddle
def convert_distill_weights(distill_weights_path, student_weights_path): def extract_subnet_weights(distill_weights_path,
student_weights_path,
student_name="Student"):
assert os.path.exists(distill_weights_path), \ assert os.path.exists(distill_weights_path), \
"Given distill_weights_path {} not exist.".format(distill_weights_path) "Given distill_weights_path {} not exist.".format(distill_weights_path)
# Load teacher and student weights # Load teacher and student weights
all_params = paddle.load(distill_weights_path) all_params = paddle.load(distill_weights_path)
# Extract student weights # Extract student weights
s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key} student_prefix = student_name + "."
# Save student weights s_params = {
key[len(student_prefix):]: all_params[key]
for key in all_params if student_prefix in key
}
assert len(
s_params
) > 0, f"extracted params length must be > 0 but got {len(s_params)}"
# Save subnet weights
paddle.save(s_params, student_weights_path) paddle.save(s_params, student_weights_path)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册