提交 b078ad1d 编写于 作者: W wuzewu

Fix op attr issue

上级 99390835
......@@ -142,34 +142,36 @@ def from_module_attr_to_param(module_attr):
return param
def _copy_vars_and_ops_in_blocks(from_block, to_block):
for var in from_block.vars:
var = from_block.var(var)
var_info = copy.deepcopy(get_variable_info(var))
if isinstance(var, fluid.framework.Parameter):
to_block.create_parameter(**var_info)
else:
to_block.create_var(**var_info)
for op in from_block.ops:
op_info = {
'type': op.type,
'inputs': {
input: [to_block.var(var) for var in op.input(input)]
for input in op.input_names
},
'outputs': {
output: [to_block.var(var) for var in op.output(output)]
for output in op.output_names
},
'attrs': copy.deepcopy(op.all_attrs())
}
to_block.append_op(**op_info)
def connect_program(pre_program,
next_program,
input_dict=None,
inplace=True,
need_log=True):
def _copy_vars_and_ops_in_blocks(from_block, to_block):
for var in from_block.vars:
var = from_block.var(var)
var_info = copy.deepcopy(get_variable_info(var))
if isinstance(var, fluid.framework.Parameter):
to_block.create_parameter(**var_info)
else:
to_block.create_var(**var_info)
for op in from_block.ops:
op_info = {
'type': op.type,
'inputs': {
input: [block.var(var) for var in op.input(input)]
for input in op.input_names
},
'outputs': {
output: [block.var(var) for var in op.output(output)]
for output in op.output_names
},
'attrs': copy.deepcopy(op.all_attrs())
}
to_block.append_op(**op_info)
if not isinstance(pre_program, fluid.Program):
raise TypeError("pre_program shoule be an instance of fluid.Program")
......@@ -268,7 +270,10 @@ def set_op_attr(program, is_test=False):
def clone_program(origin_program, for_test=False):
dest_program = origin_program.clone(for_test=for_test)
dest_program = fluid.Program()
_copy_vars_and_ops_in_blocks(origin_program.global_block(),
dest_program.global_block())
dest_program = dest_program.clone(for_test=for_test)
if not for_test:
for name, var in origin_program.global_block().vars.items():
dest_program.global_block(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册