未验证 提交 0e6fe6f1 编写于 作者: L littletomatodonkey 提交者: GitHub

refine distillation model (#565)

上级 de86985c
...@@ -2,7 +2,7 @@ mode: 'train' ...@@ -2,7 +2,7 @@ mode: 'train'
ARCHITECTURE: ARCHITECTURE:
name: 'ResNet50_vd_distill_MobileNetV3_large_x1_0' name: 'ResNet50_vd_distill_MobileNetV3_large_x1_0'
pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained/" pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained"
model_save_dir: "./output/" model_save_dir: "./output/"
classes_num: 1000 classes_num: 1000
total_images: 1281167 total_images: 1281167
......
...@@ -2,7 +2,7 @@ mode: 'train' ...@@ -2,7 +2,7 @@ mode: 'train'
ARCHITECTURE: ARCHITECTURE:
name: 'ResNeXt101_32x16d_wsl_distill_ResNet50_vd' name: 'ResNeXt101_32x16d_wsl_distill_ResNet50_vd'
pretrained_model: "./pretrained/ResNeXt101_32x16d_wsl_pretrained/" pretrained_model: "./pretrained/ResNeXt101_32x16d_wsl_pretrained"
model_save_dir: "./output/" model_save_dir: "./output/"
classes_num: 1000 classes_num: 1000
total_images: 1281167 total_images: 1281167
......
...@@ -32,34 +32,34 @@ __all__ = [ ...@@ -32,34 +32,34 @@ __all__ = [
class ResNet50_vd_distill_MobileNetV3_large_x1_0(nn.Layer): class ResNet50_vd_distill_MobileNetV3_large_x1_0(nn.Layer):
def __init__(self, class_dim=1000, **args): def __init__(self, class_dim=1000, freeze_teacher=True, **args):
super(ResNet50_vd_distill_MobileNetV3_large_x1_0, self).__init__() super(ResNet50_vd_distill_MobileNetV3_large_x1_0, self).__init__()
self.teacher = ResNet50_vd(class_dim=class_dim, **args) self.teacher = ResNet50_vd(class_dim=class_dim, **args)
self.student = MobileNetV3_large_x1_0(class_dim=class_dim, **args) self.student = MobileNetV3_large_x1_0(class_dim=class_dim, **args)
def forward(self, input): if freeze_teacher:
teacher_label = self.teacher(input) for param in self.teacher.parameters():
teacher_label.stop_gradient = True param.trainable = False
student_label = self.student(input)
def forward(self, x):
teacher_label = self.teacher(x)
student_label = self.student(x)
return teacher_label, student_label return teacher_label, student_label
class ResNeXt101_32x16d_wsl_distill_ResNet50_vd(nn.Layer): class ResNeXt101_32x16d_wsl_distill_ResNet50_vd(nn.Layer):
def __init__(self, class_dim=1000, **args): def __init__(self, class_dim=1000, freeze_teacher=True, **args):
super(ResNeXt101_32x16d_wsl_distill_ResNet50_vd, self).__init__() super(ResNeXt101_32x16d_wsl_distill_ResNet50_vd, self).__init__()
self.teacher = ResNeXt101_32x16d_wsl(class_dim=class_dim, **args) self.teacher = ResNeXt101_32x16d_wsl(class_dim=class_dim, **args)
self.student = ResNet50_vd(class_dim=class_dim, **args) self.student = ResNet50_vd(class_dim=class_dim, **args)
def forward(self, input): if freeze_teacher:
teacher_label = self.teacher(input) for param in self.teacher.parameters():
teacher_label.stop_gradient = True param.trainable = False
student_label = self.student(input)
def forward(self, x):
teacher_label = self.teacher(x)
student_label = self.student(x)
return teacher_label, student_label return teacher_label, student_label
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册