diff --git a/paddlehub/common/paddle_helper.py b/paddlehub/common/paddle_helper.py index 751bdd7246998069684fb3dd5a7bc659c8b94cfe..0c665b2f61b7e4d9035b69ccfb58ef4069265d5f 100644 --- a/paddlehub/common/paddle_helper.py +++ b/paddlehub/common/paddle_helper.py @@ -142,34 +142,36 @@ def from_module_attr_to_param(module_attr): return param +def _copy_vars_and_ops_in_blocks(from_block, to_block): + for var in from_block.vars: + var = from_block.var(var) + var_info = copy.deepcopy(get_variable_info(var)) + if isinstance(var, fluid.framework.Parameter): + to_block.create_parameter(**var_info) + else: + to_block.create_var(**var_info) + + for op in from_block.ops: + op_info = { + 'type': op.type, + 'inputs': { + input: [to_block.var(var) for var in op.input(input)] + for input in op.input_names + }, + 'outputs': { + output: [to_block.var(var) for var in op.output(output)] + for output in op.output_names + }, + 'attrs': copy.deepcopy(op.all_attrs()) + } + to_block.append_op(**op_info) + + def connect_program(pre_program, next_program, input_dict=None, inplace=True, need_log=True): - def _copy_vars_and_ops_in_blocks(from_block, to_block): - for var in from_block.vars: - var = from_block.var(var) - var_info = copy.deepcopy(get_variable_info(var)) - if isinstance(var, fluid.framework.Parameter): - to_block.create_parameter(**var_info) - else: - to_block.create_var(**var_info) - - for op in from_block.ops: - op_info = { - 'type': op.type, - 'inputs': { - input: [block.var(var) for var in op.input(input)] - for input in op.input_names - }, - 'outputs': { - output: [block.var(var) for var in op.output(output)] - for output in op.output_names - }, - 'attrs': copy.deepcopy(op.all_attrs()) - } - to_block.append_op(**op_info) if not isinstance(pre_program, fluid.Program): raise TypeError("pre_program shoule be an instance of fluid.Program") @@ -268,7 +270,10 @@ def set_op_attr(program, is_test=False): def clone_program(origin_program, for_test=False): - dest_program = origin_program.clone(for_test=for_test) + dest_program = fluid.Program() + _copy_vars_and_ops_in_blocks(origin_program.global_block(), + dest_program.global_block()) + dest_program = dest_program.clone(for_test=for_test) if not for_test: for name, var in origin_program.global_block().vars.items(): dest_program.global_block(