From 20dc1ac20bcae2d6b5f06b22daee6726ea1d5a8a Mon Sep 17 00:00:00 2001 From: JYChen Date: Fri, 31 Dec 2021 12:57:49 +0800 Subject: [PATCH] [new api] add new api paddle.quantile and paddle.Tensor.quantile (#38567) * add new api paddle.quantile and paddle.Tensor.quantile * add take_todo and fix UT --- python/paddle/__init__.py | 2 + .../fluid/tests/unittests/test_quantile.py | 150 ++++++++++++++++++ python/paddle/tensor/__init__.py | 3 + python/paddle/tensor/stat.py | 124 +++++++++++++++ 4 files changed, 279 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_quantile.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index f9467400bc..771a9053fc 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -309,6 +309,7 @@ from .tensor.stat import std # noqa: F401 from .tensor.stat import var # noqa: F401 from .tensor.stat import numel # noqa: F401 from .tensor.stat import median # noqa: F401 +from .tensor.stat import quantile # noqa: F401 from .device import get_cudnn_version # noqa: F401 from .device import set_device # noqa: F401 from .device import get_device # noqa: F401 @@ -481,6 +482,7 @@ __all__ = [ # noqa 'load', 'numel', 'median', + 'quantile', 'no_grad', 'set_grad_enabled', 'is_grad_enabled', diff --git a/python/paddle/fluid/tests/unittests/test_quantile.py b/python/paddle/fluid/tests/unittests/test_quantile.py new file mode 100644 index 0000000000..0fd3c1de9c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_quantile.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle + + +class TestQuantile(unittest.TestCase): + def setUp(self): + np.random.seed(678) + self.input_data = np.random.rand(6, 7, 8, 9, 10) + + def test_quantile_single_q(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile(x, q=0.5, axis=2) + np_res = np.quantile(self.input_data, q=0.5, axis=2) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + def test_quantile_with_no_axis(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile(x, q=0.35) + np_res = np.quantile(self.input_data, q=0.35) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + def test_quantile_with_multi_axis(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile(x, q=0.75, axis=[0, 2, 3]) + np_res = np.quantile(self.input_data, q=0.75, axis=[0, 2, 3]) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + def test_quantile_with_keepdim(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile(x, q=0.35, axis=4, keepdim=True) + np_res = np.quantile(self.input_data, q=0.35, axis=4, keepdims=True) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + def test_quantile_with_keepdim_and_multiple_axis(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile(x, q=0.1, axis=[1, 4], keepdim=True) + np_res = np.quantile(self.input_data, q=0.1, axis=[1, 4], keepdims=True) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + def test_quantile_with_boundary_q(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile(x, q=0, axis=3) + np_res = np.quantile(self.input_data, q=0, axis=3) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + def test_quantile_include_NaN(self): + input_data = np.random.randn(2, 3, 4) + input_data[0, 1, 1] = np.nan + x = paddle.to_tensor(input_data) + paddle_res = paddle.quantile(x, q=0.35, axis=0) + self.assertTrue(paddle.isnan(paddle_res[1, 1])) + + +class TestQuantileMuitlpleQ(unittest.TestCase): + def setUp(self): + np.random.seed(678) + self.input_data = np.random.rand(10, 3, 4, 5, 4) + + def test_quantile(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile(x, q=[0.3, 0.44], axis=-2) + np_res = np.quantile(self.input_data, q=[0.3, 0.44], axis=-2) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + def test_quantile_multiple_axis(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile(x, q=[0.2, 0.67], axis=[1, -1]) + np_res = np.quantile(self.input_data, q=[0.2, 0.67], axis=[1, -1]) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + def test_quantile_multiple_axis_keepdim(self): + x = paddle.to_tensor(self.input_data) + paddle_res = paddle.quantile( + x, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdim=True) + np_res = np.quantile( + self.input_data, q=[0.1, 0.2, 0.3], axis=[1, 2], keepdims=True) + self.assertTrue(np.allclose(paddle_res.numpy(), np_res)) + + +class TestQuantileError(unittest.TestCase): + def setUp(self): + self.x = paddle.randn((2, 3, 4)) + + def test_errors(self): + def test_q_range_error_1(): + paddle_res = paddle.quantile(self.x, q=1.5) + + self.assertRaises(ValueError, test_q_range_error_1) + + def test_q_range_error_2(): + paddle_res = paddle.quantile(self.x, q=[0.2, -0.3]) + + self.assertRaises(ValueError, test_q_range_error_2) + + def test_q_range_error_3(): + paddle_res = paddle.quantile(self.x, q=[]) + + self.assertRaises(ValueError, test_q_range_error_3) + + def test_x_type_error(): + x = [1, 3, 4] + paddle_res = paddle.quantile(x, q=0.9) + + self.assertRaises(TypeError, test_x_type_error) + + def test_axis_type_error_1(): + paddle_res = paddle.quantile(self.x, q=0.4, axis=0.4) + + self.assertRaises(ValueError, test_axis_type_error_1) + + def test_axis_type_error_2(): + paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, 0.4]) + + self.assertRaises(ValueError, test_axis_type_error_2) + + def test_axis_value_error_1(): + paddle_res = paddle.quantile(self.x, q=0.4, axis=10) + + self.assertRaises(ValueError, test_axis_value_error_1) + + def test_axis_value_error_2(): + paddle_res = paddle.quantile(self.x, q=0.4, axis=[1, -10]) + + self.assertRaises(ValueError, test_axis_value_error_2) + + def test_axis_value_error_3(): + paddle_res = paddle.quantile(self.x, q=0.4, axis=[]) + + self.assertRaises(ValueError, test_axis_value_error_3) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 696b062f51..69a1101a2b 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -258,6 +258,8 @@ from .stat import std # noqa: F401 from .stat import var # noqa: F401 from .stat import numel # noqa: F401 from .stat import median # noqa: F401 +from .stat import quantile # noqa: F401 + from .to_string import set_printoptions # noqa: F401 from .array import array_length # noqa: F401 @@ -437,6 +439,7 @@ tensor_method_func = [ #noqa 'var', 'numel', 'median', + 'quantile', 'is_complex', 'is_integer', 'rank', diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 6a016e42b5..45a663b016 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -333,3 +333,127 @@ def median(x, axis=None, keepdim=False, name=None): newshape = out_tensor.shape out_tensor = out_tensor.reshape(newshape, name=name) return out_tensor + + +def quantile(x, q, axis=None, keepdim=False): + """ + Compute the quantile of the input along the specified axis. + + Args: + x (Tensor): The input Tensor, it's data type can be float32, float64. + q (int|float|list): The q for calculate quantile, which should be in range [0, 1]. If q is a list, + each q will be calculated and the first dimension of output is same to the number of ``q`` . + axis (int|list, optional): The axis along which to calculate quantile. ``axis`` should be int or list of int. + ``axis`` should be in range [-D, D), where D is the dimensions of ``x`` . + If ``axis`` is less than 0, it works the same way as :math:`axis + D`. + If ``axis`` is a list, quantile is calculated over all elements of given axises. + If ``axis`` is None, quantile is calculated over all elements of ``x``. Default is None. + keepdim (bool, optional): Whether to reserve the reduced dimension(s) + in the output Tensor. If ``keepdim`` is True, the dimensions of + the output Tensor is the same as ``x`` except in the reduced + dimensions(it is of size 1 in this case). Otherwise, the shape of + the output Tensor is squeezed in ``axis`` . Default is False. + name (str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, results of quantile along ``axis`` of ``x``. If data type of ``x`` is float64, data type of results will be float64, otherwise data type will be float32. + + Examples: + .. code-block:: python + + import paddle + + x = paddle.randn((2,3)) + #[[-1.28740597, 0.49533170, -1.00698614], + # [-1.11656201, -1.01010525, -2.23457789]]) + + y1 = paddle.quantile(x, q=0.5, axis=[0, 1]) + # y1 = -1.06333363 + + y2 = paddle.quantile(x, q=0.5, axis=1) + # y2 = [-1.00698614, -1.11656201] + + y3 = paddle.quantile(x, q=[0.3, 0.5], axis=1) + # y3 =[[-1.11915410, -1.56376839], + # [-1.00698614, -1.11656201]] + + y4 = paddle.quantile(x, q=0.8, axis=1, keepdim=True) + # y4 = [[-0.10559537], + # [-1.05268800]]) + """ + if not isinstance(x, Variable): + raise TypeError("input x should be a Tensor.") + dims = len(x.shape) + out_shape = x.shape + if axis is None: + x = paddle.flatten(x) + axis = 0 + out_shape = [1] * dims + else: + if isinstance(axis, list): + if (len(axis) <= 0): + raise ValueError("axis should not be empty") + axis_src, axis_dst = [], [] + for axis_single in axis: + if not isinstance(axis_single, int) or not ( + axis_single < dims and axis_single >= -dims): + raise ValueError( + "Axis should be None, int, or a list, element should in range [-rank(x), rank(x))." + ) + if axis_single < 0: + axis_single = axis_single + dims + axis_src.append(axis_single) + out_shape[axis_single] = 1 + axis_dst = list(range(-len(axis), 0)) + x = paddle.moveaxis(x, axis_src, axis_dst) + x = paddle.flatten(x, axis_dst[0], axis_dst[-1]) + axis = axis_dst[0] + else: + if not isinstance(axis, int) or not (axis < dims and axis >= -dims): + raise ValueError( + "Axis should be None, int, or a list, element should in range [-rank(x), rank(x))." + ) + if axis < 0: + axis += dims + out_shape[axis] = 1 + indices = [] + if isinstance(q, (int, float)): + if q < 0 or q > 1: + raise ValueError("q should be in range [0, 1]") + indices.append(q * (x.shape[axis] - 1)) + elif isinstance(q, (list, tuple)): + if len(q) <= 0: + raise ValueError("q should not be empty") + for q_num in q: + if q_num < 0 or q_num > 1: + raise ValueError("q should be in range [0, 1]") + indices.append(q_num * (x.shape[axis] - 1)) + else: + raise TypeError("Type of q should be int, float, list or tuple.") + indices = paddle.to_tensor(indices).astype(paddle.float32) + sorted_tensor = paddle.sort(x, axis) + indices_below = paddle.floor(indices).astype(paddle.int32) + indices_upper = paddle.ceil(indices).astype(paddle.int32) + outputs = [] + + # TODO(chenjianye): replace the for-loop to directly take elements. + for i in range(len(indices)): + if (indices_upper[i] != indices_below[i]): + tensor_below = paddle.take_along_axis(sorted_tensor, + indices_below[i], axis) + tensor_upper = paddle.take_along_axis(sorted_tensor, + indices_upper[i], axis) + weights = (indices[i] - indices_below[i]).astype(x.dtype) + out = paddle.lerp(tensor_below, tensor_upper, weights) + else: + out = paddle.take_along_axis(sorted_tensor, indices_below[i], axis) + if not keepdim: + out = paddle.squeeze(out, axis=axis) + else: + out = out.reshape(out_shape) + outputs.append(out) + if isinstance(q, (list, tuple)): + return paddle.stack(outputs, 0) + else: + return outputs[0] -- GitLab