diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 5e1e490c1b73adbfb68f6b5d5aaab2263da6ee8f..099ebc81b900bbcded9aa82673eebb37163ad20c 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1805,5 +1805,21 @@ void roll_grad(const Tensor& x, set_output(x_grad_output, x_grad); } } + +template +void scatter_nd_add_grad(const Tensor& index, + const Tensor& updates, + const Tensor& out_grad, + Tensor* x_grad, + Tensor* updates_grad) { + if (x_grad) { + by_pass(out_grad, x_grad); + } + if (updates_grad) { + // Gradient by Gather: dUpdates = dO[Ids] + auto tmp_updates_grad = gather_nd(out_grad, index); + set_output(tmp_updates_grad, updates_grad); + } +} } // namespace prim } // namespace paddle diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 2394182ee4bd1ccaf1be30af5d40d0066132f6de..6faf2d0ba7a49b607830f56202edbf716b339ee7 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1454,6 +1454,7 @@ kernel : func : scatter_nd_add_grad no_need_buffer : updates + composite: scatter_nd_add_grad(index, updates, out_grad, x_grad, updates_grad) - backward_op : segment_pool_grad forward : segment_pool (Tensor x, Tensor segment_ids, str pooltype="SUM") -> Tensor(out), Tensor(summed_ids) diff --git a/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py b/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py index 0d09e0af5c32a79b4fcd4f7c5e4f1f5fe9e1d338..66799466c59e4b6356fc33742b281c2722a45edc 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_nd_op.py @@ -69,6 +69,8 @@ class TestScatterNdAddSimpleOp(OpTest): def setUp(self): self.op_type = "scatter_nd_add" self.python_api = paddle.scatter_nd_add + self.public_python_api = paddle.scatter_nd_add + self.prim_op_type = "prim" self._set_dtype() if self.dtype == np.float64: target_dtype = "float64" @@ -94,7 +96,7 @@ class TestScatterNdAddSimpleOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out') + self.check_grad(['X', 'Updates'], 'Out', check_prim=True) class TestScatterNdAddSimpleFP16Op(TestScatterNdAddSimpleOp): @@ -127,7 +129,9 @@ class TestScatterNdAddSimpleBF16Op(TestScatterNdAddSimpleOp): def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Updates'], 'Out') + self.check_grad_with_place( + place, ['X', 'Updates'], 'Out', check_prim=True + ) class TestScatterNdAddWithEmptyIndex(OpTest): @@ -138,6 +142,8 @@ class TestScatterNdAddWithEmptyIndex(OpTest): def setUp(self): self.op_type = "scatter_nd_add" self.python_api = paddle.scatter_nd_add + self.public_python_api = paddle.scatter_nd_add + self.prim_op_type = "prim" self._set_dtype() if self.dtype == np.float64: target_dtype = "float64" @@ -166,7 +172,7 @@ class TestScatterNdAddWithEmptyIndex(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out') + self.check_grad(['X', 'Updates'], 'Out', check_prim=True) class TestScatterNdAddWithEmptyIndexFP16(TestScatterNdAddWithEmptyIndex): @@ -199,7 +205,9 @@ class TestScatterNdAddWithEmptyIndexBF16(TestScatterNdAddWithEmptyIndex): def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Updates'], 'Out') + self.check_grad_with_place( + place, ['X', 'Updates'], 'Out', check_prim=True + ) class TestScatterNdAddWithHighRankSame(OpTest): @@ -210,6 +218,8 @@ class TestScatterNdAddWithHighRankSame(OpTest): def setUp(self): self.op_type = "scatter_nd_add" self.python_api = paddle.scatter_nd_add + self.public_python_api = paddle.scatter_nd_add + self.prim_op_type = "prim" self._set_dtype() if self.dtype == np.float64: target_dtype = "float64" @@ -241,7 +251,7 @@ class TestScatterNdAddWithHighRankSame(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out') + self.check_grad(['X', 'Updates'], 'Out', check_prim=True) class TestScatterNdAddWithHighRankSameFP16(TestScatterNdAddWithHighRankSame): @@ -274,7 +284,9 @@ class TestScatterNdAddWithHighRankSameBF16(TestScatterNdAddWithHighRankSame): def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Updates'], 'Out') + self.check_grad_with_place( + place, ['X', 'Updates'], 'Out', check_prim=True + ) class TestScatterNdAddWithHighRankDiff(OpTest): @@ -285,6 +297,8 @@ class TestScatterNdAddWithHighRankDiff(OpTest): def setUp(self): self.op_type = "scatter_nd_add" self.python_api = paddle.scatter_nd_add + self.public_python_api = paddle.scatter_nd_add + self.prim_op_type = "prim" shape = (8, 2, 2, 1, 10) ref_np = np.random.rand(*shape).astype("double") index = np.vstack([np.random.randint(0, s, size=500) for s in shape]).T @@ -300,7 +314,7 @@ class TestScatterNdAddWithHighRankDiff(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out') + self.check_grad(['X', 'Updates'], 'Out', check_prim=True) # Test Python API @@ -310,70 +324,76 @@ class TestScatterNdOpAPI(unittest.TestCase): """ def testcase1(self): - ref1 = paddle.static.data( - name='ref1', - shape=[10, 9, 8, 1, 3], - dtype='float32', - ) - index1 = paddle.static.data( - name='index1', - shape=[5, 5, 8, 5], - dtype='int32', - ) - updates1 = paddle.static.data( - name='update1', - shape=[5, 5, 8], - dtype='float32', - ) - output1 = paddle.scatter_nd_add(ref1, index1, updates1) + with paddle.fluid.framework._static_guard(): + ref1 = paddle.static.data( + name='ref1', + shape=[10, 9, 8, 1, 3], + dtype='float32', + ) + index1 = paddle.static.data( + name='index1', + shape=[5, 5, 8, 5], + dtype='int32', + ) + updates1 = paddle.static.data( + name='update1', + shape=[5, 5, 8], + dtype='float32', + ) + output1 = paddle.scatter_nd_add(ref1, index1, updates1) def testcase2(self): - ref2 = paddle.static.data( - name='ref2', - shape=[10, 9, 8, 1, 3], - dtype='double', - ) - index2 = paddle.static.data( - name='index2', - shape=[5, 8, 5], - dtype='int32', - ) - updates2 = paddle.static.data( - name='update2', - shape=[5, 8], - dtype='double', - ) - output2 = paddle.scatter_nd_add( - ref2, index2, updates2, name="scatter_nd_add" - ) + with paddle.fluid.framework._static_guard(): + ref2 = paddle.static.data( + name='ref2', + shape=[10, 9, 8, 1, 3], + dtype='double', + ) + index2 = paddle.static.data( + name='index2', + shape=[5, 8, 5], + dtype='int32', + ) + updates2 = paddle.static.data( + name='update2', + shape=[5, 8], + dtype='double', + ) + output2 = paddle.scatter_nd_add( + ref2, index2, updates2, name="scatter_nd_add" + ) def testcase3(self): - shape3 = [10, 9, 8, 1, 3] - index3 = paddle.static.data( - name='index3', - shape=[5, 5, 8, 5], - dtype='int32', - ) - updates3 = paddle.static.data( - name='update3', - shape=[5, 5, 8], - dtype='float32', - ) - output3 = paddle.scatter_nd(index3, updates3, shape3) + with paddle.fluid.framework._static_guard(): + shape3 = [10, 9, 8, 1, 3] + index3 = paddle.static.data( + name='index3', + shape=[5, 5, 8, 5], + dtype='int32', + ) + updates3 = paddle.static.data( + name='update3', + shape=[5, 5, 8], + dtype='float32', + ) + output3 = paddle.scatter_nd(index3, updates3, shape3) def testcase4(self): - shape4 = [10, 9, 8, 1, 3] - index4 = paddle.static.data( - name='index4', - shape=[5, 5, 8, 5], - dtype='int32', - ) - updates4 = paddle.static.data( - name='update4', - shape=[5, 5, 8], - dtype='double', - ) - output4 = paddle.scatter_nd(index4, updates4, shape4, name='scatter_nd') + with paddle.fluid.framework._static_guard(): + shape4 = [10, 9, 8, 1, 3] + index4 = paddle.static.data( + name='index4', + shape=[5, 5, 8, 5], + dtype='int32', + ) + updates4 = paddle.static.data( + name='update4', + shape=[5, 5, 8], + dtype='double', + ) + output4 = paddle.scatter_nd( + index4, updates4, shape4, name='scatter_nd' + ) def testcase5(self): if not fluid.core.is_compiled_with_cuda(): @@ -430,60 +450,65 @@ class TestScatterNdOpAPI(unittest.TestCase): class TestScatterNdOpRaise(unittest.TestCase): def test_check_raise(self): def check_raise_is_test(): - try: - ref5 = paddle.static.data( - name='ref5', shape=[-1, 3, 4, 5], dtype='float32' - ) - index5 = paddle.static.data( - name='index5', shape=[-1, 2, 10], dtype='int32' - ) - updates5 = paddle.static.data( - name='updates5', shape=[-1, 2, 10], dtype='float32' - ) - output5 = paddle.scatter_nd_add(ref5, index5, updates5) - except Exception as e: - t = "The last dimension of Input(Index)'s shape should be no greater " - if t in str(e): - raise IndexError + with paddle.fluid.framework._static_guard(): + try: + ref5 = paddle.static.data( + name='ref5', shape=[-1, 3, 4, 5], dtype='float32' + ) + index5 = paddle.static.data( + name='index5', shape=[-1, 2, 10], dtype='int32' + ) + updates5 = paddle.static.data( + name='updates5', shape=[-1, 2, 10], dtype='float32' + ) + output5 = paddle.scatter_nd_add(ref5, index5, updates5) + except Exception as e: + t = "The last dimension of Input(Index)'s shape should be no greater " + if t in str(e): + raise IndexError self.assertRaises(IndexError, check_raise_is_test) def test_check_raise2(self): with self.assertRaises(ValueError): - ref6 = paddle.static.data( - name='ref6', - shape=[10, 9, 8, 1, 3], - dtype='double', - ) - index6 = paddle.static.data( - name='index6', - shape=[5, 8, 5], - dtype='int32', - ) - updates6 = paddle.static.data( - name='update6', - shape=[5, 8], - dtype='float32', - ) - output6 = paddle.scatter_nd_add(ref6, index6, updates6) - - def test_check_raise3(self): - def check_raise_is_test(): - try: - shape = [3, 4, 5] - index7 = paddle.static.data( - name='index7', shape=[-1, 2, 1], dtype='int32' + with paddle.fluid.framework._static_guard(): + ref6 = paddle.static.data( + name='ref6', + shape=[10, 9, 8, 1, 3], + dtype='double', + ) + index6 = paddle.static.data( + name='index6', + shape=[5, 8, 5], + dtype='int32', ) - updates7 = paddle.static.data( - name='updates7', shape=[-1, 2, 4, 5, 20], dtype='float32' + updates6 = paddle.static.data( + name='update6', + shape=[5, 8], + dtype='float32', ) - output7 = paddle.scatter_nd(index7, updates7, shape) - except Exception as e: - t = "Updates has wrong shape" - if t in str(e): - raise ValueError + output6 = paddle.scatter_nd_add(ref6, index6, updates6) - self.assertRaises(ValueError, check_raise_is_test) + def test_check_raise3(self): + def check_raise_is_test(): + with paddle.fluid.framework._static_guard(): + try: + shape = [3, 4, 5] + index7 = paddle.static.data( + name='index7', shape=[-1, 2, 1], dtype='int32' + ) + updates7 = paddle.static.data( + name='updates7', + shape=[-1, 2, 4, 5, 20], + dtype='float32', + ) + output7 = paddle.scatter_nd(index7, updates7, shape) + except Exception as e: + t = "Updates has wrong shape" + if t in str(e): + raise ValueError + + self.assertRaises(ValueError, check_raise_is_test) class TestDygraph(unittest.TestCase):