提交 6cae5aaf 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add more distillation models

上级 5a15c165
...@@ -42,20 +42,25 @@ class ResNet50_vd_distill_MobileNetV3_large_x1_0(fluid.dygraph.Layer): ...@@ -42,20 +42,25 @@ class ResNet50_vd_distill_MobileNetV3_large_x1_0(fluid.dygraph.Layer):
def forward(self, input): def forward(self, input):
teacher_label = self.teacher(input) teacher_label = self.teacher(input)
teacher_label.stop_gradient = True
student_label = self.student(input) student_label = self.student(input)
return teacher_label, student_label return teacher_label, student_label
class ResNeXt101_32x16d_wsl_distill_ResNet50_vd(): class ResNeXt101_32x16d_wsl_distill_ResNet50_vd(fluid.dygraph.Layer):
def net(self, input, class_dim=1000): def __init__(self, class_dim=1000, **args):
# student super(ResNet50_vd_distill_MobileNetV3_large_x1_0, self).__init__()
student = ResNet50_vd()
out_student = student.net(input, class_dim=class_dim) self.teacher = ResNeXt101_32x16d_wsl(class_dim=class_dim, **args)
# teacher
teacher = ResNeXt101_32x16d_wsl() self.student = ResNet50_vd(class_dim=class_dim, **args)
out_teacher = teacher.net(input, class_dim=class_dim)
out_teacher.stop_gradient = True def forward(self, input):
teacher_label = self.teacher(input)
teacher_label.stop_gradient = True
student_label = self.student(input)
return out_teacher, out_student return teacher_label, student_label
\ No newline at end of file
...@@ -58,7 +58,6 @@ def load_dygraph_pretrain( ...@@ -58,7 +58,6 @@ def load_dygraph_pretrain(
model_dict = model.state_dict() model_dict = model.state_dict()
for key in model_dict.keys(): for key in model_dict.keys():
weight_name = model_dict[key].name weight_name = model_dict[key].name
print("dyg key: {}, weight_name: {}".format(key, weight_name))
if weight_name in pre_state_dict.keys(): if weight_name in pre_state_dict.keys():
print('Load weight: {}, shape: {}'.format( print('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape)) weight_name, pre_state_dict[weight_name].shape))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册