未验证 提交 4ae9945b 编写于 作者: C cyberslack_lee 提交者: GitHub

Add FP16 & BF16 for nanmedian (#56056)

上级 08e46d6f
......@@ -123,4 +123,5 @@ PD_REGISTER_KERNEL(nanmedian_grad,
double,
int,
int64_t,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -287,6 +287,7 @@ PD_REGISTER_KERNEL(nanmedian,
double,
int,
int64_t,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -265,7 +265,7 @@ def nanmedian(x, axis=None, keepdim=False, name=None):
the average value of both elements in the middle is calculated as the median.
Args:
x (Tensor): The input Tensor, it's data type can be int32, int64, float16, float32, float64.
x (Tensor): The input Tensor, it's data type can be int32, int64, float16, bfloat16, float32, float64.
axis (None|int|list|tuple, optional):
The axis along which to perform median calculations ``axis`` should be int or list of int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
......@@ -319,7 +319,7 @@ def nanmedian(x, axis=None, keepdim=False, name=None):
check_variable_and_dtype(
x,
'X',
['int32', 'int64', 'float16', 'float32', 'float64'],
['int32', 'int64', 'float16', 'float32', 'float64', 'uint16'],
'nanmedian',
)
......
......@@ -15,6 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
......@@ -243,5 +244,50 @@ class TestNanmedian(unittest.TestCase):
np.testing.assert_allclose(x.grad, np.array(0.0))
class TestNanmedianFP16Op(OpTest):
def setUp(self):
self.op_type = "nanmedian"
self.python_api = paddle.nanmedian
self.public_python_api = paddle.nanmedian
self.dtype = np.float16
self.python_out_sig = ["Out"]
X = np.random.random((100, 100)).astype('float16')
Out = np.nanmedian(X)
self.inputs = {'X': X}
self.outputs = {'Out': Out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestNanmedianBF16Op(OpTest):
def setUp(self):
self.op_type = "nanmedian"
self.python_api = paddle.nanmedian
self.public_python_api = paddle.nanmedian
self.dtype = np.uint16
self.python_out_sig = ["Out"]
X = np.random.random((100, 100)).astype('float32')
Out = np.nanmedian(X)
self.inputs = {'X': convert_float_to_uint16(X)}
self.outputs = {'Out': convert_float_to_uint16(Out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册