From 7b53923e3065882b6e5c6798c7a8d97d5e807d82 Mon Sep 17 00:00:00 2001 From: Li-fAngyU <56572498+Li-fAngyU@users.noreply.github.com> Date: Sat, 1 Apr 2023 00:27:57 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PaddlePaddle=20Hackathon=204=20NO.23?= =?UTF-8?q?=E3=80=91=E4=B8=BA=20Paddle=20=E6=96=B0=E5=A2=9E=20vander=20API?= =?UTF-8?q?=20(#51048)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/__init__.py | 2 + .../fluid/tests/unittests/test_vander.py | 100 ++++++++++++++++++ python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/math.py | 76 +++++++++++++ 4 files changed, 180 insertions(+) mode change 100755 => 100644 python/paddle/__init__.py create mode 100644 python/paddle/fluid/tests/unittests/test_vander.py mode change 100755 => 100644 python/paddle/tensor/__init__.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py old mode 100755 new mode 100644 index ca7c4b52543..f978cc9dbcf --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -294,6 +294,7 @@ from .tensor.math import take # noqa: F401 from .tensor.math import frexp # noqa: F401 from .tensor.math import trapezoid # noqa: F401 from .tensor.math import cumulative_trapezoid # noqa: F401 +from .tensor.math import vander # noqa: F401 from .tensor.random import bernoulli # noqa: F401 from .tensor.random import poisson # noqa: F401 @@ -687,4 +688,5 @@ __all__ = [ # noqa 'trapezoid', 'cumulative_trapezoid', 'polar', + 'vander', ] diff --git a/python/paddle/fluid/tests/unittests/test_vander.py b/python/paddle/fluid/tests/unittests/test_vander.py new file mode 100644 index 00000000000..7ad2d8e1201 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_vander.py @@ -0,0 +1,100 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.fluid import core + +np.random.seed(10) + + +def ref_vander(x, N=None, increasing=False): + return np.vander(x, N, increasing) + + +class TestVanderAPI(unittest.TestCase): + # test paddle.tensor.math.vander + + def setUp(self): + self.shape = [5] + self.x = np.random.uniform(-1, 1, self.shape).astype(np.float32) + self.place = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def api_case(self, N=None, increasing=False): + paddle.enable_static() + out_ref = ref_vander(self.x, N, increasing) + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', self.shape) + out = paddle.vander(x, N, increasing) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x}, fetch_list=[out]) + if N != 0: + np.testing.assert_allclose(res[0], out_ref, rtol=1e-05) + else: + np.testing.assert_allclose(res[0].size, out_ref.size, rtol=1e-05) + + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x) + out = paddle.vander(x, N, increasing) + np.testing.assert_allclose(out.numpy(), out_ref, rtol=1e-05) + paddle.enable_static() + + def test_api(self): + self.api_case() + N = list(range(9)) + for n in N: + self.api_case(n) + self.api_case(n, increasing=True) + + def test_complex(self): + paddle.disable_static(self.place) + real = np.random.rand(5) + imag = np.random.rand(5) + complex_np = real + 1j * imag + complex_paddle = paddle.complex( + paddle.to_tensor(real), paddle.to_tensor(imag) + ) + + def test_api_case(N, increasing=False): + for n in N: + res_np = np.vander(complex_np, n, increasing) + res_paddle = paddle.vander(complex_paddle, n, increasing) + np.testing.assert_allclose( + res_paddle.numpy(), res_np, rtol=1e-05 + ) + + N = [0, 1, 2, 3, 4] + test_api_case(N) + test_api_case(N, increasing=True) + paddle.enable_static() + + def test_errors(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + self.assertRaises(TypeError, paddle.vander, 1) + x = paddle.static.data('X', [10, 12], 'int32') + self.assertRaises(ValueError, paddle.vander, x) + x1 = paddle.static.data('X1', [10], 'int32') + self.assertRaises(ValueError, paddle.vander, x1, n=-1) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py old mode 100755 new mode 100644 index c6eb17d7abd..b78ac0e57c2 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -250,6 +250,7 @@ from .math import trapezoid # noqa: F401 from .math import cumulative_trapezoid # noqa: F401 from .math import sigmoid # noqa: F401 from .math import sigmoid_ # noqa: F401 +from .math import vander # noqa: F401 from .random import multinomial # noqa: F401 from .random import standard_normal # noqa: F401 @@ -538,6 +539,7 @@ tensor_method_func = [ # noqa 'polar', 'sigmoid', 'sigmoid_', + 'vander', ] # this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index bb4b9646374..6676d4dc604 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5333,3 +5333,79 @@ def cumulative_trapezoid(y, x=None, dx=None, axis=-1, name=None): # [3.50000000, 8. ]]) """ return _trapezoid(y, x, dx, axis, mode='cumsum') + + +def vander(x, n=None, increasing=False, name=None): + """ + Generate a Vandermonde matrix. + + The columns of the output matrix are powers of the input vector. Order of the powers is + determined by the increasing Boolean parameter. Specifically, when the increment is + "false", the ith output column is a step-up in the order of the elements of the input + vector to the N - i - 1 power. Such a matrix with a geometric progression in each row + is named after Alexandre-Theophile Vandermonde. + + Args: + x (Tensor): The input tensor, it must be 1-D Tensor, and it's data type should be ['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64']. + n (int): Number of columns in the output. If n is not specified, a square array is returned (n = len(x)). + increasing(bool): Order of the powers of the columns. If True, the powers increase from left to right, if False (the default) they are reversed. + name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. + Returns: + Tensor, A vandermonde matrix with shape (len(x), N). If increasing is False, the first column is :math:`x^{(N-1)}`, the second :math:`x^{(N-2)}` and so forth. + If increasing is True, the columns are :math:`x^0`, :math:`x^1`, ..., :math:`x^{(N-1)}`. + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor([1., 2., 3.], dtype="float32") + out = paddle.vander(x) + print(out.numpy()) + # [[1., 1., 1.], + # [4., 2., 1.], + # [9., 3., 1.]] + out1 = paddle.vander(x,2) + print(out1.numpy()) + # [[1., 1.], + # [2., 1.], + # [3., 1.]] + out2 = paddle.vander(x, increasing = True) + print(out2.numpy()) + # [[1., 1., 1.], + # [1., 2., 4.], + # [1., 3., 9.]] + real = paddle.to_tensor([2., 4.]) + imag = paddle.to_tensor([1., 3.]) + complex = paddle.complex(real, imag) + out3 = paddle.vander(complex) + print(out3.numpy()) + # [[2.+1.j, 1.+0.j], + # [4.+3.j, 1.+0.j]] + """ + check_variable_and_dtype( + x, + 'x', + ['complex64', 'complex128', 'float32', 'float64', 'int32', 'int64'], + 'vander', + ) + if x.dim() != 1: + raise ValueError( + "The input of x is expected to be a 1-D Tensor." + "But now the dims of Input(X) is %d." % x.dim() + ) + + if n is None: + n = x.shape[0] + + if n < 0: + raise ValueError("N must be non-negative.") + + res = paddle.empty([x.shape[0], n], dtype=x.dtype) + + if n > 0: + res[:, 0] = paddle.to_tensor([1], dtype=x.dtype) + if n > 1: + res[:, 1:] = x[:, None] + res[:, 1:] = paddle.cumprod(res[:, 1:], dim=-1) + res = res[:, ::-1] if not increasing else res + return res -- GitLab