diff --git a/paddleslim/dist/single_distiller.py b/paddleslim/dist/single_distiller.py index 70b843c90fec6bdf906045dbac3097f8dfba3ff1..8f5dcaeb14a0f6a7aadd5c99de7bc3c144f21414 100644 --- a/paddleslim/dist/single_distiller.py +++ b/paddleslim/dist/single_distiller.py @@ -20,8 +20,7 @@ def merge(teacher_program, student_program, data_name_map, place, - teacher_scope=fluid.global_scope(), - student_scope=fluid.global_scope(), + scope=fluid.global_scope(), name_prefix='teacher_'): """ Merge teacher program into student program and add a uniform prefix to the @@ -33,8 +32,7 @@ def merge(teacher_program, and the student var name place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents paddle run on which device. - student_scope(Scope): The input student scope - teacher_scope(Scope): The input teacher scope + scope(Scope): The input scope name_prefix(str): Name prefix added for all vars of the teacher program. Return(Program): Merged program. """ @@ -50,9 +48,9 @@ def merge(teacher_program, new_name = name_prefix + teacher_var.name if not skip_rename: # scope var rename - scope_var = teacher_scope.var(teacher_var.name).get_tensor() - renamed_scope_var = teacher_scope.var(new_name).get_tensor() - renamed_scope_var.set(np.array(scope_var), place) + old_var = scope.var(teacher_var.name).get_tensor() + renamed_var = scope.var(new_name).get_tensor() + renamed_var.set(np.array(old_var), place) # program var rename renamed_var = teacher_program.global_block()._rename_var( @@ -60,11 +58,6 @@ def merge(teacher_program, for teacher_var in teacher_program.list_vars(): if teacher_var.name != 'fetch' and teacher_var.name != 'feed': - # student scope add var - student_scope_var = student_scope.var(teacher_var.name).get_tensor() - teacher_scope_var = teacher_scope.var(teacher_var.name).get_tensor() - student_scope_var.set(np.array(teacher_scope_var), place) - # student program add var new_var = student_program.global_block()._clone_variable( teacher_var, force_persistable=False)