提交 cf7fb3c5 编写于 作者: W wanghaoshuang

Merge branch 'develop' into 'develop'

Refine dist.merge(), remove some redundant interface

See merge request !62
...@@ -140,41 +140,38 @@ def compress(args): ...@@ -140,41 +140,38 @@ def compress(args):
# define teacher program # define teacher program
teacher_program = fluid.Program() teacher_program = fluid.Program()
t_startup = fluid.Program() t_startup = fluid.Program()
teacher_scope = fluid.Scope() with fluid.program_guard(teacher_program, t_startup):
with fluid.scope_guard(teacher_scope): with fluid.unique_name.guard():
with fluid.program_guard(teacher_program, t_startup): image = fluid.layers.data(
with fluid.unique_name.guard(): name='image', shape=image_shape, dtype='float32')
image = fluid.layers.data( predict = teacher_model.net(image, class_dim=class_dim)
name='image', shape=image_shape, dtype='float32')
predict = teacher_model.net(image, class_dim=class_dim) #print("="*50+"teacher_model_params"+"="*50)
#for v in teacher_program.list_vars():
#print("="*50+"teacher_model_params"+"="*50) # print(v.name, v.shape)
#for v in teacher_program.list_vars():
# print(v.name, v.shape) exe.run(t_startup)
assert args.teacher_pretrained_model and os.path.exists(
exe.run(t_startup) args.teacher_pretrained_model
assert args.teacher_pretrained_model and os.path.exists( ), "teacher_pretrained_model should be set when teacher_model is not None."
args.teacher_pretrained_model
), "teacher_pretrained_model should be set when teacher_model is not None." def if_exist(var):
return os.path.exists(
def if_exist(var): os.path.join(args.teacher_pretrained_model, var.name)
return os.path.exists( ) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0'
os.path.join(args.teacher_pretrained_model, var.name)
) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0' fluid.io.load_vars(
exe,
fluid.io.load_vars( args.teacher_pretrained_model,
exe, main_program=teacher_program,
args.teacher_pretrained_model, predicate=if_exist)
main_program=teacher_program,
predicate=if_exist)
data_name_map = {'image': 'image'} data_name_map = {'image': 'image'}
main = merge( main = merge(
teacher_program, teacher_program,
student_program, student_program,
data_name_map, data_name_map,
place, place)
teacher_scope=teacher_scope)
#print("="*50+"teacher_vars"+"="*50) #print("="*50+"teacher_vars"+"="*50)
#for v in teacher_program.list_vars(): #for v in teacher_program.list_vars():
......
...@@ -20,8 +20,7 @@ def merge(teacher_program, ...@@ -20,8 +20,7 @@ def merge(teacher_program,
student_program, student_program,
data_name_map, data_name_map,
place, place,
teacher_scope=fluid.global_scope(), scope=fluid.global_scope(),
student_scope=fluid.global_scope(),
name_prefix='teacher_'): name_prefix='teacher_'):
""" """
Merge teacher program into student program and add a uniform prefix to the Merge teacher program into student program and add a uniform prefix to the
...@@ -33,8 +32,7 @@ def merge(teacher_program, ...@@ -33,8 +32,7 @@ def merge(teacher_program,
and the student var name and the student var name
place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents
paddle run on which device. paddle run on which device.
student_scope(Scope): The input student scope scope(Scope): The input scope
teacher_scope(Scope): The input teacher scope
name_prefix(str): Name prefix added for all vars of the teacher program. name_prefix(str): Name prefix added for all vars of the teacher program.
Return(Program): Merged program. Return(Program): Merged program.
""" """
...@@ -50,9 +48,9 @@ def merge(teacher_program, ...@@ -50,9 +48,9 @@ def merge(teacher_program,
new_name = name_prefix + teacher_var.name new_name = name_prefix + teacher_var.name
if not skip_rename: if not skip_rename:
# scope var rename # scope var rename
scope_var = teacher_scope.var(teacher_var.name).get_tensor() old_var = scope.var(teacher_var.name).get_tensor()
renamed_scope_var = teacher_scope.var(new_name).get_tensor() renamed_var = scope.var(new_name).get_tensor()
renamed_scope_var.set(np.array(scope_var), place) renamed_var.set(np.array(old_var), place)
# program var rename # program var rename
renamed_var = teacher_program.global_block()._rename_var( renamed_var = teacher_program.global_block()._rename_var(
...@@ -60,11 +58,6 @@ def merge(teacher_program, ...@@ -60,11 +58,6 @@ def merge(teacher_program,
for teacher_var in teacher_program.list_vars(): for teacher_var in teacher_program.list_vars():
if teacher_var.name != 'fetch' and teacher_var.name != 'feed': if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
# student scope add var
student_scope_var = student_scope.var(teacher_var.name).get_tensor()
teacher_scope_var = teacher_scope.var(teacher_var.name).get_tensor()
student_scope_var.set(np.array(teacher_scope_var), place)
# student program add var # student program add var
new_var = student_program.global_block()._clone_variable( new_var = student_program.global_block()._clone_variable(
teacher_var, force_persistable=False) teacher_var, force_persistable=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册