diff --git a/doc/doc_ch/knowledge_distillation.md b/doc/doc_ch/knowledge_distillation.md index 6eaaf3099f0e38ba4e99aea455a6d671393ffc6d..769401dbf6eec08a0afc56d9de8e58b9078df3c9 100644 --- a/doc/doc_ch/knowledge_distillation.md +++ b/doc/doc_ch/knowledge_distillation.md @@ -569,7 +569,7 @@ all_params = paddle.load("ch_PP-OCRv2_det_distill_train/best_accuracy.pdparams") # 查看权重参数的keys print(all_params.keys()) # 学生模型的权重提取 -s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key} +s_params = {key[len("student_model."):]: all_params[key] for key in all_params if "student_model." in key} # 查看学生模型权重参数的keys print(s_params.keys()) # 保存