未验证 提交 eefe601c 编写于 作者: P PommesPeter 提交者: GitHub

【Hackathon 4 No.14】Add Polar to paddle (#50901)

* added paddle.polar to paddle

* added paddle.polar unitest

* failed to use pytorch for evaluating results, and changed with numpy implementation

* updated code style

* updated __init__.py

* updated code style

* fixed unitest code

* lkh test polar

* polar add

* fixed errors and optimized code

* fixed error

* optimized polor api

* updated code style

* updated code style

---------
Co-authored-by: Ndiadestiny <1247889154@qq.com>
上级 aba9c4d4
......@@ -119,6 +119,7 @@ from .tensor.creation import complex # noqa: F401
from .tensor.creation import clone # noqa: F401
from .tensor.creation import tril_indices # noqa: F401
from .tensor.creation import triu_indices # noqa: F401
from .tensor.creation import polar # noqa: F401
from .tensor.linalg import matmul # noqa: F401
from .tensor.linalg import dot # noqa: F401
from .tensor.linalg import norm # noqa: F401
......@@ -685,4 +686,5 @@ __all__ = [ # noqa
'triu_indices',
'take',
'frexp',
'polar',
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
# import torch
import numpy as np
import paddle
import paddle.fluid.core as core
np.random.seed(10)
def numpy_polar(abs, angle):
real = np.multiply(abs, np.cos(angle))
imag = np.multiply(abs, np.sin(angle))
return real + imag * 1j
class TestPolarAPI(unittest.TestCase):
def setUp(self):
self.abs = np.array([1, 2]).astype("float64")
self.angle = np.array([np.pi / 2, 5 * np.pi / 4]).astype("float64")
self.place = [paddle.CPUPlace()]
if core.is_compiled_with_cuda():
self.place.append(paddle.CUDAPlace(0))
def test_api_static(self):
paddle.enable_static()
def run(place):
with paddle.static.program_guard(paddle.static.Program()):
abs = paddle.static.data(
'abs',
shape=self.abs.shape,
dtype="float64",
)
angle = paddle.static.data(
'angle', shape=self.angle.shape, dtype="float64"
)
out1 = paddle.polar(abs, angle)
exe = paddle.static.Executor(place)
res = exe.run(
feed={'abs': self.abs, 'angle': self.angle},
fetch_list=[out1],
)
out_ref = numpy_polar(self.abs, self.angle)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-05)
for place in self.place:
run(place)
def test_api_dygraph(self):
def run(place):
paddle.disable_static(place)
abs = paddle.to_tensor(self.abs)
angle = paddle.to_tensor(self.angle)
out1 = paddle.polar(abs, angle)
out_ref1 = numpy_polar(self.abs, self.angle)
np.testing.assert_allclose(out_ref1, out1.numpy(), rtol=1e-05)
paddle.enable_static()
for place in self.place:
run(place)
def test_out_complex64(self):
paddle.disable_static()
abs = paddle.to_tensor(self.abs, dtype=paddle.float32)
angle = paddle.to_tensor(self.angle, dtype=paddle.float32)
out = paddle.polar(abs, angle)
self.assertTrue(out.type, 'complex64')
def test_out_complex128(self):
paddle.disable_static()
abs = paddle.to_tensor(self.abs, dtype=paddle.float64)
angle = paddle.to_tensor(self.angle, dtype=paddle.float64)
out = paddle.polar(abs, angle)
self.assertTrue(out.type, 'complex128')
def test_empty_input_error(self):
for place in self.place:
paddle.disable_static(place)
abs = paddle.to_tensor(self.abs)
angle = paddle.to_tensor(self.angle)
self.assertRaises(AttributeError, paddle.polar, None, angle)
self.assertRaises(AttributeError, paddle.polar, abs, None)
if __name__ == "__main__":
unittest.main()
......@@ -39,6 +39,7 @@ from .creation import meshgrid # noqa: F401
from .creation import empty # noqa: F401
from .creation import empty_like # noqa: F401
from .creation import complex # noqa: F401
from .creation import polar # noqa: F401
from .linalg import matmul # noqa: F401
from .linalg import dot # noqa: F401
from .linalg import cov # noqa: F401
......@@ -529,6 +530,7 @@ tensor_method_func = [ # noqa
'bucketize',
'sgn',
'frexp',
'polar',
'sigmoid',
'sigmoid_',
]
......
......@@ -2380,3 +2380,41 @@ def triu_indices(row, col=None, offset=0, dtype='int64'):
attrs={'row': row, 'col': col, 'offset': offset, 'dtype': dtype},
)
return out
def polar(abs, angle, name=None):
"""Return a Cartesian coordinates corresponding to the polar coordinates compelx tensor given the ``abs`` and ``angle`` component.
Args:
abs (Tensor): The abs component. The data type should be 'float32' or 'float64'.
angle (Tensor): The anglee component. The data type should be the same as ``abs``.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Tensor: The output tensor. The data type is 'complex64' or 'complex128', with the same precision as ``abs`` and ``angle``.
Note:
``paddle.polar`` supports broadcasting. If you want know more about broadcasting, please refer to `Introduction to Tensor`_ .
.. _Introduction to Tensor: ../../guides/beginner/tensor_en.html#chapter5-broadcasting-of-tensor
Examples:
.. code-block:: python
import paddle
import numpy as np
abs = paddle.to_tensor([1, 2], dtype=paddle.float64)
angle = paddle.to_tensor([np.pi / 2, 5 * np.pi / 4], dtype=paddle.float64)
out = paddle.polar(abs, angle)
print(out)
# Tensor(shape=[2], dtype=complex128, place=Place(cpu), stop_gradient=True,
# [ (6.123233995736766e-17+1j) ,
# (-1.4142135623730954-1.414213562373095j)])
"""
check_variable_and_dtype(abs, 'abs', ['float32', 'float64'], 'paddle.polar')
check_variable_and_dtype(
angle, 'angle', ['float32', 'float64'], 'paddle.polar'
)
return paddle.complex(abs * paddle.cos(angle), abs * paddle.sin(angle))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册