提交 b078ad1d 编写于 作者: W wuzewu

Fix op attr issue

上级 99390835
...@@ -142,12 +142,7 @@ def from_module_attr_to_param(module_attr): ...@@ -142,12 +142,7 @@ def from_module_attr_to_param(module_attr):
return param return param
def connect_program(pre_program, def _copy_vars_and_ops_in_blocks(from_block, to_block):
next_program,
input_dict=None,
inplace=True,
need_log=True):
def _copy_vars_and_ops_in_blocks(from_block, to_block):
for var in from_block.vars: for var in from_block.vars:
var = from_block.var(var) var = from_block.var(var)
var_info = copy.deepcopy(get_variable_info(var)) var_info = copy.deepcopy(get_variable_info(var))
...@@ -160,17 +155,24 @@ def connect_program(pre_program, ...@@ -160,17 +155,24 @@ def connect_program(pre_program,
op_info = { op_info = {
'type': op.type, 'type': op.type,
'inputs': { 'inputs': {
input: [block.var(var) for var in op.input(input)] input: [to_block.var(var) for var in op.input(input)]
for input in op.input_names for input in op.input_names
}, },
'outputs': { 'outputs': {
output: [block.var(var) for var in op.output(output)] output: [to_block.var(var) for var in op.output(output)]
for output in op.output_names for output in op.output_names
}, },
'attrs': copy.deepcopy(op.all_attrs()) 'attrs': copy.deepcopy(op.all_attrs())
} }
to_block.append_op(**op_info) to_block.append_op(**op_info)
def connect_program(pre_program,
next_program,
input_dict=None,
inplace=True,
need_log=True):
if not isinstance(pre_program, fluid.Program): if not isinstance(pre_program, fluid.Program):
raise TypeError("pre_program shoule be an instance of fluid.Program") raise TypeError("pre_program shoule be an instance of fluid.Program")
...@@ -268,7 +270,10 @@ def set_op_attr(program, is_test=False): ...@@ -268,7 +270,10 @@ def set_op_attr(program, is_test=False):
def clone_program(origin_program, for_test=False): def clone_program(origin_program, for_test=False):
dest_program = origin_program.clone(for_test=for_test) dest_program = fluid.Program()
_copy_vars_and_ops_in_blocks(origin_program.global_block(),
dest_program.global_block())
dest_program = dest_program.clone(for_test=for_test)
if not for_test: if not for_test:
for name, var in origin_program.global_block().vars.items(): for name, var in origin_program.global_block().vars.items():
dest_program.global_block( dest_program.global_block(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册