diff --git a/tests/onnx/onnxbase.py b/tests/onnx/onnxbase.py index f6871083d86a928a775b9787da2ef781d4b0b152..c83456ce4e1af5a2db30f2c8e842cffb22f3a62c 100644 --- a/tests/onnx/onnxbase.py +++ b/tests/onnx/onnxbase.py @@ -19,6 +19,9 @@ import numpy as np import logging import paddle import onnx +import shutil +from paddle.inference import create_predictor, PrecisionType +from paddle.inference import Config from onnx import helper from onnx import TensorProto from onnxruntime import InferenceSession @@ -46,7 +49,7 @@ def compare(result, expect, delta=1e-10, rtol=1e-10): expect = expect[0] expect = np.array(expect) res = np.allclose(result, expect, atol=delta, rtol=rtol, equal_nan=True) - # 出错打印错误数据 + # print wrong results if res is False: if result.dtype == np.bool_: diff = abs(result.astype("int32") - expect.astype("int32")) @@ -190,10 +193,15 @@ class ONNXConverter(object): """ # input data paddle_tensor_feed = list() + result = list() for i in range(len(self.input_feed)): paddle_tensor_feed.append( paddle.to_tensor(self.input_feed[self.inputs_name[i]])) + ## PaddleInference not support float64 + if "float64" in self.inputs_dtype: + self.run_dynamic = True + if self.run_dynamic: paddle_path = os.path.join(self.pwd, self.name, self.name + '_' + str(ver) + '_paddle/') @@ -207,18 +215,49 @@ class ONNXConverter(object): model.eval() result = model(*paddle_tensor_feed) else: - paddle_path = os.path.join( - self.pwd, self.name, - self.name + '_' + str(ver) + '_paddle/inference_model/model') - paddle.disable_static() - # run - model = paddle.jit.load(paddle_path) - result = model(*paddle_tensor_feed) + paddle_model_path = os.path.join( + self.pwd, self.name, self.name + '_' + str(ver) + + '_paddle/inference_model/model.pdmodel') + paddle_param_path = os.path.join( + self.pwd, self.name, self.name + '_' + str(ver) + + '_paddle/inference_model/model.pdiparams') + config = Config() + config.set_prog_file(paddle_model_path) + if os.path.exists(paddle_param_path): + config.set_params_file(paddle_param_path) + # initial GPU memory(M), device ID + config.enable_use_gpu(200, 0) + # optimize graph and fuse op + config.switch_ir_optim(False) + config.enable_memory_optim() + # disable feed, fetch OP, needed by zero_copy_run + config.switch_use_feed_fetch_ops(False) + config.disable_glog_info() + pass_builder = config.pass_builder() + predictor = create_predictor(config) + input_names = predictor.get_input_names() + output_names = predictor.get_output_names() + for i in range(len(input_names)): + input_tensor = predictor.get_input_handle(input_names[i]) + input_tensor.copy_from_cpu(self.input_feed[self.inputs_name[i]]) + predictor.run() + for output_name in output_names: + output_tensor = predictor.get_output_handle(output_name) + result.append(output_tensor.copy_to_cpu()) + shutil.rmtree( + os.path.join(self.pwd, self.name, self.name + '_' + str(ver) + + '_paddle/')) # get paddle outputs if isinstance(result, (tuple, list)): - result = tuple(out.numpy() for out in result) + if isinstance(result[0], np.ndarray): + result = tuple(out for out in result) + else: + result = tuple(out.numpy() for out in result) else: - result = (result.numpy(), ) + if isinstance(result, np.ndarray): + result = (result, ) + else: + result = (result.numpy(), ) return result def _mk_onnx_res(self, ver): @@ -293,8 +332,6 @@ class ONNXConverter(object): self._onnx_to_paddle(ver=v) onnx_res[str(v)] = self._mk_onnx_res(ver=v) paddle_res[str(v)] = self._mk_paddle_res(ver=v) - - for v in range(self.min_opset_version, self.max_opset_version + 1): compare( onnx_res[str(v)], paddle_res[str(v)], diff --git a/tests/onnx/test_auto_scan_averagepool_10.py b/tests/onnx/test_auto_scan_averagepool_10.py new file mode 100644 index 0000000000000000000000000000000000000000..782990adf2085c939c7ff752819e482f21bb1d51 --- /dev/null +++ b/tests/onnx/test_auto_scan_averagepool_10.py @@ -0,0 +1,104 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import OPConvertAutoScanTest +from hypothesis import reproduce_failure +from onnxbase import randtool +import hypothesis.strategies as st +import numpy as np +import unittest + + +class TestAveragePoolConvert(OPConvertAutoScanTest): + """ + ONNX op: AveragePool + OPset version: 10~15 + """ + + def sample_convert_config(self, draw): + input_shape = draw( + st.lists( + st.integers( + min_value=20, max_value=30), min_size=4, max_size=4)) + # BS = 1 + input_shape[0] = 1 + kernel_size = draw( + st.lists( + st.integers( + min_value=7, max_value=10), min_size=2, max_size=2)) + + strides = draw( + st.lists( + st.integers( + min_value=1, max_value=2), min_size=2, max_size=2)) + + if draw(st.booleans()): + auto_pad = "NOTSET" + padding = None + if draw(st.booleans()): + padding = draw( + st.lists( + st.integers( + min_value=1, max_value=5), + min_size=2, + max_size=2)) + padding = [0, 0] + padding + else: + padding = draw( + st.lists( + st.integers( + min_value=1, max_value=5), + min_size=4, + max_size=4)) + else: + auto_pad = draw( + st.sampled_from( + ["SAME_LOWER", "SAME_UPPER", "VALID", "NOTSET"])) + padding = None + + if draw(st.booleans()): + ceil_mode = 0 + else: + ceil_mode = 1 + if padding == "VALID": + ceil_mode = False + + config = { + "op_names": ["AveragePool"], + "test_data_shapes": [input_shape], + "test_data_types": [["float32"], ], + "inputs_shape": [], + "min_opset_version": 10, + "max_opset_version": 15, + "inputs_name": ["x"], + "outputs_name": ["y"], + "delta": 1e-4, + "rtol": 1e-4 + } + + attrs = { + "auto_pad": auto_pad, + "ceil_mode": ceil_mode, + "kernel_shape": kernel_size, + "pads": padding, + "strides": strides, + } + return (config, attrs) + + def test(self): + self.run_and_statis(max_examples=30) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/onnx/test_auto_scan_averagepool_7.py b/tests/onnx/test_auto_scan_averagepool_7.py new file mode 100644 index 0000000000000000000000000000000000000000..666800bf47d0c1aab0975a681aae5f34bafa6859 --- /dev/null +++ b/tests/onnx/test_auto_scan_averagepool_7.py @@ -0,0 +1,95 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import OPConvertAutoScanTest +from hypothesis import reproduce_failure +from onnxbase import randtool +import hypothesis.strategies as st +import numpy as np +import unittest + + +class TestAveragePoolConvert(OPConvertAutoScanTest): + """ + ONNX op: AveragePool + OPset version: 7~9 + """ + + def sample_convert_config(self, draw): + input_shape = draw( + st.lists( + st.integers( + min_value=10, max_value=20), min_size=4, max_size=4)) + + kernel_size = draw( + st.lists( + st.integers( + min_value=7, max_value=10), min_size=2, max_size=2)) + + strides = draw( + st.lists( + st.integers( + min_value=1, max_value=2), min_size=2, max_size=2)) + + if draw(st.booleans()): + auto_pad = "NOTSET" + padding = None + if draw(st.booleans()): + padding = draw( + st.lists( + st.integers( + min_value=1, max_value=5), + min_size=2, + max_size=2)) + padding = [0, 0] + padding + else: + padding = draw( + st.lists( + st.integers( + min_value=1, max_value=5), + min_size=4, + max_size=4)) + else: + auto_pad = draw( + st.sampled_from( + ["SAME_LOWER", "SAME_UPPER", "VALID", "NOTSET"])) + padding = None + + config = { + "op_names": ["AveragePool"], + "test_data_shapes": [input_shape], + "test_data_types": [["float32"], ], + "inputs_shape": [], + "min_opset_version": 7, + "max_opset_version": 9, + "inputs_name": ["x"], + "outputs_name": ["y"], + "delta": 1e-4, + "rtol": 1e-4, + } + + attrs = { + "auto_pad": auto_pad, + "kernel_shape": kernel_size, + "pads": padding, + "strides": strides, + } + return (config, attrs) + + def test(self): + self.run_and_statis(max_examples=30) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/onnx/test_auto_scan_mod.py b/tests/onnx/test_auto_scan_mod.py index 05e6b0e31bbeaa948f653078cfaaf9fd4eea4cb8..201a943e0fe48ae7712644b4dd04d95c456372a1 100644 --- a/tests/onnx/test_auto_scan_mod.py +++ b/tests/onnx/test_auto_scan_mod.py @@ -42,8 +42,7 @@ class TestModConvert(OPConvertAutoScanTest): input_data[abs(input_data) < 1.0] = 1.0 return input_data - input_dtype = draw( - st.sampled_from(["int32", "int64", "float32", "float64"])) + input_dtype = draw(st.sampled_from(["int32", "int64"])) config = { "op_names": ["Mod"], diff --git a/x2paddle/op_mapper/onnx2paddle/opset10.py b/x2paddle/op_mapper/onnx2paddle/opset10.py index 48df1e4bc5e49da1f05313e9e59dd3bf2dad4c63..5a697e95a31d5805ac090c8c3c6533df00f3a3ca 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset10.py +++ b/x2paddle/op_mapper/onnx2paddle/opset10.py @@ -13,6 +13,9 @@ # limitations under the License. from .opset9 import OpSet9 +from x2paddle.core.util import * +import numpy as np +import math def print_mapping_info(func): @@ -29,8 +32,83 @@ def print_mapping_info(func): return run_mapping +def _get_same_padding(in_size, kernel_size, stride, autopad): + new_size = int(math.ceil(in_size * 1.0 / stride)) + pad_size = (new_size - 1) * stride + kernel_size - in_size + pad0 = int(pad_size / 2) + pad1 = pad_size - pad0 + if autopad == "SAME_UPPER": + return [pad0, pad1] + if autopad == "SAME_LOWER": + return [pad1, pad0] + + class OpSet10(OpSet9): def __init__(self, decoder, paddle_graph): super(OpSet10, self).__init__(decoder, paddle_graph) + + @print_mapping_info + def AveragePool(self, node): + val_x = self.graph.get_input_node(node, idx=0, copy=True) + + auto_pad = node.get_attr('auto_pad', 'NOTSET') + kernel_shape = node.get_attr("kernel_shape") + count_include_pad = node.get_attr("count_include_pad", 0) + # Support ceil_mode Since opset version >= 10 + ceil_mode = bool(node.get_attr("ceil_mode", 0)) + exclusive = True + if count_include_pad > 0: + exclusive = False + poolnd = len(kernel_shape) + strides = node.get_attr("strides") + pad_mode = node.get_attr("pads") + pads = node.get_attr('pads', [0] * (poolnd * 2)) + + input_shape = val_x.out_shapes[0] + paddings = np.array(pads).reshape((2, -1)).transpose().astype("int32") + paddings = paddings.flatten().tolist() + + if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": + # Warning: SAME_UPPER and SAME_LOWER does not yet support dynamic shapes + if input_shape[2] == -1 or input_shape[3] == -1: + _logger.warning( + 'SAME_UPPER and SAME_LOWER does not yet support dynamic shapes, the conversion result may have a diff!!!' + ) + pad_h = _get_same_padding(input_shape[2], kernel_shape[0], + strides[0], auto_pad) + pad_w = _get_same_padding(input_shape[3], kernel_shape[1], + strides[1], auto_pad) + paddings = pad_h + pad_w + + op_name = name_generator("pool", self.nn_name2id) + output_name = node.name + layer_outputs = [op_name, output_name] + paddle_op = 'paddle.nn.AvgPool{}D'.format(poolnd) + assert 1 <= poolnd <= 3, 'only Pool1D, Pool2D and Pool3D are supported' + layer_attrs = { + "kernel_size": kernel_shape, + "stride": strides, + "ceil_mode": ceil_mode, + "padding": paddings, + "exclusive": exclusive, + } + self.paddle_graph.add_layer( + paddle_op, + inputs={'x': val_x if isinstance(val_x, str) else val_x.name}, + outputs=layer_outputs, + **layer_attrs) + + @print_mapping_info + def Mod(self, node): # Support Mod op Since opset version >= 10 - self.elementwise_ops.update({"Mod": "paddle.mod"}) + val_x = self.graph.get_input_node(node, idx=0, copy=True) + val_y = self.graph.get_input_node(node, idx=1, copy=True) + input_dtype = str(val_x.dtype) + assert "int" in input_dtype, 'Now only support int32 or int64 dtype' + fmod = node.get_attr('fmod', 0) + assert fmod == 0, 'Now only support fmod == 0' + self.paddle_graph.add_layer( + 'paddle.mod', + inputs={"x": val_x.name, + "y": val_y.name}, + outputs=[node.name]) diff --git a/x2paddle/op_mapper/onnx2paddle/opset7.py b/x2paddle/op_mapper/onnx2paddle/opset7.py index 6c0c22ecea7aeb9ca941305e591d52c7896ffee2..2fe457ef0ab0d32c4206f75270ccf506398980fe 100644 --- a/x2paddle/op_mapper/onnx2paddle/opset7.py +++ b/x2paddle/op_mapper/onnx2paddle/opset7.py @@ -13,6 +13,9 @@ # limitations under the License. from .opset_legacy import OpSet +from x2paddle.core.util import * +import numpy as np +import math def print_mapping_info(func): @@ -29,10 +32,69 @@ def print_mapping_info(func): return run_mapping +def _get_same_padding(in_size, kernel_size, stride, autopad): + new_size = int(math.ceil(in_size * 1.0 / stride)) + pad_size = (new_size - 1) * stride + kernel_size - in_size + pad0 = int(pad_size / 2) + pad1 = pad_size - pad0 + if autopad == "SAME_UPPER": + return [pad0, pad1] + if autopad == "SAME_LOWER": + return [pad1, pad0] + + class OpSet7(OpSet): def __init__(self, decoder, paddle_graph): super(OpSet7, self).__init__(decoder, paddle_graph) + @print_mapping_info + def AveragePool(self, node): + val_x = self.graph.get_input_node(node, idx=0, copy=True) + + auto_pad = node.get_attr('auto_pad', 'NOTSET') + kernel_shape = node.get_attr("kernel_shape") + count_include_pad = node.get_attr("count_include_pad", 0) + exclusive = True + if count_include_pad > 0: + exclusive = False + poolnd = len(kernel_shape) + strides = node.get_attr("strides") + pad_mode = node.get_attr("pads") + pads = node.get_attr('pads', [0] * (poolnd * 2)) + + input_shape = val_x.out_shapes[0] + paddings = np.array(pads).reshape((2, -1)).transpose().astype("int32") + paddings = paddings.flatten().tolist() + + if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": + # Warning: SAME_UPPER and SAME_LOWER does not yet support dynamic shapes + if input_shape[2] == -1 or input_shape[3] == -1: + _logger.warning( + 'SAME_UPPER and SAME_LOWER does not yet support dynamic shapes, the conversion result may have a diff!!!' + ) + pad_h = _get_same_padding(input_shape[2], kernel_shape[0], + strides[0], auto_pad) + pad_w = _get_same_padding(input_shape[3], kernel_shape[1], + strides[1], auto_pad) + paddings = pad_h + pad_w + + op_name = name_generator("pool", self.nn_name2id) + output_name = node.name + layer_outputs = [op_name, output_name] + paddle_op = 'paddle.nn.AvgPool{}D'.format(poolnd) + assert 1 <= poolnd <= 3, 'only Pool1D, Pool2D and Pool3D are supported' + layer_attrs = { + "kernel_size": kernel_shape, + "stride": strides, + "padding": paddings, + "exclusive": exclusive, + } + self.paddle_graph.add_layer( + paddle_op, + inputs={'x': val_x if isinstance(val_x, str) else val_x.name}, + outputs=layer_outputs, + **layer_attrs) + @print_mapping_info def Or(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True)