未验证 提交 5bcdfbb0 编写于 作者: Y Yuang Liu 提交者: GitHub

gather and gather nd fp16, bf16 support and add ut (#51903)

上级 a7397e0c
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/gather_nd_grad_kernel.h" #include "paddle/phi/kernels/gather_nd_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h"
...@@ -63,4 +64,5 @@ PD_REGISTER_KERNEL(gather_nd_grad, ...@@ -63,4 +64,5 @@ PD_REGISTER_KERNEL(gather_nd_grad,
double, double,
int64_t, int64_t,
int, int,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/phi/kernels/gather_nd_kernel.h" #include "paddle/phi/kernels/gather_nd_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h"
...@@ -58,4 +59,5 @@ PD_REGISTER_KERNEL(gather_nd, ...@@ -58,4 +59,5 @@ PD_REGISTER_KERNEL(gather_nd,
int, int,
int16_t, int16_t,
bool, bool,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,10 +15,11 @@ ...@@ -15,10 +15,11 @@
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest, convert_float_to_uint16
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
class TestGatherNdOpWithEmptyIndex(OpTest): class TestGatherNdOpWithEmptyIndex(OpTest):
...@@ -29,11 +30,23 @@ class TestGatherNdOpWithEmptyIndex(OpTest): ...@@ -29,11 +30,23 @@ class TestGatherNdOpWithEmptyIndex(OpTest):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = paddle.gather_nd self.python_api = paddle.gather_nd
self.public_python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd
xnp = np.random.random((5, 20)).astype("float64") self.config_dtype()
if self.dtype == np.float64:
target_dtype = "float64"
elif self.dtype == np.float16:
target_dtype = "float16"
else:
target_dtype = "float32"
xnp = np.random.random((5, 20)).astype(target_dtype)
output = np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :]))
if self.dtype == np.uint16:
xnp = convert_float_to_uint16(xnp)
output = convert_float_to_uint16(output)
self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")} self.inputs = {'X': xnp, 'Index': np.array([[], []]).astype("int32")}
self.outputs = { self.outputs = {'Out': output}
'Out': np.vstack((xnp[np.newaxis, :], xnp[np.newaxis, :]))
} def config_dtype(self):
self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -42,15 +55,55 @@ class TestGatherNdOpWithEmptyIndex(OpTest): ...@@ -42,15 +55,55 @@ class TestGatherNdOpWithEmptyIndex(OpTest):
self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
class TestGatherNdOpWithEmptyIndexFP16(TestGatherNdOpWithEmptyIndex):
def config_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestGatherNdOpWithEmptyIndexBF16(TestGatherNdOpWithEmptyIndex):
def config_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_eager=False, check_prim=True
)
class TestGatherNdOpWithIndex1(OpTest): class TestGatherNdOpWithIndex1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "gather_nd" self.op_type = "gather_nd"
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = paddle.gather_nd self.python_api = paddle.gather_nd
self.public_python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd
xnp = np.random.random((5, 20)).astype("float64") self.config_dtype()
self.inputs = {'X': xnp, 'Index': np.array([1]).astype("int32")} if self.dtype == np.float64:
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]} target_dtype = "float64"
elif self.dtype == np.float16:
target_dtype = "float16"
else:
target_dtype = "float32"
xnp = np.random.random((5, 20)).astype(target_dtype)
index = np.array([1]).astype("int32")
output = xnp[index]
if self.dtype == np.uint16:
xnp = convert_float_to_uint16(xnp)
output = convert_float_to_uint16(output)
self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': output}
def config_dtype(self):
self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -59,6 +112,31 @@ class TestGatherNdOpWithIndex1(OpTest): ...@@ -59,6 +112,31 @@ class TestGatherNdOpWithIndex1(OpTest):
self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
class TestGatherNdOpWithIndex1FP16(TestGatherNdOpWithIndex1):
def config_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestGatherNdOpWithIndex1BF16(TestGatherNdOpWithIndex1):
def config_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_eager=False, check_prim=True
)
class TestGatherNdOpWithLowIndex(OpTest): class TestGatherNdOpWithLowIndex(OpTest):
# Index has low rank, X has high rank # Index has low rank, X has high rank
...@@ -68,14 +146,26 @@ class TestGatherNdOpWithLowIndex(OpTest): ...@@ -68,14 +146,26 @@ class TestGatherNdOpWithLowIndex(OpTest):
self.python_api = paddle.gather_nd self.python_api = paddle.gather_nd
self.public_python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd
self.enable_cinn = False self.enable_cinn = False
xnp = np.random.uniform(0, 100, (10, 10)).astype("float64") self.config_dtype()
if self.dtype == np.float64:
target_dtype = "float64"
elif self.dtype == np.float16:
target_dtype = "float16"
else:
target_dtype = "float32"
xnp = np.random.uniform(0, 100, (10, 10)).astype(target_dtype)
index = np.array([[1], [2]]).astype("int64") index = np.array([[1], [2]]).astype("int64")
output = xnp[tuple(index.T)] # [[14, 25, 1], [76, 22, 3]]
if self.dtype == np.uint16:
xnp = convert_float_to_uint16(xnp)
output = convert_float_to_uint16(output)
self.inputs = {'X': xnp, 'Index': index} self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': output}
self.outputs = { def config_dtype(self):
'Out': xnp[tuple(index.T)] self.dtype = np.float64
} # [[14, 25, 1], [76, 22, 3]]
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -84,6 +174,31 @@ class TestGatherNdOpWithLowIndex(OpTest): ...@@ -84,6 +174,31 @@ class TestGatherNdOpWithLowIndex(OpTest):
self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
class TestGatherNdOpWithLowIndexFP16(TestGatherNdOpWithLowIndex):
def config_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestGatherNdOpWithLowIndexBF16(TestGatherNdOpWithLowIndex):
def config_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_eager=False, check_prim=True
)
class TestGatherNdOpIndex1(OpTest): class TestGatherNdOpIndex1(OpTest):
# Index has low rank, X has high rank # Index has low rank, X has high rank
...@@ -92,16 +207,25 @@ class TestGatherNdOpIndex1(OpTest): ...@@ -92,16 +207,25 @@ class TestGatherNdOpIndex1(OpTest):
self.prim_op_type = "prim" self.prim_op_type = "prim"
self.python_api = paddle.gather_nd self.python_api = paddle.gather_nd
self.public_python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd
self.init_input() self.config_dtype()
if self.dtype == np.float64:
self.inputs = {'X': self.xnp, 'Index': self.index} target_dtype = "float64"
elif self.dtype == np.float16:
self.outputs = {'Out': self.xnp[tuple(self.index.T)]} target_dtype = "float16"
else:
target_dtype = "float32"
xnp = np.random.uniform(0, 100, (10, 10)).astype(target_dtype)
index = np.array([1, 2]).astype("int32")
output = xnp[tuple(index.T)]
if self.dtype == np.uint16:
xnp = convert_float_to_uint16(xnp)
output = convert_float_to_uint16(output)
self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': output}
self.enable_cinn = False self.enable_cinn = False
def init_input(self): def config_dtype(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype("float64") self.dtype = np.float64
self.index = np.array([1, 2]).astype("int32")
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -111,9 +235,28 @@ class TestGatherNdOpIndex1(OpTest): ...@@ -111,9 +235,28 @@ class TestGatherNdOpIndex1(OpTest):
class TestGatherNdOpIndex1FP16(TestGatherNdOpIndex1): class TestGatherNdOpIndex1FP16(TestGatherNdOpIndex1):
def init_input(self): def config_dtype(self):
self.xnp = np.random.uniform(0, 100, (10, 10)).astype("float16") self.dtype = np.float16
self.index = np.array([1, 2]).astype("int32")
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestGatherNdOpIndex1BF16(TestGatherNdOpIndex1):
def config_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_eager=False, check_prim=True
)
class TestGatherNdOpWithSameIndexAsX(OpTest): class TestGatherNdOpWithSameIndexAsX(OpTest):
...@@ -124,11 +267,24 @@ class TestGatherNdOpWithSameIndexAsX(OpTest): ...@@ -124,11 +267,24 @@ class TestGatherNdOpWithSameIndexAsX(OpTest):
self.python_api = paddle.gather_nd self.python_api = paddle.gather_nd
self.public_python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd
self.enable_cinn = False self.enable_cinn = False
xnp = np.random.uniform(0, 100, (10, 10)).astype("float64") self.config_dtype()
if self.dtype == np.float64:
target_dtype = "float64"
elif self.dtype == np.float16:
target_dtype = "float16"
else:
target_dtype = "float32"
xnp = np.random.uniform(0, 100, (10, 10)).astype(target_dtype)
index = np.array([[1, 1], [2, 1]]).astype("int64") index = np.array([[1, 1], [2, 1]]).astype("int64")
output = xnp[tuple(index.T)] # [25, 22]
if self.dtype == np.uint16:
xnp = convert_float_to_uint16(xnp)
output = convert_float_to_uint16(output)
self.inputs = {'X': xnp, 'Index': index} self.inputs = {'X': xnp, 'Index': index}
self.outputs = {'Out': xnp[tuple(index.T)]} # [25, 22] self.outputs = {'Out': output}
def config_dtype(self):
self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -137,6 +293,31 @@ class TestGatherNdOpWithSameIndexAsX(OpTest): ...@@ -137,6 +293,31 @@ class TestGatherNdOpWithSameIndexAsX(OpTest):
self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
class TestGatherNdOpWithSameIndexAsXFP16(TestGatherNdOpWithSameIndexAsX):
def config_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestGatherNdOpWithSameIndexAsXBF16(TestGatherNdOpWithSameIndexAsX):
def config_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_eager=False, check_prim=True
)
class TestGatherNdOpWithHighRankSame(OpTest): class TestGatherNdOpWithHighRankSame(OpTest):
# Both Index and X have high rank, and Rank(Index) = Rank(X) # Both Index and X have high rank, and Rank(Index) = Rank(X)
...@@ -146,11 +327,24 @@ class TestGatherNdOpWithHighRankSame(OpTest): ...@@ -146,11 +327,24 @@ class TestGatherNdOpWithHighRankSame(OpTest):
self.python_api = paddle.gather_nd self.python_api = paddle.gather_nd
self.public_python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd
shape = (5, 2, 3, 1, 10) shape = (5, 2, 3, 1, 10)
xnp = np.random.rand(*shape).astype("float64") self.config_dtype()
if self.dtype == np.float64:
target_dtype = "float64"
elif self.dtype == np.float16:
target_dtype = "float16"
else:
target_dtype = "float32"
xnp = np.random.rand(*shape).astype(target_dtype)
index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T index = np.vstack([np.random.randint(0, s, size=2) for s in shape]).T
output = xnp[tuple(index.T)]
if self.dtype == np.uint16:
xnp = convert_float_to_uint16(xnp)
output = convert_float_to_uint16(output)
self.inputs = {'X': xnp, 'Index': index.astype("int32")} self.inputs = {'X': xnp, 'Index': index.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)]} self.outputs = {'Out': output}
def config_dtype(self):
self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -159,6 +353,31 @@ class TestGatherNdOpWithHighRankSame(OpTest): ...@@ -159,6 +353,31 @@ class TestGatherNdOpWithHighRankSame(OpTest):
self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
class TestGatherNdOpWithHighRankSameFP16(TestGatherNdOpWithHighRankSame):
def config_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestGatherNdOpWithHighRankSameBF16(TestGatherNdOpWithHighRankSame):
def config_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_eager=False, check_prim=True
)
class TestGatherNdOpWithHighRankDiff(OpTest): class TestGatherNdOpWithHighRankDiff(OpTest):
# Both Index and X have high rank, and Rank(Index) < Rank(X) # Both Index and X have high rank, and Rank(Index) < Rank(X)
...@@ -168,12 +387,25 @@ class TestGatherNdOpWithHighRankDiff(OpTest): ...@@ -168,12 +387,25 @@ class TestGatherNdOpWithHighRankDiff(OpTest):
self.python_api = paddle.gather_nd self.python_api = paddle.gather_nd
self.public_python_api = paddle.gather_nd self.public_python_api = paddle.gather_nd
shape = (2, 3, 4, 1, 10) shape = (2, 3, 4, 1, 10)
xnp = np.random.rand(*shape).astype("float64") self.config_dtype()
if self.dtype == np.float64:
target_dtype = "float64"
elif self.dtype == np.float16:
target_dtype = "float16"
else:
target_dtype = "float32"
xnp = np.random.rand(*shape).astype(target_dtype)
index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T index = np.vstack([np.random.randint(0, s, size=200) for s in shape]).T
index_re = index.reshape([20, 5, 2, 5]) index_re = index.reshape([20, 5, 2, 5])
output = xnp[tuple(index.T)].reshape([20, 5, 2])
if self.dtype == np.uint16:
xnp = convert_float_to_uint16(xnp)
output = convert_float_to_uint16(output)
self.inputs = {'X': xnp, 'Index': index_re.astype("int32")} self.inputs = {'X': xnp, 'Index': index_re.astype("int32")}
self.outputs = {'Out': xnp[tuple(index.T)].reshape([20, 5, 2])} self.outputs = {'Out': output}
def config_dtype(self):
self.dtype = np.float64
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=False)
...@@ -182,6 +414,31 @@ class TestGatherNdOpWithHighRankDiff(OpTest): ...@@ -182,6 +414,31 @@ class TestGatherNdOpWithHighRankDiff(OpTest):
self.check_grad(['X'], 'Out', check_eager=False, check_prim=True) self.check_grad(['X'], 'Out', check_eager=False, check_prim=True)
class TestGatherNdOpWithHighRankDiffFP16(TestGatherNdOpWithHighRankDiff):
def config_dtype(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestGatherNdOpWithHighRankDiffBF16(TestGatherNdOpWithHighRankDiff):
def config_dtype(self):
self.dtype = np.uint16
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X'], 'Out', check_eager=False, check_prim=True
)
# Test Python API # Test Python API
class TestGatherNdOpAPI(unittest.TestCase): class TestGatherNdOpAPI(unittest.TestCase):
def test_case1(self): def test_case1(self):
......
...@@ -55,10 +55,18 @@ class TestGatherOp(OpTest): ...@@ -55,10 +55,18 @@ class TestGatherOp(OpTest):
For multi-dimension input For multi-dimension input
""" """
self.x_shape = (10, 20) self.x_shape = (10, 20)
self.x_type = "float64" self.config_dtype()
self.index = [1, 3, 5] self.index = [1, 3, 5]
self.index_type = "int32" self.index_type = "int32"
def config_dtype(self):
self.x_type = "float64"
class TestGatherOpFP16(TestGatherOp):
def config_dtype(self):
self.x_type = "float16"
class TestCase1(TestGatherOp): class TestCase1(TestGatherOp):
def config(self): def config(self):
...@@ -66,10 +74,18 @@ class TestCase1(TestGatherOp): ...@@ -66,10 +74,18 @@ class TestCase1(TestGatherOp):
For one dimension input For one dimension input
""" """
self.x_shape = 100 self.x_shape = 100
self.x_type = "float64" self.config_dtype()
self.index = [1, 3, 5] self.index = [1, 3, 5]
self.index_type = "int32" self.index_type = "int32"
def config_dtype(self):
self.x_type = "float64"
class TestCase1FP16(TestCase1):
def config_dtype(self):
self.x_type = "float16"
class TestCase2(TestGatherOp): class TestCase2(TestGatherOp):
def config(self): def config(self):
...@@ -77,10 +93,18 @@ class TestCase2(TestGatherOp): ...@@ -77,10 +93,18 @@ class TestCase2(TestGatherOp):
For int64_t index type For int64_t index type
""" """
self.x_shape = 100 self.x_shape = 100
self.x_type = "float64" self.config_dtype()
self.index = [1, 3, 5] self.index = [1, 3, 5]
self.index_type = "int64" self.index_type = "int64"
def config_dtype(self):
self.x_type = "float64"
class TestCase2FP16(TestCase2):
def config_dtype(self):
self.x_type = "float16"
class TestCase3(TestGatherOp): class TestCase3(TestGatherOp):
def config(self): def config(self):
...@@ -88,37 +112,69 @@ class TestCase3(TestGatherOp): ...@@ -88,37 +112,69 @@ class TestCase3(TestGatherOp):
For other input type For other input type
""" """
self.x_shape = (10, 20) self.x_shape = (10, 20)
self.x_type = "float64" self.config_dtype()
self.index = [1, 3, 5] self.index = [1, 3, 5]
self.index_type = "int64" self.index_type = "int64"
def config_dtype(self):
self.x_type = "float64"
class TestCase3Fp16(TestCase3):
def config_dtype(self):
self.x_type = "float16"
class TestCase4(TestGatherOp): class TestCase4(TestGatherOp):
def config(self): def config(self):
self.x_shape = (10, 20) self.x_shape = (10, 20)
self.attrs = {'overwrite': False} self.attrs = {'overwrite': False}
self.x_type = "double" self.config_dtype()
self.index = [1, 1] self.index = [1, 1]
self.index_type = "int32" self.index_type = "int32"
def config_dtype(self):
self.x_type = "float64"
class TestCase4FP16(TestCase4):
def config_dtype(self):
self.x_type = "float16"
class TestCase5(TestGatherOp): class TestCase5(TestGatherOp):
def config(self): def config(self):
self.x_shape = (10, 20) self.x_shape = (10, 20)
self.attrs = {'overwrite': False} self.attrs = {'overwrite': False}
self.x_type = "float64" self.config_dtype()
self.index = [1, 1, 3] self.index = [1, 1, 3]
self.index_type = "int32" self.index_type = "int32"
def config_dtype(self):
self.x_type = "float64"
class TestCase5FP16(TestCase5):
def config_dtype(self):
self.x_type = "float16"
class TestCase6(TestGatherOp): class TestCase6(TestGatherOp):
def config(self): def config(self):
self.x_shape = (10, 20) self.x_shape = (10, 20)
self.attrs = {'overwrite': True} self.attrs = {'overwrite': True}
self.x_type = "float64" self.config_dtype()
self.index = [1, 3] self.index = [1, 3]
self.index_type = "int32" self.index_type = "int32"
def config_dtype(self):
self.x_type = "float64"
class TestCase6FP16(TestCase6):
def config_dtype(self):
self.x_type = "float16"
class TestGatherBF16Op(OpTest): class TestGatherBF16Op(OpTest):
def setUp(self): def setUp(self):
...@@ -177,12 +233,20 @@ class TestGatherOp1(OpTest): ...@@ -177,12 +233,20 @@ class TestGatherOp1(OpTest):
For multi-dimension input For multi-dimension input
""" """
self.x_shape = (3, 88, 3) self.x_shape = (3, 88, 3)
self.x_type = "float64" self.config_dtype()
self.index = [1, 3, 5] self.index = [1, 3, 5]
self.index_type = "int32" self.index_type = "int32"
self.axis = [1] self.axis = [1]
self.axis_type = "int32" self.axis_type = "int32"
def config_dtype(self):
self.x_type = "float64"
class TestGatherOp1FP16(TestGatherOp1):
def config_dtype(self):
self.x_type = "float16"
class TestGatherOp2(TestGatherOp1): class TestGatherOp2(TestGatherOp1):
def config(self): def config(self):
...@@ -190,12 +254,20 @@ class TestGatherOp2(TestGatherOp1): ...@@ -190,12 +254,20 @@ class TestGatherOp2(TestGatherOp1):
For multi-dimension input For multi-dimension input
""" """
self.x_shape = (10, 88, 10) self.x_shape = (10, 88, 10)
self.x_type = "float64" self.config_dtype()
self.index = [1, 3, 5] self.index = [1, 3, 5]
self.index_type = "int64" self.index_type = "int64"
self.axis = [0] self.axis = [0]
self.axis_type = "int32" self.axis_type = "int32"
def config_dtype(self):
self.x_type = "float64"
class TestGatherOp2FP16(TestGatherOp2):
def config_dtype(self):
self.x_type = "float16"
class TestGatherOp3(TestGatherOp1): class TestGatherOp3(TestGatherOp1):
def config(self): def config(self):
...@@ -203,12 +275,20 @@ class TestGatherOp3(TestGatherOp1): ...@@ -203,12 +275,20 @@ class TestGatherOp3(TestGatherOp1):
For multi-dimension input For multi-dimension input
""" """
self.x_shape = (10, 88, 10) self.x_shape = (10, 88, 10)
self.x_type = "float64" self.config_dtype()
self.index = [1, 3, 5] self.index = [1, 3, 5]
self.index_type = "int64" self.index_type = "int64"
self.axis = [2] self.axis = [2]
self.axis_type = "int32" self.axis_type = "int32"
def config_dtype(self):
self.x_type = "float64"
class TestGatherOp3FP16(TestGatherOp3):
def config_dtype(self):
self.x_type = "float16"
class TestGatherOp4(TestGatherOp1): class TestGatherOp4(TestGatherOp1):
def config(self): def config(self):
...@@ -216,13 +296,21 @@ class TestGatherOp4(TestGatherOp1): ...@@ -216,13 +296,21 @@ class TestGatherOp4(TestGatherOp1):
For multi-dimension input For multi-dimension input
""" """
self.x_shape = (3, 100, 10) self.x_shape = (3, 100, 10)
self.x_type = "float64" self.config_dtype()
self.index = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] self.index = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
self.index_type = "int64" self.index_type = "int64"
self.axis = [0] self.axis = [0]
self.axis_type = "int32" self.axis_type = "int32"
self.attrs = {'overwrite': False} self.attrs = {'overwrite': False}
def config_dtype(self):
self.x_type = "float64"
class TestGatherOp4FP16(TestGatherOp4):
def config_dtype(self):
self.x_type = "float16"
class API_TestGather(unittest.TestCase): class API_TestGather(unittest.TestCase):
def test_out1(self): def test_out1(self):
......
...@@ -3742,6 +3742,7 @@ def gather_nd(x, index, name=None): ...@@ -3742,6 +3742,7 @@ def gather_nd(x, index, name=None):
[ [
'bool', 'bool',
'float16', 'float16',
'uint16',
'float32', 'float32',
'float64', 'float64',
'int16', 'int16',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册