diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 518e2c0c4d90daec12c3b924ace10fbd667c22ae..94de3fa0adb42f7b358688d4c1af78e822e64613 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -183,6 +183,7 @@ from .tensor.math import addmm #DEFINE_ALIAS from .tensor.math import clamp #DEFINE_ALIAS from .tensor.math import trace #DEFINE_ALIAS from .tensor.math import kron #DEFINE_ALIAS +from .tensor.math import prod #DEFINE_ALIAS # from .tensor.random import gaussin #DEFINE_ALIAS # from .tensor.random import uniform #DEFINE_ALIAS from .tensor.random import shuffle #DEFINE_ALIAS diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index efa60b70001e54e82b18ecc3dbf5f51a16f0ac0d..985db80a22297a41436205d14a91188d560a9c41 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4595,7 +4595,7 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): Args: input (Variable): The input variable which is a Tensor, the data type is float32, float64, int32, int64. - dim (list|int, optional): The dimensions along which the product is performed. If + dim (int|list|tuple, optional): The dimensions along which the product is performed. If :attr:`None`, multiply all elements of :attr:`input` and return a Tensor variable with a single element, otherwise must be in the range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`, @@ -4635,9 +4635,18 @@ def reduce_prod(input, dim=None, keep_dim=False, name=None): fluid.layers.reduce_prod(y, dim=[0, 1]) # [105.0, 384.0] """ helper = LayerHelper('reduce_prod', **locals()) - out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) if dim is not None and not isinstance(dim, list): - dim = [dim] + if isinstance(dim, tuple): + dim = list(dim) + elif isinstance(dim, int): + dim = [dim] + else: + raise TypeError( + "The type of axis must be int, list or tuple, but received {}". + format(type(dim))) + check_variable_and_dtype( + input, 'input', ['float32', 'float64', 'int32', 'int64'], 'reduce_prod') + out = helper.create_variable_for_type_inference(dtype=helper.input_dtype()) helper.append_op( type='reduce_prod', inputs={'X': input}, diff --git a/python/paddle/fluid/tests/unittests/test_prod_op.py b/python/paddle/fluid/tests/unittests/test_prod_op.py new file mode 100644 index 0000000000000000000000000000000000000000..158683907253e2ebc5adab6799c75ffd914df1c7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_prod_op.py @@ -0,0 +1,132 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import unittest +import numpy as np + + +class TestProdOp(unittest.TestCase): + def setUp(self): + self.input = np.random.random(size=(10, 10, 5)).astype(np.float32) + + def run_imperative(self): + input = paddle.to_tensor(self.input) + dy_result = paddle.prod(input) + expected_result = np.prod(self.input) + self.assertTrue(np.allclose(dy_result.numpy(), expected_result)) + + dy_result = paddle.prod(input, axis=1) + expected_result = np.prod(self.input, axis=1) + self.assertTrue(np.allclose(dy_result.numpy(), expected_result)) + + dy_result = paddle.prod(input, axis=-1) + expected_result = np.prod(self.input, axis=-1) + self.assertTrue(np.allclose(dy_result.numpy(), expected_result)) + + dy_result = paddle.prod(input, axis=[0, 1]) + expected_result = np.prod(self.input, axis=(0, 1)) + self.assertTrue(np.allclose(dy_result.numpy(), expected_result)) + + dy_result = paddle.prod(input, axis=1, keepdim=True) + expected_result = np.prod(self.input, axis=1, keepdims=True) + self.assertTrue(np.allclose(dy_result.numpy(), expected_result)) + + dy_result = paddle.prod(input, axis=1, dtype='int64') + expected_result = np.prod(self.input, axis=1, dtype=np.int64) + self.assertTrue(np.allclose(dy_result.numpy(), expected_result)) + + dy_result = paddle.prod(input, axis=1, keepdim=True, dtype='int64') + expected_result = np.prod( + self.input, axis=1, keepdims=True, dtype=np.int64) + self.assertTrue(np.allclose(dy_result.numpy(), expected_result)) + + def run_static(self, use_gpu=False): + input = paddle.data(name='input', shape=[10, 10, 5], dtype='float32') + result0 = paddle.prod(input) + result1 = paddle.prod(input, axis=1) + result2 = paddle.prod(input, axis=-1) + result3 = paddle.prod(input, axis=[0, 1]) + result4 = paddle.prod(input, axis=1, keepdim=True) + result5 = paddle.prod(input, axis=1, dtype='int64') + result6 = paddle.prod(input, axis=1, keepdim=True, dtype='int64') + + place = paddle.CUDAPlace(0) if use_gpu else paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(paddle.static.default_startup_program()) + static_result = exe.run(feed={"input": self.input}, + fetch_list=[ + result0, result1, result2, result3, result4, + result5, result6 + ]) + + expected_result = np.prod(self.input) + self.assertTrue(np.allclose(static_result[0], expected_result)) + expected_result = np.prod(self.input, axis=1) + self.assertTrue(np.allclose(static_result[1], expected_result)) + expected_result = np.prod(self.input, axis=-1) + self.assertTrue(np.allclose(static_result[2], expected_result)) + expected_result = np.prod(self.input, axis=(0, 1)) + self.assertTrue(np.allclose(static_result[3], expected_result)) + expected_result = np.prod(self.input, axis=1, keepdims=True) + self.assertTrue(np.allclose(static_result[4], expected_result)) + expected_result = np.prod(self.input, axis=1, dtype=np.int64) + self.assertTrue(np.allclose(static_result[5], expected_result)) + expected_result = np.prod( + self.input, axis=1, keepdims=True, dtype=np.int64) + self.assertTrue(np.allclose(static_result[6], expected_result)) + + def test_cpu(self): + paddle.disable_static(place=paddle.CPUPlace()) + self.run_imperative() + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + self.run_static() + + def test_gpu(self): + if not paddle.fluid.core.is_compiled_with_cuda(): + return + + paddle.disable_static(place=paddle.CUDAPlace(0)) + self.run_imperative() + paddle.enable_static() + + with paddle.static.program_guard(paddle.static.Program()): + self.run_static(use_gpu=True) + + +class TestProdOpError(unittest.TestCase): + def test_error(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x = paddle.data(name='x', shape=[2, 2, 4], dtype='float32') + bool_x = paddle.data(name='bool_x', shape=[2, 2, 4], dtype='bool') + # The argument x shoule be a Tensor + self.assertRaises(TypeError, paddle.prod, [1]) + + # The data type of x should be float32, float64, int32, int64 + self.assertRaises(TypeError, paddle.prod, bool_x) + + # The argument axis's type shoule be int ,list or tuple + self.assertRaises(TypeError, paddle.prod, x, 1.5) + + # The argument dtype of prod_op should be float32, float64, int32 or int64. + self.assertRaises(TypeError, paddle.prod, x, 'bool') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 8833bf5735c9e0d44cfe7df2517c174ac933c53c..ba108beb0bd93f537065ae71fa76a41c9da23853 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -157,6 +157,7 @@ from .math import addmm #DEFINE_ALIAS from .math import clamp #DEFINE_ALIAS from .math import trace #DEFINE_ALIAS from .math import kron #DEFINE_ALIAS +from .math import prod #DEFINE_ALIAS # from .random import gaussin #DEFINE_ALIAS # from .random import uniform #DEFINE_ALIAS from .random import shuffle #DEFINE_ALIAS diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 50f67574526482fbbb5176186ebc32fecf49f883..ea9e2b4a1550231eda16346016094a763f65090d 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -63,6 +63,7 @@ from ..fluid.layers import tanh #DEFINE_ALIAS from ..fluid.layers import increment #DEFINE_ALIAS from ..fluid.layers import multiplex #DEFINE_ALIAS from ..fluid.layers import sums #DEFINE_ALIAS +from ..fluid import layers __all__ = [ 'abs', @@ -85,6 +86,7 @@ __all__ = [ 'log', 'mul', 'multiplex', + 'prod', 'pow', 'reciprocal', 'reduce_max', @@ -1632,3 +1634,85 @@ def cumsum(x, axis=None, dtype=None, name=None): kwargs[name] = val _cum_sum_ = generate_layer_fn('cumsum') return _cum_sum_(**kwargs) + +def prod(x, axis=None, keepdim=False, dtype=None, name=None): + """ + Compute the product of tensor elements over the given axis. + + Args: + x(Tensor): An N-D Tensor, the data type is float32, float64, int32 or int64. + axis(int|list|tuple, optional): The axis along which the product is computed. If :attr:`None`, + multiply all elements of `x` and return a Tensor with a single element, + otherwise must be in the range :math:`[-x.ndim, x.ndim)`. If :math:`axis[i]<0`, + the axis to reduce is :math:`x.ndim + axis[i]`. Default is None. + dtype(str|np.dtype, optional): The desired date type of returned tensor, can be float32, float64, + int32, int64. If specified, the input tensor is casted to dtype before operator performed. + This is very useful for avoiding data type overflows. The default value is None, the dtype + of output is the same as input Tensor `x`. + keepdim(bool, optional): Whether to reserve the reduced dimension in the output Tensor. The result + tensor will have one fewer dimension than the input unless keep_dim is true. Default is False. + name(string, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Tensor, result of product on the specified dim of input tensor. + + Raises: + ValueError: The :attr:`dtype` must be float32, float64, int32 or int64. + TypeError: The type of :attr:`axis` must be int, list or tuple. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + # the axis is a int element + data_x = np.array([[0.2, 0.3, 0.5, 0.9], + [0.1, 0.2, 0.6, 0.7]]).astype(np.float32) + x = paddle.to_tensor(data_x) + out1 = paddle.prod(x) + print(out1.numpy()) + # [0.0002268] + + out2 = paddle.prod(x, -1) + print(out2.numpy()) + # [0.027 0.0084] + + out3 = paddle.prod(x, 0) + print(out3.numpy()) + # [0.02 0.06 0.3 0.63] + print(out3.numpy().dtype) + # float32 + + out4 = paddle.prod(x, 0, keepdim=True) + print(out4.numpy()) + # [[0.02 0.06 0.3 0.63]] + + out5 = paddle.prod(x, 0, dtype='int64') + print(out5.numpy()) + # [0 0 0 0] + print(out5.numpy().dtype) + # int64 + + # the axis is list + data_y = np.array([[[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]]]) + y = paddle.to_tensor(data_y) + out6 = paddle.prod(y, [0, 1]) + print(out6.numpy()) + # [105. 384.] + + out7 = paddle.prod(y, (1, 2)) + print(out7.numpy()) + # [ 24. 1680.] + + """ + if dtype is not None: + check_dtype(dtype, 'dtype', ['float32', 'float64', 'int32', 'int64'], 'prod') + if x.dtype != convert_np_dtype_to_dtype_(dtype): + x = layers.cast(x, dtype) + + return layers.reduce_prod(input=x, dim=axis, keep_dim=keepdim, name=name)