未验证 提交 20dc1ac2 编写于 作者: J JYChen 提交者: GitHub

[new api] add new api paddle.quantile and paddle.Tensor.quantile (#38567)

* add new api paddle.quantile and paddle.Tensor.quantile

* add take_todo and fix UT
上级 2ce91c33
......@@ -309,6 +309,7 @@ from .tensor.stat import std # noqa: F401
from .tensor.stat import var # noqa: F401
from .tensor.stat import numel # noqa: F401
from .tensor.stat import median # noqa: F401
from .tensor.stat import quantile # noqa: F401
from .device import get_cudnn_version # noqa: F401
from .device import set_device # noqa: F401
from .device import get_device # noqa: F401
......@@ -481,6 +482,7 @@ __all__ = [ # noqa
'load',
'numel',
'median',
'quantile',
'no_grad',
'set_grad_enabled',
'is_grad_enabled',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
class TestQuantile(unittest.TestCase):
def setUp(self):
np.random.seed(678)
self.input_data = np.random.rand(6, 7, 8, 9, 10)
def test_quantile_single_q(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0.5, axis=2)
np_res = np.quantile(self.input_data, q=0.5, axis=2)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
def test_quantile_with_no_axis(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0.35)
np_res = np.quantile(self.input_data, q=0.35)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
def test_quantile_with_multi_axis(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0.75, axis=[0, 2, 3])
np_res = np.quantile(self.input_data, q=0.75, axis=[0, 2, 3])
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
def test_quantile_with_keepdim(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0.35, axis=4, keepdim=True)
np_res = np.quantile(self.input_data, q=0.35, axis=4, keepdims=True)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
def test_quantile_with_keepdim_and_multiple_axis(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0.1, axis=[1, 4], keepdim=True)
np_res = np.quantile(self.input_data, q=0.1, axis=[1, 4], keepdims=True)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
def test_quantile_with_boundary_q(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=0, axis=3)
np_res = np.quantile(self.input_data, q=0, axis=3)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
def test_quantile_include_NaN(self):
input_data = np.random.randn(2, 3, 4)
input_data[0, 1, 1] = np.nan
x = paddle.to_tensor(input_data)
paddle_res = paddle.quantile(x, q=0.35, axis=0)
self.assertTrue(paddle.isnan(paddle_res[1, 1]))
class TestQuantileMuitlpleQ(unittest.TestCase):
def setUp(self):
np.random.seed(678)
self.input_data = np.random.rand(10, 3, 4, 5, 4)
def test_quantile(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=[0.3, 0.44], axis=-2)
np_res = np.quantile(self.input_data, q=[0.3, 0.44], axis=-2)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
def test_quantile_multiple_axis(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(x, q=[0.2, 0.67], axis=[1, -1])
np_res = np.quantile(self.input_data, q=[0.2, 0.67], axis=[1, -1])
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
def test_quantile_multiple_axis_keepdim(self):
x = paddle.to_tensor(self.input_data)
paddle_res = paddle.quantile(
x, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdim=True)
np_res = np.quantile(
self.input_data, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdims=True)
self.assertTrue(np.allclose(paddle_res.numpy(), np_res))
class TestQuantileError(unittest.TestCase):
def setUp(self):
self.x = paddle.randn((2, 3, 4))
def test_errors(self):
def test_q_range_error_1():
paddle_res = paddle.quantile(self.x, q=1.5)
self.assertRaises(ValueError, test_q_range_error_1)
def test_q_range_error_2():
paddle_res = paddle.quantile(self.x, q=[0.2, -0.3])
self.assertRaises(ValueError, test_q_range_error_2)
def test_q_range_error_3():
paddle_res = paddle.quantile(self.x, q=[])
self.assertRaises(ValueError, test_q_range_error_3)
def test_x_type_error():
x = [1, 3, 4]
paddle_res = paddle.quantile(x, q=0.9)
self.assertRaises(TypeError, test_x_type_error)
def test_axis_type_error_1():
paddle_res = paddle.quantile(self.x, q=0.4, axis=0.4)
self.assertRaises(ValueError, test_axis_type_error_1)
def test_axis_type_error_2():
paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, 0.4])
self.assertRaises(ValueError, test_axis_type_error_2)
def test_axis_value_error_1():
paddle_res = paddle.quantile(self.x, q=0.4, axis=10)
self.assertRaises(ValueError, test_axis_value_error_1)
def test_axis_value_error_2():
paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, -10])
self.assertRaises(ValueError, test_axis_value_error_2)
def test_axis_value_error_3():
paddle_res = paddle.quantile(self.x, q=0.4, axis=[])
self.assertRaises(ValueError, test_axis_value_error_3)
if __name__ == '__main__':
unittest.main()
......@@ -258,6 +258,8 @@ from .stat import std # noqa: F401
from .stat import var # noqa: F401
from .stat import numel # noqa: F401
from .stat import median # noqa: F401
from .stat import quantile # noqa: F401
from .to_string import set_printoptions # noqa: F401
from .array import array_length # noqa: F401
......@@ -437,6 +439,7 @@ tensor_method_func = [ #noqa
'var',
'numel',
'median',
'quantile',
'is_complex',
'is_integer',
'rank',
......
......@@ -333,3 +333,127 @@ def median(x, axis=None, keepdim=False, name=None):
newshape = out_tensor.shape
out_tensor = out_tensor.reshape(newshape, name=name)
return out_tensor
def quantile(x, q, axis=None, keepdim=False):
"""
Compute the quantile of the input along the specified axis.
Args:
x (Tensor): The input Tensor, it's data type can be float32, float64.
q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list,
each q will be calculated and the first dimension of output is same to the number of ``q`` .
axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int.
``axis`` should be in range [-D, D), where D is the dimensions of ``x`` .
If ``axis`` is less than 0, it works the same way as :math:`axis + D`.
If ``axis`` is a list, quantile is calculated over all elements of given axises.
If ``axis`` is None, quantile is calculated over all elements of ``x``. Default is None.
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
in the output Tensor. If ``keepdim`` is True, the dimensions of
the output Tensor is the same as ``x`` except in the reduced
dimensions(it is of size 1 in this case). Otherwise, the shape of
the output Tensor is squeezed in ``axis`` . Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, results of quantile along ``axis`` of ``x``. If data type of ``x`` is float64, data type of results will be float64, otherwise data type will be float32.
Examples:
.. code-block:: python
import paddle
x = paddle.randn((2,3))
#[[-1.28740597, 0.49533170, -1.00698614],
# [-1.11656201, -1.01010525, -2.23457789]])
y1 = paddle.quantile(x, q=0.5, axis=[0, 1])
# y1 = -1.06333363
y2 = paddle.quantile(x, q=0.5, axis=1)
# y2 = [-1.00698614, -1.11656201]
y3 = paddle.quantile(x, q=[0.3, 0.5], axis=1)
# y3 =[[-1.11915410, -1.56376839],
# [-1.00698614, -1.11656201]]
y4 = paddle.quantile(x, q=0.8, axis=1, keepdim=True)
# y4 = [[-0.10559537],
# [-1.05268800]])
"""
if not isinstance(x, Variable):
raise TypeError("input x should be a Tensor.")
dims = len(x.shape)
out_shape = x.shape
if axis is None:
x = paddle.flatten(x)
axis = 0
out_shape = [1] * dims
else:
if isinstance(axis, list):
if (len(axis) <= 0):
raise ValueError("axis should not be empty")
axis_src, axis_dst = [], []
for axis_single in axis:
if not isinstance(axis_single, int) or not (
axis_single < dims and axis_single >= -dims):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
if axis_single < 0:
axis_single = axis_single + dims
axis_src.append(axis_single)
out_shape[axis_single] = 1
axis_dst = list(range(-len(axis), 0))
x = paddle.moveaxis(x, axis_src, axis_dst)
x = paddle.flatten(x, axis_dst[0], axis_dst[-1])
axis = axis_dst[0]
else:
if not isinstance(axis, int) or not (axis < dims and axis >= -dims):
raise ValueError(
"Axis should be None, int, or a list, element should in range [-rank(x), rank(x))."
)
if axis < 0:
axis += dims
out_shape[axis] = 1
indices = []
if isinstance(q, (int, float)):
if q < 0 or q > 1:
raise ValueError("q should be in range [0, 1]")
indices.append(q * (x.shape[axis] - 1))
elif isinstance(q, (list, tuple)):
if len(q) <= 0:
raise ValueError("q should not be empty")
for q_num in q:
if q_num < 0 or q_num > 1:
raise ValueError("q should be in range [0, 1]")
indices.append(q_num * (x.shape[axis] - 1))
else:
raise TypeError("Type of q should be int, float, list or tuple.")
indices = paddle.to_tensor(indices).astype(paddle.float32)
sorted_tensor = paddle.sort(x, axis)
indices_below = paddle.floor(indices).astype(paddle.int32)
indices_upper = paddle.ceil(indices).astype(paddle.int32)
outputs = []
# TODO(chenjianye): replace the for-loop to directly take elements.
for i in range(len(indices)):
if (indices_upper[i] != indices_below[i]):
tensor_below = paddle.take_along_axis(sorted_tensor,
indices_below[i], axis)
tensor_upper = paddle.take_along_axis(sorted_tensor,
indices_upper[i], axis)
weights = (indices[i] - indices_below[i]).astype(x.dtype)
out = paddle.lerp(tensor_below, tensor_upper, weights)
else:
out = paddle.take_along_axis(sorted_tensor, indices_below[i], axis)
if not keepdim:
out = paddle.squeeze(out, axis=axis)
else:
out = out.reshape(out_shape)
outputs.append(out)
if isinstance(q, (list, tuple)):
return paddle.stack(outputs, 0)
else:
return outputs[0]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册