提交 cf7fb3c5 编写于 作者: W wanghaoshuang

Merge branch 'develop' into 'develop'

Refine dist.merge(), remove some redundant interface

See merge request !62
......@@ -140,41 +140,38 @@ def compress(args):
# define teacher program
teacher_program = fluid.Program()
t_startup = fluid.Program()
teacher_scope = fluid.Scope()
with fluid.scope_guard(teacher_scope):
with fluid.program_guard(teacher_program, t_startup):
with fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
predict = teacher_model.net(image, class_dim=class_dim)
#print("="*50+"teacher_model_params"+"="*50)
#for v in teacher_program.list_vars():
# print(v.name, v.shape)
exe.run(t_startup)
assert args.teacher_pretrained_model and os.path.exists(
args.teacher_pretrained_model
), "teacher_pretrained_model should be set when teacher_model is not None."
def if_exist(var):
return os.path.exists(
os.path.join(args.teacher_pretrained_model, var.name)
) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0'
fluid.io.load_vars(
exe,
args.teacher_pretrained_model,
main_program=teacher_program,
predicate=if_exist)
with fluid.program_guard(teacher_program, t_startup):
with fluid.unique_name.guard():
image = fluid.layers.data(
name='image', shape=image_shape, dtype='float32')
predict = teacher_model.net(image, class_dim=class_dim)
#print("="*50+"teacher_model_params"+"="*50)
#for v in teacher_program.list_vars():
# print(v.name, v.shape)
exe.run(t_startup)
assert args.teacher_pretrained_model and os.path.exists(
args.teacher_pretrained_model
), "teacher_pretrained_model should be set when teacher_model is not None."
def if_exist(var):
return os.path.exists(
os.path.join(args.teacher_pretrained_model, var.name)
) and var.name != 'conv1_weights' and var.name != 'fc_0.w_0' and var.name != 'fc_0.b_0'
fluid.io.load_vars(
exe,
args.teacher_pretrained_model,
main_program=teacher_program,
predicate=if_exist)
data_name_map = {'image': 'image'}
main = merge(
teacher_program,
student_program,
data_name_map,
place,
teacher_scope=teacher_scope)
place)
#print("="*50+"teacher_vars"+"="*50)
#for v in teacher_program.list_vars():
......
......@@ -20,8 +20,7 @@ def merge(teacher_program,
student_program,
data_name_map,
place,
teacher_scope=fluid.global_scope(),
student_scope=fluid.global_scope(),
scope=fluid.global_scope(),
name_prefix='teacher_'):
"""
Merge teacher program into student program and add a uniform prefix to the
......@@ -33,8 +32,7 @@ def merge(teacher_program,
and the student var name
place(fluid.CPUPlace()|fluid.CUDAPlace(N)): This parameter represents
paddle run on which device.
student_scope(Scope): The input student scope
teacher_scope(Scope): The input teacher scope
scope(Scope): The input scope
name_prefix(str): Name prefix added for all vars of the teacher program.
Return(Program): Merged program.
"""
......@@ -50,9 +48,9 @@ def merge(teacher_program,
new_name = name_prefix + teacher_var.name
if not skip_rename:
# scope var rename
scope_var = teacher_scope.var(teacher_var.name).get_tensor()
renamed_scope_var = teacher_scope.var(new_name).get_tensor()
renamed_scope_var.set(np.array(scope_var), place)
old_var = scope.var(teacher_var.name).get_tensor()
renamed_var = scope.var(new_name).get_tensor()
renamed_var.set(np.array(old_var), place)
# program var rename
renamed_var = teacher_program.global_block()._rename_var(
......@@ -60,11 +58,6 @@ def merge(teacher_program,
for teacher_var in teacher_program.list_vars():
if teacher_var.name != 'fetch' and teacher_var.name != 'feed':
# student scope add var
student_scope_var = student_scope.var(teacher_var.name).get_tensor()
teacher_scope_var = teacher_scope.var(teacher_var.name).get_tensor()
student_scope_var.set(np.array(teacher_scope_var), place)
# student program add var
new_var = student_program.global_block()._clone_variable(
teacher_var, force_persistable=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册