diff --git a/python_module/src/python/opr_template.py b/python_module/src/python/opr_template.py index c83d4172d7cf4c82d97b22e243d0018f96063fdf..29ea0a9d4df138b8dc4583074e7404673179802e 100644 --- a/python_module/src/python/opr_template.py +++ b/python_module/src/python/opr_template.py @@ -106,8 +106,6 @@ def reduce_(src, mode, axis=None, keepdims=False, *, inputs.append(1) assert not keepdims, 'can not set axis=None and keepdims=True' else: - assert isinstance(axis, int) and axis >= 0, ( - 'bad axis: {!r}'.format(axis)) remove_axis = not keepdims kwargs['axis'] = axis diff --git a/python_module/src/swig/symbol_var_SymbolVar.py b/python_module/src/swig/symbol_var_SymbolVar.py index f2c104378e7cf1b41e8f41d47cc23e4ba2a9787c..96fb338350fc212f045b823047a3a6d092b938e1 100644 --- a/python_module/src/swig/symbol_var_SymbolVar.py +++ b/python_module/src/swig/symbol_var_SymbolVar.py @@ -196,7 +196,6 @@ def shape(self): return get_var_shape(self) def axis_shape(self, axis): - assert axis >= 0 from .opr import get_var_shape return get_var_shape(self, axis=axis) diff --git a/python_module/test/unit/functional/test_math.py b/python_module/test/unit/functional/test_math.py index 96e4940811a36e02e8dce47f660f010fd19c9423..b5cb4a4981c7f8b830b75e8deeba778684a0fbea 100644 --- a/python_module/test/unit/functional/test_math.py +++ b/python_module/test/unit/functional/test_math.py @@ -20,30 +20,31 @@ def common_test_reduce(opr, ref_opr): cases = [{"input": data1}, {"input": data2}] if opr not in (F.argmin, F.argmax): + # test default axis opr_test(cases, opr, ref_fn=ref_opr) - - axis = 2 - opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis) - - axis = 2 - keepdims = True - opr_test( - cases, - opr, - ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=keepdims), - axis=axis, - keepdims=keepdims, - ) + # test all axises in range of input shape + for axis in range(-3, 3): + # test keepdims False + opr_test(cases, opr, ref_fn=lambda x: ref_opr(x, axis=axis), axis=axis) + # test keepdims True + opr_test( + cases, + opr, + ref_fn=lambda x: ref_opr(x, axis=axis, keepdims=True), + axis=axis, + keepdims=True, + ) else: + # test defaut axis opr_test(cases, opr, ref_fn=lambda x: ref_opr(x).astype(np.int32)) - - axis = 2 - opr_test( - cases, - opr, - ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), - axis=axis, - ) + # test all axises in range of input shape + for axis in range(0, 3): + opr_test( + cases, + opr, + ref_fn=lambda x: ref_opr(x, axis=axis).astype(np.int32), + axis=axis, + ) def test_sum():