提交 9b413219 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mge/functional): support negative axis in math.py

GitOrigin-RevId: 75143a73096eb46a1ef821b95a4228fd8f06ab58
上级 207527d1
......@@ -106,8 +106,6 @@ def reduce_(src, mode, axis=None, keepdims=False, *,
inputs.append(1)
assert not keepdims, 'can not set axis=None and keepdims=True'
else:
assert isinstance(axis, int) and axis >= 0, (
'bad axis: {!r}'.format(axis))
remove_axis = not keepdims
kwargs['axis'] = axis
......
......@@ -196,7 +196,6 @@ def shape(self):
return get_var_shape(self)
def axis_shape(self, axis):
assert axis >= 0
from .opr import get_var_shape
return get_var_shape(self, axis=axis)
......
......@@ -20,30 +20,31 @@ def common_test_reduce(opr, ref_opr):
cases = [{"input": data1}, {"input": data2}]
if opr not in (F.argmin, F.argmax):
# test default axis
opr_test(cases, opr, ref_fn=ref_opr)
axis = 2
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis)
axis = 2
keepdims = True
opr_test(
cases,
opr,
ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=keepdims),
axis=axis,
keepdims=keepdims,
)
# test all axises in range of input shape
for axis in range(-3, 3):
# test keepdims False
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis)
# test keepdims True
opr_test(
cases,
opr,
ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=True),
axis=axis,
keepdims=True,
)
else:
# test defaut axis
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x).astype(np.int32))
axis = 2
opr_test(
cases,
opr,
ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32),
axis=axis,
)
# test all axises in range of input shape
for axis in range(0, 3):
opr_test(
cases,
opr,
ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32),
axis=axis,
)
def test_sum():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册