提交 906b8268 编写于 作者: W wjj19950828

Add AveragePool tests

上级 33ea85b8
...@@ -19,6 +19,7 @@ import numpy as np ...@@ -19,6 +19,7 @@ import numpy as np
import logging import logging
import paddle import paddle
import onnx import onnx
import shutil
from onnx import helper from onnx import helper
from onnx import TensorProto from onnx import TensorProto
from onnxruntime import InferenceSession from onnxruntime import InferenceSession
...@@ -46,7 +47,7 @@ def compare(result, expect, delta=1e-10, rtol=1e-10): ...@@ -46,7 +47,7 @@ def compare(result, expect, delta=1e-10, rtol=1e-10):
expect = expect[0] expect = expect[0]
expect = np.array(expect) expect = np.array(expect)
res = np.allclose(result, expect, atol=delta, rtol=rtol, equal_nan=True) res = np.allclose(result, expect, atol=delta, rtol=rtol, equal_nan=True)
# 出错打印错误数据 # print wrong results
if res is False: if res is False:
if result.dtype == np.bool_: if result.dtype == np.bool_:
diff = abs(result.astype("int32") - expect.astype("int32")) diff = abs(result.astype("int32") - expect.astype("int32"))
...@@ -214,6 +215,9 @@ class ONNXConverter(object): ...@@ -214,6 +215,9 @@ class ONNXConverter(object):
# run # run
model = paddle.jit.load(paddle_path) model = paddle.jit.load(paddle_path)
result = model(*paddle_tensor_feed) result = model(*paddle_tensor_feed)
shutil.rmtree(
os.path.join(self.pwd, self.name, self.name + '_' + str(ver) +
'_paddle/'))
# get paddle outputs # get paddle outputs
if isinstance(result, (tuple, list)): if isinstance(result, (tuple, list)):
result = tuple(out.numpy() for out in result) result = tuple(out.numpy() for out in result)
...@@ -293,8 +297,6 @@ class ONNXConverter(object): ...@@ -293,8 +297,6 @@ class ONNXConverter(object):
self._onnx_to_paddle(ver=v) self._onnx_to_paddle(ver=v)
onnx_res[str(v)] = self._mk_onnx_res(ver=v) onnx_res[str(v)] = self._mk_onnx_res(ver=v)
paddle_res[str(v)] = self._mk_paddle_res(ver=v) paddle_res[str(v)] = self._mk_paddle_res(ver=v)
for v in range(self.min_opset_version, self.max_opset_version + 1):
compare( compare(
onnx_res[str(v)], onnx_res[str(v)],
paddle_res[str(v)], paddle_res[str(v)],
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import OPConvertAutoScanTest
from hypothesis import reproduce_failure
from onnxbase import randtool
import hypothesis.strategies as st
import numpy as np
import unittest
class TestAveragePoolConvert(OPConvertAutoScanTest):
"""
ONNX op: AveragePool
OPset version: 10~15
"""
def sample_convert_config(self, draw):
input_shape = draw(
st.lists(
st.integers(
min_value=10, max_value=20), min_size=4, max_size=4))
kernel_size = draw(
st.lists(
st.integers(
min_value=7, max_value=10), min_size=2, max_size=2))
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=2, max_size=2))
if draw(st.booleans()):
auto_pad = "NOTSET"
padding = None
if draw(st.booleans()):
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=5),
min_size=2,
max_size=2))
padding = [0, 0] + padding
else:
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=5),
min_size=4,
max_size=4))
else:
auto_pad = draw(
st.sampled_from(
["SAME_LOWER", "SAME_UPPER", "VALID", "NOTSET"]))
padding = None
if draw(st.booleans()):
ceil_mode = 0
else:
ceil_mode = 1
if padding == "VALID":
ceil_mode = False
config = {
"op_names": ["AveragePool"],
"test_data_shapes": [input_shape],
"test_data_types": [["float32"], ],
"inputs_shape": [],
"min_opset_version": 10,
"max_opset_version": 15,
"inputs_name": ["x"],
"outputs_name": ["y"],
"delta": 1e-4,
"rtol": 1e-4
}
attrs = {
"auto_pad": auto_pad,
"ceil_mode": ceil_mode,
"kernel_shape": kernel_size,
"pads": padding,
"strides": strides,
}
return (config, attrs)
def test(self):
self.run_and_statis(max_examples=30)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import OPConvertAutoScanTest
from hypothesis import reproduce_failure
from onnxbase import randtool
import hypothesis.strategies as st
import numpy as np
import unittest
class TestAveragePoolConvert(OPConvertAutoScanTest):
"""
ONNX op: AveragePool
OPset version: 7~15
"""
def sample_convert_config(self, draw):
input_shape = draw(
st.lists(
st.integers(
min_value=10, max_value=20), min_size=4, max_size=4))
kernel_size = draw(
st.lists(
st.integers(
min_value=7, max_value=10), min_size=2, max_size=2))
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=2, max_size=2))
if draw(st.booleans()):
auto_pad = "NOTSET"
padding = None
if draw(st.booleans()):
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=5),
min_size=2,
max_size=2))
padding = [0, 0] + padding
else:
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=5),
min_size=4,
max_size=4))
else:
auto_pad = draw(
st.sampled_from(
["SAME_LOWER", "SAME_UPPER", "VALID", "NOTSET"]))
padding = None
config = {
"op_names": ["AveragePool"],
"test_data_shapes": [input_shape],
"test_data_types": [["float32"], ],
"inputs_shape": [],
"min_opset_version": 7,
"max_opset_version": 9,
"inputs_name": ["x"],
"outputs_name": ["y"],
"delta": 1e-4,
"rtol": 1e-4,
}
attrs = {
"auto_pad": auto_pad,
"kernel_shape": kernel_size,
"pads": padding,
"strides": strides,
}
return (config, attrs)
def test(self):
self.run_and_statis(max_examples=30)
if __name__ == "__main__":
unittest.main()
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
# limitations under the License. # limitations under the License.
from .opset9 import OpSet9 from .opset9 import OpSet9
from x2paddle.core.util import *
import numpy as np
import math
def print_mapping_info(func): def print_mapping_info(func):
...@@ -29,8 +32,70 @@ def print_mapping_info(func): ...@@ -29,8 +32,70 @@ def print_mapping_info(func):
return run_mapping return run_mapping
def _get_same_padding(in_size, kernel_size, stride, autopad):
new_size = int(math.ceil(in_size * 1.0 / stride))
pad_size = (new_size - 1) * stride + kernel_size - in_size
pad0 = int(pad_size / 2)
pad1 = pad_size - pad0
if autopad == "SAME_UPPER":
return [pad0, pad1]
if autopad == "SAME_LOWER":
return [pad1, pad0]
class OpSet10(OpSet9): class OpSet10(OpSet9):
def __init__(self, decoder, paddle_graph): def __init__(self, decoder, paddle_graph):
super(OpSet10, self).__init__(decoder, paddle_graph) super(OpSet10, self).__init__(decoder, paddle_graph)
# Support Mod op Since opset version >= 10 # Support Mod op Since opset version >= 10
self.elementwise_ops.update({"Mod": "paddle.mod"}) self.elementwise_ops.update({"Mod": "paddle.mod"})
@print_mapping_info
def AveragePool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
kernel_shape = node.get_attr("kernel_shape")
count_include_pad = node.get_attr("count_include_pad", 0)
# Support ceil_mode Since opset version >= 10
ceil_mode = bool(node.get_attr("ceil_mode", 0))
exclusive = True
if count_include_pad > 0:
exclusive = False
poolnd = len(kernel_shape)
strides = node.get_attr("strides")
pad_mode = node.get_attr("pads")
pads = node.get_attr('pads', [0] * (poolnd * 2))
input_shape = val_x.out_shapes[0]
paddings = np.array(pads).reshape((2, -1)).transpose().astype("int32")
paddings = paddings.flatten().tolist()
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
# Warning: SAME_UPPER and SAME_LOWER does not yet support dynamic shapes
if input_shape[2] == -1 or input_shape[3] == -1:
_logger.warning(
'SAME_UPPER and SAME_LOWER does not yet support dynamic shapes, the conversion result may have a diff!!!'
)
pad_h = _get_same_padding(input_shape[2], kernel_shape[0],
strides[0], auto_pad)
pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
strides[1], auto_pad)
paddings = pad_h + pad_w
op_name = name_generator("pool", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
paddle_op = 'paddle.nn.AvgPool{}D'.format(poolnd)
assert 1 <= poolnd <= 3, 'only Pool1D, Pool2D and Pool3D are supported'
layer_attrs = {
"kernel_size": kernel_shape,
"stride": strides,
"ceil_mode": ceil_mode,
"padding": paddings,
"exclusive": exclusive,
}
self.paddle_graph.add_layer(
paddle_op,
inputs={'x': val_x if isinstance(val_x, str) else val_x.name},
outputs=layer_outputs,
**layer_attrs)
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
# limitations under the License. # limitations under the License.
from .opset_legacy import OpSet from .opset_legacy import OpSet
from x2paddle.core.util import *
import numpy as np
import math
def print_mapping_info(func): def print_mapping_info(func):
...@@ -29,10 +32,69 @@ def print_mapping_info(func): ...@@ -29,10 +32,69 @@ def print_mapping_info(func):
return run_mapping return run_mapping
def _get_same_padding(in_size, kernel_size, stride, autopad):
new_size = int(math.ceil(in_size * 1.0 / stride))
pad_size = (new_size - 1) * stride + kernel_size - in_size
pad0 = int(pad_size / 2)
pad1 = pad_size - pad0
if autopad == "SAME_UPPER":
return [pad0, pad1]
if autopad == "SAME_LOWER":
return [pad1, pad0]
class OpSet7(OpSet): class OpSet7(OpSet):
def __init__(self, decoder, paddle_graph): def __init__(self, decoder, paddle_graph):
super(OpSet7, self).__init__(decoder, paddle_graph) super(OpSet7, self).__init__(decoder, paddle_graph)
@print_mapping_info
def AveragePool(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
auto_pad = node.get_attr('auto_pad', 'NOTSET')
kernel_shape = node.get_attr("kernel_shape")
count_include_pad = node.get_attr("count_include_pad", 0)
exclusive = True
if count_include_pad > 0:
exclusive = False
poolnd = len(kernel_shape)
strides = node.get_attr("strides")
pad_mode = node.get_attr("pads")
pads = node.get_attr('pads', [0] * (poolnd * 2))
input_shape = val_x.out_shapes[0]
paddings = np.array(pads).reshape((2, -1)).transpose().astype("int32")
paddings = paddings.flatten().tolist()
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
# Warning: SAME_UPPER and SAME_LOWER does not yet support dynamic shapes
if input_shape[2] == -1 or input_shape[3] == -1:
_logger.warning(
'SAME_UPPER and SAME_LOWER does not yet support dynamic shapes, the conversion result may have a diff!!!'
)
pad_h = _get_same_padding(input_shape[2], kernel_shape[0],
strides[0], auto_pad)
pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
strides[1], auto_pad)
paddings = pad_h + pad_w
op_name = name_generator("pool", self.nn_name2id)
output_name = node.name
layer_outputs = [op_name, output_name]
paddle_op = 'paddle.nn.AvgPool{}D'.format(poolnd)
assert 1 <= poolnd <= 3, 'only Pool1D, Pool2D and Pool3D are supported'
layer_attrs = {
"kernel_size": kernel_shape,
"stride": strides,
"padding": paddings,
"exclusive": exclusive,
}
self.paddle_graph.add_layer(
paddle_op,
inputs={'x': val_x if isinstance(val_x, str) else val_x.name},
outputs=layer_outputs,
**layer_attrs)
@print_mapping_info @print_mapping_info
def Or(self, node): def Or(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册