提交 08ac685e 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(mge/functional): add logsumexp

GitOrigin-RevId: e7ef50e9ece3ba6f71f5fe23c9a81ce9ccd0dc60
上级 65ec4f7c
......@@ -50,7 +50,19 @@ from .loss import (
square_loss,
triplet_margin_loss,
)
from .math import argmax, argmin, max, mean, min, norm, normalize, prod, sqrt, sum
from .math import (
argmax,
argmin,
logsumexp,
max,
mean,
min,
norm,
normalize,
prod,
sqrt,
sum,
)
from .nn import (
assert_equal,
avg_pool2d,
......
......@@ -6,12 +6,15 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional
import math
import numbers
from typing import Optional, Sequence, Union
import megengine._internal as mgb
from ..core import Tensor, wrap_io_tensor
from .elemwise import clamp
from .elemwise import clamp, exp, isinf, log
from .tensor import remove_axis, where, zeros_like
@wrap_io_tensor
......@@ -296,3 +299,35 @@ def normalize(
return inp / clamp(norm(inp, p), lower=eps)
else:
return inp / clamp(norm(inp, p, axis, keepdims=True), lower=eps)
def logsumexp(inp: Tensor, axis: Union[int, Sequence[int]], keepdims: bool = False):
r"""
Compute the log of the sum of exponentials of inputs along the given :attr:`axis`. The computation is numerically stabilized.
.. math::
\mathsf{logsumexp}(x_1, \dots, x_n) = \log(\exp(x_1) + \cdots + \exp(x_n))
:param inp: The input tensor.
:param axis: Axis over which the sum is taken. It can be a single axis or a list of axes.
:param keepdims: whether to retain :attr:`axis` or not for the output tensor.
"""
if isinstance(axis, numbers.Integral):
axis = (axis,)
max_value = inp
for dim in axis:
max_value = max_value.max(axis=dim, keepdims=True)
max_value = where(
isinf(max_value).astype("int32"), zeros_like(max_value), max_value
)
x = exp(inp - max_value)
for dim in axis:
x = x.sum(axis=dim, keepdims=True)
x = max_value + log(x)
if not keepdims:
axis = sorted(axis, reverse=True)
for i in axis:
x = remove_axis(x, axis=i)
return x
......@@ -9,9 +9,12 @@
import numpy as np
def assertTensorClose(v0, v1, *, max_err=1e-6, name=None):
def assertTensorClose(
v0, v1, *, max_err: float = 1e-6, allow_special_values: bool = False, name=None
):
"""
max_err: relative error
:param allow_special_values: whether to allow :attr:`v0` and :attr:`v1` to contain inf and nan values.
:param max_err: relative error
"""
__tracebackhide__ = True # pylint: disable=unused-variable
......@@ -20,9 +23,30 @@ def assertTensorClose(v0, v1, *, max_err=1e-6, name=None):
), "Two Tensor must have same dtype, but the inputs are {} and {}".format(
v0.dtype, v1.dtype
)
v0 = np.ascontiguousarray(v0, dtype=np.float32)
v1 = np.ascontiguousarray(v1, dtype=np.float32)
assert np.isfinite(v0.sum()) and np.isfinite(v1.sum()), (v0, v1)
v0 = np.ascontiguousarray(v0, dtype=np.float32).copy()
v1 = np.ascontiguousarray(v1, dtype=np.float32).copy()
if allow_special_values:
# check nan and rm it
v0_nan_mask = np.isnan(v0)
if np.any(v0_nan_mask):
assert np.array_equiv(v0_nan_mask, np.isnan(v1)), (v0, v1)
v0[v0_nan_mask] = 0
v1[v0_nan_mask] = 0
# check inf and rm it
v0_inf_mask = v0 == float("inf")
if np.any(v0_inf_mask):
assert np.array_equiv(v0_inf_mask, v1 == float("inf")), (v0, v1)
v0[v0_inf_mask] = 0
v1[v0_inf_mask] = 0
# check -inf and rm it
v0_inf_mask = v0 == float("-inf")
if np.any(v0_inf_mask):
assert np.array_equiv(v0_inf_mask, v1 == float("-inf")), (v0, v1)
v0[v0_inf_mask] = 0
v1[v0_inf_mask] = 0
else:
assert np.isfinite(v0.sum()) and np.isfinite(v1.sum()), (v0, v1)
assert v0.shape == v1.shape, "Two tensor must have same shape({} v.s. {})".format(
v0.shape, v1.shape
)
......
......@@ -6,10 +6,14 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
import numpy as np
from helpers import opr_test
import megengine.functional as F
from megengine.test import assertTensorClose
def common_test_reduce(opr, ref_opr):
......@@ -86,7 +90,6 @@ def test_sqrt():
def test_normalize():
from functools import partial
cases = [
{"input": np.random.random((2, 3, 12, 12)).astype(np.float32)} for i in range(2)
......@@ -112,3 +115,54 @@ def test_normalize():
cases[0]["input"][0, 0, 0, :] = 0
cases[1]["input"][0, 0, 0, :] = 0
opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3))
def test_logsumexp():
x = np.arange(10).astype(np.float32)
expected = np.log(np.sum(np.exp(x)))
cases = [{"input": x, "output": expected}]
compare_fn = partial(assertTensorClose, allow_special_values=True)
# large value check
n = 100
x = np.full(n, 10000, dtype=np.float32)
expected = 10000 + np.log(n)
cases.append({"input": x, "output": expected.astype(np.float32)})
opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn)
# special value check
x = np.array([np.inf], dtype=np.float32)
expected = x
cases = [{"input": x, "output": expected}]
x = np.array([-np.inf, 0.0], dtype=np.float32)
expected = np.zeros(1).astype(np.float32)
cases.append({"input": x, "output": expected})
opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn)
x = np.array([np.nan], dtype=np.float32)
expected = x
cases = [{"input": x, "output": expected}]
x = np.array([-np.inf, 1], dtype=np.float32)
expected = np.array([1.0], dtype=np.float32)
cases.append({"input": x, "output": expected})
opr_test(cases, F.logsumexp, axis=0, compare_fn=compare_fn)
# keepdims check
x = np.array([[1e10, 1e-10], [-1e10, -np.inf]], dtype=np.float32)
expected = np.array([[1e10], [-1e10]], dtype=np.float32)
cases = [{"input": x, "output": expected}]
x = np.array([[1e10, -1e-10, 1e-10], [1e10, 1e-10, np.inf]], dtype=np.float32)
expected = np.array([[1e10], [np.inf]], dtype=np.float32)
cases.append({"input": x, "output": expected})
opr_test(cases, F.logsumexp, axis=1, keepdims=True, compare_fn=compare_fn)
# multiple axes check
x = np.array([[1e10, 1e-10], [-1e10, -np.inf]], dtype=np.float32)
expected = np.array([1e10], dtype=np.float32)
cases = [{"input": x, "output": expected}]
x = np.array([[1e10, -1e-10, 1e-10], [1e10, 1e-10, np.inf]], dtype=np.float32)
expected = np.array([np.inf], dtype=np.float32)
cases.append({"input": x, "output": expected})
opr_test(cases, F.logsumexp, axis=(0, 1), keepdims=False, compare_fn=compare_fn)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册