未验证 提交 cfe3ff48 编写于 作者: A Ainavo 提交者: GitHub

【Hackathon 4th 】add Trapezoid API && add Cumulative_trapezoid API (#51195)

上级 c1838da6
......@@ -296,6 +296,8 @@ from .tensor.math import frac # noqa: F401
from .tensor.math import sgn # noqa: F401
from .tensor.math import take # noqa: F401
from .tensor.math import frexp # noqa: F401
from .tensor.math import trapezoid # noqa: F401
from .tensor.math import cumulative_trapezoid # noqa: F401
from .tensor.random import bernoulli # noqa: F401
from .tensor.random import poisson # noqa: F401
......@@ -686,5 +688,7 @@ __all__ = [ # noqa
'triu_indices',
'take',
'frexp',
'trapezoid',
'cumulative_trapezoid',
'polar',
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from scipy.integrate import cumulative_trapezoid
from test_trapezoid import (
Testfp16Trapezoid,
TestTrapezoidAPI,
TestTrapezoidError,
)
import paddle
class TestCumulativeTrapezoidAPI(TestTrapezoidAPI):
def set_api(self):
self.ref_api = cumulative_trapezoid
self.paddle_api = paddle.cumulative_trapezoid
class TestCumulativeTrapezoidWithX(TestCumulativeTrapezoidAPI):
def set_args(self):
self.y = np.array([[2, 4, 8], [3, 5, 9]]).astype('float32')
self.x = np.array([[1, 2, 3], [3, 4, 5]]).astype('float32')
self.dx = None
self.axis = -1
class TestCumulativeTrapezoidAxis(TestCumulativeTrapezoidAPI):
def set_args(self):
self.y = np.array([[2, 4, 8], [3, 5, 9]]).astype('float32')
self.x = None
self.dx = 1.0
self.axis = 0
class TestCumulativeTrapezoidWithDx(TestCumulativeTrapezoidAPI):
def set_args(self):
self.y = np.array([[2, 4, 8], [3, 5, 9]]).astype('float32')
self.x = None
self.dx = 3.0
self.axis = -1
class TestCumulativeTrapezoidfloat64(TestCumulativeTrapezoidAPI):
def set_args(self):
self.y = np.array([[2, 4, 8], [3, 5, 9]]).astype('float64')
self.x = np.array([[1, 2, 3], [3, 4, 5]]).astype('float64')
self.dx = None
self.axis = -1
class TestCumulativeTrapezoidWithOutDxX(TestCumulativeTrapezoidAPI):
def set_args(self):
self.y = np.array([[2, 4, 8], [3, 5, 9]]).astype('float64')
self.x = None
self.dx = None
self.axis = -1
class TestCumulativeTrapezoidBroadcast(TestCumulativeTrapezoidAPI):
def set_args(self):
self.y = np.random.random((3, 3, 4)).astype('float32')
self.x = np.random.random((3)).astype('float32')
self.dx = None
self.axis = 1
class TestCumulativeTrapezoidAxis1(TestCumulativeTrapezoidAPI):
def set_args(self):
self.y = np.random.random((3, 3, 4)).astype('float32')
self.x = None
self.dx = 1
self.axis = 1
class TestCumulativeTrapezoidError(TestTrapezoidError):
def set_api(self):
self.paddle_api = paddle.cumulative_trapezoid
class Testfp16CumulativeTrapezoid(Testfp16Trapezoid):
def set_api(self):
self.paddle_api = paddle.cumulative_trapezoid
self.ref_api = cumulative_trapezoid
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid.framework import _test_eager_guard
class TestTrapezoidAPI(unittest.TestCase):
def set_args(self):
self.y = np.array([[2, 4, 8], [3, 5, 9]]).astype('float32')
self.x = None
self.dx = None
self.axis = -1
def get_output(self):
if self.x is None and self.dx is None:
self.output = self.ref_api(
y=self.y, x=self.x, dx=1.0, axis=self.axis
)
else:
self.output = self.ref_api(
y=self.y, x=self.x, dx=self.dx, axis=self.axis
)
def set_api(self):
self.ref_api = np.trapz
self.paddle_api = paddle.trapezoid
def setUp(self):
self.set_api()
self.set_args()
self.get_output()
self.places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))
def func_dygraph(self):
for place in self.places:
paddle.disable_static()
y = paddle.to_tensor(self.y, place=place)
if self.x is not None:
self.x = paddle.to_tensor(self.x, place=place)
if self.dx is not None:
self.dx = paddle.to_tensor(self.dx, place=place)
out = self.paddle_api(y=y, x=self.x, dx=self.dx, axis=self.axis)
np.testing.assert_allclose(out, self.output, rtol=1e-05)
def test_dygraph(self):
with _test_eager_guard():
self.setUp()
self.func_dygraph()
self.setUp()
self.func_dygraph()
def test_static(self):
paddle.enable_static()
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
for place in places:
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
y = paddle.static.data(
name="y", shape=self.y.shape, dtype=self.y.dtype
)
x = None
dx = None
if self.x is not None:
x = paddle.static.data(
name="x", shape=self.x.shape, dtype=self.x.dtype
)
if self.dx is not None:
dx = paddle.static.data(
name="dx", shape=[1], dtype='float32'
)
exe = paddle.static.Executor(place)
out = self.paddle_api(y=y, x=x, dx=dx, axis=self.axis)
fetches = exe.run(
paddle.static.default_main_program(),
feed={
"y": self.y,
"x": self.x,
"dx": self.dx,
"axis": self.axis,
},
fetch_list=[out],
)
np.testing.assert_allclose(fetches[0], self.output, rtol=1e-05)
class TestTrapezoidWithX(TestTrapezoidAPI):
def set_args(self):
self.y = np.array([[2, 4, 8], [3, 5, 9]]).astype('float32')
self.x = np.array([[1, 2, 3], [3, 4, 5]]).astype('float32')
self.dx = None
self.axis = -1
class TestTrapezoidAxis(TestTrapezoidAPI):
def set_args(self):
self.y = np.array([[2, 4, 8], [3, 5, 9]]).astype('float32')
self.x = None
self.dx = 1.0
self.axis = 0
class TestTrapezoidWithDx(TestTrapezoidAPI):
def set_args(self):
self.y = np.array([[2, 4, 8], [3, 5, 9]]).astype('float32')
self.x = None
self.dx = 3.0
self.axis = -1
class TestTrapezoidfloat64(TestTrapezoidAPI):
def set_args(self):
self.y = np.array([[2, 4, 8], [3, 5, 9]]).astype('float64')
self.x = np.array([[1, 2, 3], [3, 4, 5]]).astype('float64')
self.dx = None
self.axis = -1
class TestTrapezoidWithOutDxX(TestTrapezoidAPI):
def set_args(self):
self.y = np.array([[2, 4, 8], [3, 5, 9]]).astype('float64')
self.x = None
self.dx = None
self.axis = -1
class TestTrapezoidBroadcast(TestTrapezoidAPI):
def set_args(self):
self.y = np.random.random((3, 3, 4)).astype('float32')
self.x = np.random.random((3)).astype('float32')
self.dx = None
self.axis = 1
class TestTrapezoidAxis1(TestTrapezoidAPI):
def set_args(self):
self.y = np.random.random((3, 3, 4)).astype('float32')
self.x = None
self.dx = 1
self.axis = 1
class TestTrapezoidError(unittest.TestCase):
# test error
def set_api(self):
self.paddle_api = paddle.trapezoid
def test_errors(self):
self.set_api()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
def test_y_dtype():
y = paddle.static.data(
name='y',
shape=[4, 4],
dtype="int64",
)
x = paddle.static.data(name='x', shape=[4, 4], dtype="float32")
dx = None
self.paddle_api(y, x, dx)
self.assertRaises(TypeError, test_y_dtype)
def test_x_dtype():
y1 = paddle.static.data(
name='y1',
shape=[4, 4],
dtype="float32",
)
x1 = paddle.static.data(name='x1', shape=[4, 4], dtype="int64")
dx1 = None
self.paddle_api(y1, x1, dx1)
self.assertRaises(TypeError, test_x_dtype)
def test_dx_dim():
y2 = paddle.static.data(
name='y2',
shape=[4, 4],
dtype="float32",
)
x2 = None
dx2 = paddle.static.data(
name='dx2', shape=[4, 4], dtype="float32"
)
self.paddle_api(y2, x2, dx2)
self.assertRaises(ValueError, test_dx_dim)
def test_xwithdx():
y3 = paddle.static.data(
name='y3',
shape=[4, 4],
dtype="float32",
)
x3 = paddle.static.data(
name='x3', shape=[4, 4], dtype="float32"
)
dx3 = 1.0
self.paddle_api(y3, x3, dx3)
self.assertRaises(ValueError, test_xwithdx)
class Testfp16Trapezoid(TestTrapezoidAPI):
def set_api(self):
self.paddle_api = paddle.trapezoid
self.ref_api = np.trapz
def test_fp16_with_gpu(self):
paddle.enable_static()
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
input_y = np.random.random([4, 4]).astype("float16")
y = paddle.static.data(name="y", shape=[4, 4], dtype="float16")
input_x = np.random.random([4, 4]).astype("float16")
x = paddle.static.data(name="x", shape=[4, 4], dtype="float16")
exe = paddle.static.Executor(place)
out = self.paddle_api(y=y, x=x, dx=self.dx, axis=self.axis)
res = exe.run(
paddle.static.default_main_program(),
feed={
"y": input_y,
"x": input_x,
"dx": self.dx,
"axis": self.axis,
},
fetch_list=[out],
)
def test_fp16_func_dygraph(self):
if paddle.fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
paddle.disable_static()
input_y = np.random.random([4, 4])
y = paddle.to_tensor(input_y, dtype='float16', place=place)
input_x = np.random.random([4, 4])
x = paddle.to_tensor(input_x, dtype='float16', place=place)
out = self.paddle_api(y=y, x=x)
def test_fp16_dygraph(self):
with _test_eager_guard():
self.func_dygraph()
self.func_dygraph()
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -246,6 +246,8 @@ from .math import frac # noqa: F401
from .math import sgn # noqa: F401
from .math import take # noqa: F401
from .math import frexp # noqa: F401
from .math import trapezoid # noqa: F401
from .math import cumulative_trapezoid # noqa: F401
from .math import sigmoid # noqa: F401
from .math import sigmoid_ # noqa: F401
......@@ -531,6 +533,8 @@ tensor_method_func = [ # noqa
'bucketize',
'sgn',
'frexp',
'trapezoid',
'cumulative_trapezoid',
'polar',
'sigmoid',
'sigmoid_',
......
......@@ -4611,7 +4611,7 @@ def diff(x, n=1, axis=-1, prepend=None, append=None, name=None):
Only n=1 is currently supported.
Args:
x (Tensor): The input tensor to compute the forward difference on, the data type is float16(GPU), float32, float64, bool, int32, int64.
x (Tensor): The input tensor to compute the forward difference on, the data type is float16, float32, float64, bool, int32, int64.
n (int, optional): The number of times to recursively compute the difference.
Only support n=1. Default:1
axis (int, optional): The axis to compute the difference along. Default:-1
......@@ -5147,3 +5147,186 @@ def frexp(x, name=None):
mantissa = paddle.where((x < 0), mantissa * -1, mantissa)
return mantissa, exponent
def _trapezoid(y, x=None, dx=None, axis=-1, mode='sum'):
"""
Integrate along the given axis using the composite trapezoidal rule.
Args:
y (Tensor): Input tensor to integrate. It's data type should be float16, float32, float64.
x (Tensor, optional): The sample points corresponding to the :attr:`y` values, the same type as :attr:`y`.
It is known that the size of :attr:`y` is `[d_1, d_2, ... , d_n]` and :math:`axis=k`, then the size of :attr:`x` can only be `[d_k]` or `[d_1, d_2, ... , d_n ]`.
If :attr:`x` is None, the sample points are assumed to be evenly spaced :attr:`dx` apart. The default is None.
dx (float, optional): The spacing between sample points when :attr:`x` is None. If neither :attr:`x` nor :attr:`dx` is provided then the default is :math:`dx = 1`.
axis (int, optional): The axis along which to integrate. The default is -1.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
sum_mode (str): use a different summation. The default is `sum`.
Returns:
Tensor, Definite integral of :attr:`y` is N-D tensor as approximated along a single axis by the trapezoidal rule.
"""
if mode == 'sum':
sum_mode = paddle.sum
elif mode == 'cumsum':
sum_mode = paddle.cumsum
if not (x is None or dx is None):
raise ValueError("Not permitted to specify both x and dx input args.")
if y.dtype not in [paddle.float16, paddle.float32, paddle.float64]:
raise TypeError(
"The data type of input must be Tensor, and dtype should be one of ['paddle.float16', 'paddle.float32', 'paddle.float64'], but got {}".format(
y.dtype
)
)
y_shape = y.shape
length = y_shape[axis]
if axis < 0:
axis += y.dim()
if x is None:
if dx is None:
dx = 1.0
dx = paddle.to_tensor(dx)
if dx.dim() > 1:
raise ValueError('Expected dx to be a scalar, got dx={}'.format(dx))
else:
if x.dtype not in [paddle.float16, paddle.float32, paddle.float64]:
raise TypeError(
"The data type of input must be Tensor, and dtype should be one of ['paddle.float16', 'paddle.float32', 'paddle.float64'], but got {}".format(
x.dtype
)
)
# Reshape to correct shape
if x.dim() == 1:
dx = paddle.diff(x)
shape = [1] * y.dim()
shape[axis] = dx.shape[0]
dx = dx.reshape(shape)
else:
dx = paddle.diff(x, axis=axis)
return 0.5 * sum_mode(
(
paddle.gather(y, paddle.arange(1, length), axis=axis)
+ paddle.gather(y, paddle.arange(0, length - 1), axis=axis)
)
* dx,
axis=axis,
)
def trapezoid(y, x=None, dx=None, axis=-1, name=None):
"""
Integrate along the given axis using the composite trapezoidal rule. Use the sum method.
Args:
y (Tensor): Input tensor to integrate. It's data type should be float16, float32, float64.
x (Tensor, optional): The sample points corresponding to the :attr:`y` values, the same type as :attr:`y`.
It is known that the size of :attr:`y` is `[d_1, d_2, ... , d_n]` and :math:`axis=k`, then the size of :attr:`x` can only be `[d_k]` or `[d_1, d_2, ... , d_n ]`.
If :attr:`x` is None, the sample points are assumed to be evenly spaced :attr:`dx` apart. The default is None.
dx (float, optional): The spacing between sample points when :attr:`x` is None. If neither :attr:`x` nor :attr:`dx` is provided then the default is :math:`dx = 1`.
axis (int, optional): The axis along which to integrate. The default is -1.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, Definite integral of :attr:`y` is N-D tensor as approximated along a single axis by the trapezoidal rule.
If :attr:`y` is a 1D tensor, then the result is a float. If N is greater than 1, then the result is an (N-1)-D tensor.
Examples:
.. code-block:: python
import paddle
y = paddle.to_tensor([4, 5, 6], dtype='float32')
print(paddle.trapezoid(y))
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [10.])
print(paddle.trapezoid(y, dx=2.))
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [20.])
y = paddle.to_tensor([4, 5, 6], dtype='float32')
x = paddle.to_tensor([1, 2, 3], dtype='float32')
print(paddle.trapezoid(y, x))
# Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
# [10.])
y = paddle.to_tensor([1, 2, 3], dtype='float64')
x = paddle.to_tensor([8, 6, 4], dtype='float64')
print(paddle.trapezoid(y, x))
# Tensor(shape=[1], dtype=float64, place=Place(cpu), stop_gradient=True,
# [-8.])
y = paddle.arange(6).reshape((2, 3)).astype('float32')
print(paddle.trapezoid(y, axis=0))
# Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True,
# [1.50000000, 2.50000000, 3.50000000])
print(paddle.trapezoid(y, axis=1))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [2., 8.])
"""
return _trapezoid(y, x, dx, axis, mode='sum')
def cumulative_trapezoid(y, x=None, dx=None, axis=-1, name=None):
"""
Integrate along the given axis using the composite trapezoidal rule. Use the cumsum method
Args:
y (Tensor): Input tensor to integrate. It's data type should be float16, float32, float64.
x (Tensor, optional): The sample points corresponding to the :attr:`y` values, the same type as :attr:`y`.
It is known that the size of :attr:`y` is `[d_1, d_2, ... , d_n]` and :math:`axis=k`, then the size of :attr:`x` can only be `[d_k]` or `[d_1, d_2, ... , d_n ]`.
If :attr:`x` is None, the sample points are assumed to be evenly spaced :attr:`dx` apart. The default is None.
dx (float, optional): The spacing between sample points when :attr:`x` is None. If neither :attr:`x` nor :attr:`dx` is provided then the default is :math:`dx = 1`.
axis (int, optional): The axis along which to integrate. The default is -1.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor, Definite integral of :attr:`y` is N-D tensor as approximated along a single axis by the trapezoidal rule.
The result is an N-D tensor.
Examples:
.. code-block:: python
import paddle
y = paddle.to_tensor([4, 5, 6], dtype='float32')
print(paddle.cumulative_trapezoid(y))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [4.50000000, 10. ])
print(paddle.cumulative_trapezoid(y, dx=2.))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [9. , 20.])
y = paddle.to_tensor([4, 5, 6], dtype='float32')
x = paddle.to_tensor([1, 2, 3], dtype='float32')
print(paddle.cumulative_trapezoid(y, x))
# Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [4.50000000, 10. ])
y = paddle.to_tensor([1, 2, 3], dtype='float64')
x = paddle.to_tensor([8, 6, 4], dtype='float64')
print(paddle.cumulative_trapezoid(y, x))
# Tensor(shape=[2], dtype=float64, place=Place(cpu), stop_gradient=True,
# [-3., -8.])
y = paddle.arange(6).reshape((2, 3)).astype('float32')
print(paddle.cumulative_trapezoid(y, axis=0))
# Tensor(shape=[1, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[1.50000000, 2.50000000, 3.50000000]])
print(paddle.cumulative_trapezoid(y, axis=1))
# Tensor(shape=[2, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
# [[0.50000000, 2. ],
# [3.50000000, 8. ]])
"""
return _trapezoid(y, x, dx, axis, mode='cumsum')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册