提交 23b147ae 编写于 作者: W wjj19950828

Add Unsqueeze Tests

上级 ed15a8f1
......@@ -154,6 +154,7 @@ class OPConvertAutoScanTest(unittest.TestCase):
# max_opset_version is a fixed value
max_opset_version = 15
enable_onnx_checker = True
self.num_ran_tests += 1
# add ignore testcases
......@@ -183,11 +184,14 @@ class OPConvertAutoScanTest(unittest.TestCase):
rtol = config["rtol"]
if "max_opset_version" in config.keys():
max_opset_version = config["max_opset_version"]
if "enable_onnx_checker" in config.keys():
enable_onnx_checker = config["enable_onnx_checker"]
for i in range(len(op_names)):
obj = ONNXConverter(op_names[i], min_opset_version,
max_opset_version, op_names[i], inputs_name,
outputs_name, inputs_shape, delta, rtol, attrs)
outputs_name, inputs_shape, delta, rtol, attrs,
enable_onnx_checker)
for input_type in input_type_list:
input_data = list()
for j, shape in enumerate(test_data_shapes):
......
......@@ -100,7 +100,8 @@ class ONNXConverter(object):
inputs_shape=[],
delta=1e-5,
rtol=1e-5,
attrs=[]):
attrs=[],
enable_onnx_checker=True):
self.op_type = op_type
assert isinstance(self.op_type,
str), "The dtype of op_type must be string!"
......@@ -122,6 +123,7 @@ class ONNXConverter(object):
self.outputs_name = outputs_name
self.inputs_shape = inputs_shape
self.attrs = attrs
self.enable_onnx_checker = enable_onnx_checker
def set_input_data(self, group_name, *args):
"""
......@@ -132,17 +134,24 @@ class ONNXConverter(object):
self.kwargs_dict[group_name] = self.kwargs_dict[group_name][0]
i = 0
add_inputs_shape = False
if len(self.inputs_shape) == 0:
add_inputs_shape = True
for in_data in self.kwargs_dict[group_name]:
if isinstance(in_data, list):
for data in in_data:
self.inputs_dtype.append(str(data.dtype))
self.input_feed[self.inputs_name[i]] = data
if add_inputs_shape:
self.inputs_shape.append(data.shape)
i += 1
else:
if isinstance(in_data, tuple):
in_data = in_data[0]
self.inputs_dtype.append(str(in_data.dtype))
self.input_feed[self.inputs_name[i]] = in_data
if add_inputs_shape:
self.inputs_shape.append(in_data.shape)
i += 1
def _mkdir(self):
......@@ -162,7 +171,11 @@ class ONNXConverter(object):
self.name + '_' + str(ver) + '.onnx')
paddle_path = os.path.join(self.pwd, self.name,
self.name + '_' + str(ver) + '_paddle')
onnx2paddle(onnx_path, paddle_path, convert_to_lite=False)
onnx2paddle(
onnx_path,
paddle_path,
convert_to_lite=False,
enable_onnx_checker=self.enable_onnx_checker)
def _mk_paddle_res(self, ver):
"""
......@@ -237,7 +250,8 @@ class ONNXConverter(object):
onnx.save(model,
os.path.join(self.pwd, self.name,
self.name + '_' + str(ver) + '.onnx'))
onnx.checker.check_model(model)
if self.enable_onnx_checker:
onnx.checker.check_model(model)
def run(self):
"""
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import OPConvertAutoScanTest
from hypothesis import reproduce_failure
import hypothesis.strategies as st
import onnx
from onnx import helper
from onnx import TensorProto
import numpy as np
import unittest
class TestUnsqueezeConvert0(OPConvertAutoScanTest):
"""
ONNX op: Unsqueeze
OPset version: 7~12
"""
def sample_convert_config(self, draw):
input_shape = draw(
st.lists(
st.integers(
min_value=2, max_value=6), min_size=2, max_size=5))
input_dtype = draw(
st.sampled_from(["int32", "int64", "float32", "float64"]))
axis = draw(
st.integers(
min_value=-len(input_shape), max_value=len(input_shape) - 1))
if len(input_shape) == 5:
axis = [0]
if len(input_shape) == 4:
axis = [0, 1]
if len(input_shape) == 3:
axis = [1, 2, 3]
if len(input_shape) == 2:
if draw(st.booleans()):
axis = [0, 1, 2]
else:
axis = [1, 3]
config = {
"op_names": ["Unsqueeze"],
"test_data_shapes": [input_shape],
"test_data_types": [[input_dtype]],
"inputs_shape": [input_shape],
"min_opset_version": 7,
"max_opset_version": 12,
"inputs_name": ["x"],
"outputs_name": ["y"],
"delta": 1e-4,
"rtol": 1e-4
}
attrs = {"axes": axis, }
return (config, attrs)
def test(self):
self.run_and_statis(max_examples=30)
class TestUnsqueezeConvert1(OPConvertAutoScanTest):
"""
ONNX op: Unsqueeze
OPset version: 13~15
"""
def sample_convert_config(self, draw):
input_shape = draw(
st.lists(
st.integers(
min_value=2, max_value=6), min_size=2, max_size=5))
input_dtype = draw(
st.sampled_from(["int32", "int64", "float32", "float64"]))
def generator_axis():
axis = list()
if len(input_shape) == 5:
axis = [0]
if len(input_shape) == 4:
axis = [0, 1]
if len(input_shape) == 3:
axis = [1, 2, 3]
if len(input_shape) == 2:
if draw(st.booleans()):
axis = [0, 1, 2]
else:
axis = [1, 3]
axis = np.array(axis)
print("axis:", axis)
return axis
config = {
"op_names": ["Unsqueeze"],
"test_data_shapes": [input_shape, generator_axis],
"test_data_types": [[input_dtype], ["int64"]],
"inputs_shape": [],
"min_opset_version": 13,
"max_opset_version": 15,
"inputs_name": ["x", "axes"],
"outputs_name": ["y"],
"delta": 1e-4,
"rtol": 1e-4,
"enable_onnx_checker": False,
}
attrs = {}
return (config, attrs)
def test(self):
self.run_and_statis(max_examples=30)
if __name__ == "__main__":
unittest.main()
......@@ -263,7 +263,8 @@ def onnx2paddle(model_path,
convert_to_lite=False,
lite_valid_places="arm",
lite_model_type="naive_buffer",
disable_feedback=False):
disable_feedback=False,
enable_onnx_checker=True):
# for convert_id
time_info = int(time.time())
if not disable_feedback:
......@@ -286,7 +287,7 @@ def onnx2paddle(model_path,
from x2paddle.decoder.onnx_decoder import ONNXDecoder
from x2paddle.op_mapper.onnx2paddle.onnx_op_mapper import ONNXOpMapper
model = ONNXDecoder(model_path)
model = ONNXDecoder(model_path, enable_onnx_checker)
mapper = ONNXOpMapper(model)
mapper.paddle_graph.build()
logging.info("Model optimizing ...")
......
......@@ -416,13 +416,14 @@ class ONNXGraph(Graph):
class ONNXDecoder(object):
def __init__(self, onnx_model):
def __init__(self, onnx_model, enable_onnx_checker):
onnx_model = onnx.load(onnx_model)
print('model ir_version: {}, op version: {}'.format(
onnx_model.ir_version, onnx_model.opset_import[0].version))
self.op_set = onnx_model.opset_import[0].version
check_model(onnx_model)
if enable_onnx_checker:
check_model(onnx_model)
onnx_model = self.optimize_model_skip_op(onnx_model)
onnx_model = self.optimize_node_name(onnx_model)
......
......@@ -12,9 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset_legacy import OpSet
from .opset9 import OpSet9
class OpSet10(OpSet):
def print_mapping_info(func):
def run_mapping(*args, **kwargs):
node = args[1]
try:
res = func(*args, **kwargs)
except:
raise Exception("convert failed node:{}, op_type is {}".format(
node.name[9:], node.layer_type))
else:
return res
return run_mapping
class OpSet10(OpSet9):
def __init__(self, decoder, paddle_graph):
super(OpSet10, self).__init__(decoder, paddle_graph)
......@@ -12,9 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset_legacy import OpSet
from .opset10 import OpSet10
class OpSet11(OpSet):
def print_mapping_info(func):
def run_mapping(*args, **kwargs):
node = args[1]
try:
res = func(*args, **kwargs)
except:
raise Exception("convert failed node:{}, op_type is {}".format(
node.name[9:], node.layer_type))
else:
return res
return run_mapping
class OpSet11(OpSet10):
def __init__(self, decoder, paddle_graph):
super(OpSet11, self).__init__(decoder, paddle_graph)
......@@ -12,9 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset_legacy import OpSet
from .opset11 import OpSet11
class OpSet12(OpSet):
def print_mapping_info(func):
def run_mapping(*args, **kwargs):
node = args[1]
try:
res = func(*args, **kwargs)
except:
raise Exception("convert failed node:{}, op_type is {}".format(
node.name[9:], node.layer_type))
else:
return res
return run_mapping
class OpSet12(OpSet11):
def __init__(self, decoder, paddle_graph):
super(OpSet12, self).__init__(decoder, paddle_graph)
......@@ -12,9 +12,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset_legacy import OpSet
from .opset12 import OpSet12
class OpSet13(OpSet):
def print_mapping_info(func):
def run_mapping(*args, **kwargs):
node = args[1]
try:
res = func(*args, **kwargs)
except:
raise Exception("convert failed node:{}, op_type is {}".format(
node.name[9:], node.layer_type))
else:
return res
return run_mapping
class OpSet13(OpSet12):
def __init__(self, decoder, paddle_graph):
super(OpSet13, self).__init__(decoder, paddle_graph)
@print_mapping_info
def Unsqueeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = self.graph.get_input_node(node, idx=1, copy=True)
# deal with scalar(0D) tensor
if len(val_x.out_shapes[0]) == 0 and len(axes.out_shapes[0]) == 1:
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=[1])
else:
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name,
"axis": axes.name},
outputs=[node.name])
......@@ -12,9 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset_legacy import OpSet
from .opset13 import OpSet13
class OpSet14(OpSet):
def print_mapping_info(func):
def run_mapping(*args, **kwargs):
node = args[1]
try:
res = func(*args, **kwargs)
except:
raise Exception("convert failed node:{}, op_type is {}".format(
node.name[9:], node.layer_type))
else:
return res
return run_mapping
class OpSet14(OpSet13):
def __init__(self, decoder, paddle_graph):
super(OpSet14, self).__init__(decoder, paddle_graph)
......@@ -12,9 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset_legacy import OpSet
from .opset14 import OpSet14
class OpSet15(OpSet):
def print_mapping_info(func):
def run_mapping(*args, **kwargs):
node = args[1]
try:
res = func(*args, **kwargs)
except:
raise Exception("convert failed node:{}, op_type is {}".format(
node.name[9:], node.layer_type))
else:
return res
return run_mapping
class OpSet15(OpSet14):
def __init__(self, decoder, paddle_graph):
super(OpSet15, self).__init__(decoder, paddle_graph)
......@@ -15,6 +15,38 @@
from .opset_legacy import OpSet
def print_mapping_info(func):
def run_mapping(*args, **kwargs):
node = args[1]
try:
res = func(*args, **kwargs)
except:
raise Exception("convert failed node:{}, op_type is {}".format(
node.name[9:], node.layer_type))
else:
return res
return run_mapping
class OpSet7(OpSet):
def __init__(self, decoder, paddle_graph):
super(OpSet7, self).__init__(decoder, paddle_graph)
@print_mapping_info
def Unsqueeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes')
# deal with scalar(0D) tensor
if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0:
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": val_x.name},
outputs=[node.name],
shape=[1])
else:
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name},
axis=axes,
outputs=[node.name])
......@@ -12,9 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset_legacy import OpSet
from .opset7 import OpSet7
class OpSet8(OpSet):
def print_mapping_info(func):
def run_mapping(*args, **kwargs):
node = args[1]
try:
res = func(*args, **kwargs)
except:
raise Exception("convert failed node:{}, op_type is {}".format(
node.name[9:], node.layer_type))
else:
return res
return run_mapping
class OpSet8(OpSet7):
def __init__(self, decoder, paddle_graph):
super(OpSet8, self).__init__(decoder, paddle_graph)
......@@ -12,9 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .opset_legacy import OpSet
from .opset8 import OpSet8
class OpSet9(OpSet):
def print_mapping_info(func):
def run_mapping(*args, **kwargs):
node = args[1]
try:
res = func(*args, **kwargs)
except:
raise Exception("convert failed node:{}, op_type is {}".format(
node.name[9:], node.layer_type))
else:
return res
return run_mapping
class OpSet9(OpSet8):
def __init__(self, decoder, paddle_graph):
super(OpSet9, self).__init__(decoder, paddle_graph)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册