未验证 提交 b9ee6a29 编写于 作者: A Asthestarsfalll 提交者: GitHub

【Hackathon No.25】为 Paddle 新增 nanquantile 数学计算API (#41343)

上级 2eac4db8
...@@ -329,6 +329,7 @@ from .tensor.stat import var # noqa: F401 ...@@ -329,6 +329,7 @@ from .tensor.stat import var # noqa: F401
from .tensor.stat import numel # noqa: F401 from .tensor.stat import numel # noqa: F401
from .tensor.stat import median # noqa: F401 from .tensor.stat import median # noqa: F401
from .tensor.stat import quantile # noqa: F401 from .tensor.stat import quantile # noqa: F401
from .tensor.stat import nanquantile # noqa: F401
from .device import get_cudnn_version # noqa: F401 from .device import get_cudnn_version # noqa: F401
from .device import set_device # noqa: F401 from .device import set_device # noqa: F401
from .device import get_device # noqa: F401 from .device import get_device # noqa: F401
...@@ -495,6 +496,7 @@ __all__ = [ # noqa ...@@ -495,6 +496,7 @@ __all__ = [ # noqa
'numel', 'numel',
'median', 'median',
'quantile', 'quantile',
'nanquantile',
'no_grad', 'no_grad',
'set_grad_enabled', 'set_grad_enabled',
'is_grad_enabled', 'is_grad_enabled',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle import paddle
API_list = [(paddle.quantile, np.quantile),
class TestQuantile(unittest.TestCase): (paddle.nanquantile, np.nanquantile)]
"""
This class is used for numerical precision testing. If there is
a corresponding numpy API, the precision comparison can be performed directly. class TestQuantileAndNanquantile(unittest.TestCase):
Otherwise, it needs to be verified by numpy implementated function. """
""" This class is used for numerical precision testing. If there is
a corresponding numpy API, the precision comparison can be performed directly.
def setUp(self): Otherwise, it needs to be verified by numpy implementated function.
np.random.seed(678) """
self.input_data = np.random.rand(6, 7, 8, 9, 10)
def setUp(self):
# Test correctness when q and axis are set. self.input_data = np.random.rand(4, 7, 6)
def test_quantile_single_q(self):
x = paddle.to_tensor(self.input_data) # Test correctness when q and axis are set.
paddle_res = paddle.quantile(x, q=0.5, axis=2) def test_single_q(self):
np_res = np.quantile(self.input_data, q=0.5, axis=2) inp = self.input_data
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) for (func, res_func) in API_list:
x = paddle.to_tensor(inp)
# Test correctness for default axis. paddle_res = func(x, q=0.5, axis=2)
def test_quantile_with_no_axis(self): np_res = res_func(inp, q=0.5, axis=2)
x = paddle.to_tensor(self.input_data) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
paddle_res = paddle.quantile(x, q=0.35) inp[0, 1, 2] = np.nan
np_res = np.quantile(self.input_data, q=0.35)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) # Test correctness for default axis.
def test_with_no_axis(self):
# Test correctness for multiple axis. inp = self.input_data
def test_quantile_with_multi_axis(self): for (func, res_func) in API_list:
x = paddle.to_tensor(self.input_data) x = paddle.to_tensor(inp)
paddle_res = paddle.quantile(x, q=0.75, axis=[0, 2, 3]) paddle_res = func(x, q=0.35)
np_res = np.quantile(self.input_data, q=0.75, axis=[0, 2, 3]) np_res = res_func(inp, q=0.35)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
inp[0, 2, 1] = np.nan
# Test correctness when keepdim is set. inp[0, 1, 2] = np.nan
def test_quantile_with_keepdim(self):
x = paddle.to_tensor(self.input_data) # Test correctness for multiple axis.
paddle_res = paddle.quantile(x, q=0.35, axis=4, keepdim=True) def test_with_multi_axis(self):
np_res = np.quantile(self.input_data, q=0.35, axis=4, keepdims=True) inp = self.input_data
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) for (func, res_func) in API_list:
x = paddle.to_tensor(inp)
# Test correctness when all parameters are set. paddle_res = func(x, q=0.75, axis=[0, 2])
def test_quantile_with_keepdim_and_multiple_axis(self): np_res = res_func(inp, q=0.75, axis=[0, 2])
x = paddle.to_tensor(self.input_data) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
paddle_res = paddle.quantile(x, q=0.1, axis=[1, 4], keepdim=True) inp[0, 5, 3] = np.nan
np_res = np.quantile(self.input_data, q=0.1, axis=[1, 4], keepdims=True) inp[0, 6, 2] = np.nan
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
# Test correctness when keepdim is set.
# Test correctness when q = 0. def test_with_keepdim(self):
def test_quantile_with_boundary_q(self): inp = self.input_data
x = paddle.to_tensor(self.input_data) for (func, res_func) in API_list:
paddle_res = paddle.quantile(x, q=0, axis=3) x = paddle.to_tensor(inp)
np_res = np.quantile(self.input_data, q=0, axis=3) paddle_res = func(x, q=0.35, axis=2, keepdim=True)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) np_res = res_func(inp, q=0.35, axis=2, keepdims=True)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
# Test correctness when input includes NaN. inp[0, 3, 4] = np.nan
def test_quantile_include_NaN(self):
input_data = np.random.randn(2, 3, 4) # Test correctness when all parameters are set.
input_data[0, 1, 1] = np.nan def test_with_keepdim_and_multiple_axis(self):
x = paddle.to_tensor(input_data) inp = self.input_data
paddle_res = paddle.quantile(x, q=0.35, axis=0) for (func, res_func) in API_list:
self.assertTrue(paddle.isnan(paddle_res[1, 1])) x = paddle.to_tensor(inp)
paddle_res = func(x, q=0.1, axis=[1, 2], keepdim=True)
np_res = res_func(inp, q=0.1, axis=[1, 2], keepdims=True)
class TestQuantileMuitlpleQ(unittest.TestCase): self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
""" inp[0, 6, 3] = np.nan
This class is used to test multiple input of q.
""" # Test correctness when q = 0.
def test_with_boundary_q(self):
def setUp(self): inp = self.input_data
np.random.seed(678) for (func, res_func) in API_list:
self.input_data = np.random.rand(10, 3, 4, 5, 4) x = paddle.to_tensor(inp)
paddle_res = func(x, q=0, axis=1)
def test_quantile(self): np_res = res_func(inp, q=0, axis=1)
x = paddle.to_tensor(self.input_data) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
paddle_res = paddle.quantile(x, q=[0.3, 0.44], axis=-2) inp[0, 2, 5] = np.nan
np_res = np.quantile(self.input_data, q=[0.3, 0.44], axis=-2)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) # Test correctness when input includes NaN.
def test_quantile_include_NaN(self):
def test_quantile_multiple_axis(self): input_data = np.random.randn(2, 3, 4)
x = paddle.to_tensor(self.input_data) input_data[0, 1, 1] = np.nan
paddle_res = paddle.quantile(x, q=[0.2, 0.67], axis=[1, -1]) x = paddle.to_tensor(input_data)
np_res = np.quantile(self.input_data, q=[0.2, 0.67], axis=[1, -1]) paddle_res = paddle.quantile(x, q=0.35, axis=0)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) np_res = np.quantile(input_data, q=0.35, axis=0)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res, equal_nan=True))
def test_quantile_multiple_axis_keepdim(self):
x = paddle.to_tensor(self.input_data) # Test correctness when input filled with NaN.
paddle_res = paddle.quantile( def test_nanquantile_all_NaN(self):
x, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdim=True) input_data = np.full(shape=[2, 3], fill_value=np.nan)
np_res = np.quantile( input_data[0, 2] = 0
self.input_data, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdims=True) x = paddle.to_tensor(input_data)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) paddle_res = paddle.nanquantile(x, q=0.35, axis=0)
np_res = np.nanquantile(input_data, q=0.35, axis=0)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res, equal_nan=True))
class TestQuantileError(unittest.TestCase):
"""
This class is used to test that exceptions are thrown correctly. class TestMuitlpleQ(unittest.TestCase):
Validity of all parameter values and types should be considered. """
""" This class is used to test multiple input of q.
"""
def setUp(self):
self.x = paddle.randn((2, 3, 4)) def setUp(self):
self.input_data = np.random.rand(5, 3, 4)
def test_errors(self):
# Test error when q > 1 def test_quantile(self):
def test_q_range_error_1(): x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(self.x, q=1.5) paddle_res = paddle.quantile(x, q=[0.3, 0.44], axis=-2)
np_res = np.quantile(self.input_data, q=[0.3, 0.44], axis=-2)
self.assertRaises(ValueError, test_q_range_error_1) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
# Test error when q < 0 def test_quantile_multiple_axis(self):
def test_q_range_error_2(): x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(self.x, q=[0.2, -0.3]) paddle_res = paddle.quantile(x, q=[0.2, 0.67], axis=[1, -1])
np_res = np.quantile(self.input_data, q=[0.2, 0.67], axis=[1, -1])
self.assertRaises(ValueError, test_q_range_error_2) self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
# Test error with no valid q def test_quantile_multiple_axis_keepdim(self):
def test_q_range_error_3(): x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(self.x, q=[]) paddle_res = paddle.quantile(
x, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdim=True)
self.assertRaises(ValueError, test_q_range_error_3) np_res = np.quantile(
self.input_data, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdims=True)
# Test error when x is not Tensor self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
def test_x_type_error():
x = [1, 3, 4]
paddle_res = paddle.quantile(x, q=0.9) class TestError(unittest.TestCase):
"""
self.assertRaises(TypeError, test_x_type_error) This class is used to test that exceptions are thrown correctly.
Validity of all parameter values and types should be considered.
# Test error when scalar axis is not int """
def test_axis_type_error_1():
paddle_res = paddle.quantile(self.x, q=0.4, axis=0.4) def setUp(self):
self.x = paddle.randn((2, 3, 4))
self.assertRaises(ValueError, test_axis_type_error_1)
def test_errors(self):
# Test error when axis in List is not int # Test error when q > 1
def test_axis_type_error_2(): def test_q_range_error_1():
paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, 0.4]) paddle_res = paddle.quantile(self.x, q=1.5)
self.assertRaises(ValueError, test_axis_type_error_2) self.assertRaises(ValueError, test_q_range_error_1)
# Test error when axis not in [-D, D) # Test error when q < 0
def test_axis_value_error_1(): def test_q_range_error_2():
paddle_res = paddle.quantile(self.x, q=0.4, axis=10) paddle_res = paddle.quantile(self.x, q=[0.2, -0.3])
self.assertRaises(ValueError, test_axis_value_error_1) self.assertRaises(ValueError, test_q_range_error_2)
# Test error when axis not in [-D, D) # Test error with no valid q
def test_axis_value_error_2(): def test_q_range_error_3():
paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, -10]) paddle_res = paddle.quantile(self.x, q=[])
self.assertRaises(ValueError, test_axis_value_error_2) self.assertRaises(ValueError, test_q_range_error_3)
# Test error with no valid axis # Test error when x is not Tensor
def test_axis_value_error_3(): def test_x_type_error():
paddle_res = paddle.quantile(self.x, q=0.4, axis=[]) x = [1, 3, 4]
paddle_res = paddle.quantile(x, q=0.9)
self.assertRaises(ValueError, test_axis_value_error_3)
self.assertRaises(TypeError, test_x_type_error)
class TestQuantileRuntime(unittest.TestCase): # Test error when scalar axis is not int
""" def test_axis_type_error_1():
This class is used to test the API could run correctly with paddle_res = paddle.quantile(self.x, q=0.4, axis=0.4)
different devices, different data types, and dygraph/static mode.
""" self.assertRaises(ValueError, test_axis_type_error_1)
def setUp(self): # Test error when axis in List is not int
np.random.seed(678) def test_axis_type_error_2():
self.input_data = np.random.rand(6, 7, 8, 9, 10) paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, 0.4])
self.dtypes = ['float32', 'float64']
self.devices = ['cpu'] self.assertRaises(ValueError, test_axis_type_error_2)
if paddle.device.is_compiled_with_cuda():
self.devices.append('gpu') # Test error when axis not in [-D, D)
def test_axis_value_error_1():
def test_dygraph(self): paddle_res = paddle.quantile(self.x, q=0.4, axis=10)
paddle.disable_static()
for device in self.devices: self.assertRaises(ValueError, test_axis_value_error_1)
# Check different devices
paddle.set_device(device) # Test error when axis not in [-D, D)
for dtype in self.dtypes: def test_axis_value_error_2():
# Check different dtypes paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, -10])
np_input_data = self.input_data.astype(dtype)
x = paddle.to_tensor(np_input_data, dtype=dtype) self.assertRaises(ValueError, test_axis_value_error_2)
paddle_res = paddle.quantile(x, q=0.5, axis=2)
np_res = np.quantile(np_input_data, q=0.5, axis=2) # Test error with no valid axis
self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) def test_axis_value_error_3():
paddle_res = paddle.quantile(self.x, q=0.4, axis=[])
def test_static(self):
paddle.enable_static() self.assertRaises(ValueError, test_axis_value_error_3)
for device in self.devices:
x = paddle.static.data(
name="x", shape=self.input_data.shape, dtype=paddle.float32) class TestQuantileRuntime(unittest.TestCase):
x_fp64 = paddle.static.data( """
name="x_fp64", This class is used to test the API could run correctly with
shape=self.input_data.shape, different devices, different data types, and dygraph/static mode.
dtype=paddle.float64) """
results = paddle.quantile(x, q=0.5, axis=2) def setUp(self):
np_input_data = self.input_data.astype('float32') self.input_data = np.random.rand(4, 7)
results_fp64 = paddle.quantile(x_fp64, q=0.5, axis=2) self.dtypes = ['float32', 'float64']
np_input_data_fp64 = self.input_data.astype('float64') self.devices = ['cpu']
if paddle.device.is_compiled_with_cuda():
exe = paddle.static.Executor(device) self.devices.append('gpu')
paddle_res, paddle_res_fp64 = exe.run(
paddle.static.default_main_program(), def test_dygraph(self):
feed={"x": np_input_data, paddle.disable_static()
"x_fp64": np_input_data_fp64}, for (func, res_func) in API_list:
fetch_list=[results, results_fp64]) for device in self.devices:
np_res = np.quantile(np_input_data, q=0.5, axis=2) # Check different devices
np_res_fp64 = np.quantile(np_input_data_fp64, q=0.5, axis=2) paddle.set_device(device)
self.assertTrue( for dtype in self.dtypes:
np.allclose(paddle_res, np_res) and np.allclose(paddle_res_fp64, # Check different dtypes
np_res_fp64)) np_input_data = self.input_data.astype(dtype)
x = paddle.to_tensor(np_input_data, dtype=dtype)
paddle_res = func(x, q=0.5, axis=1)
if __name__ == '__main__': np_res = res_func(np_input_data, q=0.5, axis=1)
unittest.main() self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
def test_static(self):
paddle.enable_static()
for (func, res_func) in API_list:
for device in self.devices:
x = paddle.static.data(
name="x", shape=self.input_data.shape, dtype=paddle.float32)
x_fp64 = paddle.static.data(
name="x_fp64",
shape=self.input_data.shape,
dtype=paddle.float64)
results = func(x, q=0.5, axis=1)
np_input_data = self.input_data.astype('float32')
results_fp64 = func(x_fp64, q=0.5, axis=1)
np_input_data_fp64 = self.input_data.astype('float64')
exe = paddle.static.Executor(device)
paddle_res, paddle_res_fp64 = exe.run(
paddle.static.default_main_program(),
feed={"x": np_input_data,
"x_fp64": np_input_data_fp64},
fetch_list=[results, results_fp64])
np_res = res_func(np_input_data, q=0.5, axis=1)
np_res_fp64 = res_func(np_input_data_fp64, q=0.5, axis=1)
self.assertTrue(
np.allclose(paddle_res, np_res) and
np.allclose(paddle_res_fp64, np_res_fp64))
if __name__ == '__main__':
unittest.main()
...@@ -262,6 +262,7 @@ from .stat import var # noqa: F401 ...@@ -262,6 +262,7 @@ from .stat import var # noqa: F401
from .stat import numel # noqa: F401 from .stat import numel # noqa: F401
from .stat import median # noqa: F401 from .stat import median # noqa: F401
from .stat import quantile # noqa: F401 from .stat import quantile # noqa: F401
from .stat import nanquantile # noqa: F401
from .to_string import set_printoptions # noqa: F401 from .to_string import set_printoptions # noqa: F401
...@@ -445,6 +446,7 @@ tensor_method_func = [ #noqa ...@@ -445,6 +446,7 @@ tensor_method_func = [ #noqa
'numel', 'numel',
'median', 'median',
'quantile', 'quantile',
'nanquantile',
'is_complex', 'is_complex',
'is_integer', 'is_integer',
'rank', 'rank',
......
...@@ -342,13 +342,14 @@ def median(x, axis=None, keepdim=False, name=None): ...@@ -342,13 +342,14 @@ def median(x, axis=None, keepdim=False, name=None):
return out_tensor return out_tensor
def quantile(x, q, axis=None, keepdim=False): def _compute_quantile(x, q, axis=None, keepdim=False, ignore_nan=False):
""" """
Compute the quantile of the input along the specified axis. Compute the quantile of the input along the specified axis.
Args:
Args: Args:
x (Tensor): The input Tensor, it's data type can be float32, float64. x (Tensor): The input Tensor, it's data type can be float32, float64.
q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list,
each q will be calculated and the first dimension of output is same to the number of ``q`` . each q will be calculated and the first dimension of output is same to the number of ``q`` .
axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
...@@ -360,37 +361,28 @@ def quantile(x, q, axis=None, keepdim=False): ...@@ -360,37 +361,28 @@ def quantile(x, q, axis=None, keepdim=False):
the output Tensor is the same as ``x`` except in the reduced the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False. the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None). ignore_nan: (bool, optional): Whether to ignore NaN of input Tensor.
For more information, please refer to :ref:`api_guide_Name`. If ``ignore_nan`` is True, it will calculate nanquantile.
Otherwise it will calculate quantile. Default is False.
Returns: Returns:
Tensor, results of quantile along ``axis`` of ``x``. If data type of ``x`` is float64, data type of results will be float64, otherwise data type will be float32. Tensor, results of quantile along ``axis`` of ``x``.
In order to obtain higher precision, data type of results will be float64.
Examples:
.. code-block:: python
import paddle
x = paddle.randn((2,3))
#[[-1.28740597, 0.49533170, -1.00698614],
# [-1.11656201, -1.01010525, -2.23457789]])
y1 = paddle.quantile(x, q=0.5, axis=[0, 1])
# y1 = -1.06333363
y2 = paddle.quantile(x, q=0.5, axis=1)
# y2 = [-1.00698614, -1.11656201]
y3 = paddle.quantile(x, q=[0.3, 0.5], axis=1)
# y3 =[[-1.11915410, -1.56376839],
# [-1.00698614, -1.11656201]]
y4 = paddle.quantile(x, q=0.8, axis=1, keepdim=True)
# y4 = [[-0.10559537],
# [-1.05268800]])
""" """
# Validate x
if not isinstance(x, Variable): if not isinstance(x, Variable):
raise TypeError("input x should be a Tensor.") raise TypeError("input x should be a Tensor.")
# Validate q
if isinstance(q, (int, float)):
q = [q]
elif isinstance(q, (list, tuple)):
if len(q) <= 0:
raise ValueError("q should not be empty")
else:
raise TypeError("Type of q should be int, float, list or tuple.")
# Validate axis
dims = len(x.shape) dims = len(x.shape)
out_shape = list(x.shape) out_shape = list(x.shape)
if axis is None: if axis is None:
...@@ -399,7 +391,7 @@ def quantile(x, q, axis=None, keepdim=False): ...@@ -399,7 +391,7 @@ def quantile(x, q, axis=None, keepdim=False):
out_shape = [1] * dims out_shape = [1] * dims
else: else:
if isinstance(axis, list): if isinstance(axis, list):
if (len(axis) <= 0): if len(axis) <= 0:
raise ValueError("axis should not be empty") raise ValueError("axis should not be empty")
axis_src, axis_dst = [], [] axis_src, axis_dst = [], []
for axis_single in axis: for axis_single in axis:
...@@ -424,54 +416,177 @@ def quantile(x, q, axis=None, keepdim=False): ...@@ -424,54 +416,177 @@ def quantile(x, q, axis=None, keepdim=False):
if axis < 0: if axis < 0:
axis += dims axis += dims
out_shape[axis] = 1 out_shape[axis] = 1
mask = x.isnan()
valid_counts = mask.logical_not().sum(axis=axis,
keepdim=True,
dtype='float64')
indices = [] indices = []
if isinstance(q, (int, float)):
if q < 0 or q > 1: for q_num in q:
if q_num < 0 or q_num > 1:
raise ValueError("q should be in range [0, 1]") raise ValueError("q should be in range [0, 1]")
indices.append(q * (x.shape[axis] - 1)) if paddle.in_dynamic_mode():
elif isinstance(q, (list, tuple)): q_num = paddle.to_tensor(q_num, dtype='float64')
if len(q) <= 0: if ignore_nan:
raise ValueError("q should not be empty") indices.append(q_num * (valid_counts - 1))
for q_num in q: else:
if q_num < 0 or q_num > 1: # TODO(Asthestarsfalll): Use paddle.index_fill instead of where
raise ValueError("q should be in range [0, 1]") index = q_num * (valid_counts - 1)
indices.append(q_num * (x.shape[axis] - 1)) last_index = x.shape[axis] - 1
else: nums = paddle.full_like(index, fill_value=last_index)
raise TypeError("Type of q should be int, float, list or tuple.") index = paddle.where(mask.any(axis=axis, keepdim=True), nums, index)
indices.append(index)
sorted_tensor = paddle.sort(x, axis) sorted_tensor = paddle.sort(x, axis)
indices_tensor = paddle.assign(indices).astype(paddle.float32)
indices_below = paddle.floor(indices_tensor).astype(paddle.int32)
indices_upper = paddle.ceil(indices_tensor).astype(paddle.int32)
outputs = []
def expand_dim(indices, sorted_tensor_shape, axis): outputs = []
assert axis < len(list(sorted_tensor_shape))
expanded_shape = [1] * len(list(sorted_tensor_shape))
expanded_shape = tuple(expanded_shape)
indices = indices.reshape(expanded_shape)
return indices
# TODO(chenjianye): replace the for-loop to directly take elements. # TODO(chenjianye): replace the for-loop to directly take elements.
for i in range(len(indices)): for index in indices:
if (indices_upper[i] != indices_below[i]): indices_below = paddle.floor(index).astype(paddle.int32)
tensor_below = paddle.take_along_axis( indices_upper = paddle.ceil(index).astype(paddle.int32)
sorted_tensor, tensor_upper = paddle.take_along_axis(
expand_dim(indices_below[i], sorted_tensor.shape, axis), axis) sorted_tensor, indices_upper, axis=axis)
tensor_upper = paddle.take_along_axis( tensor_below = paddle.take_along_axis(
sorted_tensor, sorted_tensor, indices_below, axis=axis)
expand_dim(indices_upper[i], sorted_tensor.shape, axis), axis) weights = (index - indices_below.astype('float64'))
weights = (indices[i] - indices_below[i]).astype(x.dtype) out = paddle.lerp(
out = paddle.lerp(tensor_below, tensor_upper, weights) tensor_below.astype('float64'),
else: tensor_upper.astype('float64'), weights)
out = paddle.take_along_axis(
sorted_tensor,
expand_dim(indices_below[i], sorted_tensor.shape, axis), axis)
if not keepdim: if not keepdim:
out = paddle.squeeze(out, axis=axis) out = paddle.squeeze(out, axis=axis)
else: else:
out = out.reshape(out_shape) out = out.reshape(out_shape)
outputs.append(out) outputs.append(out)
if isinstance(q, (list, tuple)):
return paddle.stack(outputs, 0) if len(q) > 1:
outputs = paddle.stack(outputs, 0)
else: else:
return outputs[0] outputs = outputs[0]
return outputs
def quantile(x, q, axis=None, keepdim=False):
"""
Compute the quantile of the input along the specified axis.
If any values in a reduced row are NaN, then the quantiles for that reduction will be NaN.
Args:
x (Tensor): The input Tensor, it's data type can be float32, float64.
q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list,
each q will be calculated and the first dimension of output is same to the number of ``q`` .
axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
If ``axis`` is less than 0, it works the same way as :math:`axis + D`.
If ``axis`` is a list, quantile is calculated over all elements of given axises.
If ``axis`` is None, quantile is calculated over all elements of ``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, results of quantile along ``axis`` of ``x``.
In order to obtain higher precision, data type of results will be float64.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.arange(0, 8, dtype=np.float32).reshape(4, 2)
# [[0 1]
# [2 3]
# [4 5]
# [6 7]]
y = paddle.to_tensor(x)
y1 = paddle.quantile(y, q=0.5, axis=[0, 1])
# 3.5
y2 = paddle.quantile(y, q=0.5, axis=1)
# [0.5 2.5 4.5 6.5]
y3 = paddle.quantile(y, q=[0.3, 0.5], axis=0)
# [[1.8 2.8]
# [3. 4. ]]
x[0][0] = np.nan
y = paddle.to_tensor(x)
y4 = paddle.quantile(y, q=0.8, axis=1, keepdim=True)
# [[nan]
# [2.8]
# [4.8]
# [6.8]]
"""
return _compute_quantile(x, q, axis=axis, keepdim=keepdim, ignore_nan=False)
def nanquantile(x, q, axis=None, keepdim=False):
"""
Compute the quantile of the input as if NaN values in input did not exist.
If all values in a reduced row are NaN, then the quantiles for that reduction will be NaN.
Args:
x (Tensor): The input Tensor, it's data type can be float32, float64.
q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list,
each q will be calculated and the first dimension of output is same to the number of ``q`` .
axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
If ``axis`` is less than 0, it works the same way as :math:`axis + D`.
If ``axis`` is a list, quantile is calculated over all elements of given axises.
If ``axis`` is None, quantile is calculated over all elements of ``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, results of quantile along ``axis`` of ``x``.
In order to obtain higher precision, data type of results will be float64.
Examples:
.. code-block:: python
import numpy as np
import paddle
x = np.array(
[[0, 1, 2, 3, 4],
[5, 6, 7, 8, 9]],
dtype=np.float32
)
x[0][0] = np.nan
x = paddle.to_tensor(x)
y1 = paddle.nanquantile(x, q=0.5, axis=[0, 1])
# 5.0
y2 = paddle.nanquantile(x, q=0.5, axis=1)
# [2.5 7. ]
y3 = paddle.nanquantile(x, q=[0.3, 0.5], axis=0)
# [[5. 2.5 3.5 4.5 5.5]
# [5. 3.5 4.5 5.5 6.5]
y4 = paddle.nanquantile(x, q=0.8, axis=1, keepdim=True)
# [[3.4]
# [8.2]]
nan = paddle.full(shape=[2, 3], fill_value=np.nan)
y5 = paddle.nanquantile(nan, q=0.8, axis=1, keepdim=True)
# [[nan]
# [nan]]
"""
return _compute_quantile(x, q, axis=axis, keepdim=keepdim, ignore_nan=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册