From 807b43851794dbf09b780144029aa01c6979af84 Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Tue, 2 Aug 2022 16:13:45 +0800 Subject: [PATCH] add Compare ops --- docs/inference_model_convertor/op_list.md | 2 +- tests/onnx/test_auto_scan_compare_ops.py | 78 +++++++++++++++++++++++ tests/onnx/test_auto_scan_equal.py | 67 +++++++++++++++++++ 3 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 tests/onnx/test_auto_scan_compare_ops.py create mode 100644 tests/onnx/test_auto_scan_equal.py diff --git a/docs/inference_model_convertor/op_list.md b/docs/inference_model_convertor/op_list.md index 7925eab..6c074b4 100755 --- a/docs/inference_model_convertor/op_list.md +++ b/docs/inference_model_convertor/op_list.md @@ -76,7 +76,7 @@ | 81 | Add | 82 | Concat | 83 | Max | 84 | Min | | 85 | GreaterOrEqual | 86 | GatherND | 87 | And | 88 | cos | | 89 | Neg | 90 | SpaceToDepth | 91 | GatherElement | 92 | Sin | -| 93 | CumSum | 94 | Or | 95 | Xor | | | +| 93 | CumSum | 94 | Or | 95 | Xor | 96 | Mod | ## PyTorch diff --git a/tests/onnx/test_auto_scan_compare_ops.py b/tests/onnx/test_auto_scan_compare_ops.py new file mode 100644 index 0000000..f984a35 --- /dev/null +++ b/tests/onnx/test_auto_scan_compare_ops.py @@ -0,0 +1,78 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import OPConvertAutoScanTest +from hypothesis import reproduce_failure +from onnxbase import randtool +import hypothesis.strategies as st +import numpy as np +import unittest + +min_opset_version_map = { + "Greater": 7, + "Less": 7, + "GreaterOrEqual": 12, + "LessOrEqual": 12, +} + + +class TestCompareopsConvert(OPConvertAutoScanTest): + """ + ONNX op: Compare ops + OPset version: 7~15 + """ + + def sample_convert_config(self, draw): + input1_shape = draw( + st.lists( + st.integers( + min_value=10, max_value=20), min_size=2, max_size=4)) + + if draw(st.booleans()): + input2_shape = [input1_shape[-1]] + else: + input2_shape = input1_shape + + if draw(st.booleans()): + input2_shape = [1] + + input_dtype = draw(st.sampled_from(["float32", "float64"])) + + config = { + "op_names": ["Greater", "Less", "GreaterOrEqual", "LessOrEqual"], + "test_data_shapes": [input1_shape, input2_shape], + "test_data_types": [[input_dtype], [input_dtype]], + "inputs_shape": [], + "min_opset_version": 7, + "inputs_name": ["x", "y"], + "outputs_name": ["z"], + "delta": 1e-4, + "rtol": 1e-4, + "run_dynamic": True, + } + min_opset_versions = list() + for op_name in config["op_names"]: + min_opset_versions.append(min_opset_version_map[op_name]) + config["min_opset_version"] = min_opset_versions + + attrs = {} + + return (config, attrs) + + def test(self): + self.run_and_statis(max_examples=30) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/onnx/test_auto_scan_equal.py b/tests/onnx/test_auto_scan_equal.py new file mode 100644 index 0000000..64e583d --- /dev/null +++ b/tests/onnx/test_auto_scan_equal.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import OPConvertAutoScanTest +from hypothesis import reproduce_failure +from onnxbase import randtool +import hypothesis.strategies as st +import numpy as np +import unittest + + +class TestEqualConvert(OPConvertAutoScanTest): + """ + ONNX op: Equal + OPset version: 7~15 + """ + + def sample_convert_config(self, draw): + input1_shape = draw( + st.lists( + st.integers( + min_value=10, max_value=20), min_size=2, max_size=4)) + + if draw(st.booleans()): + input2_shape = [input1_shape[-1]] + else: + input2_shape = input1_shape + + if draw(st.booleans()): + input2_shape = [1] + + input_dtype = draw(st.sampled_from(["int32", "int64", "bool"])) + + config = { + "op_names": ["Equal"], + "test_data_shapes": [input1_shape, input2_shape], + "test_data_types": [[input_dtype], [input_dtype]], + "inputs_shape": [], + "min_opset_version": 7, + "inputs_name": ["x", "y"], + "outputs_name": ["z"], + "delta": 1e-4, + "rtol": 1e-4, + "run_dynamic": True, + } + + attrs = {} + + return (config, attrs) + + def test(self): + self.run_and_statis(max_examples=30) + + +if __name__ == "__main__": + unittest.main() -- GitLab