From b078ad1d8da1b7cb346eb40be89088d1624bf942 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Wed, 13 Nov 2019 17:41:03 +0800 Subject: [PATCH] Fix op attr issue --- paddlehub/common/paddle_helper.py | 53 +++++++++++++++++-------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/paddlehub/common/paddle_helper.py b/paddlehub/common/paddle_helper.py index 751bdd72..0c665b2f 100644 --- a/paddlehub/common/paddle_helper.py +++ b/paddlehub/common/paddle_helper.py @@ -142,34 +142,36 @@ def from_module_attr_to_param(module_attr): return param +def _copy_vars_and_ops_in_blocks(from_block, to_block): + for var in from_block.vars: + var = from_block.var(var) + var_info = copy.deepcopy(get_variable_info(var)) + if isinstance(var, fluid.framework.Parameter): + to_block.create_parameter(**var_info) + else: + to_block.create_var(**var_info) + + for op in from_block.ops: + op_info = { + 'type': op.type, + 'inputs': { + input: [to_block.var(var) for var in op.input(input)] + for input in op.input_names + }, + 'outputs': { + output: [to_block.var(var) for var in op.output(output)] + for output in op.output_names + }, + 'attrs': copy.deepcopy(op.all_attrs()) + } + to_block.append_op(**op_info) + + def connect_program(pre_program, next_program, input_dict=None, inplace=True, need_log=True): - def _copy_vars_and_ops_in_blocks(from_block, to_block): - for var in from_block.vars: - var = from_block.var(var) - var_info = copy.deepcopy(get_variable_info(var)) - if isinstance(var, fluid.framework.Parameter): - to_block.create_parameter(**var_info) - else: - to_block.create_var(**var_info) - - for op in from_block.ops: - op_info = { - 'type': op.type, - 'inputs': { - input: [block.var(var) for var in op.input(input)] - for input in op.input_names - }, - 'outputs': { - output: [block.var(var) for var in op.output(output)] - for output in op.output_names - }, - 'attrs': copy.deepcopy(op.all_attrs()) - } - to_block.append_op(**op_info) if not isinstance(pre_program, fluid.Program): raise TypeError("pre_program shoule be an instance of fluid.Program") @@ -268,7 +270,10 @@ def set_op_attr(program, is_test=False): def clone_program(origin_program, for_test=False): - dest_program = origin_program.clone(for_test=for_test) + dest_program = fluid.Program() + _copy_vars_and_ops_in_blocks(origin_program.global_block(), + dest_program.global_block()) + dest_program = dest_program.clone(for_test=for_test) if not for_test: for name, var in origin_program.global_block().vars.items(): dest_program.global_block( -- GitLab