From 963ac9bad86e2e11c57680a6b68b47fc9227c5bd Mon Sep 17 00:00:00 2001 From: WJJ1995 Date: Wed, 24 Aug 2022 14:11:52 +0800 Subject: [PATCH] fixed reduce bug (#880) --- tests/onnx/test_auto_scan_reduce_ops.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/onnx/test_auto_scan_reduce_ops.py b/tests/onnx/test_auto_scan_reduce_ops.py index 3c5d9fc..77b3b8b 100644 --- a/tests/onnx/test_auto_scan_reduce_ops.py +++ b/tests/onnx/test_auto_scan_reduce_ops.py @@ -19,6 +19,11 @@ import numpy as np import unittest import random +min_opset_version_map = { + "ReduceL1": 7, + "ReduceL2": 7, +} + class TestReduceOpsConvert(OPConvertAutoScanTest): """ @@ -30,7 +35,7 @@ class TestReduceOpsConvert(OPConvertAutoScanTest): input_shape = draw( st.lists( st.integers( - min_value=20, max_value=30), min_size=3, max_size=5)) + min_value=10, max_value=20), min_size=3, max_size=5)) input_dtype = draw(st.sampled_from(["float32", "int32", "int64"])) @@ -55,6 +60,10 @@ class TestReduceOpsConvert(OPConvertAutoScanTest): "delta": 1e-4, "rtol": 1e-4, } + min_opset_versions = list() + for op_name in config["op_names"]: + min_opset_versions.append(min_opset_version_map[op_name]) + config["min_opset_version"] = min_opset_versions attrs = { "axes": axes, -- GitLab