diff --git a/paddle/fluid/operators/scatter_op.cu b/paddle/fluid/operators/scatter_op.cu index 1556099d6f11f2722457becd8687486e4a5ee92b..c8dae35822544219cb75f3f79f70c921f1f7c7d9 100644 --- a/paddle/fluid/operators/scatter_op.cu +++ b/paddle/fluid/operators/scatter_op.cu @@ -110,6 +110,11 @@ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(scatter, ops::ScatterOpCUDAKernel, ops::ScatterOpCUDAKernel, ops::ScatterOpCUDAKernel, - ops::ScatterOpCUDAKernel); -REGISTER_OP_CUDA_KERNEL(scatter_grad, ops::ScatterGradOpCUDAKernel, - ops::ScatterGradOpCUDAKernel); + ops::ScatterOpCUDAKernel, + ops::ScatterOpCUDAKernel); + +REGISTER_OP_CUDA_KERNEL( + scatter_grad, ops::ScatterGradOpCUDAKernel, + ops::ScatterGradOpCUDAKernel, ops::ScatterOpCUDAKernel, + ops::ScatterOpCUDAKernel, + ops::ScatterGradOpCUDAKernel); diff --git a/python/paddle/fluid/tests/unittests/test_scatter_op.py b/python/paddle/fluid/tests/unittests/test_scatter_op.py index e58b2279e56535f96abb223e78ba607324e745ee..ad542da781670e1357cdb2f46b61a3b71d060ccf 100644 --- a/python/paddle/fluid/tests/unittests/test_scatter_op.py +++ b/python/paddle/fluid/tests/unittests/test_scatter_op.py @@ -269,6 +269,49 @@ class TestScatterAPI(unittest.TestCase): self.assertTrue(np.array_equal(test_dygraph(), test_static_graph())) +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestScatterOpFp16(OpTest): + def setUp(self): + self.__class__.op_type = "scatter" + # compute grad in the following code handly. + self.__class__.no_need_check_grad = True + self.x_type = 'float16' + self.x_np = np.ones((3, 3)).astype(self.x_type) + self.index_np = np.array([1, 2]).astype("int32") + self.updates_np = np.random.random((2, 3)).astype(self.x_type) + self.output_np = np.copy(self.x_np) + self.output_np[self.index_np] = self.updates_np + self.dout_np = np.random.random((3, 3)).astype(self.x_type) + + # compute ref_dx + self.ref_dx = np.copy(self.dout_np) + zero_np = np.zeros((2, 3)).astype(self.x_type) + self.ref_dx[self.index_np] = zero_np + + def compute_ref_grad_updates(self): + ref_grad_updates = paddle.gather( + paddle.to_tensor(self.dout_np), paddle.to_tensor(self.index_np)) + return ref_grad_updates + + def test_scatter_fp16(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + x_tensor = paddle.to_tensor(self.x_np, stop_gradient=False) + index_tensor = paddle.to_tensor(self.index_np) + updates_tensor = paddle.to_tensor(self.updates_np, stop_gradient=False) + out_tensor = paddle.scatter(x_tensor, index_tensor, updates_tensor) + paddle.autograd.backward( + [out_tensor], [paddle.to_tensor(self.dout_np)], retain_graph=True) + ref_grad_updates = self.compute_ref_grad_updates() + np.testing.assert_allclose( + ref_grad_updates.numpy(), + updates_tensor.grad.numpy(), + rtol=1e-5, + atol=1e-5) + np.testing.assert_allclose( + self.ref_dx, x_tensor.grad.numpy(), rtol=1e-5, atol=1e-5) + + class TestScatterInplaceAPI(TestScatterAPI): def executed_api(self): self.scatter = paddle.scatter_ diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 42abf4b1466fa2d10d7383495f18d13b355cebe4..5d263bde8b3b5c9c02ebe58c9e1f75dab05c3f4f 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1566,7 +1566,9 @@ def scatter(x, index, updates, overwrite=True, name=None): if in_dygraph_mode(): return _C_ops.scatter(x, index, updates, 'overwrite', overwrite) - check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'scatter') + check_variable_and_dtype( + x, 'dtype', ['float32', 'float64', 'float16', 'int32', 'int64'], + 'scatter') check_type(overwrite, 'overwrite', bool, 'scatter') helper = LayerHelper('scatter', **locals()) out = helper.create_variable_for_type_inference(x.dtype)