From 08ac685e3a2fc6a97ef9334c067bb87bb4335c73 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 10 Jun 2020 03:42:37 +0800 Subject: [PATCH] feat(mge/functional): add logsumexp GitOrigin-RevId: e7ef50e9ece3ba6f71f5fe23c9a81ce9ccd0dc60 --- .../megengine/functional/__init__.py | 14 ++++- python_module/megengine/functional/math.py | 39 ++++++++++++- python_module/megengine/test/__init__.py | 34 +++++++++-- .../test/unit/functional/test_math.py | 56 ++++++++++++++++++- 4 files changed, 134 insertions(+), 9 deletions(-) diff --git a/python_module/megengine/functional/__init__.py b/python_module/megengine/functional/__init__.py index 82cbc171..a2d25ac4 100644 --- a/python_module/megengine/functional/__init__.py +++ b/python_module/megengine/functional/__init__.py @@ -50,7 +50,19 @@ from .loss import ( square_loss, triplet_margin_loss, ) -from .math import argmax, argmin, max, mean, min, norm, normalize, prod, sqrt, sum +from .math import ( + argmax, + argmin, + logsumexp, + max, + mean, + min, + norm, + normalize, + prod, + sqrt, + sum, +) from .nn import ( assert_equal, avg_pool2d, diff --git a/python_module/megengine/functional/math.py b/python_module/megengine/functional/math.py index 295f5ad2..06f9cebe 100644 --- a/python_module/megengine/functional/math.py +++ b/python_module/megengine/functional/math.py @@ -6,12 +6,15 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from typing import Optional +import math +import numbers +from typing import Optional, Sequence, Union import megengine._internal as mgb from ..core import Tensor, wrap_io_tensor -from .elemwise import clamp +from .elemwise import clamp, exp, isinf, log +from .tensor import remove_axis, where, zeros_like @wrap_io_tensor @@ -296,3 +299,35 @@ def normalize( return inp / clamp(norm(inp, p), lower=eps) else: return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps) + + +def logsumexp(inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False): + r""" + Compute the log of the sum of exponentials of inputs along the given :attr:`axis`. The computation is numerically stabilized. + + .. math:: + + \mathsf{logsumexp}(x_1, \dots, x_n) = \log(\exp(x_1) + \cdots + \exp(x_n)) + + :param inp: The input tensor. + :param axis: Axis over which the sum is taken. It can be a single axis or a list of axes. + :param keepdims: whether to retain :attr:`axis` or not for the output tensor. + + """ + if isinstance(axis, numbers.Integral): + axis = (axis,) + max_value = inp + for dim in axis: + max_value = max_value.max(axis=dim, keepdims=True) + max_value = where( + isinf(max_value).astype("int32"), zeros_like(max_value), max_value + ) + x = exp(inp - max_value) + for dim in axis: + x = x.sum(axis=dim, keepdims=True) + x = max_value + log(x) + if not keepdims: + axis = sorted(axis, reverse=True) + for i in axis: + x = remove_axis(x, axis=i) + return x diff --git a/python_module/megengine/test/__init__.py b/python_module/megengine/test/__init__.py index 28713668..44ed54c2 100644 --- a/python_module/megengine/test/__init__.py +++ b/python_module/megengine/test/__init__.py @@ -9,9 +9,12 @@ import numpy as np -def assertTensorClose(v0, v1, *, max_err=1e-6, name=None): +def assertTensorClose( + v0, v1, *, max_err: float = 1e-6, allow_special_values: bool = False, name=None +): """ - max_err: relative error + :param allow_special_values: whether to allow :attr:`v0` and :attr:`v1` to contain inf and nan values. + :param max_err: relative error """ __tracebackhide__ = True # pylint: disable=unused-variable @@ -20,9 +23,30 @@ def assertTensorClose(v0, v1, *, max_err=1e-6, name=None): ), "Two Tensor must have same dtype, but the inputs are {} and {}".format( v0.dtype, v1.dtype ) - v0 = np.ascontiguousarray(v0, dtype=np.float32) - v1 = np.ascontiguousarray(v1, dtype=np.float32) - assert np.isfinite(v0.sum()) and np.isfinite(v1.sum()), (v0, v1) + v0 = np.ascontiguousarray(v0, dtype=np.float32).copy() + v1 = np.ascontiguousarray(v1, dtype=np.float32).copy() + if allow_special_values: + # check nan and rm it + v0_nan_mask = np.isnan(v0) + if np.any(v0_nan_mask): + assert np.array_equiv(v0_nan_mask, np.isnan(v1)), (v0, v1) + v0[v0_nan_mask] = 0 + v1[v0_nan_mask] = 0 + # check inf and rm it + v0_inf_mask = v0 == float("inf") + if np.any(v0_inf_mask): + assert np.array_equiv(v0_inf_mask, v1 == float("inf")), (v0, v1) + v0[v0_inf_mask] = 0 + v1[v0_inf_mask] = 0 + # check -inf and rm it + v0_inf_mask = v0 == float("-inf") + if np.any(v0_inf_mask): + assert np.array_equiv(v0_inf_mask, v1 == float("-inf")), (v0, v1) + v0[v0_inf_mask] = 0 + v1[v0_inf_mask] = 0 + else: + assert np.isfinite(v0.sum()) and np.isfinite(v1.sum()), (v0, v1) + assert v0.shape == v1.shape, "Two tensor must have same shape({} v.s. {})".format( v0.shape, v1.shape ) diff --git a/python_module/test/unit/functional/test_math.py b/python_module/test/unit/functional/test_math.py index 9354fee8..6dc5c82c 100644 --- a/python_module/test/unit/functional/test_math.py +++ b/python_module/test/unit/functional/test_math.py @@ -6,10 +6,14 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from functools import partial + import numpy as np from helpers import opr_test import megengine.functional as F +from megengine.test import assertTensorClose def common_test_reduce(opr, ref_opr): @@ -86,7 +90,6 @@ def test_sqrt(): def test_normalize(): - from functools import partial cases = [ {"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2) @@ -112,3 +115,54 @@ def test_normalize(): cases[0]["input"][0, 0, 0, :] = 0 cases[1]["input"][0, 0, 0, :] = 0 opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3)) + + +def test_logsumexp(): + x = np.arange(10).astype(np.float32) + expected = np.log(np.sum(np.exp(x))) + cases = [{"input": x, "output": expected}] + compare_fn = partial(assertTensorClose, allow_special_values=True) + # large value check + n = 100 + x = np.full(n, 10000, dtype=np.float32) + expected = 10000 + np.log(n) + cases.append({"input": x, "output": expected.astype(np.float32)}) + opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) + + # special value check + x = np.array([np.inf], dtype=np.float32) + expected = x + cases = [{"input": x, "output": expected}] + + x = np.array([-np.inf, 0.0], dtype=np.float32) + expected = np.zeros(1).astype(np.float32) + cases.append({"input": x, "output": expected}) + opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) + + x = np.array([np.nan], dtype=np.float32) + expected = x + cases = [{"input": x, "output": expected}] + + x = np.array([-np.inf, 1], dtype=np.float32) + expected = np.array([1.0], dtype=np.float32) + cases.append({"input": x, "output": expected}) + + opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn) + + # keepdims check + x = np.array([[1e10, 1e-10], [-1e10, -np.inf]], dtype=np.float32) + expected = np.array([[1e10], [-1e10]], dtype=np.float32) + cases = [{"input": x, "output": expected}] + x = np.array([[1e10, -1e-10, 1e-10], [1e10, 1e-10, np.inf]], dtype=np.float32) + expected = np.array([[1e10], [np.inf]], dtype=np.float32) + cases.append({"input": x, "output": expected}) + opr_test(cases, F.logsumexp, axis=1, keepdims=True, compare_fn=compare_fn) + + # multiple axes check + x = np.array([[1e10, 1e-10], [-1e10, -np.inf]], dtype=np.float32) + expected = np.array([1e10], dtype=np.float32) + cases = [{"input": x, "output": expected}] + x = np.array([[1e10, -1e-10, 1e-10], [1e10, 1e-10, np.inf]], dtype=np.float32) + expected = np.array([np.inf], dtype=np.float32) + cases.append({"input": x, "output": expected}) + opr_test(cases, F.logsumexp, axis=(0, 1), keepdims=False, compare_fn=compare_fn) -- GitLab