未验证 提交 60e37d17 编写于 作者: C cyberslack_lee 提交者: GitHub

Add float16 and bfloat16 support and test for argsort (#55105)

上级 6e61c9f9
......@@ -38,6 +38,10 @@ namespace detail {
template <>
struct radix_key_codec_base<phi::dtype::float16>
: radix_key_codec_integral<phi::dtype::float16, uint16_t> {};
template <>
struct radix_key_codec_base<phi::dtype::bfloat16>
: radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {};
} // namespace detail
} // namespace rocprim
#else
......@@ -46,6 +50,11 @@ namespace cub {
template <>
struct NumericTraits<phi::dtype::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::float16> {};
template <>
struct NumericTraits<phi::dtype::bfloat16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::bfloat16> {
};
} // namespace cub
#endif
......@@ -222,4 +231,5 @@ PD_REGISTER_KERNEL(argsort_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -61,7 +61,13 @@ namespace cub {
template <>
struct NumericTraits<phi::dtype::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::float16> {};
template <>
struct NumericTraits<phi::dtype::bfloat16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::bfloat16> {
};
} // namespace cub
#endif
namespace phi {
......@@ -328,6 +334,7 @@ PD_REGISTER_KERNEL(argsort,
double,
int,
int64_t,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -39,7 +39,7 @@ def argsort(x, axis=-1, descending=False, name=None):
Sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
Args:
x (Tensor): An input N-D Tensor with type float16, float32, float64, int16,
x (Tensor): An input N-D Tensor with type bfloat16, float16, float32, float64, int16,
int32, int64, uint8.
axis (int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is Rank(x). when axis<0, it works the same way
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid
......@@ -513,5 +514,95 @@ class TestArgsortOpFp16(unittest.TestCase):
out = exe.run(feed={'x': x_np}, fetch_list=[out])
class TestArgsortFP16Op(OpTest):
def setUp(self):
self.init()
self.init_direction()
self.op_type = "argsort"
self.python_api = paddle.argsort
self.public_python_api = paddle.argsort
self.python_out_sig = ["Out"]
self.dtype = np.float16
self.descending = False
self.attrs = {"axis": self.axis, "descending": self.descending}
X = np.random.rand(*self.input_shape).astype('float16')
Out = np.sort(X, kind='quicksort', axis=self.axis)
indices = np.argsort(X, kind='quicksort', axis=self.axis)
self.inputs = {'X': X}
self.outputs = {
'Out': Out,
'Indices': indices,
}
def init(self):
self.input_shape = [
10000,
]
self.axis = 0
def init_direction(self):
self.descending = False
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_dygraph=False)
class TestArgsortFP16OpDescendingTrue(TestArgsortFP16Op):
def init_direction(self):
self.descending = True
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestArgsortBF16Op(OpTest):
def setUp(self):
self.init()
self.init_direction()
self.op_type = "argsort"
self.python_api = paddle.argsort
self.public_python_api = paddle.argsort
self.python_out_sig = ["Out"]
self.dtype = np.uint16
self.np_dtype = np.float32
self.descending = False
self.attrs = {"axis": self.axis, "descending": self.descending}
X = np.random.rand(*self.input_shape).astype(self.np_dtype)
Out = np.sort(X, kind='quicksort', axis=self.axis)
indices = np.argsort(X, kind='quicksort', axis=self.axis)
self.inputs = {'X': convert_float_to_uint16(X)}
self.outputs = {
'Out': convert_float_to_uint16(Out),
'Indices': convert_float_to_uint16(indices),
}
def init(self):
self.input_shape = [
10000,
]
self.axis = 0
def init_direction(self):
self.descending = False
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', check_dygraph=False)
class TestArgsortBF16OpDescendingTrue(TestArgsortBF16Op):
def init_direction(self):
self.descending = True
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册