未验证 提交 9bac4a76 编写于 作者: L Li Min 提交者: GitHub

Add float16 type for scatter op. (#38136)

* Add float16 type for scatter op.

* Add fp16 test for scatter op.

* Add int and int64 support for scatter_grad on gpu.

* Add int and int64 for check_variable_and_dtype routine.

* Minors.

* Code format.
上级 08482a86
...@@ -110,6 +110,11 @@ namespace ops = paddle::operators; ...@@ -110,6 +110,11 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(scatter, ops::ScatterOpCUDAKernel<float>, REGISTER_OP_CUDA_KERNEL(scatter, ops::ScatterOpCUDAKernel<float>,
ops::ScatterOpCUDAKernel<double>, ops::ScatterOpCUDAKernel<double>,
ops::ScatterOpCUDAKernel<int>, ops::ScatterOpCUDAKernel<int>,
ops::ScatterOpCUDAKernel<int64_t>); ops::ScatterOpCUDAKernel<int64_t>,
REGISTER_OP_CUDA_KERNEL(scatter_grad, ops::ScatterGradOpCUDAKernel<float>, ops::ScatterOpCUDAKernel<paddle::platform::float16>);
ops::ScatterGradOpCUDAKernel<double>);
REGISTER_OP_CUDA_KERNEL(
scatter_grad, ops::ScatterGradOpCUDAKernel<float>,
ops::ScatterGradOpCUDAKernel<double>, ops::ScatterOpCUDAKernel<int>,
ops::ScatterOpCUDAKernel<int64_t>,
ops::ScatterGradOpCUDAKernel<paddle::platform::float16>);
...@@ -269,6 +269,49 @@ class TestScatterAPI(unittest.TestCase): ...@@ -269,6 +269,49 @@ class TestScatterAPI(unittest.TestCase):
self.assertTrue(np.array_equal(test_dygraph(), test_static_graph())) self.assertTrue(np.array_equal(test_dygraph(), test_static_graph()))
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestScatterOpFp16(OpTest):
def setUp(self):
self.__class__.op_type = "scatter"
# compute grad in the following code handly.
self.__class__.no_need_check_grad = True
self.x_type = 'float16'
self.x_np = np.ones((3, 3)).astype(self.x_type)
self.index_np = np.array([1, 2]).astype("int32")
self.updates_np = np.random.random((2, 3)).astype(self.x_type)
self.output_np = np.copy(self.x_np)
self.output_np[self.index_np] = self.updates_np
self.dout_np = np.random.random((3, 3)).astype(self.x_type)
# compute ref_dx
self.ref_dx = np.copy(self.dout_np)
zero_np = np.zeros((2, 3)).astype(self.x_type)
self.ref_dx[self.index_np] = zero_np
def compute_ref_grad_updates(self):
ref_grad_updates = paddle.gather(
paddle.to_tensor(self.dout_np), paddle.to_tensor(self.index_np))
return ref_grad_updates
def test_scatter_fp16(self):
paddle.disable_static(place=paddle.CUDAPlace(0))
x_tensor = paddle.to_tensor(self.x_np, stop_gradient=False)
index_tensor = paddle.to_tensor(self.index_np)
updates_tensor = paddle.to_tensor(self.updates_np, stop_gradient=False)
out_tensor = paddle.scatter(x_tensor, index_tensor, updates_tensor)
paddle.autograd.backward(
[out_tensor], [paddle.to_tensor(self.dout_np)], retain_graph=True)
ref_grad_updates = self.compute_ref_grad_updates()
np.testing.assert_allclose(
ref_grad_updates.numpy(),
updates_tensor.grad.numpy(),
rtol=1e-5,
atol=1e-5)
np.testing.assert_allclose(
self.ref_dx, x_tensor.grad.numpy(), rtol=1e-5, atol=1e-5)
class TestScatterInplaceAPI(TestScatterAPI): class TestScatterInplaceAPI(TestScatterAPI):
def executed_api(self): def executed_api(self):
self.scatter = paddle.scatter_ self.scatter = paddle.scatter_
......
...@@ -1566,7 +1566,9 @@ def scatter(x, index, updates, overwrite=True, name=None): ...@@ -1566,7 +1566,9 @@ def scatter(x, index, updates, overwrite=True, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.scatter(x, index, updates, 'overwrite', overwrite) return _C_ops.scatter(x, index, updates, 'overwrite', overwrite)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'scatter') check_variable_and_dtype(
x, 'dtype', ['float32', 'float64', 'float16', 'int32', 'int64'],
'scatter')
check_type(overwrite, 'overwrite', bool, 'scatter') check_type(overwrite, 'overwrite', bool, 'scatter')
helper = LayerHelper('scatter', **locals()) helper = LayerHelper('scatter', **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册