未验证 提交 60e37d17 编写于 作者: C cyberslack_lee 提交者: GitHub

Add float16 and bfloat16 support and test for argsort (#55105)

上级 6e61c9f9
...@@ -38,6 +38,10 @@ namespace detail { ...@@ -38,6 +38,10 @@ namespace detail {
template <> template <>
struct radix_key_codec_base<phi::dtype::float16> struct radix_key_codec_base<phi::dtype::float16>
: radix_key_codec_integral<phi::dtype::float16, uint16_t> {}; : radix_key_codec_integral<phi::dtype::float16, uint16_t> {};
template <>
struct radix_key_codec_base<phi::dtype::bfloat16>
: radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {};
} // namespace detail } // namespace detail
} // namespace rocprim } // namespace rocprim
#else #else
...@@ -46,6 +50,11 @@ namespace cub { ...@@ -46,6 +50,11 @@ namespace cub {
template <> template <>
struct NumericTraits<phi::dtype::float16> struct NumericTraits<phi::dtype::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::float16> {}; : BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::float16> {};
template <>
struct NumericTraits<phi::dtype::bfloat16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::bfloat16> {
};
} // namespace cub } // namespace cub
#endif #endif
...@@ -222,4 +231,5 @@ PD_REGISTER_KERNEL(argsort_grad, ...@@ -222,4 +231,5 @@ PD_REGISTER_KERNEL(argsort_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -61,7 +61,13 @@ namespace cub { ...@@ -61,7 +61,13 @@ namespace cub {
template <> template <>
struct NumericTraits<phi::dtype::float16> struct NumericTraits<phi::dtype::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::float16> {}; : BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::float16> {};
template <>
struct NumericTraits<phi::dtype::bfloat16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::bfloat16> {
};
} // namespace cub } // namespace cub
#endif #endif
namespace phi { namespace phi {
...@@ -328,6 +334,7 @@ PD_REGISTER_KERNEL(argsort, ...@@ -328,6 +334,7 @@ PD_REGISTER_KERNEL(argsort,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64); kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
} }
...@@ -39,7 +39,7 @@ def argsort(x, axis=-1, descending=False, name=None): ...@@ -39,7 +39,7 @@ def argsort(x, axis=-1, descending=False, name=None):
Sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True. Sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
Args: Args:
x (Tensor): An input N-D Tensor with type float16, float32, float64, int16, x (Tensor): An input N-D Tensor with type bfloat16, float16, float32, float64, int16,
int32, int64, uint8. int32, int64, uint8.
axis (int, optional): Axis to compute indices along. The effective range axis (int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is Rank(x). when axis<0, it works the same way is [-R, R), where R is Rank(x). when axis<0, it works the same way
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -513,5 +514,95 @@ class TestArgsortOpFp16(unittest.TestCase): ...@@ -513,5 +514,95 @@ class TestArgsortOpFp16(unittest.TestCase):
out = exe.run(feed={'x': x_np}, fetch_list=[out]) out = exe.run(feed={'x': x_np}, fetch_list=[out])
class TestArgsortFP16Op(OpTest):
def setUp(self):
self.init()
self.init_direction()
self.op_type = "argsort"
self.python_api = paddle.argsort
self.public_python_api = paddle.argsort
self.python_out_sig = ["Out"]
self.dtype = np.float16
self.descending = False
self.attrs = {"axis": self.axis, "descending": self.descending}
X = np.random.rand(*self.input_shape).astype('float16')
Out = np.sort(X, kind='quicksort', axis=self.axis)
indices = np.argsort(X, kind='quicksort', axis=self.axis)
self.inputs = {'X': X}
self.outputs = {
'Out': Out,
'Indices': indices,
}
def init(self):
self.input_shape = [
10000,
]
self.axis = 0
def init_direction(self):
self.descending = False
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_dygraph=False)
class TestArgsortFP16OpDescendingTrue(TestArgsortFP16Op):
def init_direction(self):
self.descending = True
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestArgsortBF16Op(OpTest):
def setUp(self):
self.init()
self.init_direction()
self.op_type = "argsort"
self.python_api = paddle.argsort
self.public_python_api = paddle.argsort
self.python_out_sig = ["Out"]
self.dtype = np.uint16
self.np_dtype = np.float32
self.descending = False
self.attrs = {"axis": self.axis, "descending": self.descending}
X = np.random.rand(*self.input_shape).astype(self.np_dtype)
Out = np.sort(X, kind='quicksort', axis=self.axis)
indices = np.argsort(X, kind='quicksort', axis=self.axis)
self.inputs = {'X': convert_float_to_uint16(X)}
self.outputs = {
'Out': convert_float_to_uint16(Out),
'Indices': convert_float_to_uint16(indices),
}
def init(self):
self.input_shape = [
10000,
]
self.axis = 0
def init_direction(self):
self.descending = False
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', check_dygraph=False)
class TestArgsortBF16OpDescendingTrue(TestArgsortBF16Op):
def init_direction(self):
self.descending = True
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册