未验证 提交 cb2476cf 编写于 作者: L LoneRanger 提交者: GitHub

【PaddlePaddle Hackathon 第四期】No.6:为 Paddle 新增 ldexp API (#51395)

* add ldexp api

* fix ldexp

* Update math.py

fix math.py

* rewrite the ldexp

* Simplify ldexp implementation

* fix codestyle

* Update math.py

* Update math.py

* modify the test_ldexp.py

* fix the input bug of np.ldexp

* fix the bug of np.ldexp in windows

* modify the ldexp function and add the dtype check of output

* Update test_ldexp.py

* fix the dtype

* fix codestyle

* Update python/paddle/tensor/math.py

add note for code description
Co-authored-by: Nzachary sun <70642955+sunzhongkai588@users.noreply.github.com>

---------
Co-authored-by: Nzachary sun <70642955+sunzhongkai588@users.noreply.github.com>
上级 a6f5021f
......@@ -301,6 +301,7 @@ from .tensor.math import frac # noqa: F401
from .tensor.math import sgn # noqa: F401
from .tensor.math import take # noqa: F401
from .tensor.math import frexp # noqa: F401
from .tensor.math import ldexp # noqa: F401
from .tensor.math import trapezoid # noqa: F401
from .tensor.math import cumulative_trapezoid # noqa: F401
from .tensor.math import vander # noqa: F401
......@@ -699,6 +700,7 @@ __all__ = [ # noqa
'triu_indices',
'take',
'frexp',
'ldexp',
'trapezoid',
'cumulative_trapezoid',
'polar',
......
......@@ -251,6 +251,7 @@ from .math import frac # noqa: F401
from .math import sgn # noqa: F401
from .math import take # noqa: F401
from .math import frexp # noqa: F401
from .math import ldexp # noqa: F401
from .math import trapezoid # noqa: F401
from .math import cumulative_trapezoid # noqa: F401
from .math import sigmoid # noqa: F401
......@@ -550,6 +551,7 @@ tensor_method_func = [ # noqa
'bucketize',
'sgn',
'frexp',
'ldexp',
'trapezoid',
'cumulative_trapezoid',
'polar',
......
......@@ -5785,3 +5785,55 @@ def polygamma(x, n, name=None):
attrs={'n': n},
)
return out
def ldexp(x, y, name=None):
"""
Compute the result of multiplying x by 2 to the power of y. The equation is:
.. math::
out = x * 2^{y}
Args:
x (Tensor): The input Tensor, the data type is float32, float64, int32 or int64.
y (Tensor): A Tensor of exponents, typically integers.
name (str, optional): Name for the operation (optional, default is None).For more information, please refer to :ref:`api_guide_Name`.
Returns:
out (Tensor): An N-D Tensor. If x, y have different shapes and are "broadcastable", the resulting tensor shape is the shape of x and y after broadcasting. If x, y have the same shape, its shape is the same as x and y. And the data type is float32 or float64.
Examples:
.. code-block:: python
import paddle
#example1
x = paddle.to_tensor([1, 2, 3], dtype='float32')
y = paddle.to_tensor([2, 3, 4], dtype='int32')
res = paddle.ldexp(x, y)
print(res)
# Tensor(shape=[3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [4., 16., 48.])
#example2
x = paddle.to_tensor([1, 2, 3], dtype='float32')
y = paddle.to_tensor([2], dtype='int32')
res = paddle.ldexp(x, y)
print(res)
# Tensor(shape=[3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [4., 8., 12.])
"""
if not isinstance(x, (paddle.Tensor, Variable)):
raise TypeError(f"x must be tensor type, but got {type(x)}")
if not isinstance(y, (paddle.Tensor, Variable)):
raise TypeError(f"y must be tensor type, but got {type(y)}")
if x.dtype == paddle.float64 or y.dtype == paddle.float64:
out_dtype = paddle.float64
else:
out_dtype = paddle.get_default_dtype()
x = paddle.cast(x, dtype=out_dtype)
y = paddle.cast(y, dtype=out_dtype)
two = paddle.to_tensor(2, dtype=out_dtype)
return paddle.multiply(x, paddle.pow(two, y))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid import core
from paddle.static import Program, program_guard
DYNAMIC = 1
STATIC = 2
def _run_ldexp(mode, x, y, device='cpu'):
# dynamic mode
if mode == DYNAMIC:
paddle.disable_static()
# Set device
paddle.set_device(device)
x_ = paddle.to_tensor(x)
# y is scalar
if isinstance(y, (int)):
y_ = y
# y is tensor
else:
y_ = paddle.to_tensor(y)
res = paddle.ldexp(x_, y_)
return res.numpy()
# static graph mode
elif mode == STATIC:
paddle.enable_static()
# y is scalar
if isinstance(y, (int)):
with program_guard(Program(), Program()):
x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype)
y_ = y
res = paddle.ldexp(x_, y_)
place = (
paddle.CPUPlace()
if device == 'cpu'
else paddle.CUDAPlace(0)
)
exe = paddle.static.Executor(place)
outs = exe.run(feed={'x': x, 'y': y}, fetch_list=[res])
return outs[0]
# y is tensor
else:
with program_guard(Program(), Program()):
x_ = paddle.static.data(name="x", shape=x.shape, dtype=x.dtype)
y_ = paddle.static.data(name="y", shape=y.shape, dtype=y.dtype)
res = paddle.ldexp(x_, y_)
place = (
paddle.CPUPlace()
if device == 'cpu'
else paddle.CUDAPlace(0)
)
exe = paddle.static.Executor(place)
outs = exe.run(feed={'x': x, 'y': y}, fetch_list=[res])
return outs[0]
def check_dtype(input, desired_dtype):
if input.dtype != desired_dtype:
raise ValueError(
"The expected data type to be obtained is {}, but got {}".format(
desired_dtype, input.dtype
)
)
class TestLdexpAPI(unittest.TestCase):
def setUp(self):
self.places = ['cpu']
if core.is_compiled_with_cuda():
self.places.append('gpu')
def test_ldexp(self):
np.random.seed(7)
for place in self.places:
# test 1-d float tensor and 1-d int tensor
dims = (np.random.randint(200, 300),)
x = (np.random.rand(*dims) * 10).astype(np.float64)
y = (np.random.randint(-10, 10, dims)).astype(np.int32)
res = _run_ldexp(DYNAMIC, x, y, place)
check_dtype(res, np.float64)
np.testing.assert_allclose(res, np.ldexp(x, y))
res = _run_ldexp(STATIC, x, y, place)
check_dtype(res, np.float64)
np.testing.assert_allclose(res, np.ldexp(x, y))
dims = (np.random.randint(200, 300),)
x = (np.random.rand(*dims) * 10).astype(np.float32)
y = (np.random.randint(-10, 10, dims)).astype(np.int32)
res = _run_ldexp(DYNAMIC, x, y, place)
check_dtype(res, np.float32)
np.testing.assert_allclose(res, np.ldexp(x, y))
res = _run_ldexp(STATIC, x, y, place)
check_dtype(res, np.float32)
np.testing.assert_allclose(res, np.ldexp(x, y))
# test 1-d int tensor and 1-d int tensor
dims = (np.random.randint(200, 300),)
x = (np.random.randint(-10, 10, dims)).astype(np.int64)
y = (np.random.randint(-10, 10, dims)).astype(np.int32)
res = _run_ldexp(DYNAMIC, x, y, place)
check_dtype(res, np.float32)
np.testing.assert_allclose(res, np.ldexp(x, y))
res = _run_ldexp(STATIC, x, y, place)
check_dtype(res, np.float32)
np.testing.assert_allclose(res, np.ldexp(x, y))
dims = (np.random.randint(200, 300),)
x = (np.random.randint(-10, 10, dims)).astype(np.int32)
y = (np.random.randint(-10, 10, dims)).astype(np.int32)
res = _run_ldexp(DYNAMIC, x, y, place)
check_dtype(res, np.float32)
np.testing.assert_allclose(res, np.ldexp(x, y))
res = _run_ldexp(STATIC, x, y, place)
check_dtype(res, np.float32)
np.testing.assert_allclose(res, np.ldexp(x, y))
# test broadcast
dims = (
np.random.randint(1, 10),
np.random.randint(5, 10),
np.random.randint(5, 10),
)
x = (np.random.rand(*dims) * 10).astype(np.float64)
y = (np.random.randint(-10, 10, dims[-1])).astype(np.int32)
res = _run_ldexp(DYNAMIC, x, y)
check_dtype(res, np.float64)
np.testing.assert_allclose(res, np.ldexp(x, y))
res = _run_ldexp(STATIC, x, y)
check_dtype(res, np.float64)
np.testing.assert_allclose(res, np.ldexp(x, y))
class TestLdexpError(unittest.TestCase):
"""TestLdexpError."""
def test_errors(self):
"""test_errors."""
np.random.seed(7)
# test 1-d float and int tensor
dims = (np.random.randint(200, 300),)
x = (np.random.rand(*dims) * 10).astype(np.float64)
y = (np.random.randint(-10, 10, dims)).astype(np.int32)
self.assertRaises(TypeError, paddle.ldexp, x, paddle.to_tensor(y))
# test 1-d float tensor and int
dims = (np.random.randint(200, 300),)
x = (np.random.rand(*dims) * 10).astype(np.float64)
y = (np.random.randint(-10, 10, dims)).astype(np.int32)
self.assertRaises(TypeError, paddle.ldexp, paddle.to_tensor(x), y)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册