提交 cf7fb3c5 编写于 作者: W wanghaoshuang

Merge branch 'develop' into 'develop'

Refine dist.merge(), remove some redundant interface

See merge request !62
......@@ -140,8 +140,6 @@ def compress(args):
# define teacher program
teacher_program = fluid.Program()
t_startup = fluid.Program()
teacher_scope = fluid.Scope()
with fluid.scope_guard(teacher_scope):
with fluid.program_guard(teacher_program, t_startup):
with fluid.unique_name.guard():
image = fluid.layers.data(
......@@ -173,8 +171,7 @@ def compress(args):
teacher_program,
student_program,
data_name_map,
place,
teacher_scope=teacher_scope)
place)
#print("="*50+"teacher_vars"+"="*50)
#for v in teacher_program.list_vars():
......
......@@ -20,8 +20,7 @@ def merge(teacher_program,
student_program,
data_name_map,
place,
teacher_scope=fluid.global_scope(),
student_scope=fluid.global_scope(),
scope=fluid.global_scope(),
name_prefix='teacher_'):
"""
Merge teacher program into student program and add a uniform prefix to the
......@@ -33,8 +32,7 @@ def merge(teacher_program,
and the student var name
place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents
paddle run on which device.
student_scope(Scope): The input student scope
teacher_scope(Scope): The input teacher scope
scope(Scope): The input scope
name_prefix(str): Name prefix added for all vars of the teacher program.
Return(Program): Merged program.
"""
......@@ -50,9 +48,9 @@ def merge(teacher_program,
new_name = name_prefix + teacher_var.name
if not skip_rename:
# scope var rename
scope_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_scope_var = teacher_scope.var(new_name).get_tensor()
renamed_scope_var.set(np.array(scope_var), place)
old_var = scope.var(teacher_var.name).get_tensor()
renamed_var = scope.var(new_name).get_tensor()
renamed_var.set(np.array(old_var), place)
# program var rename
renamed_var = teacher_program.global_block()._rename_var(
......@@ -60,11 +58,6 @@ def merge(teacher_program,
for teacher_var in teacher_program.list_vars():
if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
# student scope add var
student_scope_var = student_scope.var(teacher_var.name).get_tensor()
teacher_scope_var = teacher_scope.var(teacher_var.name).get_tensor()
student_scope_var.set(np.array(teacher_scope_var), place)
# student program add var
new_var = student_program.global_block()._clone_variable(
teacher_var, force_persistable=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册