diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index b84f011296ddc3f0402a2e2c94a2ff4418709b44..f13b07db9b290c954f3114bf3cff1936ecdff33d 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -543,6 +543,7 @@ kernel : data_type: x func : gather_grad + composite : gather_grad(x, index, out_grad, axis, overwrite) no_need_buffer : x - backward_op : group_norm_grad diff --git a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_gather_grad.py b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_gather_grad.py index 284620ba76ae364c3f7e6dcf31618fafa0d60250..1f89b024e819b4757a20ee283ce0573f9aa39e94 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_gather_grad.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/vjp/static/test_comp_gather_grad.py @@ -113,9 +113,15 @@ class TestGatherGradComp(unittest.TestCase): def test_cinn(self): paddle.disable_static() + use_cinn = True + if isinstance( + framework._current_expected_place(), framework.core.CPUPlace + ): + # TODO(jiabin): CINN will crashed in this case open it when fixed + use_cinn = False dy_res = self.train(use_prim=False, use_cinn=False) - # TODO(jiabin): CINN will crashed in this case open it when fixed - comp_st_cinn_res = self.train(use_prim=True, use_cinn=False) + + comp_st_cinn_res = self.train(use_prim=True, use_cinn=use_cinn) for i in range(len(dy_res)): np.testing.assert_allclose(