未验证 提交 97886e04 编写于 作者: L littletomatodonkey 提交者: GitHub

att (#566)

* fix vgg stop grad

* beautify code

* refine distillation model
上级 c8c9abf0
......@@ -2,7 +2,7 @@ mode: 'train'
ARCHITECTURE:
name: 'ResNet50_vd_distill_MobileNetV3_large_x1_0'
pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained/"
pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained"
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
......
......@@ -2,7 +2,7 @@ mode: 'train'
ARCHITECTURE:
name: 'ResNeXt101_32x16d_wsl_distill_ResNet50_vd'
pretrained_model: "./pretrained/ResNeXt101_32x16d_wsl_pretrained/"
pretrained_model: "./pretrained/ResNeXt101_32x16d_wsl_pretrained"
model_save_dir: "./output/"
classes_num: 1000
total_images: 1281167
......
......@@ -32,34 +32,34 @@ __all__ = [
class ResNet50_vd_distill_MobileNetV3_large_x1_0(nn.Layer):
def __init__(self, class_dim=1000, **args):
def __init__(self, class_dim=1000, freeze_teacher=True, **args):
super(ResNet50_vd_distill_MobileNetV3_large_x1_0, self).__init__()
self.teacher = ResNet50_vd(class_dim=class_dim, **args)
self.student = MobileNetV3_large_x1_0(class_dim=class_dim, **args)
def forward(self, input):
teacher_label = self.teacher(input)
teacher_label.stop_gradient = True
student_label = self.student(input)
if freeze_teacher:
for param in self.teacher.parameters():
param.trainable = False
def forward(self, x):
teacher_label = self.teacher(x)
student_label = self.student(x)
return teacher_label, student_label
class ResNeXt101_32x16d_wsl_distill_ResNet50_vd(nn.Layer):
def __init__(self, class_dim=1000, **args):
def __init__(self, class_dim=1000, freeze_teacher=True, **args):
super(ResNeXt101_32x16d_wsl_distill_ResNet50_vd, self).__init__()
self.teacher = ResNeXt101_32x16d_wsl(class_dim=class_dim, **args)
self.student = ResNet50_vd(class_dim=class_dim, **args)
def forward(self, input):
teacher_label = self.teacher(input)
teacher_label.stop_gradient = True
student_label = self.student(input)
if freeze_teacher:
for param in self.teacher.parameters():
param.trainable = False
def forward(self, x):
teacher_label = self.teacher(x)
student_label = self.student(x)
return teacher_label, student_label
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册