From e16eb22cb4dfc1204ff70b75f6f1f147c289f852 Mon Sep 17 00:00:00 2001 From: zxcd <228587199@qq.com> Date: Thu, 30 Mar 2023 14:53:09 +0800 Subject: [PATCH] add scatter composite rule. (#52005) * add scatter composite rule. * add public_python_api * add python unit16 support. * fix code style. * add cinn to makelist * cinn unsupport uint16, forbidden cinn when dtype==uint16. --- paddle/fluid/prim/api/api.yaml | 1 + .../composite_backward_api.h | 21 ++++ paddle/phi/api/yaml/backward.yaml | 1 + .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_scatter_op.py | 116 ++++++++++++++---- python/paddle/tensor/manipulation.py | 2 +- 6 files changed, 115 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/prim/api/api.yaml b/paddle/fluid/prim/api/api.yaml index c9bc040a1ad..d04cd2d9d56 100644 --- a/paddle/fluid/prim/api/api.yaml +++ b/paddle/fluid/prim/api/api.yaml @@ -23,6 +23,7 @@ - concat - elementwise_pow - floor +- gather - gather_nd - log - max diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 58617f9b36d..106eee9ff0a 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1222,6 +1222,27 @@ void cos_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) { set_output(x_grad_tmp, x_grad); } +template +void scatter_grad(const Tensor& index, + const Tensor& updates, + const Tensor& out_grad, + bool overwrite, + Tensor* x_grad, + Tensor* updates_grad) { + if (x_grad) { + auto zero_tensor = + full(phi::vectorize(updates.dims()), 0.0, updates.dtype()); + auto tmp_grad = scatter(out_grad, index, zero_tensor, false); + set_output(tmp_grad, x_grad); + } + + if (updates_grad) { + Scalar tmp_zero = 0; + auto tmp_updates_grad = gather(out_grad, index, tmp_zero); + set_output(tmp_updates_grad, updates_grad); + } +} + template void batch_norm_grad(const Tensor& x, const Tensor& scale, diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 68ef92c2626..92beb701e5d 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -1318,6 +1318,7 @@ kernel : func : scatter_grad no_need_buffer : updates + composite: scatter_grad(index, updates, out_grad, overwrite, x_grad, updates_grad) - backward_op : scatter_nd_add_grad forward : scatter_nd_add (Tensor x, Tensor index, Tensor updates) -> Tensor(out) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 683d0d6cbda..ad435c67783 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1197,6 +1197,7 @@ set(TEST_CINN_OPS test_mean_op test_unsqueeze2_op test_meshgrid_op + test_scatter_op test_gather_op test_cast_op test_dropout_op diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index 5d3a7d18e8e..34c30e6591d 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -28,6 +28,8 @@ class TestScatterOp(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 50)).astype(target_dtype) @@ -46,10 +48,10 @@ class TestScatterOp(OpTest): self.dtype = np.float32 def test_check_output(self): - self.check_output() + self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out") + self.check_grad(["X", "Updates"], "Out", check_prim=True) class TestScatterFP16Op(TestScatterOp): @@ -65,22 +67,30 @@ class TestScatterFP16Op(TestScatterOp): class TestScatterBF16Op(TestScatterOp): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Updates'], 'Out') + self.check_grad_with_place( + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + ) class TestScatterOp0(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -100,10 +110,10 @@ class TestScatterOp0(OpTest): self.dtype = np.float32 def test_check_output(self): - self.check_output() + self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out") + self.check_grad(["X", "Updates"], "Out", check_prim=True) class TestScatterFP16Op0(TestScatterOp0): @@ -119,22 +129,30 @@ class TestScatterFP16Op0(TestScatterOp0): class TestScatterBF16Op0(TestScatterOp0): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Updates'], 'Out') + self.check_grad_with_place( + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + ) class TestScatterOp1(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -157,10 +175,10 @@ class TestScatterOp1(OpTest): self.dtype = np.float32 def test_check_output(self): - self.check_output() + self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out") + self.check_grad(["X", "Updates"], "Out", check_prim=True) class TestScatterFP16Op1(TestScatterOp1): @@ -176,16 +194,22 @@ class TestScatterFP16Op1(TestScatterOp1): class TestScatterBF16Op1(TestScatterOp1): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Updates'], 'Out') + self.check_grad_with_place( + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + ) @unittest.skipIf( @@ -195,6 +219,8 @@ class TestScatterOp2(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -215,12 +241,17 @@ class TestScatterOp2(OpTest): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place, atol=1e-3, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Updates'], 'Out') + self.check_grad_with_place( + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + ) @unittest.skipIf( @@ -239,6 +270,7 @@ class TestScatterFP16Op2(TestScatterOp2): class TestScatterBF16Op2(TestScatterOp2): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False @unittest.skipIf( @@ -248,6 +280,8 @@ class TestScatterOp3(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -272,12 +306,17 @@ class TestScatterOp3(OpTest): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place, atol=1e-3, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Updates'], 'Out') + self.check_grad_with_place( + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + ) @unittest.skipIf( @@ -296,12 +335,15 @@ class TestScatterFP16Op3(TestScatterOp3): class TestScatterBF16Op3(TestScatterOp3): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False class TestScatterOp4(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -320,10 +362,10 @@ class TestScatterOp4(OpTest): self.dtype = np.float32 def test_check_output(self): - self.check_output() + self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad(['X', 'Updates'], 'Out') + self.check_grad(['X', 'Updates'], 'Out', check_prim=True) class TestScatterFP16Op4(TestScatterOp4): @@ -339,16 +381,22 @@ class TestScatterFP16Op4(TestScatterOp4): class TestScatterBF16Op4(TestScatterOp4): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Updates'], 'Out') + self.check_grad_with_place( + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + ) @unittest.skipIf( @@ -358,6 +406,8 @@ class TestScatterOp5(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter + self.prim_op_type = "prim" self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 3)).astype(target_dtype) @@ -378,12 +428,17 @@ class TestScatterOp5(OpTest): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-3) + self.check_output_with_place(place, atol=1e-3, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Updates'], 'Out') + self.check_grad_with_place( + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + ) @unittest.skipIf( @@ -402,12 +457,16 @@ class TestScatterFP16Op5(TestScatterOp5): class TestScatterBF16Op5(TestScatterOp5): def _set_dtype(self): self.dtype = np.uint16 + self.enable_cinn = False class TestScatterOp6(OpTest): def setUp(self): self.op_type = "scatter" self.python_api = paddle.scatter + self.public_python_api = paddle.scatter + self.prim_op_type = "prim" + self.enable_cinn = False self._set_dtype() target_dtype = "float16" if self.dtype == np.float16 else "float32" ref_np = np.ones((3, 50)).astype(target_dtype) @@ -426,10 +485,10 @@ class TestScatterOp6(OpTest): self.dtype = np.float32 def test_check_output(self): - self.check_output() + self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad(["X", "Updates"], "Out") + self.check_grad(["X", "Updates"], "Out", check_prim=True) class TestScatterFP16Op6(TestScatterOp6): @@ -449,12 +508,17 @@ class TestScatterBF16Op6(TestScatterOp6): def test_check_output(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_prim=True) def test_check_grad(self): if core.is_compiled_with_cuda(): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X', 'Updates'], 'Out') + self.check_grad_with_place( + place, + ['X', 'Updates'], + 'Out', + check_prim=True, + ) class TestScatterAPI(unittest.TestCase): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index b6fda0fd189..a71c997f018 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2903,7 +2903,7 @@ def scatter(x, index, updates, overwrite=True, name=None): check_variable_and_dtype( x, 'dtype', - ['float32', 'float64', 'float16', 'int32', 'int64'], + ['float32', 'float64', 'float16', 'int32', 'int64', 'uint16'], 'scatter', ) check_type(overwrite, 'overwrite', bool, 'scatter') -- GitLab