提交 984d85ca 编写于 作者: M Megvii Engine Team

feat(mge/functional): argmin and argmax support negtive axis

GitOrigin-RevId: a1bd1102a6c24b6c4ceb257f962ffdebe1669744
上级 87d6ff22
......@@ -165,7 +165,7 @@ def _remove_axis(inp: Tensor, axis) -> Tensor:
return list(map(int, axis))
axis = get_axes()
axis = sorted(i + inp.ndim if i < 0 else i for i in axis)
axis = utils._normalize_axis(inp.ndim, axis)
axis = [a - i for i, a in enumerate(axis)]
op = builtin.RemoveAxis(axis=axis)
......@@ -190,8 +190,7 @@ def _reduce(mode):
op = builtin.Reduce(mode=mode, axis=0)
(result,) = apply(op, data)
elif isinstance(axis, collections.abc.Iterable):
axis = list(axis)
axis.sort(reverse=True)
axis = utils._normalize_axis(self.ndim, axis, reverse=True)
for ai in axis:
op = builtin.Reduce(mode=mode, axis=ai)
(data,) = apply(op, data)
......@@ -199,6 +198,7 @@ def _reduce(mode):
data = _remove_axis(data, ai)
result = data
else:
# builtin.Reduce already accept negtive axis
op = builtin.Reduce(mode=mode, axis=axis)
(result,) = apply(op, data)
......
......@@ -178,3 +178,28 @@ def make_shape_tuple(shape):
s = []
_expand_int(s, shape)
return tuple(s)
def _normalize_axis(
ndim: int, axis: Union[int, Iterable], reverse=False
) -> Union[int, list]:
def convert(x):
x_org = x
if x < 0:
x = ndim + x
assert (
x >= 0 and x < ndim
), "axis {} is out of bounds for tensor of dimension {}".format(x_org, ndim)
return x
if isinstance(axis, int):
return convert(axis)
elif isinstance(axis, Iterable):
axis_org = axis
axis = list(sorted(map(convert, axis), reverse=reverse))
for i in range(len(axis) - 1):
assert axis[i] != axis[i + 1], "axis {} contains duplicated indices".format(
axis_org
)
return axis
raise
......@@ -466,9 +466,13 @@ def argmin(
0
"""
if axis is None:
assert not keepdims, "can not set axis=None and keepdims=True"
inp = inp.flatten()
axis = 0
axis = utils._normalize_axis(inp.ndim, axis, reverse=True)
if isinstance(axis, collections.abc.Iterable):
axis = list(axis)
axis.sort(reverse=True)
for ai in axis:
op = builtin.Argmin(axis=ai)
......@@ -479,11 +483,6 @@ def argmin(
return inp
if axis is None:
assert not keepdims, "can not set axis=None and keepdims=True"
inp = inp.flatten()
axis = 0
op = builtin.Argmin(axis=axis)
(result,) = apply(op, inp)
if not keepdims:
......@@ -525,9 +524,13 @@ def argmax(
5
"""
if axis is None:
assert not keepdims, "can not set axis=None and keepdims=True"
inp = inp.flatten()
axis = 0
axis = utils._normalize_axis(inp.ndim, axis, reverse=True)
if isinstance(axis, collections.abc.Iterable):
axis = list(axis)
axis.sort(reverse=True)
for ai in axis:
op = builtin.Argmax(axis=ai)
......@@ -538,11 +541,6 @@ def argmax(
return inp
if axis is None:
assert not keepdims, "can not set axis=None and keepdims=True"
inp = inp.flatten()
axis = 0
op = builtin.Argmax(axis=axis)
(result,) = apply(op, inp)
if not keepdims:
......
......@@ -811,3 +811,19 @@ def test_assert_not_equal():
y = F.zeros(shape, dtype=np.float32) + 1.1
with pytest.raises(RuntimeError):
z = F.utils._assert_equal(x, y)
def test_neg_axis():
x = tensor(np.random.normal(0, 1, (32, 5)))
y = F.argmax(x, axis=-1)
yy = F.argmax(x, axis=1)
np.testing.assert_equal(y.numpy(), yy.numpy())
y = F.argmax(x, axis=(-1, -2))
yy = F.argmax(x, axis=(0, 1))
np.testing.assert_equal(y.numpy(), yy.numpy())
y = F.argmin(x, axis=(-1, -2))
yy = F.argmin(x, axis=(0, 1))
np.testing.assert_equal(y.numpy(), yy.numpy())
......@@ -9,6 +9,7 @@
from functools import partial
import numpy as np
import pytest
from utils import opr_test
import megengine.functional as F
......@@ -48,6 +49,14 @@ def common_test_reduce(opr, ref_opr):
ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32),
axis=axis,
)
# test negative axis
axis = axis - len(data1_shape)
opr_test(
cases,
opr,
ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32),
axis=axis,
)
def test_sum():
......@@ -137,3 +146,14 @@ def test_normalize():
cases[0]["input"][0, 0, 0, :] = 0
cases[1]["input"][0, 0, 0, :] = 0
opr_test(cases, partial(F.normalize, axis=3), ref_fn=partial(np_normalize, axis=3))
def test_sum_neg_axis():
shape = (2, 3)
data = np.random.random(shape).astype(np.float32)
for axis in (-1, -2, (-2, 1), (-1, 0)):
get = F.sum(tensor(data), axis=axis)
ref = np.sum(data, axis=axis)
np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6)
with pytest.raises(AssertionError):
F.sum(tensor(data), axis=(-1, 1))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册