From f85eb14693a5e69c192cb31c5b50d62bddf617b3 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Wed, 3 Aug 2022 17:54:07 +0800 Subject: [PATCH] fixed compare bug --- tests/onnx/onnxbase.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/onnx/onnxbase.py b/tests/onnx/onnxbase.py index b5f9db7..34954b4 100644 --- a/tests/onnx/onnxbase.py +++ b/tests/onnx/onnxbase.py @@ -60,7 +60,7 @@ def compare(result, expect, delta=1e-10, rtol=1e-10): result.shape, expect.shape) assert result.dtype == expect.dtype, "result.dtype: {} != expect.dtype: {}".format( result.dtype, expect.dtype) - elif isinstance(result, (list, tuple)) and len(result) > 1: + elif isinstance(result, (list, tuple)): for i in range(len(result)): if isinstance(result[i], (np.generic, np.ndarray)): compare(result[i], expect[i], delta, rtol) @@ -69,6 +69,8 @@ def compare(result, expect, delta=1e-10, rtol=1e-10): # deal with scalar tensor elif len(expect) == 1: compare(result, expect[0], delta, rtol) + else: + raise Exception("Compare diff wrong!!!!!!") def randtool(dtype, low, high, shape): -- GitLab