diff --git a/paddle_hub/tools/paddle_helper.py b/paddle_hub/tools/paddle_helper.py index ccc6f0ce4996ba0824dd32e6b4039caaca429263..e1555edec9a076137caf3fc7d53cc964ef27cc4b 100644 --- a/paddle_hub/tools/paddle_helper.py +++ b/paddle_hub/tools/paddle_helper.py @@ -120,7 +120,7 @@ def from_flexible_data_to_param(flexible_data): return param -def connect_program(pre_program, next_program, input_dict=None): +def connect_program(pre_program, next_program, input_dict=None, inplace=True): def _copy_vars_and_ops_in_blocks(from_block, to_block): for var in from_block.vars: var = from_block.var(var) @@ -149,7 +149,8 @@ def connect_program(pre_program, next_program, input_dict=None): fluid.Program), "pre_program should be fluid.Program" assert isinstance(next_program, fluid.Program), "next_program should be fluid.Program" - new_program = pre_program.clone() + output_program = pre_program if inplace else pre_program.clone( + for_test=False) if input_dict: assert isinstance( input_dict, @@ -159,11 +160,11 @@ def connect_program(pre_program, next_program, input_dict=None): var, fluid.framework.Variable ), "the input_dict should be a dict with string-Variable pair" var_info = copy.deepcopy(get_variable_info(var)) - input_var = new_program.global_block().create_var(**var_info) + input_var = output_program.global_block().create_var(**var_info) output_var = next_program.global_block().var(key) var_info = copy.deepcopy(get_variable_info(output_var)) - output_var = new_program.global_block().create_var(**var_info) - new_program.global_block().append_op( + output_var = output_program.global_block().create_var(**var_info) + output_program.global_block().append_op( type="assign", inputs={'X': input_var}, outputs={'Out': output_var}) @@ -172,17 +173,17 @@ def connect_program(pre_program, next_program, input_dict=None): logger.info("start to connect program") for index, block in enumerate(next_program.blocks): if block.idx == 0: - _copy_vars_and_ops_in_blocks(block, new_program.global_block()) + _copy_vars_and_ops_in_blocks(block, output_program.global_block()) else: - block_map[index] = len(new_program.blocks) + block_map[index] = len(output_program.blocks) logger.info( "block_%d in next_program merge into block_%d in pre_program" % (index, block_map[index])) - new_block = new_program._create_block( + new_block = output_program._create_block( parent_idx=block_map[block.parent_idx]) _copy_vars_and_ops_in_blocks(block, new_block) logger.info("end of connect program") - return new_program + return output_program def remove_feed_fetch_op(program):