diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py index 88d34d3b2b2378681c9e35e10380d8514a9a54aa..7e39a9e2d9c743681320eaa70e0d75476844018c 100644 --- a/paddleslim/dist/single_distiller.py +++ b/paddleslim/dist/single_distiller.py @@ -38,21 +38,25 @@ def merge(teacher_program, name_prefix(str): Name prefix added for all vars of the teacher program. Return(Program): Merged program. """ - teacher_program = teacher_program.clone(for_test = True) + teacher_program = teacher_program.clone(for_test=True) for teacher_var in teacher_program.list_vars(): + skip_rename = False if teacher_var.name != 'fetch' and teacher_var.name != 'feed': if teacher_var.name in data_name_map.keys(): new_name = data_name_map[teacher_var.name] + if new_name == teacher_var.name: + skip_rename = True else: new_name = name_prefix + teacher_var.name - # scope var rename - scope_var = teacher_scope.var(teacher_var.name).get_tensor() - renamed_scope_var = teacher_scope.var(new_name).get_tensor() - renamed_scope_var.set(np.array(scope_var), place) - - # program var rename - renamed_var = teacher_program.global_block()._rename_var( - teacher_var.name, new_name) + if not skip_rename: + # scope var rename + scope_var = teacher_scope.var(teacher_var.name).get_tensor() + renamed_scope_var = teacher_scope.var(new_name).get_tensor() + renamed_scope_var.set(np.array(scope_var), place) + + # program var rename + renamed_var = teacher_program.global_block()._rename_var( + teacher_var.name, new_name) for teacher_var in teacher_program.list_vars(): if teacher_var.name != 'fetch' and teacher_var.name != 'feed': @@ -148,7 +152,6 @@ def soft_label_loss(teacher_var_name, teacher_feature_map before softmax. default: 1.0 student_temperature(float): Temperature used to divide student_feature_map before softmax. default: 1.0 - Return(Variable): l2 distiller loss. """ student_var = program.global_block().var(student_var_name) @@ -162,13 +165,12 @@ def soft_label_loss(teacher_var_name, return soft_label_loss -def self_defined_loss(program, loss_func, **kwargs): +def loss(program, loss_func, **kwargs): """ Combine variables from student model and teacher model by self defined loss. Args: program(Program): The input distiller program. loss_func(function): The user self defined loss function. - Return(Variable): self defined distiller loss. """ func_parameters = {}