diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index da32aab839cb777cf6fb4d49790fff6e367ebea9..28060ad171a1b6b832afe3a67569ad5832a78bee 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -227,6 +227,8 @@ from .tensor.math import lgamma # noqa: F401 from .tensor.math import lerp # noqa: F401 from .tensor.math import rad2deg # noqa: F401 from .tensor.math import deg2rad # noqa: F401 +from .tensor.math import gcd # noqa: F401 +from .tensor.math import lcm # noqa: F401 from .tensor.math import diff # noqa: F401 from .tensor.math import angle # noqa: F401 @@ -480,6 +482,8 @@ __all__ = [ # noqa 'atan2', 'rad2deg', 'deg2rad', + 'gcd', + 'lcm', 'expand', 'broadcast_to', 'ones_like', diff --git a/python/paddle/fluid/tests/unittests/test_gcd.py b/python/paddle/fluid/tests/unittests/test_gcd.py new file mode 100644 index 0000000000000000000000000000000000000000..820216dc56cd60e75aaca4ef0c3bcb3d93743450 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_gcd.py @@ -0,0 +1,93 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard +from op_test import OpTest + +paddle.enable_static() + + +class TestGcdAPI(unittest.TestCase): + def setUp(self): + self.x_np = 12 + self.y_np = 20 + self.x_shape = [1] + self.y_shape = [1] + + def test_static_graph(self): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(startup_program, train_program): + x = fluid.data(name='input1', dtype='int32', shape=self.x_shape) + y = fluid.data(name='input2', dtype='int32', shape=self.y_shape) + out = paddle.gcd(x, y) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + res = exe.run(fluid.default_main_program(), + feed={'input1': self.x_np, + 'input2': self.y_np}, + fetch_list=[out]) + self.assertTrue((np.array(res[0]) == np.gcd(self.x_np, self.y_np) + ).all()) + + def test_dygraph(self): + paddle.disable_static() + x = paddle.to_tensor(self.x_np) + y = paddle.to_tensor(self.y_np) + result = paddle.gcd(x, y) + self.assertEqual( + np.allclose(np.gcd(self.x_np, self.y_np), result.numpy()), True) + + paddle.enable_static() + + +class TestGcdAPI2(TestGcdAPI): + def setUp(self): + self.x_np = np.arange(6).astype(np.int32) + self.y_np = np.array([20]).astype(np.int32) + self.x_shape = [6] + self.y_shape = [1] + + +class TestGcdAPI3(TestGcdAPI): + def setUp(self): + self.x_np = 0 + self.y_np = 20 + self.x_shape = [1] + self.y_shape = [1] + + +class TestGcdAPI4(TestGcdAPI): + def setUp(self): + self.x_np = 0 + self.y_np = 0 + self.x_shape = [1] + self.y_shape = [1] + + +class TestGcdAPI5(TestGcdAPI): + def setUp(self): + self.x_np = 12 + self.y_np = -20 + self.x_shape = [1] + self.y_shape = [1] diff --git a/python/paddle/fluid/tests/unittests/test_lcm.py b/python/paddle/fluid/tests/unittests/test_lcm.py new file mode 100644 index 0000000000000000000000000000000000000000..123c3e3d444e1be7062986e4820ff7bfc3df882c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_lcm.py @@ -0,0 +1,93 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard +from op_test import OpTest + +paddle.enable_static() + + +class TestLcmAPI(unittest.TestCase): + def setUp(self): + self.x_np = 12 + self.y_np = 20 + self.x_shape = [1] + self.y_shape = [1] + + def test_static_graph(self): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(startup_program, train_program): + x1 = fluid.data(name='input1', dtype='int32', shape=self.x_shape) + x2 = fluid.data(name='input2', dtype='int32', shape=self.y_shape) + out = paddle.lcm(x1, x2) + + place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + exe = fluid.Executor(place) + res = exe.run(fluid.default_main_program(), + feed={'input1': self.x_np, + 'input2': self.y_np}, + fetch_list=[out]) + self.assertTrue((np.array(res[0]) == np.lcm(self.x_np, self.y_np) + ).all()) + + def test_dygraph(self): + paddle.disable_static() + x1 = paddle.to_tensor(self.x_np) + x2 = paddle.to_tensor(self.y_np) + result = paddle.lcm(x1, x2) + self.assertEqual( + np.allclose(np.lcm(self.x_np, self.y_np), result.numpy()), True) + + paddle.enable_static() + + +class TestLcmAPI2(TestLcmAPI): + def setUp(self): + self.x_np = np.arange(6).astype(np.int32) + self.y_np = np.array([20]).astype(np.int32) + self.x_shape = [6] + self.y_shape = [1] + + +class TestLcmAPI3(TestLcmAPI): + def setUp(self): + self.x_np = 0 + self.y_np = 20 + self.x_shape = [1] + self.y_shape = [1] + + +class TestLcmAPI4(TestLcmAPI): + def setUp(self): + self.x_np = 0 + self.y_np = 0 + self.x_shape = [1] + self.y_shape = [1] + + +class TestLcmAPI5(TestLcmAPI): + def setUp(self): + self.x_np = 12 + self.y_np = -20 + self.x_shape = [1] + self.y_shape = [1] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 53001c071598941449fb0dc6f43a7f4843ea7c32..82727b33f9734b66db78c08a250d4c1878b13e11 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -194,6 +194,8 @@ from .math import lerp # noqa: F401 from .math import lerp_ # noqa: F401 from .math import rad2deg # noqa: F401 from .math import deg2rad # noqa: F401 +from .math import gcd # noqa: F401 +from .math import lcm # noqa: F401 from .math import diff # noqa: F401 from .math import angle # noqa: F401 @@ -409,6 +411,10 @@ tensor_method_func = [ #noqa 'multi_dot', 'solve', 'triangular_solve', + 'rad2deg', + 'deg2rad', + 'gcd', + 'lcm', 'diff', 'lerp', 'lerp_', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 88a381827910cbda677e42cd1cdb318740c0e4f2..b79caf0559d379ae1d77ec072e2d0735308b26c4 100755 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -2788,6 +2788,139 @@ def deg2rad(x, name=None): type='scale', inputs={'X':out_cast}, outputs={'Out': out}, attrs={'scale': deg2rad_scale}) return out +def gcd(x, y, name=None): + """ + Computes the element-wise greatest common divisor (GCD) of input |x| and |y|. + Both x and y must have integer types. + + Note: + gcd(0,0)=0, gcd(0, y)=|y| + + Args: + x, y (Tensor): An N-D Tensor, the data type is int8,int16,int32,int64,uint8. + If x.shape != y.shape, they must be broadcastable to a common shape (which becomes the shape of the output). + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): An N-D Tensor, the data type is the same with input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + x1 = paddle.to_tensor(12) + x2 = paddle.to_tensor(20) + paddle.gcd(x1, x2) + # Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [4]) + + x3 = paddle.to_tensor(np.arange(6)) + paddle.gcd(x3, x2) + # Tensor(shape=[6], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [20, 1 , 2 , 1 , 4 , 5]) + + x4 = paddle.to_tensor(0) + paddle.gcd(x4, x2) + # Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [20]) + + paddle.gcd(x4, x4) + # Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [0]) + + x5 = paddle.to_tensor(-20) + paddle.gcd(x1, x5) + # Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [4]) + """ + shape = paddle.broadcast_shape(x.shape, y.shape) + x = paddle.broadcast_to(x, shape) + y = paddle.broadcast_to(y, shape) + x = paddle.abs(x) + y = paddle.abs(y) + + def _gcd_cond_fn(x, y): + return paddle.any(y != 0) + + def _gcd_body_fn(x, y): + # paddle.mod will raise an error when any element of y is 0. To avoid + # that, we change those zeros to ones. Their values don't matter because + # they won't be used. + y_not_equal_0 = (y != 0) + y_safe = paddle.where(y_not_equal_0, y, paddle.ones(y.shape, y.dtype)) + x, y = (paddle.where(y_not_equal_0, y, x), + paddle.where(y_not_equal_0, paddle.mod(x, y_safe),paddle.zeros(y.shape, y.dtype))) + return (paddle.where(x < y, y, x), paddle.where(x < y, x, y)) + + if in_dygraph_mode(): + while _gcd_cond_fn(x, y): + x, y = _gcd_body_fn(x, y) + + return x + else: + check_variable_and_dtype(x, 'x', ['int32', 'int64', 'int8', 'int16', 'uint8'], 'gcd') + check_variable_and_dtype(y, 'y', ['int32', 'int64', 'int8', 'int16', 'uint8'], 'gcd') + out, _ = paddle.static.nn.while_loop(_gcd_cond_fn, _gcd_body_fn, [x, y]) + return out + +def lcm(x, y, name=None): + """ + Computes the element-wise least common multiple (LCM) of input |x| and |y|. + Both x and y must have integer types. + + Note: + lcm(0,0)=0, lcm(0, y)=0 + + Args: + x, y (Tensor): An N-D Tensor, the data type is int8,int16,int32,int64,uint8. + If x.shape != y.shape, they must be broadcastable to a common shape (which becomes the shape of the output). + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + out (Tensor): An N-D Tensor, the data type is the same with input. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + x1 = paddle.to_tensor(12) + x2 = paddle.to_tensor(20) + paddle.lcm(x1, x2) + # Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [60]) + + x3 = paddle.to_tensor(np.arange(6)) + paddle.lcm(x3, x2) + # Tensor(shape=[6], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [0, 20, 20, 60, 20, 20]) + + x4 = paddle.to_tensor(0) + paddle.lcm(x4, x2) + # Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [0]) + + paddle.lcm(x4, x4) + # Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [0]) + + x5 = paddle.to_tensor(-20) + paddle.lcm(x1, x5) + # Tensor(shape=[1], dtype=int64, place=CUDAPlace(0), stop_gradient=True, + # [60]) + """ + d = paddle.gcd(x, y) + # paddle.mod will raise an error when any element of y is 0. To avoid + # that, we change those zeros to ones. Their values don't matter because + # they won't be used. + d_equal_0 = paddle.equal(d, 0) + d_safe = paddle.where(d_equal_0, paddle.ones(d.shape, d.dtype), d) + out = paddle.where(d_equal_0, paddle.zeros(d.shape, d.dtype), paddle.abs(x * y) // d_safe) + return out + def diff(x, n=1, axis=-1, prepend=None, append=None, name=None): r""" Computes the n-th forward difference along the given axis.