未验证 提交 1e2af54c 编写于 作者: Z Zheng_Bicheng 提交者: GitHub

[Hackathon No.18] 为 Paddle 新增 frexp API (#46401)

* 之前的pr合并了大量错误代码,重新提交一份

* 之前的pr合并了大量错误代码,重新提交一份

* 修正格式问题

* 改回原来的格式

* 按照要求修改

* 按照要求修改格式

* 修复注释的问题

* 更新格式

* 测试自动格式化

* 修正英文注释

* fix docs build error

* pre-commit

* for docs build

* for docs build

* 修复mantissa计算错误的bug

* 修复误判exponent可能存在负数,导致计算量增加的情况
Co-authored-by: NLigoml <39876205+Ligoml@users.noreply.github.com>
上级 9a1855ff
...@@ -286,6 +286,7 @@ from .tensor.math import heaviside # noqa: F401 ...@@ -286,6 +286,7 @@ from .tensor.math import heaviside # noqa: F401
from .tensor.math import frac # noqa: F401 from .tensor.math import frac # noqa: F401
from .tensor.math import sgn # noqa: F401 from .tensor.math import sgn # noqa: F401
from .tensor.math import take # noqa: F401 from .tensor.math import take # noqa: F401
from .tensor.math import frexp # noqa: F401
from .tensor.random import bernoulli # noqa: F401 from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401 from .tensor.random import poisson # noqa: F401
...@@ -386,7 +387,6 @@ if is_compiled_with_cinn(): ...@@ -386,7 +387,6 @@ if is_compiled_with_cinn():
os.environ.setdefault('runtime_include_dir', runtime_include_dir) os.environ.setdefault('runtime_include_dir', runtime_include_dir)
disable_static() disable_static()
__all__ = [ # noqa __all__ = [ # noqa
'iinfo', 'iinfo',
'dtype', 'dtype',
...@@ -667,4 +667,5 @@ __all__ = [ # noqa ...@@ -667,4 +667,5 @@ __all__ = [ # noqa
'sgn', 'sgn',
'triu_indices', 'triu_indices',
'take', 'take',
'frexp',
] ]
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.fluid
class TestFrexpAPI(unittest.TestCase):
def setUp(self):
np.random.seed(1024)
self.rtol = 1e-5
self.atol = 1e-8
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
self.set_input()
def set_input(self):
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32')
# 静态图单测
def test_static_api(self):
# 开启静态图模式
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
input_data = paddle.fluid.data('X', self.x_np.shape,
self.x_np.dtype)
out = paddle.frexp(input_data)
# 计算静态图结果
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = np.frexp(self.x_np)
# 对比静态图与 numpy 实现函数计算结果是否相同
for n, p in zip(out_ref, res):
np.testing.assert_allclose(n, p, rtol=self.rtol, atol=self.atol)
# 动态图单测
def test_dygraph_api(self):
# 关闭静态图模式
paddle.disable_static(self.place)
input_num = paddle.to_tensor(self.x_np)
# 测试动态图 tensor.frexp 和 paddle.tensor.math.frexp 计算结果
out1 = np.frexp(self.x_np)
out2 = paddle.frexp(input_num)
np.testing.assert_allclose(out1, out2, rtol=1e-05)
out1 = np.frexp(self.x_np)
out2 = input_num.frexp()
np.testing.assert_allclose(out1, out2, rtol=1e-05)
paddle.enable_static()
class TestSplitsFloat32Case1(TestFrexpAPI):
"""
Test num_or_sections which is an integer and data type is float32.
"""
def set_input(self):
self.x_np = np.random.uniform(-1, 1, [4, 5, 2]).astype('float32')
class TestSplitsFloat64Case1(TestFrexpAPI):
"""
Test num_or_sections which is an integer and data type is float64.
"""
def set_input(self):
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float64')
class TestSplitsFloat64Case2(TestFrexpAPI):
"""
Test num_or_sections which is an integer and data type is float64.
"""
def set_input(self):
self.x_np = np.random.uniform(-1, 1, [4, 5, 2]).astype('float64')
if __name__ == "__main__":
unittest.main()
...@@ -239,6 +239,7 @@ from .math import heaviside # noqa: F401 ...@@ -239,6 +239,7 @@ from .math import heaviside # noqa: F401
from .math import frac # noqa: F401 from .math import frac # noqa: F401
from .math import sgn # noqa: F401 from .math import sgn # noqa: F401
from .math import take # noqa: F401 from .math import take # noqa: F401
from .math import frexp # noqa: F401
from .random import multinomial # noqa: F401 from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401 from .random import standard_normal # noqa: F401
...@@ -517,6 +518,7 @@ tensor_method_func = [ # noqa ...@@ -517,6 +518,7 @@ tensor_method_func = [ # noqa
'take', 'take',
'bucketize', 'bucketize',
'sgn', 'sgn',
'frexp',
] ]
# this list used in math_op_patch.py for magic_method bind # this list used in math_op_patch.py for magic_method bind
......
...@@ -5108,3 +5108,52 @@ def take(x, index, mode='raise', name=None): ...@@ -5108,3 +5108,52 @@ def take(x, index, mode='raise', name=None):
out = input_1d.index_select(index_1d).reshape(index.shape) out = input_1d.index_select(index_1d).reshape(index.shape)
return out return out
def frexp(x, name=None):
"""
The function used to decompose a floating point number into mantissa and exponent.
Args:
x (Tensor): The input tensor, it's data type should be float32, float64.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
- mantissa (Tensor), A mantissa Tensor. The shape and data type of mantissa tensor and exponential tensor are
the same as those of input.
- exponent (Tensor), A exponent Tensor. The shape and data type of mantissa tensor and exponential tensor are
the same as those of input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[1, 2, 3, 4]], dtype="float32")
print(paddle.tensor.math.frexp(x))
# (Tensor(shape=[1, 4], dtype=float32, place=Place(cpu), stop_gradient=True,[[0.50000000, 0.50000000, 0.75000000, 0.50000000]]),
# Tensor(shape=[1, 4], dtype=float32, place=Place(cpu), stop_gradient=True,[[1., 2., 2., 3.]]))
"""
if x.dtype not in [paddle.float32, paddle.float64]:
raise TypeError(
"The data type of input must be one of ['float32', 'float64'], but got {}"
.format(x.dtype))
input_x = paddle.abs(x)
exponent = paddle.floor(paddle.log2(input_x))
exponent = paddle.where(paddle.isinf(exponent),
paddle.full_like(exponent, 0), exponent)
# 0填充
mantissa = paddle.divide(input_x, 2**exponent)
# 计算exponent
exponent = paddle.where((mantissa >= 1),
paddle.add(exponent, paddle.ones_like(exponent)),
exponent)
mantissa = paddle.where((mantissa >= 1),
paddle.divide(mantissa,
2**paddle.ones_like(exponent)),
mantissa)
mantissa = paddle.where((x < 0), mantissa * -1, mantissa)
return mantissa, exponent
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册