未验证 提交 dd9a04b7 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】modify split vjp Concat can't get one input "split_0_tmp_1@GRAD" (#54132)

* modify gradOpMaker

* modify concat bug

* modify concat bug

* delete unnecessary block

* modify fill_any_like value from 1 to 0

* modify fill_any_like dtype from other to -1

* ci_bug
上级 14425c06
......@@ -751,29 +751,32 @@ def _remove_no_grad_branch_(
]
# Insert fill_any_like_op with value 0
to_insert = []
for idx, op_desc in enumerate(op_descs):
for arg in op_desc.input_arg_names():
# arg is a gradient var name and arg should not have gradient
if core.grad_var_suffix() in arg and arg in no_grad_set:
x_in = _strip_grad_suffix_(arg)
# the reason should be: arg can be input of another grad op
# and the op is a not-to-remove op
new_op_desc = _create_op_desc_(
"fill_any_like",
{"X": [x_in]},
{"Out": [arg]},
{'value': 0, 'dtype': -1},
)
# update the mapping between fwd and bwd
if (
grad_op_id_to_fwd_op is not None
and grad_op_id_to_fwd_op.get(op_desc.original_id(), None)
is not None
):
grad_op_id_to_fwd_op[
new_op_desc.original_id()
] = grad_op_id_to_fwd_op[op_desc.original_id()]
to_insert.append((new_op_desc, idx))
if not core._is_bwd_prim_enabled():
for idx, op_desc in enumerate(op_descs):
for arg in op_desc.input_arg_names():
# arg is a gradient var name and arg should not have gradient
if core.grad_var_suffix() in arg and arg in no_grad_set:
x_in = _strip_grad_suffix_(arg)
# the reason should be: arg can be input of another grad op
# and the op is a not-to-remove op
new_op_desc = _create_op_desc_(
"fill_any_like",
{"X": [x_in]},
{"Out": [arg]},
{'value': 0, 'dtype': -1},
)
# update the mapping between fwd and bwd
if (
grad_op_id_to_fwd_op is not None
and grad_op_id_to_fwd_op.get(
op_desc.original_id(), None
)
is not None
):
grad_op_id_to_fwd_op[
new_op_desc.original_id()
] = grad_op_id_to_fwd_op[op_desc.original_id()]
to_insert.append((new_op_desc, idx))
list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])
......@@ -1349,6 +1352,28 @@ def _append_backward_ops_(
assert isinstance(rename_var_map, dict)
if core._is_bwd_prim_enabled():
grad_name_set = set()
for target in target_vars:
grad_name_set.add(_append_grad_suffix_(target.name))
for op in reversed(block.ops):
if op.type == "fill_any_like":
for out_name in op.desc.output_arg_names():
grad_name_set.add(out_name)
continue
for var_name in op.desc.output_arg_names():
grad_var_name = _append_grad_suffix_(var_name)
if grad_var_name not in grad_name_set:
op_desc = _create_op_desc_(
"fill_any_like",
{"X": [var_name]},
{"Out": [grad_var_name]},
{'value': 0, 'dtype': target_vars[0].dtype},
)
block.desc.append_op().copy_from(op_desc)
break
block.program._sync_with_cpp()
composite_block = program.clone().current_block()
# Create output and infer shape for operators whose output haven't
# been created.
......
......@@ -489,7 +489,7 @@ class TestPrimEvalBranch(unittest.TestCase):
def train(self, use_prim):
core._set_prim_all_enabled(use_prim)
paddle.seed(2022)
net = BatchNorm(2, is_test=True)
net = BatchNorm(2, act="relu", is_test=True)
net = apply_to_static(net, False)
out = net(self.x)
loss = paddle.mean(out)
......
......@@ -42,6 +42,7 @@ class TestCustomVJP(unittest.TestCase):
'uniform_random',
'dropout',
'fill_any_like',
'fill_any_like',
'cast',
'elementwise_mul',
'scale',
......@@ -56,6 +57,7 @@ class TestCustomVJP(unittest.TestCase):
'scale',
'cast',
'fill_constant',
'fill_constant',
'cast',
'elementwise_mul',
'scale',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册