From 5bcdfbb0aa5b00581792dfba6b9b249ea37fb46c Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Thu, 23 Mar 2023 10:43:24 +0800 Subject: [PATCH] gather and gather nd fp16, bf16 support and add ut (#51903) --- .../phi/kernels/gpu/gather_nd_grad_kernel.cu | 4 +- paddle/phi/kernels/gpu/gather_nd_kernel.cu | 4 +- .../tests/unittests/test_gather_nd_op.py | 321 ++++++++++++++++-- .../fluid/tests/unittests/test_gather_op.py | 110 +++++- python/paddle/tensor/manipulation.py | 1 + 5 files changed, 395 insertions(+), 45 deletions(-) diff --git a/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu b/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu index a78dc717b04..da1045c27c5 100644 --- a/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/gather_nd_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" @@ -63,4 +64,5 @@ PD_REGISTER_KERNEL(gather_nd_grad, double, int64_t, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gather_nd_kernel.cu b/paddle/phi/kernels/gpu/gather_nd_kernel.cu index 7b241295890..b8ac4aa263a 100644 --- a/paddle/phi/kernels/gpu/gather_nd_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_nd_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/gather_nd_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" @@ -58,4 +59,5 @@ PD_REGISTER_KERNEL(gather_nd, int, int16_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py index ac6f1a32cca..8c29da524e1 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_nd_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_nd_op.py @@ -15,10 +15,11 @@ import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, convert_float_to_uint16 import paddle import paddle.fluid as fluid +import paddle.fluid.core as core class TestGatherNdOpWithEmptyIndex(OpTest): @@ -29,11 +30,23 @@ class TestGatherNdOpWithEmptyIndex(OpTest): self.prim_op_type = "prim" self.python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd - xnp = np.random.random((5, 20)).astype("float64") + self.config_dtype() + if self.dtype == np.float64: + target_dtype = "float64" + elif self.dtype == np.float16: + target_dtype = "float16" + else: + target_dtype = "float32" + xnp = np.random.random((5, 20)).astype(target_dtype) + output = np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :])) + if self.dtype == np.uint16: + xnp = convert_float_to_uint16(xnp) + output = convert_float_to_uint16(output) self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")} - self.outputs = { - 'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :])) - } + self.outputs = {'Out': output} + + def config_dtype(self): + self.dtype = np.float64 def test_check_output(self): self.check_output(check_eager=False) @@ -42,15 +55,55 @@ class TestGatherNdOpWithEmptyIndex(OpTest): self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) +class TestGatherNdOpWithEmptyIndexFP16(TestGatherNdOpWithEmptyIndex): + def config_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestGatherNdOpWithEmptyIndexBF16(TestGatherNdOpWithEmptyIndex): + def config_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=False) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', check_eager=False, check_prim=True + ) + + class TestGatherNdOpWithIndex1(OpTest): def setUp(self): self.op_type = "gather_nd" self.prim_op_type = "prim" self.python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd - xnp = np.random.random((5, 20)).astype("float64") - self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")} - self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} + self.config_dtype() + if self.dtype == np.float64: + target_dtype = "float64" + elif self.dtype == np.float16: + target_dtype = "float16" + else: + target_dtype = "float32" + xnp = np.random.random((5, 20)).astype(target_dtype) + index = np.array([1]).astype("int32") + output = xnp[index] + if self.dtype == np.uint16: + xnp = convert_float_to_uint16(xnp) + output = convert_float_to_uint16(output) + self.inputs = {'X': xnp, 'Index': index} + self.outputs = {'Out': output} + + def config_dtype(self): + self.dtype = np.float64 def test_check_output(self): self.check_output(check_eager=False) @@ -59,6 +112,31 @@ class TestGatherNdOpWithIndex1(OpTest): self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) +class TestGatherNdOpWithIndex1FP16(TestGatherNdOpWithIndex1): + def config_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestGatherNdOpWithIndex1BF16(TestGatherNdOpWithIndex1): + def config_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=False) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', check_eager=False, check_prim=True + ) + + class TestGatherNdOpWithLowIndex(OpTest): # Index has low rank, X has high rank @@ -68,14 +146,26 @@ class TestGatherNdOpWithLowIndex(OpTest): self.python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd self.enable_cinn = False - xnp = np.random.uniform(0, 100, (10, 10)).astype("float64") + self.config_dtype() + if self.dtype == np.float64: + target_dtype = "float64" + elif self.dtype == np.float16: + target_dtype = "float16" + else: + target_dtype = "float32" + xnp = np.random.uniform(0, 100, (10, 10)).astype(target_dtype) index = np.array([[1], [2]]).astype("int64") + output = xnp[tuple(index.T)] # [[14, 25, 1], [76, 22, 3]] + + if self.dtype == np.uint16: + xnp = convert_float_to_uint16(xnp) + output = convert_float_to_uint16(output) self.inputs = {'X': xnp, 'Index': index} + self.outputs = {'Out': output} - self.outputs = { - 'Out': xnp[tuple(index.T)] - } # [[14, 25, 1], [76, 22, 3]] + def config_dtype(self): + self.dtype = np.float64 def test_check_output(self): self.check_output(check_eager=False) @@ -84,6 +174,31 @@ class TestGatherNdOpWithLowIndex(OpTest): self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) +class TestGatherNdOpWithLowIndexFP16(TestGatherNdOpWithLowIndex): + def config_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestGatherNdOpWithLowIndexBF16(TestGatherNdOpWithLowIndex): + def config_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=False) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', check_eager=False, check_prim=True + ) + + class TestGatherNdOpIndex1(OpTest): # Index has low rank, X has high rank @@ -92,16 +207,25 @@ class TestGatherNdOpIndex1(OpTest): self.prim_op_type = "prim" self.python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd - self.init_input() - - self.inputs = {'X': self.xnp, 'Index': self.index} - - self.outputs = {'Out': self.xnp[tuple(self.index.T)]} + self.config_dtype() + if self.dtype == np.float64: + target_dtype = "float64" + elif self.dtype == np.float16: + target_dtype = "float16" + else: + target_dtype = "float32" + xnp = np.random.uniform(0, 100, (10, 10)).astype(target_dtype) + index = np.array([1, 2]).astype("int32") + output = xnp[tuple(index.T)] + if self.dtype == np.uint16: + xnp = convert_float_to_uint16(xnp) + output = convert_float_to_uint16(output) + self.inputs = {'X': xnp, 'Index': index} + self.outputs = {'Out': output} self.enable_cinn = False - def init_input(self): - self.xnp = np.random.uniform(0, 100, (10, 10)).astype("float64") - self.index = np.array([1, 2]).astype("int32") + def config_dtype(self): + self.dtype = np.float64 def test_check_output(self): self.check_output(check_eager=False) @@ -111,9 +235,28 @@ class TestGatherNdOpIndex1(OpTest): class TestGatherNdOpIndex1FP16(TestGatherNdOpIndex1): - def init_input(self): - self.xnp = np.random.uniform(0, 100, (10, 10)).astype("float16") - self.index = np.array([1, 2]).astype("int32") + def config_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestGatherNdOpIndex1BF16(TestGatherNdOpIndex1): + def config_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=False) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', check_eager=False, check_prim=True + ) class TestGatherNdOpWithSameIndexAsX(OpTest): @@ -124,11 +267,24 @@ class TestGatherNdOpWithSameIndexAsX(OpTest): self.python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd self.enable_cinn = False - xnp = np.random.uniform(0, 100, (10, 10)).astype("float64") + self.config_dtype() + if self.dtype == np.float64: + target_dtype = "float64" + elif self.dtype == np.float16: + target_dtype = "float16" + else: + target_dtype = "float32" + xnp = np.random.uniform(0, 100, (10, 10)).astype(target_dtype) index = np.array([[1, 1], [2, 1]]).astype("int64") - + output = xnp[tuple(index.T)] # [25, 22] + if self.dtype == np.uint16: + xnp = convert_float_to_uint16(xnp) + output = convert_float_to_uint16(output) self.inputs = {'X': xnp, 'Index': index} - self.outputs = {'Out': xnp[tuple(index.T)]} # [25, 22] + self.outputs = {'Out': output} + + def config_dtype(self): + self.dtype = np.float64 def test_check_output(self): self.check_output(check_eager=False) @@ -137,6 +293,31 @@ class TestGatherNdOpWithSameIndexAsX(OpTest): self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) +class TestGatherNdOpWithSameIndexAsXFP16(TestGatherNdOpWithSameIndexAsX): + def config_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestGatherNdOpWithSameIndexAsXBF16(TestGatherNdOpWithSameIndexAsX): + def config_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=False) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', check_eager=False, check_prim=True + ) + + class TestGatherNdOpWithHighRankSame(OpTest): # Both Index and X have high rank, and Rank(Index) = Rank(X) @@ -146,11 +327,24 @@ class TestGatherNdOpWithHighRankSame(OpTest): self.python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd shape = (5, 2, 3, 1, 10) - xnp = np.random.rand(*shape).astype("float64") + self.config_dtype() + if self.dtype == np.float64: + target_dtype = "float64" + elif self.dtype == np.float16: + target_dtype = "float16" + else: + target_dtype = "float32" + xnp = np.random.rand(*shape).astype(target_dtype) index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T - + output = xnp[tuple(index.T)] + if self.dtype == np.uint16: + xnp = convert_float_to_uint16(xnp) + output = convert_float_to_uint16(output) self.inputs = {'X': xnp, 'Index': index.astype("int32")} - self.outputs = {'Out': xnp[tuple(index.T)]} + self.outputs = {'Out': output} + + def config_dtype(self): + self.dtype = np.float64 def test_check_output(self): self.check_output(check_eager=False) @@ -159,6 +353,31 @@ class TestGatherNdOpWithHighRankSame(OpTest): self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) +class TestGatherNdOpWithHighRankSameFP16(TestGatherNdOpWithHighRankSame): + def config_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestGatherNdOpWithHighRankSameBF16(TestGatherNdOpWithHighRankSame): + def config_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=False) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', check_eager=False, check_prim=True + ) + + class TestGatherNdOpWithHighRankDiff(OpTest): # Both Index and X have high rank, and Rank(Index) < Rank(X) @@ -168,12 +387,25 @@ class TestGatherNdOpWithHighRankDiff(OpTest): self.python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd shape = (2, 3, 4, 1, 10) - xnp = np.random.rand(*shape).astype("float64") + self.config_dtype() + if self.dtype == np.float64: + target_dtype = "float64" + elif self.dtype == np.float16: + target_dtype = "float16" + else: + target_dtype = "float32" + xnp = np.random.rand(*shape).astype(target_dtype) index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T index_re = index.reshape([20, 5, 2, 5]) - + output = xnp[tuple(index.T)].reshape([20, 5, 2]) + if self.dtype == np.uint16: + xnp = convert_float_to_uint16(xnp) + output = convert_float_to_uint16(output) self.inputs = {'X': xnp, 'Index': index_re.astype("int32")} - self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])} + self.outputs = {'Out': output} + + def config_dtype(self): + self.dtype = np.float64 def test_check_output(self): self.check_output(check_eager=False) @@ -182,6 +414,31 @@ class TestGatherNdOpWithHighRankDiff(OpTest): self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) +class TestGatherNdOpWithHighRankDiffFP16(TestGatherNdOpWithHighRankDiff): + def config_dtype(self): + self.dtype = np.float16 + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not complied with CUDA and not support the bfloat16", +) +class TestGatherNdOpWithHighRankDiffBF16(TestGatherNdOpWithHighRankDiff): + def config_dtype(self): + self.dtype = np.uint16 + + def test_check_output(self): + place = core.CUDAPlace(0) + self.check_output_with_place(place, check_eager=False) + + def test_check_grad(self): + place = core.CUDAPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', check_eager=False, check_prim=True + ) + + # Test Python API class TestGatherNdOpAPI(unittest.TestCase): def test_case1(self): diff --git a/python/paddle/fluid/tests/unittests/test_gather_op.py b/python/paddle/fluid/tests/unittests/test_gather_op.py index ec3c400d972..d6d64f84c77 100644 --- a/python/paddle/fluid/tests/unittests/test_gather_op.py +++ b/python/paddle/fluid/tests/unittests/test_gather_op.py @@ -55,10 +55,18 @@ class TestGatherOp(OpTest): For multi-dimension input """ self.x_shape = (10, 20) - self.x_type = "float64" + self.config_dtype() self.index = [1, 3, 5] self.index_type = "int32" + def config_dtype(self): + self.x_type = "float64" + + +class TestGatherOpFP16(TestGatherOp): + def config_dtype(self): + self.x_type = "float16" + class TestCase1(TestGatherOp): def config(self): @@ -66,10 +74,18 @@ class TestCase1(TestGatherOp): For one dimension input """ self.x_shape = 100 - self.x_type = "float64" + self.config_dtype() self.index = [1, 3, 5] self.index_type = "int32" + def config_dtype(self): + self.x_type = "float64" + + +class TestCase1FP16(TestCase1): + def config_dtype(self): + self.x_type = "float16" + class TestCase2(TestGatherOp): def config(self): @@ -77,10 +93,18 @@ class TestCase2(TestGatherOp): For int64_t index type """ self.x_shape = 100 - self.x_type = "float64" + self.config_dtype() self.index = [1, 3, 5] self.index_type = "int64" + def config_dtype(self): + self.x_type = "float64" + + +class TestCase2FP16(TestCase2): + def config_dtype(self): + self.x_type = "float16" + class TestCase3(TestGatherOp): def config(self): @@ -88,37 +112,69 @@ class TestCase3(TestGatherOp): For other input type """ self.x_shape = (10, 20) - self.x_type = "float64" + self.config_dtype() self.index = [1, 3, 5] self.index_type = "int64" + def config_dtype(self): + self.x_type = "float64" + + +class TestCase3Fp16(TestCase3): + def config_dtype(self): + self.x_type = "float16" + class TestCase4(TestGatherOp): def config(self): self.x_shape = (10, 20) self.attrs = {'overwrite': False} - self.x_type = "double" + self.config_dtype() self.index = [1, 1] self.index_type = "int32" + def config_dtype(self): + self.x_type = "float64" + + +class TestCase4FP16(TestCase4): + def config_dtype(self): + self.x_type = "float16" + class TestCase5(TestGatherOp): def config(self): self.x_shape = (10, 20) self.attrs = {'overwrite': False} - self.x_type = "float64" + self.config_dtype() self.index = [1, 1, 3] self.index_type = "int32" + def config_dtype(self): + self.x_type = "float64" + + +class TestCase5FP16(TestCase5): + def config_dtype(self): + self.x_type = "float16" + class TestCase6(TestGatherOp): def config(self): self.x_shape = (10, 20) self.attrs = {'overwrite': True} - self.x_type = "float64" + self.config_dtype() self.index = [1, 3] self.index_type = "int32" + def config_dtype(self): + self.x_type = "float64" + + +class TestCase6FP16(TestCase6): + def config_dtype(self): + self.x_type = "float16" + class TestGatherBF16Op(OpTest): def setUp(self): @@ -177,12 +233,20 @@ class TestGatherOp1(OpTest): For multi-dimension input """ self.x_shape = (3, 88, 3) - self.x_type = "float64" + self.config_dtype() self.index = [1, 3, 5] self.index_type = "int32" self.axis = [1] self.axis_type = "int32" + def config_dtype(self): + self.x_type = "float64" + + +class TestGatherOp1FP16(TestGatherOp1): + def config_dtype(self): + self.x_type = "float16" + class TestGatherOp2(TestGatherOp1): def config(self): @@ -190,12 +254,20 @@ class TestGatherOp2(TestGatherOp1): For multi-dimension input """ self.x_shape = (10, 88, 10) - self.x_type = "float64" + self.config_dtype() self.index = [1, 3, 5] self.index_type = "int64" self.axis = [0] self.axis_type = "int32" + def config_dtype(self): + self.x_type = "float64" + + +class TestGatherOp2FP16(TestGatherOp2): + def config_dtype(self): + self.x_type = "float16" + class TestGatherOp3(TestGatherOp1): def config(self): @@ -203,12 +275,20 @@ class TestGatherOp3(TestGatherOp1): For multi-dimension input """ self.x_shape = (10, 88, 10) - self.x_type = "float64" + self.config_dtype() self.index = [1, 3, 5] self.index_type = "int64" self.axis = [2] self.axis_type = "int32" + def config_dtype(self): + self.x_type = "float64" + + +class TestGatherOp3FP16(TestGatherOp3): + def config_dtype(self): + self.x_type = "float16" + class TestGatherOp4(TestGatherOp1): def config(self): @@ -216,13 +296,21 @@ class TestGatherOp4(TestGatherOp1): For multi-dimension input """ self.x_shape = (3, 100, 10) - self.x_type = "float64" + self.config_dtype() self.index = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] self.index_type = "int64" self.axis = [0] self.axis_type = "int32" self.attrs = {'overwrite': False} + def config_dtype(self): + self.x_type = "float64" + + +class TestGatherOp4FP16(TestGatherOp4): + def config_dtype(self): + self.x_type = "float16" + class API_TestGather(unittest.TestCase): def test_out1(self): diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 41a8cfa856f..59ebcbaafdd 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3742,6 +3742,7 @@ def gather_nd(x, index, name=None): [ 'bool', 'float16', + 'uint16', 'float32', 'float64', 'int16', -- GitLab