未验证 提交 b9ee6a29 编写于 作者: A Asthestarsfalll 提交者: GitHub

【Hackathon No.25】为 Paddle 新增 nanquantile 数学计算API (#41343)

上级 2eac4db8
...@@ -329,6 +329,7 @@ from .tensor.stat import var # noqa: F401 ...@@ -329,6 +329,7 @@ from .tensor.stat import var # noqa: F401
from .tensor.stat import numel # noqa: F401 from .tensor.stat import numel # noqa: F401
from .tensor.stat import median # noqa: F401 from .tensor.stat import median # noqa: F401
from .tensor.stat import quantile # noqa: F401 from .tensor.stat import quantile # noqa: F401
from .tensor.stat import nanquantile # noqa: F401
from .device import get_cudnn_version # noqa: F401 from .device import get_cudnn_version # noqa: F401
from .device import set_device # noqa: F401 from .device import set_device # noqa: F401
from .device import get_device # noqa: F401 from .device import get_device # noqa: F401
...@@ -495,6 +496,7 @@ __all__ = [ # noqa ...@@ -495,6 +496,7 @@ __all__ = [ # noqa
'numel', 'numel',
'median', 'median',
'quantile', 'quantile',
'nanquantile',
'no_grad', 'no_grad',
'set_grad_enabled', 'set_grad_enabled',
'is_grad_enabled', 'is_grad_enabled',
......
...@@ -18,8 +18,11 @@ import unittest ...@@ -18,8 +18,11 @@ import unittest
import numpy as np import numpy as np
import paddle import paddle
API_list = [(paddle.quantile, np.quantile),
(paddle.nanquantile, np.nanquantile)]
class TestQuantile(unittest.TestCase):
class TestQuantileAndNanquantile(unittest.TestCase):
""" """
This class is used for numerical precision testing. If there is This class is used for numerical precision testing. If there is
a corresponding numpy API, the precision comparison can be performed directly. a corresponding numpy API, the precision comparison can be performed directly.
...@@ -27,50 +30,69 @@ class TestQuantile(unittest.TestCase): ...@@ -27,50 +30,69 @@ class TestQuantile(unittest.TestCase):
""" """
def setUp(self): def setUp(self):
np.random.seed(678) self.input_data = np.random.rand(4, 7, 6)
self.input_data = np.random.rand(6, 7, 8, 9, 10)
# Test correctness when q and axis are set. # Test correctness when q and axis are set.
def test_quantile_single_q(self): def test_single_q(self):
x = paddle.to_tensor(self.input_data) inp = self.input_data
paddle_res = paddle.quantile(x, q=0.5, axis=2) for (func, res_func) in API_list:
np_res = np.quantile(self.input_data, q=0.5, axis=2) x = paddle.to_tensor(inp)
paddle_res = func(x, q=0.5, axis=2)
np_res = res_func(inp, q=0.5, axis=2)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
inp[0, 1, 2] = np.nan
# Test correctness for default axis. # Test correctness for default axis.
def test_quantile_with_no_axis(self): def test_with_no_axis(self):
x = paddle.to_tensor(self.input_data) inp = self.input_data
paddle_res = paddle.quantile(x, q=0.35) for (func, res_func) in API_list:
np_res = np.quantile(self.input_data, q=0.35) x = paddle.to_tensor(inp)
paddle_res = func(x, q=0.35)
np_res = res_func(inp, q=0.35)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
inp[0, 2, 1] = np.nan
inp[0, 1, 2] = np.nan
# Test correctness for multiple axis. # Test correctness for multiple axis.
def test_quantile_with_multi_axis(self): def test_with_multi_axis(self):
x = paddle.to_tensor(self.input_data) inp = self.input_data
paddle_res = paddle.quantile(x, q=0.75, axis=[0, 2, 3]) for (func, res_func) in API_list:
np_res = np.quantile(self.input_data, q=0.75, axis=[0, 2, 3]) x = paddle.to_tensor(inp)
paddle_res = func(x, q=0.75, axis=[0, 2])
np_res = res_func(inp, q=0.75, axis=[0, 2])
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
inp[0, 5, 3] = np.nan
inp[0, 6, 2] = np.nan
# Test correctness when keepdim is set. # Test correctness when keepdim is set.
def test_quantile_with_keepdim(self): def test_with_keepdim(self):
x = paddle.to_tensor(self.input_data) inp = self.input_data
paddle_res = paddle.quantile(x, q=0.35, axis=4, keepdim=True) for (func, res_func) in API_list:
np_res = np.quantile(self.input_data, q=0.35, axis=4, keepdims=True) x = paddle.to_tensor(inp)
paddle_res = func(x, q=0.35, axis=2, keepdim=True)
np_res = res_func(inp, q=0.35, axis=2, keepdims=True)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
inp[0, 3, 4] = np.nan
# Test correctness when all parameters are set. # Test correctness when all parameters are set.
def test_quantile_with_keepdim_and_multiple_axis(self): def test_with_keepdim_and_multiple_axis(self):
x = paddle.to_tensor(self.input_data) inp = self.input_data
paddle_res = paddle.quantile(x, q=0.1, axis=[1, 4], keepdim=True) for (func, res_func) in API_list:
np_res = np.quantile(self.input_data, q=0.1, axis=[1, 4], keepdims=True) x = paddle.to_tensor(inp)
paddle_res = func(x, q=0.1, axis=[1, 2], keepdim=True)
np_res = res_func(inp, q=0.1, axis=[1, 2], keepdims=True)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
inp[0, 6, 3] = np.nan
# Test correctness when q = 0. # Test correctness when q = 0.
def test_quantile_with_boundary_q(self): def test_with_boundary_q(self):
x = paddle.to_tensor(self.input_data) inp = self.input_data
paddle_res = paddle.quantile(x, q=0, axis=3) for (func, res_func) in API_list:
np_res = np.quantile(self.input_data, q=0, axis=3) x = paddle.to_tensor(inp)
paddle_res = func(x, q=0, axis=1)
np_res = res_func(inp, q=0, axis=1)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
inp[0, 2, 5] = np.nan
# Test correctness when input includes NaN. # Test correctness when input includes NaN.
def test_quantile_include_NaN(self): def test_quantile_include_NaN(self):
...@@ -78,17 +100,26 @@ class TestQuantile(unittest.TestCase): ...@@ -78,17 +100,26 @@ class TestQuantile(unittest.TestCase):
input_data[0, 1, 1] = np.nan input_data[0, 1, 1] = np.nan
x = paddle.to_tensor(input_data) x = paddle.to_tensor(input_data)
paddle_res = paddle.quantile(x, q=0.35, axis=0) paddle_res = paddle.quantile(x, q=0.35, axis=0)
self.assertTrue(paddle.isnan(paddle_res[1, 1])) np_res = np.quantile(input_data, q=0.35, axis=0)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res, equal_nan=True))
# Test correctness when input filled with NaN.
def test_nanquantile_all_NaN(self):
input_data = np.full(shape=[2, 3], fill_value=np.nan)
input_data[0, 2] = 0
x = paddle.to_tensor(input_data)
paddle_res = paddle.nanquantile(x, q=0.35, axis=0)
np_res = np.nanquantile(input_data, q=0.35, axis=0)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res, equal_nan=True))
class TestQuantileMuitlpleQ(unittest.TestCase): class TestMuitlpleQ(unittest.TestCase):
""" """
This class is used to test multiple input of q. This class is used to test multiple input of q.
""" """
def setUp(self): def setUp(self):
np.random.seed(678) self.input_data = np.random.rand(5, 3, 4)
self.input_data = np.random.rand(10, 3, 4, 5, 4)
def test_quantile(self): def test_quantile(self):
x = paddle.to_tensor(self.input_data) x = paddle.to_tensor(self.input_data)
...@@ -111,7 +142,7 @@ class TestQuantileMuitlpleQ(unittest.TestCase): ...@@ -111,7 +142,7 @@ class TestQuantileMuitlpleQ(unittest.TestCase):
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
class TestQuantileError(unittest.TestCase): class TestError(unittest.TestCase):
""" """
This class is used to test that exceptions are thrown correctly. This class is used to test that exceptions are thrown correctly.
Validity of all parameter values and types should be considered. Validity of all parameter values and types should be considered.
...@@ -184,8 +215,7 @@ class TestQuantileRuntime(unittest.TestCase): ...@@ -184,8 +215,7 @@ class TestQuantileRuntime(unittest.TestCase):
""" """
def setUp(self): def setUp(self):
np.random.seed(678) self.input_data = np.random.rand(4, 7)
self.input_data = np.random.rand(6, 7, 8, 9, 10)
self.dtypes = ['float32', 'float64'] self.dtypes = ['float32', 'float64']
self.devices = ['cpu'] self.devices = ['cpu']
if paddle.device.is_compiled_with_cuda(): if paddle.device.is_compiled_with_cuda():
...@@ -193,6 +223,7 @@ class TestQuantileRuntime(unittest.TestCase): ...@@ -193,6 +223,7 @@ class TestQuantileRuntime(unittest.TestCase):
def test_dygraph(self): def test_dygraph(self):
paddle.disable_static() paddle.disable_static()
for (func, res_func) in API_list:
for device in self.devices: for device in self.devices:
# Check different devices # Check different devices
paddle.set_device(device) paddle.set_device(device)
...@@ -200,12 +231,13 @@ class TestQuantileRuntime(unittest.TestCase): ...@@ -200,12 +231,13 @@ class TestQuantileRuntime(unittest.TestCase):
# Check different dtypes # Check different dtypes
np_input_data = self.input_data.astype(dtype) np_input_data = self.input_data.astype(dtype)
x = paddle.to_tensor(np_input_data, dtype=dtype) x = paddle.to_tensor(np_input_data, dtype=dtype)
paddle_res = paddle.quantile(x, q=0.5, axis=2) paddle_res = func(x, q=0.5, axis=1)
np_res = np.quantile(np_input_data, q=0.5, axis=2) np_res = res_func(np_input_data, q=0.5, axis=1)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
def test_static(self): def test_static(self):
paddle.enable_static() paddle.enable_static()
for (func, res_func) in API_list:
for device in self.devices: for device in self.devices:
x = paddle.static.data( x = paddle.static.data(
name="x", shape=self.input_data.shape, dtype=paddle.float32) name="x", shape=self.input_data.shape, dtype=paddle.float32)
...@@ -214,9 +246,9 @@ class TestQuantileRuntime(unittest.TestCase): ...@@ -214,9 +246,9 @@ class TestQuantileRuntime(unittest.TestCase):
shape=self.input_data.shape, shape=self.input_data.shape,
dtype=paddle.float64) dtype=paddle.float64)
results = paddle.quantile(x, q=0.5, axis=2) results = func(x, q=0.5, axis=1)
np_input_data = self.input_data.astype('float32') np_input_data = self.input_data.astype('float32')
results_fp64 = paddle.quantile(x_fp64, q=0.5, axis=2) results_fp64 = func(x_fp64, q=0.5, axis=1)
np_input_data_fp64 = self.input_data.astype('float64') np_input_data_fp64 = self.input_data.astype('float64')
exe = paddle.static.Executor(device) exe = paddle.static.Executor(device)
...@@ -225,11 +257,11 @@ class TestQuantileRuntime(unittest.TestCase): ...@@ -225,11 +257,11 @@ class TestQuantileRuntime(unittest.TestCase):
feed={"x": np_input_data, feed={"x": np_input_data,
"x_fp64": np_input_data_fp64}, "x_fp64": np_input_data_fp64},
fetch_list=[results, results_fp64]) fetch_list=[results, results_fp64])
np_res = np.quantile(np_input_data, q=0.5, axis=2) np_res = res_func(np_input_data, q=0.5, axis=1)
np_res_fp64 = np.quantile(np_input_data_fp64, q=0.5, axis=2) np_res_fp64 = res_func(np_input_data_fp64, q=0.5, axis=1)
self.assertTrue( self.assertTrue(
np.allclose(paddle_res, np_res) and np.allclose(paddle_res_fp64, np.allclose(paddle_res, np_res) and
np_res_fp64)) np.allclose(paddle_res_fp64, np_res_fp64))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -262,6 +262,7 @@ from .stat import var # noqa: F401 ...@@ -262,6 +262,7 @@ from .stat import var # noqa: F401
from .stat import numel # noqa: F401 from .stat import numel # noqa: F401
from .stat import median # noqa: F401 from .stat import median # noqa: F401
from .stat import quantile # noqa: F401 from .stat import quantile # noqa: F401
from .stat import nanquantile # noqa: F401
from .to_string import set_printoptions # noqa: F401 from .to_string import set_printoptions # noqa: F401
...@@ -445,6 +446,7 @@ tensor_method_func = [ #noqa ...@@ -445,6 +446,7 @@ tensor_method_func = [ #noqa
'numel', 'numel',
'median', 'median',
'quantile', 'quantile',
'nanquantile',
'is_complex', 'is_complex',
'is_integer', 'is_integer',
'rank', 'rank',
......
...@@ -342,10 +342,11 @@ def median(x, axis=None, keepdim=False, name=None): ...@@ -342,10 +342,11 @@ def median(x, axis=None, keepdim=False, name=None):
return out_tensor return out_tensor
def quantile(x, q, axis=None, keepdim=False): def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False):
""" """
Compute the quantile of the input along the specified axis. Compute the quantile of the input along the specified axis.
Args:
Args: Args:
x (Tensor): The input Tensor, it's data type can be float32, float64. x (Tensor): The input Tensor, it's data type can be float32, float64.
q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list,
...@@ -360,37 +361,28 @@ def quantile(x, q, axis=None, keepdim=False): ...@@ -360,37 +361,28 @@ def quantile(x, q, axis=None, keepdim=False):
the output Tensor is the same as ``x`` except in the reduced the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False. the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None). ignore_nan: (bool, optional): Whether to ignore NaN of input Tensor.
For more information, please refer to :ref:`api_guide_Name`. If ``ignore_nan`` is True, it will calculate nanquantile.
Otherwise it will calculate quantile. Default is False.
Returns: Returns:
Tensor, results of quantile along ``axis`` of ``x``. If data type of ``x`` is float64, data type of results will be float64, otherwise data type will be float32. Tensor, results of quantile along ``axis`` of ``x``.
In order to obtain higher precision, data type of results will be float64.
Examples:
.. code-block:: python
import paddle
x = paddle.randn((2,3))
#[[-1.28740597, 0.49533170, -1.00698614],
# [-1.11656201, -1.01010525, -2.23457789]])
y1 = paddle.quantile(x, q=0.5, axis=[0, 1])
# y1 = -1.06333363
y2 = paddle.quantile(x, q=0.5, axis=1)
# y2 = [-1.00698614, -1.11656201]
y3 = paddle.quantile(x, q=[0.3, 0.5], axis=1)
# y3 =[[-1.11915410, -1.56376839],
# [-1.00698614, -1.11656201]]
y4 = paddle.quantile(x, q=0.8, axis=1, keepdim=True)
# y4 = [[-0.10559537],
# [-1.05268800]])
""" """
# Validate x
if not isinstance(x, Variable): if not isinstance(x, Variable):
raise TypeError("input x should be a Tensor.") raise TypeError("input x should be a Tensor.")
# Validate q
if isinstance(q, (int, float)):
q = [q]
elif isinstance(q, (list, tuple)):
if len(q) <= 0:
raise ValueError("q should not be empty")
else:
raise TypeError("Type of q should be int, float, list or tuple.")
# Validate axis
dims = len(x.shape) dims = len(x.shape)
out_shape = list(x.shape) out_shape = list(x.shape)
if axis is None: if axis is None:
...@@ -399,7 +391,7 @@ def quantile(x, q, axis=None, keepdim=False): ...@@ -399,7 +391,7 @@ def quantile(x, q, axis=None, keepdim=False):
out_shape = [1] * dims out_shape = [1] * dims
else: else:
if isinstance(axis, list): if isinstance(axis, list):
if (len(axis) <= 0): if len(axis) <= 0:
raise ValueError("axis should not be empty") raise ValueError("axis should not be empty")
axis_src, axis_dst = [], [] axis_src, axis_dst = [], []
for axis_single in axis: for axis_single in axis:
...@@ -424,54 +416,177 @@ def quantile(x, q, axis=None, keepdim=False): ...@@ -424,54 +416,177 @@ def quantile(x, q, axis=None, keepdim=False):
if axis < 0: if axis < 0:
axis += dims axis += dims
out_shape[axis] = 1 out_shape[axis] = 1
mask = x.isnan()
valid_counts = mask.logical_not().sum(axis=axis,
keepdim=True,
dtype='float64')
indices = [] indices = []
if isinstance(q, (int, float)):
if q < 0 or q > 1:
raise ValueError("q should be in range [0, 1]")
indices.append(q * (x.shape[axis] - 1))
elif isinstance(q, (list, tuple)):
if len(q) <= 0:
raise ValueError("q should not be empty")
for q_num in q: for q_num in q:
if q_num < 0 or q_num > 1: if q_num < 0 or q_num > 1:
raise ValueError("q should be in range [0, 1]") raise ValueError("q should be in range [0, 1]")
indices.append(q_num * (x.shape[axis] - 1)) if paddle.in_dynamic_mode():
q_num = paddle.to_tensor(q_num, dtype='float64')
if ignore_nan:
indices.append(q_num * (valid_counts - 1))
else: else:
raise TypeError("Type of q should be int, float, list or tuple.") # TODO(Asthestarsfalll): Use paddle.index_fill instead of where
index = q_num * (valid_counts - 1)
last_index = x.shape[axis] - 1
nums = paddle.full_like(index, fill_value=last_index)
index = paddle.where(mask.any(axis=axis, keepdim=True), nums, index)
indices.append(index)
sorted_tensor = paddle.sort(x, axis) sorted_tensor = paddle.sort(x, axis)
indices_tensor = paddle.assign(indices).astype(paddle.float32)
indices_below = paddle.floor(indices_tensor).astype(paddle.int32)
indices_upper = paddle.ceil(indices_tensor).astype(paddle.int32)
outputs = []
def expand_dim(indices, sorted_tensor_shape, axis): outputs = []
assert axis < len(list(sorted_tensor_shape))
expanded_shape = [1] * len(list(sorted_tensor_shape))
expanded_shape = tuple(expanded_shape)
indices = indices.reshape(expanded_shape)
return indices
# TODO(chenjianye): replace the for-loop to directly take elements. # TODO(chenjianye): replace the for-loop to directly take elements.
for i in range(len(indices)): for index in indices:
if (indices_upper[i] != indices_below[i]): indices_below = paddle.floor(index).astype(paddle.int32)
tensor_below = paddle.take_along_axis( indices_upper = paddle.ceil(index).astype(paddle.int32)
sorted_tensor,
expand_dim(indices_below[i], sorted_tensor.shape, axis), axis)
tensor_upper = paddle.take_along_axis( tensor_upper = paddle.take_along_axis(
sorted_tensor, sorted_tensor, indices_upper, axis=axis)
expand_dim(indices_upper[i], sorted_tensor.shape, axis), axis) tensor_below = paddle.take_along_axis(
weights = (indices[i] - indices_below[i]).astype(x.dtype) sorted_tensor, indices_below, axis=axis)
out = paddle.lerp(tensor_below, tensor_upper, weights) weights = (index - indices_below.astype('float64'))
else: out = paddle.lerp(
out = paddle.take_along_axis( tensor_below.astype('float64'),
sorted_tensor, tensor_upper.astype('float64'), weights)
expand_dim(indices_below[i], sorted_tensor.shape, axis), axis)
if not keepdim: if not keepdim:
out = paddle.squeeze(out, axis=axis) out = paddle.squeeze(out, axis=axis)
else: else:
out = out.reshape(out_shape) out = out.reshape(out_shape)
outputs.append(out) outputs.append(out)
if isinstance(q, (list, tuple)):
return paddle.stack(outputs, 0) if len(q) > 1:
outputs = paddle.stack(outputs, 0)
else: else:
return outputs[0] outputs = outputs[0]
return outputs
def quantile(x, q, axis=None, keepdim=False):
"""
Compute the quantile of the input along the specified axis.
If any values in a reduced row are NaN, then the quantiles for that reduction will be NaN.
Args:
x (Tensor): The input Tensor, it's data type can be float32, float64.
q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list,
each q will be calculated and the first dimension of output is same to the number of ``q`` .
axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
If ``axis`` is less than 0, it works the same way as :math:`axis + D`.
If ``axis`` is a list, quantile is calculated over all elements of given axises.
If ``axis`` is None, quantile is calculated over all elements of ``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, results of quantile along ``axis`` of ``x``.
In order to obtain higher precision, data type of results will be float64.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.arange(0, 8, dtype=np.float32).reshape(4, 2)
# [[0 1]
# [2 3]
# [4 5]
# [6 7]]
y = paddle.to_tensor(x)
y1 = paddle.quantile(y, q=0.5, axis=[0, 1])
# 3.5
y2 = paddle.quantile(y, q=0.5, axis=1)
# [0.5 2.5 4.5 6.5]
y3 = paddle.quantile(y, q=[0.3, 0.5], axis=0)
# [[1.8 2.8]
# [3. 4. ]]
x[0][0] = np.nan
y = paddle.to_tensor(x)
y4 = paddle.quantile(y, q=0.8, axis=1, keepdim=True)
# [[nan]
# [2.8]
# [4.8]
# [6.8]]
"""
return _compute_quantile(x, q, axis=axis, keepdim=keepdim, ignore_nan=False)
def nanquantile(x, q, axis=None, keepdim=False):
"""
Compute the quantile of the input as if NaN values in input did not exist.
If all values in a reduced row are NaN, then the quantiles for that reduction will be NaN.
Args:
x (Tensor): The input Tensor, it's data type can be float32, float64.
q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list,
each q will be calculated and the first dimension of output is same to the number of ``q`` .
axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
If ``axis`` is less than 0, it works the same way as :math:`axis + D`.
If ``axis`` is a list, quantile is calculated over all elements of given axises.
If ``axis`` is None, quantile is calculated over all elements of ``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, results of quantile along ``axis`` of ``x``.
In order to obtain higher precision, data type of results will be float64.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.array(
[[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]],
dtype=np.float32
)
x[0][0] = np.nan
x = paddle.to_tensor(x)
y1 = paddle.nanquantile(x, q=0.5, axis=[0, 1])
# 5.0
y2 = paddle.nanquantile(x, q=0.5, axis=1)
# [2.5 7. ]
y3 = paddle.nanquantile(x, q=[0.3, 0.5], axis=0)
# [[5. 2.5 3.5 4.5 5.5]
# [5. 3.5 4.5 5.5 6.5]
y4 = paddle.nanquantile(x, q=0.8, axis=1, keepdim=True)
# [[3.4]
# [8.2]]
nan = paddle.full(shape=[2, 3], fill_value=np.nan)
y5 = paddle.nanquantile(nan, q=0.8, axis=1, keepdim=True)
# [[nan]
# [nan]]
"""
return _compute_quantile(x, q, axis=axis, keepdim=keepdim, ignore_nan=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册