diff --git a/tests/onnx/onnxbase.py b/tests/onnx/onnxbase.py index b5f9db718b8c8be7fe2d76c002cc04233029ee97..34954b4df1c0cb2d7c14a92af3b8140c1f44783d 100644 --- a/tests/onnx/onnxbase.py +++ b/tests/onnx/onnxbase.py @@ -60,7 +60,7 @@ def compare(result, expect, delta=1e-10, rtol=1e-10): result.shape, expect.shape) assert result.dtype == expect.dtype, "result.dtype: {} != expect.dtype: {}".format( result.dtype, expect.dtype) - elif isinstance(result, (list, tuple)) and len(result) > 1: + elif isinstance(result, (list, tuple)): for i in range(len(result)): if isinstance(result[i], (np.generic, np.ndarray)): compare(result[i], expect[i], delta, rtol) @@ -69,6 +69,8 @@ def compare(result, expect, delta=1e-10, rtol=1e-10): # deal with scalar tensor elif len(expect) == 1: compare(result, expect[0], delta, rtol) + else: + raise Exception("Compare diff wrong!!!!!!") def randtool(dtype, low, high, shape):