未验证 提交 e152e891 编写于 作者: J Jiabin Yang 提交者: GitHub

fix attrs copy error (#51056)

上级 51aa2129
......@@ -655,16 +655,10 @@ def _lower_composite(block, blacklist=[]):
for i in range(len(op.output_names)):
outputs[op.output_names[i]] = op.output(op.output_names[i])
attrs = {}
# When copying op, all attrs defined in api should be kept.But op.attr_names is not complete here.
# Thus, all attrs should be got from init attrs of origin op.
runtime_attrs = op._get_runtime_attrs()
for name in runtime_attrs.keys():
attrs[name] = runtime_attrs[name]
from paddle.fluid.dygraph.base import param_guard
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(op.desc)
with param_guard(inputs), param_guard(outputs):
op = Operator(
block=block,
......@@ -672,7 +666,7 @@ def _lower_composite(block, blacklist=[]):
type=op.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
attrs=None,
)
block.ops.append(op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册