提交 ee02b6c4 编写于 作者: Y yangfukui

fix the name problem

上级 49a17a25
...@@ -38,13 +38,17 @@ def merge(teacher_program, ...@@ -38,13 +38,17 @@ def merge(teacher_program,
name_prefix(str): Name prefix added for all vars of the teacher program. name_prefix(str): Name prefix added for all vars of the teacher program.
Return(Program): Merged program. Return(Program): Merged program.
""" """
teacher_program = teacher_program.clone(for_test = True) teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars(): for teacher_var in teacher_program.list_vars():
skip_rename = False
if teacher_var.name != 'fetch' and teacher_var.name != 'feed': if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
if teacher_var.name in data_name_map.keys(): if teacher_var.name in data_name_map.keys():
new_name = data_name_map[teacher_var.name] new_name = data_name_map[teacher_var.name]
if new_name == teacher_var.name:
skip_rename = True
else: else:
new_name = name_prefix + teacher_var.name new_name = name_prefix + teacher_var.name
if not skip_rename:
# scope var rename # scope var rename
scope_var = teacher_scope.var(teacher_var.name).get_tensor() scope_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_scope_var = teacher_scope.var(new_name).get_tensor() renamed_scope_var = teacher_scope.var(new_name).get_tensor()
...@@ -148,7 +152,6 @@ def soft_label_loss(teacher_var_name, ...@@ -148,7 +152,6 @@ def soft_label_loss(teacher_var_name,
teacher_feature_map before softmax. default: 1.0 teacher_feature_map before softmax. default: 1.0
student_temperature(float): Temperature used to divide student_temperature(float): Temperature used to divide
student_feature_map before softmax. default: 1.0 student_feature_map before softmax. default: 1.0
Return(Variable): l2 distiller loss. Return(Variable): l2 distiller loss.
""" """
student_var = program.global_block().var(student_var_name) student_var = program.global_block().var(student_var_name)
...@@ -162,13 +165,12 @@ def soft_label_loss(teacher_var_name, ...@@ -162,13 +165,12 @@ def soft_label_loss(teacher_var_name,
return soft_label_loss return soft_label_loss
def self_defined_loss(program, loss_func, **kwargs): def loss(program, loss_func, **kwargs):
""" """
Combine variables from student model and teacher model by self defined loss. Combine variables from student model and teacher model by self defined loss.
Args: Args:
program(Program): The input distiller program. program(Program): The input distiller program.
loss_func(function): The user self defined loss function. loss_func(function): The user self defined loss function.
Return(Variable): self defined distiller loss. Return(Variable): self defined distiller loss.
""" """
func_parameters = {} func_parameters = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册