diff --git a/configs/Distillation/R50_vd_distill_MV3_large_x1_0.yaml b/configs/Distillation/R50_vd_distill_MV3_large_x1_0.yaml index fef7c51fbec7f51811354b42ed9895544e8fd836..6f266432ffede1254f1f6c41fa879727f884b03b 100644 --- a/configs/Distillation/R50_vd_distill_MV3_large_x1_0.yaml +++ b/configs/Distillation/R50_vd_distill_MV3_large_x1_0.yaml @@ -2,7 +2,7 @@ mode: 'train' ARCHITECTURE: name: 'ResNet50_vd_distill_MobileNetV3_large_x1_0' -pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained/" +pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained" model_save_dir: "./output/" classes_num: 1000 total_images: 1281167 diff --git a/configs/Distillation/ResX101_32x16d_wsl_distill_R50_vd.yaml b/configs/Distillation/ResX101_32x16d_wsl_distill_R50_vd.yaml index 9e5b060f199f88bce1e7fa4b706cc77c3c53d892..619442a6849048c8ac2feeb0ab68bd4a6bbdd2fb 100644 --- a/configs/Distillation/ResX101_32x16d_wsl_distill_R50_vd.yaml +++ b/configs/Distillation/ResX101_32x16d_wsl_distill_R50_vd.yaml @@ -2,7 +2,7 @@ mode: 'train' ARCHITECTURE: name: 'ResNeXt101_32x16d_wsl_distill_ResNet50_vd' -pretrained_model: "./pretrained/ResNeXt101_32x16d_wsl_pretrained/" +pretrained_model: "./pretrained/ResNeXt101_32x16d_wsl_pretrained" model_save_dir: "./output/" classes_num: 1000 total_images: 1281167 diff --git a/ppcls/modeling/architectures/distillation_models.py b/ppcls/modeling/architectures/distillation_models.py index 55d165dcd8d67018865699b7d06715f071ea493d..f9a36dd310aa6e8c33699f686ccd762d49a4283c 100644 --- a/ppcls/modeling/architectures/distillation_models.py +++ b/ppcls/modeling/architectures/distillation_models.py @@ -32,34 +32,34 @@ __all__ = [ class ResNet50_vd_distill_MobileNetV3_large_x1_0(nn.Layer): - def __init__(self, class_dim=1000, **args): + def __init__(self, class_dim=1000, freeze_teacher=True, **args): super(ResNet50_vd_distill_MobileNetV3_large_x1_0, self).__init__() self.teacher = ResNet50_vd(class_dim=class_dim, **args) - self.student = MobileNetV3_large_x1_0(class_dim=class_dim, **args) - def forward(self, input): - teacher_label = self.teacher(input) - teacher_label.stop_gradient = True - - student_label = self.student(input) + if freeze_teacher: + for param in self.teacher.parameters(): + param.trainable = False + def forward(self, x): + teacher_label = self.teacher(x) + student_label = self.student(x) return teacher_label, student_label class ResNeXt101_32x16d_wsl_distill_ResNet50_vd(nn.Layer): - def __init__(self, class_dim=1000, **args): + def __init__(self, class_dim=1000, freeze_teacher=True, **args): super(ResNeXt101_32x16d_wsl_distill_ResNet50_vd, self).__init__() self.teacher = ResNeXt101_32x16d_wsl(class_dim=class_dim, **args) - self.student = ResNet50_vd(class_dim=class_dim, **args) - def forward(self, input): - teacher_label = self.teacher(input) - teacher_label.stop_gradient = True - - student_label = self.student(input) + if freeze_teacher: + for param in self.teacher.parameters(): + param.trainable = False + def forward(self, x): + teacher_label = self.teacher(x) + student_label = self.student(x) return teacher_label, student_label