提交 ee02b6c4 编写于 作者: Y yangfukui

fix the name problem

上级 49a17a25
......@@ -38,21 +38,25 @@ def merge(teacher_program,
name_prefix(str): Name prefix added for all vars of the teacher program.
Return(Program): Merged program.
"""
teacher_program = teacher_program.clone(for_test = True)
teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars():
skip_rename = False
if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
if teacher_var.name in data_name_map.keys():
new_name = data_name_map[teacher_var.name]
if new_name == teacher_var.name:
skip_rename = True
else:
new_name = name_prefix + teacher_var.name
# scope var rename
scope_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_scope_var = teacher_scope.var(new_name).get_tensor()
renamed_scope_var.set(np.array(scope_var), place)
# program var rename
renamed_var = teacher_program.global_block()._rename_var(
teacher_var.name, new_name)
if not skip_rename:
# scope var rename
scope_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_scope_var = teacher_scope.var(new_name).get_tensor()
renamed_scope_var.set(np.array(scope_var), place)
# program var rename
renamed_var = teacher_program.global_block()._rename_var(
teacher_var.name, new_name)
for teacher_var in teacher_program.list_vars():
if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
......@@ -148,7 +152,6 @@ def soft_label_loss(teacher_var_name,
teacher_feature_map before softmax. default: 1.0
student_temperature(float): Temperature used to divide
student_feature_map before softmax. default: 1.0
Return(Variable): l2 distiller loss.
"""
student_var = program.global_block().var(student_var_name)
......@@ -162,13 +165,12 @@ def soft_label_loss(teacher_var_name,
return soft_label_loss
def self_defined_loss(program, loss_func, **kwargs):
def loss(program, loss_func, **kwargs):
"""
Combine variables from student model and teacher model by self defined loss.
Args:
program(Program): The input distiller program.
loss_func(function): The user self defined loss function.
Return(Variable): self defined distiller loss.
"""
func_parameters = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册