From b61aaa2c1030c8d3204872b61e75c194859f66a8 Mon Sep 17 00:00:00 2001 From: liuwei1031 <46661762+liuwei1031@users.noreply.github.com> Date: Sun, 12 Apr 2020 09:47:09 +0800 Subject: [PATCH] add logsumexp op, test=develop (#23585) --- python/paddle/__init__.py | 2 +- .../fluid/tests/unittests/test_logsumexp.py | 85 +++++++++++++++++++ python/paddle/tensor/__init__.py | 2 +- python/paddle/tensor/math.py | 57 ++++++++++++- 4 files changed, 143 insertions(+), 3 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_logsumexp.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 5e28e698be2..5587fd795b1 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -137,7 +137,7 @@ from .tensor.math import mm #DEFINE_ALIAS from .tensor.math import div #DEFINE_ALIAS from .tensor.math import add #DEFINE_ALIAS # from .tensor.math import atan #DEFINE_ALIAS -# from .tensor.math import logsumexp #DEFINE_ALIAS +from .tensor.math import logsumexp #DEFINE_ALIAS # from .tensor.math import inverse #DEFINE_ALIAS # from .tensor.math import log1p #DEFINE_ALIAS # from .tensor.math import erf #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_logsumexp.py b/python/paddle/fluid/tests/unittests/test_logsumexp.py new file mode 100644 index 00000000000..791f0307cd8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_logsumexp.py @@ -0,0 +1,85 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import paddle +import paddle.fluid as fluid +import unittest +import numpy as np +from op_test import OpTest +from paddle.fluid import Program, program_guard +from paddle.fluid.layer_helper import LayerHelper + + +class TestLogSumOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + + x1 = fluid.layers.data(name='x1', shape=[120], dtype="uint8") + self.assertRaises(Exception, paddle.logsumexp, x1) + + x2 = fluid.layers.data(name='x2', shape=[2, 3], dtype="int") + self.assertRaises(Exception, paddle.logsumexp, x2) + + x3 = fluid.layers.data(name='x3', shape=[3], dtype="float16") + self.assertRaises(Exception, paddle.logsumexp, x3) + + self.assertRaises(AssertionError, paddle.logsumexp, None) + + +class TestLogSumExpOp(unittest.TestCase): + def test_dygraph(self): + with fluid.dygraph.guard(): + np_x = np.random.uniform(0.1, 1, [123]).astype(np.float32) + x = fluid.dygraph.to_variable(np_x) + self.assertTrue( + np.allclose( + paddle.logsumexp(x).numpy(), np.log(np.sum(np.exp(np_x))))) + + np_x = np.random.uniform(0.1, 1, [2, 3, 4]).astype(np.float32) + x = fluid.dygraph.to_variable(np_x) + self.assertTrue( + np.allclose( + paddle.logsumexp( + x, dim=[1, 2]).numpy(), + np.log(np.sum(np.exp(np_x), axis=(1, 2))))) + + np_x = np.random.uniform(0.1, 1, [2, 3, 4]).astype(np.float32) + x = fluid.dygraph.to_variable(np_x) + self.assertTrue( + np.allclose( + paddle.logsumexp( + x, dim=[2]).numpy(), + np.log(np.sum(np.exp(np_x), axis=(2))))) + + np_x = np.random.uniform(0.1, 1, [2, 3, 4]).astype(np.float32) + x = fluid.dygraph.to_variable(np_x) + self.assertTrue( + np.allclose( + paddle.logsumexp( + x, keepdim=True).numpy(), + np.log(np.sum(np.exp(np_x), keepdims=True)))) + + np_x = np.random.uniform(0.1, 1, [2, 3, 4]).astype(np.float32) + x = fluid.dygraph.to_variable(np_x) + helper = LayerHelper("test_logsumexp") + out = helper.create_variable( + type=x.type, name='out', dtype=x.dtype, persistable=False) + paddle.logsumexp(x, out=out) + self.assertTrue( + np.allclose(out.numpy(), np.log(np.sum(np.exp(np_x))))) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 4748172a00c..15a1607ac87 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -112,7 +112,7 @@ from .math import mm #DEFINE_ALIAS from .math import div #DEFINE_ALIAS from .math import add #DEFINE_ALIAS # from .math import atan #DEFINE_ALIAS -# from .math import logsumexp #DEFINE_ALIAS +from .math import logsumexp #DEFINE_ALIAS # from .math import inverse #DEFINE_ALIAS # from .math import log1p #DEFINE_ALIAS # from .math import erf #DEFINE_ALIAS diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index fc350bc7817..b138ac9624e 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -18,6 +18,7 @@ math functions from __future__ import print_function from paddle.common_ops_import import * +from ..fluid import layers from ..fluid.framework import core from ..fluid.layers.layer_function_generator import _generate_doc_string_ @@ -70,7 +71,7 @@ __all__ = [ 'div', 'add', # 'atan', -# 'logsumexp', + 'logsumexp', # 'inverse', # 'log1p', # 'erf', @@ -994,3 +995,57 @@ def addmm(input, x, y, alpha=1.0, beta=1.0, name=None): helper.append_op( type="addmm", inputs=inputs, attrs=attrs, outputs={"Out": out}) return out + + +def logsumexp(x, dim=None, keepdim=False, out=None, name=None): + """ +This operator calculates the log of the sum of exponentials of the input Tensor. + +.. math:: + logsumexp(x) = \log\sum exp(x) + + +Parameters: + x (Variable): Input LoDTensor or Tensor. Must be one of the following types: float32, float64. + dim (list|int, optional): The dimensions along which the sum is performed. If :attr:`None`, + sum all elements of :attr:`input` and return a Tensor variable with a single element, + otherwise must be in the range :math:`[-rank(input), rank(input))`. If :math:`dim[i] < 0`, + the dimension to reduce is :math:`rank + dim[i]`. + keep_dim (bool, optional): Whether to reserve the reduced dimension in the output Tensor. + The result tensor will have one fewer dimension than the :attr:`input` unless :attr:`keep_dim` + is true, default value is False. + name (str, optional): The default value is None. Normally there is no need for user to + set this property. For more information, please refer to :ref:`api_guide_Name` + + +Examples: + +.. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + with fluid.dygraph.guard(): + np_x = np.random.uniform(0.1, 1, [10]).astype(np.float32) + x = fluid.dygraph.to_variable(np_x) + print(paddle.logsumexp(x).numpy()) + + + """ + op_type = 'logsumexp' + assert x is not None, 'x cannot be None in {}'.format(op_type) + + # reduce_sum does not support float16 + check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type) + + exp_out = layers.exp(x) + sum_out = layers.reduce_sum(exp_out, dim, keepdim) + + if out is not None: + check_variable_and_dtype(out, 'out', [x.dtype], op_type) + helper = LayerHelper(op_type, **locals()) + helper.append_op(type="log", inputs={"X": sum_out}, outputs={"Out": out}) + return out + + return layers.log(sum_out, name) -- GitLab