提交 6cae5aaf 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add more distillation models

上级 5a15c165
......@@ -42,20 +42,25 @@ class ResNet50_vd_distill_MobileNetV3_large_x1_0(fluid.dygraph.Layer):
def forward(self, input):
teacher_label = self.teacher(input)
teacher_label.stop_gradient = True
student_label = self.student(input)
return teacher_label, student_label
class ResNeXt101_32x16d_wsl_distill_ResNet50_vd():
def net(self, input, class_dim=1000):
# student
student = ResNet50_vd()
out_student = student.net(input, class_dim=class_dim)
# teacher
teacher = ResNeXt101_32x16d_wsl()
out_teacher = teacher.net(input, class_dim=class_dim)
out_teacher.stop_gradient = True
class ResNeXt101_32x16d_wsl_distill_ResNet50_vd(fluid.dygraph.Layer):
def __init__(self, class_dim=1000, **args):
super(ResNet50_vd_distill_MobileNetV3_large_x1_0, self).__init__()
self.teacher = ResNeXt101_32x16d_wsl(class_dim=class_dim, **args)
self.student = ResNet50_vd(class_dim=class_dim, **args)
def forward(self, input):
teacher_label = self.teacher(input)
teacher_label.stop_gradient = True
student_label = self.student(input)
return out_teacher, out_student
return teacher_label, student_label
\ No newline at end of file
......@@ -58,7 +58,6 @@ def load_dygraph_pretrain(
model_dict = model.state_dict()
for key in model_dict.keys():
weight_name = model_dict[key].name
print("dyg key: {}, weight_name: {}".format(key, weight_name))
if weight_name in pre_state_dict.keys():
print('Load weight: {}, shape: {}'.format(
weight_name, pre_state_dict[weight_name].shape))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册