提交 40426b6f 编写于 作者: W wuzewu

don't clone the input program in default when connect program

上级 6c167379
......@@ -120,7 +120,7 @@ def from_flexible_data_to_param(flexible_data):
return param
def connect_program(pre_program, next_program, input_dict=None):
def connect_program(pre_program, next_program, input_dict=None, inplace=True):
def _copy_vars_and_ops_in_blocks(from_block, to_block):
for var in from_block.vars:
var = from_block.var(var)
......@@ -149,7 +149,8 @@ def connect_program(pre_program, next_program, input_dict=None):
fluid.Program), "pre_program should be fluid.Program"
assert isinstance(next_program,
fluid.Program), "next_program should be fluid.Program"
new_program = pre_program.clone()
output_program = pre_program if inplace else pre_program.clone(
for_test=False)
if input_dict:
assert isinstance(
input_dict,
......@@ -159,11 +160,11 @@ def connect_program(pre_program, next_program, input_dict=None):
var, fluid.framework.Variable
), "the input_dict should be a dict with string-Variable pair"
var_info = copy.deepcopy(get_variable_info(var))
input_var = new_program.global_block().create_var(**var_info)
input_var = output_program.global_block().create_var(**var_info)
output_var = next_program.global_block().var(key)
var_info = copy.deepcopy(get_variable_info(output_var))
output_var = new_program.global_block().create_var(**var_info)
new_program.global_block().append_op(
output_var = output_program.global_block().create_var(**var_info)
output_program.global_block().append_op(
type="assign",
inputs={'X': input_var},
outputs={'Out': output_var})
......@@ -172,17 +173,17 @@ def connect_program(pre_program, next_program, input_dict=None):
logger.info("start to connect program")
for index, block in enumerate(next_program.blocks):
if block.idx == 0:
_copy_vars_and_ops_in_blocks(block, new_program.global_block())
_copy_vars_and_ops_in_blocks(block, output_program.global_block())
else:
block_map[index] = len(new_program.blocks)
block_map[index] = len(output_program.blocks)
logger.info(
"block_%d in next_program merge into block_%d in pre_program" %
(index, block_map[index]))
new_block = new_program._create_block(
new_block = output_program._create_block(
parent_idx=block_map[block.parent_idx])
_copy_vars_and_ops_in_blocks(block, new_block)
logger.info("end of connect program")
return new_program
return output_program
def remove_feed_fetch_op(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册