提交 ee02b6c4 编写于 作者: Y yangfukui

fix the name problem

上级 49a17a25
...@@ -38,21 +38,25 @@ def merge(teacher_program, ...@@ -38,21 +38,25 @@ def merge(teacher_program,
name_prefix(str): Name prefix added for all vars of the teacher program. name_prefix(str): Name prefix added for all vars of the teacher program.
Return(Program): Merged program. Return(Program): Merged program.
""" """
teacher_program = teacher_program.clone(for_test = True) teacher_program = teacher_program.clone(for_test=True)
for teacher_var in teacher_program.list_vars(): for teacher_var in teacher_program.list_vars():
skip_rename = False
if teacher_var.name != 'fetch' and teacher_var.name != 'feed': if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
if teacher_var.name in data_name_map.keys(): if teacher_var.name in data_name_map.keys():
new_name = data_name_map[teacher_var.name] new_name = data_name_map[teacher_var.name]
if new_name == teacher_var.name:
skip_rename = True
else: else:
new_name = name_prefix + teacher_var.name new_name = name_prefix + teacher_var.name
# scope var rename if not skip_rename:
scope_var = teacher_scope.var(teacher_var.name).get_tensor() # scope var rename
renamed_scope_var = teacher_scope.var(new_name).get_tensor() scope_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_scope_var.set(np.array(scope_var), place) renamed_scope_var = teacher_scope.var(new_name).get_tensor()
renamed_scope_var.set(np.array(scope_var), place)
# program var rename
renamed_var = teacher_program.global_block()._rename_var( # program var rename
teacher_var.name, new_name) renamed_var = teacher_program.global_block()._rename_var(
teacher_var.name, new_name)
for teacher_var in teacher_program.list_vars(): for teacher_var in teacher_program.list_vars():
if teacher_var.name != 'fetch' and teacher_var.name != 'feed': if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
...@@ -148,7 +152,6 @@ def soft_label_loss(teacher_var_name, ...@@ -148,7 +152,6 @@ def soft_label_loss(teacher_var_name,
teacher_feature_map before softmax. default: 1.0 teacher_feature_map before softmax. default: 1.0
student_temperature(float): Temperature used to divide student_temperature(float): Temperature used to divide
student_feature_map before softmax. default: 1.0 student_feature_map before softmax. default: 1.0
Return(Variable): l2 distiller loss. Return(Variable): l2 distiller loss.
""" """
student_var = program.global_block().var(student_var_name) student_var = program.global_block().var(student_var_name)
...@@ -162,13 +165,12 @@ def soft_label_loss(teacher_var_name, ...@@ -162,13 +165,12 @@ def soft_label_loss(teacher_var_name,
return soft_label_loss return soft_label_loss
def self_defined_loss(program, loss_func, **kwargs): def loss(program, loss_func, **kwargs):
""" """
Combine variables from student model and teacher model by self defined loss. Combine variables from student model and teacher model by self defined loss.
Args: Args:
program(Program): The input distiller program. program(Program): The input distiller program.
loss_func(function): The user self defined loss function. loss_func(function): The user self defined loss function.
Return(Variable): self defined distiller loss. Return(Variable): self defined distiller loss.
""" """
func_parameters = {} func_parameters = {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册