未验证 提交 dd9a04b7 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【prim】modify split vjp Concat can't get one input "split_0_tmp_1@GRAD" (#54132)

* modify gradOpMaker

* modify concat bug

* modify concat bug

* delete unnecessary block

* modify fill_any_like value from 1 to 0

* modify fill_any_like dtype from other to -1

* ci_bug
上级 14425c06
...@@ -751,29 +751,32 @@ def _remove_no_grad_branch_( ...@@ -751,29 +751,32 @@ def _remove_no_grad_branch_(
] ]
# Insert fill_any_like_op with value 0 # Insert fill_any_like_op with value 0
to_insert = [] to_insert = []
for idx, op_desc in enumerate(op_descs): if not core._is_bwd_prim_enabled():
for arg in op_desc.input_arg_names(): for idx, op_desc in enumerate(op_descs):
# arg is a gradient var name and arg should not have gradient for arg in op_desc.input_arg_names():
if core.grad_var_suffix() in arg and arg in no_grad_set: # arg is a gradient var name and arg should not have gradient
x_in = _strip_grad_suffix_(arg) if core.grad_var_suffix() in arg and arg in no_grad_set:
# the reason should be: arg can be input of another grad op x_in = _strip_grad_suffix_(arg)
# and the op is a not-to-remove op # the reason should be: arg can be input of another grad op
new_op_desc = _create_op_desc_( # and the op is a not-to-remove op
"fill_any_like", new_op_desc = _create_op_desc_(
{"X": [x_in]}, "fill_any_like",
{"Out": [arg]}, {"X": [x_in]},
{'value': 0, 'dtype': -1}, {"Out": [arg]},
) {'value': 0, 'dtype': -1},
# update the mapping between fwd and bwd )
if ( # update the mapping between fwd and bwd
grad_op_id_to_fwd_op is not None if (
and grad_op_id_to_fwd_op.get(op_desc.original_id(), None) grad_op_id_to_fwd_op is not None
is not None and grad_op_id_to_fwd_op.get(
): op_desc.original_id(), None
grad_op_id_to_fwd_op[ )
new_op_desc.original_id() is not None
] = grad_op_id_to_fwd_op[op_desc.original_id()] ):
to_insert.append((new_op_desc, idx)) grad_op_id_to_fwd_op[
new_op_desc.original_id()
] = grad_op_id_to_fwd_op[op_desc.original_id()]
to_insert.append((new_op_desc, idx))
list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)]) list([op_descs.insert(p[1], p[0]) for p in reversed(to_insert)])
...@@ -1349,6 +1352,28 @@ def _append_backward_ops_( ...@@ -1349,6 +1352,28 @@ def _append_backward_ops_(
assert isinstance(rename_var_map, dict) assert isinstance(rename_var_map, dict)
if core._is_bwd_prim_enabled(): if core._is_bwd_prim_enabled():
grad_name_set = set()
for target in target_vars:
grad_name_set.add(_append_grad_suffix_(target.name))
for op in reversed(block.ops):
if op.type == "fill_any_like":
for out_name in op.desc.output_arg_names():
grad_name_set.add(out_name)
continue
for var_name in op.desc.output_arg_names():
grad_var_name = _append_grad_suffix_(var_name)
if grad_var_name not in grad_name_set:
op_desc = _create_op_desc_(
"fill_any_like",
{"X": [var_name]},
{"Out": [grad_var_name]},
{'value': 0, 'dtype': target_vars[0].dtype},
)
block.desc.append_op().copy_from(op_desc)
break
block.program._sync_with_cpp()
composite_block = program.clone().current_block() composite_block = program.clone().current_block()
# Create output and infer shape for operators whose output haven't # Create output and infer shape for operators whose output haven't
# been created. # been created.
......
...@@ -489,7 +489,7 @@ class TestPrimEvalBranch(unittest.TestCase): ...@@ -489,7 +489,7 @@ class TestPrimEvalBranch(unittest.TestCase):
def train(self, use_prim): def train(self, use_prim):
core._set_prim_all_enabled(use_prim) core._set_prim_all_enabled(use_prim)
paddle.seed(2022) paddle.seed(2022)
net = BatchNorm(2, is_test=True) net = BatchNorm(2, act="relu", is_test=True)
net = apply_to_static(net, False) net = apply_to_static(net, False)
out = net(self.x) out = net(self.x)
loss = paddle.mean(out) loss = paddle.mean(out)
......
...@@ -42,6 +42,7 @@ class TestCustomVJP(unittest.TestCase): ...@@ -42,6 +42,7 @@ class TestCustomVJP(unittest.TestCase):
'uniform_random', 'uniform_random',
'dropout', 'dropout',
'fill_any_like', 'fill_any_like',
'fill_any_like',
'cast', 'cast',
'elementwise_mul', 'elementwise_mul',
'scale', 'scale',
...@@ -56,6 +57,7 @@ class TestCustomVJP(unittest.TestCase): ...@@ -56,6 +57,7 @@ class TestCustomVJP(unittest.TestCase):
'scale', 'scale',
'cast', 'cast',
'fill_constant', 'fill_constant',
'fill_constant',
'cast', 'cast',
'elementwise_mul', 'elementwise_mul',
'scale', 'scale',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册