提交 9b413219 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(mge/functional): support negative axis in math.py

GitOrigin-RevId: 75143a73096eb46a1ef821b95a4228fd8f06ab58
上级 207527d1
...@@ -106,8 +106,6 @@ def reduce_(src, mode, axis=None, keepdims=False, *, ...@@ -106,8 +106,6 @@ def reduce_(src, mode, axis=None, keepdims=False, *,
inputs.append(1) inputs.append(1)
assert not keepdims, 'can not set axis=None and keepdims=True' assert not keepdims, 'can not set axis=None and keepdims=True'
else: else:
assert isinstance(axis, int) and axis >= 0, (
'bad axis: {!r}'.format(axis))
remove_axis = not keepdims remove_axis = not keepdims
kwargs['axis'] = axis kwargs['axis'] = axis
......
...@@ -196,7 +196,6 @@ def shape(self): ...@@ -196,7 +196,6 @@ def shape(self):
return get_var_shape(self) return get_var_shape(self)
def axis_shape(self, axis): def axis_shape(self, axis):
assert axis >= 0
from .opr import get_var_shape from .opr import get_var_shape
return get_var_shape(self, axis=axis) return get_var_shape(self, axis=axis)
......
...@@ -20,30 +20,31 @@ def common_test_reduce(opr, ref_opr): ...@@ -20,30 +20,31 @@ def common_test_reduce(opr, ref_opr):
cases = [{"input": data1}, {"input": data2}] cases = [{"input": data1}, {"input": data2}]
if opr not in (F.argmin, F.argmax): if opr not in (F.argmin, F.argmax):
# test default axis
opr_test(cases, opr, ref_fn=ref_opr) opr_test(cases, opr, ref_fn=ref_opr)
# test all axises in range of input shape
axis = 2 for axis in range(-3, 3):
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis) # test keepdims False
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis)
axis = 2 # test keepdims True
keepdims = True opr_test(
opr_test( cases,
cases, opr,
opr, ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=True),
ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=keepdims), axis=axis,
axis=axis, keepdims=True,
keepdims=keepdims, )
)
else: else:
# test defaut axis
opr_test(cases, opr, ref_fn=lambda x: ref_opr(x).astype(np.int32)) opr_test(cases, opr, ref_fn=lambda x: ref_opr(x).astype(np.int32))
# test all axises in range of input shape
axis = 2 for axis in range(0, 3):
opr_test( opr_test(
cases, cases,
opr, opr,
ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32),
axis=axis, axis=axis,
) )
def test_sum(): def test_sum():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册