diff --git a/docs/en/quick_start/distillation_tutorial_en.md b/docs/en/quick_start/distillation_tutorial_en.md index d948f117b9fb188571b0fc07b677824ab237156e..7fb410655fe086fcee84e1f84a49c1e38609cde1 100755 --- a/docs/en/quick_start/distillation_tutorial_en.md +++ b/docs/en/quick_start/distillation_tutorial_en.md @@ -25,7 +25,7 @@ This tutorial trains and verifies distillation model on the MNIST dataset. The i Select `ResNet50` as the teacher to perform distillation training on the students of the` MobileNet` architecture. ```python -model = models.__dict__['MobileNet']() +model = slim.models.MobileNet() student_program = fluid.Program() student_startup = fluid.Program() with fluid.program_guard(student_program, student_startup): @@ -42,7 +42,7 @@ with fluid.program_guard(student_program, student_startup): ```python -teacher_model = models.__dict__['ResNet50']() +model = slim.models.ResNet50() teacher_program = fluid.Program() teacher_startup = fluid.Program() with fluid.program_guard(teacher_program, teacher_startup): diff --git a/docs/zh_cn/quick_start/distillation_tutorial.md b/docs/zh_cn/quick_start/distillation_tutorial.md index a9b989760c1609cea40faad6482ef2056218db02..aa2dd3e9f33d6cb8dc87251c600eab1e88bfac59 100755 --- a/docs/zh_cn/quick_start/distillation_tutorial.md +++ b/docs/zh_cn/quick_start/distillation_tutorial.md @@ -27,7 +27,7 @@ import paddleslim as slim 选择`ResNet50`作为teacher对`MobileNet`结构的student进行蒸馏训练。 ```python -model = models.__dict__['MobileNet']() +model = slim.models.MobileNet() student_program = fluid.Program() student_startup = fluid.Program() with fluid.program_guard(student_program, student_startup): @@ -44,7 +44,7 @@ with fluid.program_guard(student_program, student_startup): ```python -teacher_model = models.__dict__['ResNet50']() +model = slim.models.ResNet50() teacher_program = fluid.Program() teacher_startup = fluid.Program() with fluid.program_guard(teacher_program, teacher_startup): diff --git a/paddleslim/models/__init__.py b/paddleslim/models/__init__.py index 14ea9f3d15fa953f0c4dba47aee6bc45a0e1ee62..bb308a71d30309ac893be8032acf34a661e35c5d 100644 --- a/paddleslim/models/__init__.py +++ b/paddleslim/models/__init__.py @@ -16,4 +16,7 @@ from __future__ import absolute_import from .util import image_classification from .slimfacenet import SlimFaceNet_A_x0_60, SlimFaceNet_B_x0_75, SlimFaceNet_C_x0_75 from .slim_mobilenet import SlimMobileNet_v1, SlimMobileNet_v2, SlimMobileNet_v3, SlimMobileNet_v4, SlimMobileNet_v5 -__all__ = ["image_classification"] +from .mobilenet import MobileNet +from .resnet import ResNet50 + +__all__ = ["image_classification", "MobileNet", "ResNet50"]