未验证 提交 4ae9945b 编写于 作者: C cyberslack_lee 提交者: GitHub

Add FP16 & BF16 for nanmedian (#56056)

上级 08e46d6f
...@@ -123,4 +123,5 @@ PD_REGISTER_KERNEL(nanmedian_grad, ...@@ -123,4 +123,5 @@ PD_REGISTER_KERNEL(nanmedian_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -287,6 +287,7 @@ PD_REGISTER_KERNEL(nanmedian, ...@@ -287,6 +287,7 @@ PD_REGISTER_KERNEL(nanmedian,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64); kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
} }
...@@ -265,7 +265,7 @@ def nanmedian(x, axis=None, keepdim=False, name=None): ...@@ -265,7 +265,7 @@ def nanmedian(x, axis=None, keepdim=False, name=None):
the average value of both elements in the middle is calculated as the median. the average value of both elements in the middle is calculated as the median.
Args: Args:
x (Tensor): The input Tensor, it's data type can be int32, int64, float16, float32, float64. x (Tensor): The input Tensor, it's data type can be int32, int64, float16, bfloat16, float32, float64.
axis (None|int|list|tuple, optional): axis (None|int|list|tuple, optional):
The axis along which to perform median calculations ``axis`` should be int or list of int. The axis along which to perform median calculations ``axis`` should be int or list of int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
...@@ -319,7 +319,7 @@ def nanmedian(x, axis=None, keepdim=False, name=None): ...@@ -319,7 +319,7 @@ def nanmedian(x, axis=None, keepdim=False, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'X', 'X',
['int32', 'int64', 'float16', 'float32', 'float64'], ['int32', 'int64', 'float16', 'float32', 'float64', 'uint16'],
'nanmedian', 'nanmedian',
) )
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -243,5 +244,50 @@ class TestNanmedian(unittest.TestCase): ...@@ -243,5 +244,50 @@ class TestNanmedian(unittest.TestCase):
np.testing.assert_allclose(x.grad, np.array(0.0)) np.testing.assert_allclose(x.grad, np.array(0.0))
class TestNanmedianFP16Op(OpTest):
def setUp(self):
self.op_type = "nanmedian"
self.python_api = paddle.nanmedian
self.public_python_api = paddle.nanmedian
self.dtype = np.float16
self.python_out_sig = ["Out"]
X = np.random.random((100, 100)).astype('float16')
Out = np.nanmedian(X)
self.inputs = {'X': X}
self.outputs = {'Out': Out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestNanmedianBF16Op(OpTest):
def setUp(self):
self.op_type = "nanmedian"
self.python_api = paddle.nanmedian
self.public_python_api = paddle.nanmedian
self.dtype = np.uint16
self.python_out_sig = ["Out"]
X = np.random.random((100, 100)).astype('float32')
Out = np.nanmedian(X)
self.inputs = {'X': convert_float_to_uint16(X)}
self.outputs = {'Out': convert_float_to_uint16(Out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册