提交 8cbf9efa 编写于 作者: W wjj19950828

add autoscan test

上级 dbcb7398
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import unittest
import os
import time
import logging
import paddle
import hypothesis
from hypothesis import given, settings, seed, reproduce_failure
import hypothesis.strategies as st
from onnxbase import ONNXConverter, randtool
from itertools import product
import copy
from inspect import isfunction
paddle.set_device("cpu")
logging.basicConfig(level=logging.INFO, format="%(message)s")
settings.register_profile(
"ci",
max_examples=100,
suppress_health_check=hypothesis.HealthCheck.all(),
deadline=None,
print_blob=True,
derandomize=True,
report_multiple_bugs=False)
settings.register_profile(
"dev",
max_examples=1000,
suppress_health_check=hypothesis.HealthCheck.all(),
deadline=None,
print_blob=True,
derandomize=True,
report_multiple_bugs=False)
if float(os.getenv('TEST_NUM_PERCENT_CASES', default='1.0')) < 1 or \
os.getenv('HYPOTHESIS_TEST_PROFILE', 'dev') == 'ci':
settings.load_profile("ci")
else:
settings.load_profile("dev")
class OPConvertAutoScanTest(unittest.TestCase):
def __init__(self, *args, **kwargs):
super(OPConvertAutoScanTest, self).__init__(*args, **kwargs)
np.random.seed(1024)
paddle.enable_static()
self.num_ran_models = 0
def run_and_statis(self,
max_examples=100,
opset_version=[7, 9, 15],
reproduce=None,
min_success_num=25,
max_duration=-1):
if os.getenv("CE_STAGE", "OFF") == "ON":
max_examples *= 10
min_success_num *= 10
# while at ce phase, there's no limit on time
max_duration = -1
start_time = time.time()
settings.register_profile(
"ci",
max_examples=max_examples,
suppress_health_check=hypothesis.HealthCheck.all(),
deadline=None,
print_blob=True,
derandomize=True,
report_multiple_bugs=False, )
settings.load_profile("ci")
def sample_convert_generator(draw):
return self.sample_convert_config(draw)
def run_test(configs):
return self.run_test(configs=configs)
generator = st.composite(sample_convert_generator)
loop_func = given(generator())(run_test)
if reproduce is not None:
loop_func = reproduce(loop_func)
logging.info("Start to running test of {}".format(type(self)))
paddle.disable_static()
loop_func()
logging.info(
"===================Statistical Information===================")
logging.info("Number of Generated Programs: {}".format(
self.num_ran_models))
successful_ran_programs = int(self.num_ran_models)
if successful_ran_programs < min_success_num:
logging.warning("satisfied_programs = ran_programs")
logging.error(
"At least {} programs need to ran successfully, but now only about {} programs satisfied.".
format(min_success_num, successful_ran_programs))
assert False
used_time = time.time() - start_time
logging.info("Used time: {} s".format(round(used_time, 2)))
if max_duration > 0 and used_time > max_duration:
logging.error(
"The duration exceeds {} seconds, if this is neccessary, try to set a larger number for parameter `max_duration`.".
format(max_duration))
assert False
def run_test(self, configs):
config, attrs = configs
logging.info("Run configs: {}".format(config))
logging.info("Run attrs: {}".format(attrs))
assert "op_names" in config.keys(
), "config must include op_names in dict keys"
assert "test_data_shapes" in config.keys(
), "config must include test_data_shapes in dict keys"
assert "test_data_types" in config.keys(
), "config must include test_data_types in dict keys"
assert "opset_version" in config.keys(
), "config must include opset_version in dict keys"
assert "inputs_name" in config.keys(
), "config must include inputs_name in dict keys"
assert "outputs_name" in config.keys(
), "config must include outputs_name in dict keys"
assert "inputs_shape" in config.keys(
), "config must include inputs_shape in dict keys"
assert "outputs_shape" in config.keys(
), "config must include outputs_shape in dict keys"
assert "outputs_dtype" in config.keys(
), "config must include outputs_dtype in dict keys"
op_names = config["op_names"]
test_data_shapes = config["test_data_shapes"]
test_data_types = config["test_data_types"]
opset_version = config["opset_version"]
inputs_name = config["inputs_name"]
outputs_name = config["outputs_name"]
inputs_shape = config["inputs_shape"]
outputs_shape = config["outputs_shape"]
outputs_dtype = config["outputs_dtype"]
use_gpu = True
if "use_gpu" in config.keys():
use_gpu = config["use_gpu"]
self.num_ran_models += 1
if not isinstance(op_names, (tuple, list)):
op_names = [op_names]
if not isinstance(opset_version[0], (tuple, list)):
opset_version = [opset_version]
if len(opset_version) == 1 and len(op_names) != len(opset_version):
opset_version = opset_version * len(op_names)
input_type_list = None
if len(test_data_types) > 1:
input_type_list = list(product(*test_data_types))
elif len(test_data_types) == 1:
if isinstance(test_data_types[0], str):
input_type_list = [test_data_types[0]]
else:
input_type_list = test_data_types
elif len(test_data_types) == 0:
input_type_list = [["float32"] * len(test_data_shapes)]
delta = 1e-5
rtol = 1e-5
if "delta" in config.keys():
delta = config["delta"]
if "rtol" in config.keys():
rtol = config["rtol"]
for i in range(len(op_names)):
obj = ONNXConverter(op_names[i], opset_version[i], op_names[i],
inputs_name, outputs_name, inputs_shape,
outputs_shape, outputs_dtype, delta, rtol,
use_gpu, attrs)
for input_type in input_type_list:
input_data = list()
for j, shape in enumerate(test_data_shapes):
# Determine whether it is a user-defined data generation function
if isfunction(shape):
data = shape()
data = data.astype(input_type[j])
input_data.append(data)
continue
if input_type[j].count('int') > 0:
input_data.append(
randtool("int", -20, 20, shape).astype(input_type[
j]))
elif input_type[j].count('bool') > 0:
input_data.append(
randtool("bool", -2, 2, shape).astype(input_type[
j]))
else:
input_data.append(
randtool("float", -2, 2, shape).astype(input_type[
j]))
obj.set_input_data("input_data", tuple(input_data))
logging.info("Now Run >>> dtype: {}, op_name: {}".format(
input_type, op_names[i]))
obj.run()
if len(input_type_list) == 0:
obj.run()
logging.info("Run Successfully!")
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import logging
import paddle
import onnx
from onnx import helper
from onnx import TensorProto
from onnxruntime import InferenceSession
DTYPE_ONNX_STR_MAP = {
'float32': TensorProto.FLOAT,
'float64': TensorProto.DOUBLE,
'int16': TensorProto.INT16,
'int32': TensorProto.INT32,
'int64': TensorProto.INT64,
'bool': TensorProto.BOOL,
}
def compare(result, expect, delta=1e-10, rtol=1e-10):
"""
比较函数
:param result: 输入值
:param expect: 输出值
:param delta: 误差值
:return:
"""
if type(result) == np.ndarray:
if type(expect) == list:
expect = expect[0]
expect = np.array(expect)
res = np.allclose(result, expect, atol=delta, rtol=rtol, equal_nan=True)
# 出错打印错误数据
if res is False:
if result.dtype == np.bool_:
diff = abs(result.astype("int32") - expect.astype("int32"))
else:
diff = abs(result - expect)
logging.error("Output has diff! max diff: {}".format(np.amax(diff)))
if result.dtype != expect.dtype:
logging.error(
"Different output data types! res type is: {}, and expect type is: {}".
format(result.dtype, expect.dtype))
assert res
assert result.shape == expect.shape, "result.shape: {} != expect.shape: {}".format(
result.shape, expect.shape)
assert result.dtype == expect.dtype, "result.dtype: {} != expect.dtype: {}".format(
result.dtype, expect.dtype)
elif isinstance(result, (list, tuple)) and len(result) > 1:
for i in range(len(result)):
if isinstance(result[i], (np.generic, np.ndarray)):
compare(result[i], expect[i], delta, rtol)
else:
compare(result[i].numpy(), expect[i], delta, rtol)
elif len(result) == 1:
compare(result[0], expect[0], delta, rtol)
def randtool(dtype, low, high, shape):
"""
np random tools
"""
if dtype == "int":
return np.random.randint(low, high, shape)
elif dtype == "float":
return low + (high - low) * np.random.random(shape)
elif dtype == "bool":
return np.random.randint(low, high, shape).astype("bool")
class ONNXConverter(object):
"""
onnx model transfer to paddle
"""
def __init__(self,
file_name,
ver_list,
op_type=[],
inputs_name=[],
outputs_name=[],
inputs_shape=[],
outputs_shape=[],
outputs_dtype=[],
delta=1e-5,
rtol=1e-5,
use_gpu=True,
attrs=[]):
self.op_type = op_type
assert isinstance(self.op_type,
str), "The dtype of op_type must be string!"
self.seed = 33
np.random.seed(self.seed)
paddle.seed(self.seed)
if use_gpu and paddle.device.is_compiled_with_cuda() is True:
self.places = ['gpu']
else:
self.places = ['cpu']
self.name = file_name
self._version = ver_list
self.pwd = os.getcwd()
self.delta = delta
self.rtol = rtol
self.static = False
self.kwargs_dict = {"input_data": ()}
self.input_feed = {}
self.inputs_dtype = []
self.inputs_name = inputs_name
self.outputs_name = outputs_name
self.inputs_shape = inputs_shape
self.outputs_shape = outputs_shape
self.outputs_dtype = outputs_dtype
self.attrs = attrs
def set_input_data(self, group_name, *args):
"""
set input data
"""
self.kwargs_dict[group_name] = args
if isinstance(self.kwargs_dict[group_name][0], tuple):
self.kwargs_dict[group_name] = self.kwargs_dict[group_name][0]
i = 0
for in_data in self.kwargs_dict[group_name]:
if isinstance(in_data, list):
for data in in_data:
self.inputs_dtype.append(str(data.dtype))
self.input_feed[self.inputs_name[i]] = data
i += 1
else:
if isinstance(in_data, tuple):
in_data = in_data[0]
self.inputs_dtype.append(str(in_data.dtype))
self.input_feed[self.inputs_name[i]] = in_data
i += 1
def _mkdir(self):
"""
make dir to save all
"""
save_path = os.path.join(self.pwd, self.name)
if not os.path.exists(save_path):
os.mkdir(save_path)
def _onnx_to_paddle(self, ver):
"""
convert onnx to paddle
"""
from x2paddle.convert import onnx2paddle
onnx_path = os.path.join(self.pwd, self.name,
self.name + '_' + str(ver) + '.onnx')
paddle_path = os.path.join(self.pwd, self.name,
self.name + '_' + str(ver) + '_paddle')
onnx2paddle(onnx_path, paddle_path, convert_to_lite=False)
def _mk_paddle_res(self, ver):
"""
make paddle res
"""
paddle_path = os.path.join(
self.pwd, self.name,
self.name + '_' + str(ver) + '_paddle/inference_model/model')
paddle.disable_static()
# run
model = paddle.jit.load(paddle_path)
paddle_feed = list()
for i in range(len(self.input_feed)):
paddle_feed.append(self.input_feed[self.inputs_name[i]])
result = model(*paddle_feed)
# get paddle outputs
if isinstance(result, (tuple, list)):
result = tuple(out.numpy() for out in result)
else:
result = (result.numpy(), )
print("paddle result:", result[0].shape)
return result
def _mk_onnx_res(self, ver):
"""
make onnx res
"""
sess = InferenceSession(
os.path.join(self.pwd, self.name, self.name + '_' + str(ver) +
'.onnx'))
ort_outs = sess.run(output_names=None, input_feed=self.input_feed)
print("onnx result:", ort_outs[0].shape)
return ort_outs
def set_onnx_inputs(self):
graph_inputs = list()
for i in range(len(self.inputs_name)):
graph_inputs.append(
helper.make_tensor_value_info(self.inputs_name[
i], DTYPE_ONNX_STR_MAP[self.inputs_dtype[i]],
self.inputs_shape[i]))
return graph_inputs
def set_onnx_outputs(self):
graph_outputs = list()
for i in range(len(self.outputs_name)):
graph_outputs.append(
helper.make_tensor_value_info(self.outputs_name[
i], DTYPE_ONNX_STR_MAP[self.outputs_dtype[i][0]],
self.outputs_shape[i]))
return graph_outputs
def _mk_onnx_graph(self, ver):
"""
make onnx graph
"""
node = onnx.helper.make_node(
self.op_type,
inputs=self.inputs_name,
outputs=self.outputs_name,
**self.attrs, )
graph_inputs = self.set_onnx_inputs()
graph_outputs = self.set_onnx_outputs()
graph = helper.make_graph(
[node],
self.name,
graph_inputs, # graph inputs
graph_outputs, # graph outputs
)
opset_imports = [helper.make_opsetid("", ver)]
model = helper.make_model(
graph, producer_name='onnx-example', opset_imports=opset_imports)
onnx.save(model,
os.path.join(self.pwd, self.name,
self.name + '_' + str(ver) + '.onnx'))
onnx.checker.check_model(model)
def run(self):
"""
1. make onnx model
2. convert onnx to paddle
3. use onnx to make res
4. compare diff
"""
self._mkdir()
for place in self.places:
paddle.set_device(place)
onnx_res = {}
paddle_res = {}
# export onnx models and make onnx res
for v in self._version:
self._mk_onnx_graph(ver=v)
self._onnx_to_paddle(ver=v)
onnx_res[str(v)] = self._mk_onnx_res(ver=v)
paddle_res[str(v)] = self._mk_paddle_res(ver=v)
for v in self._version:
compare(
onnx_res[str(v)],
paddle_res[str(v)],
delta=self.delta,
rtol=self.rtol)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import OPConvertAutoScanTest
from hypothesis import reproduce_failure
import hypothesis.strategies as st
import onnx
from onnx import helper
from onnx import TensorProto
import numpy as np
import unittest
class TestConvConvert(OPConvertAutoScanTest):
"""
api: paddle.nn.Conv2d
OPset version: 9
1.OPset version需要根据op_mapper中定义的version来设置。
2.测试中所有OP对应升级到Opset version 15。
"""
def sample_convert_config(self, draw):
input_shape = draw(
st.lists(
st.integers(
min_value=20, max_value=30), min_size=4, max_size=4))
kernel_size = draw(
st.lists(
st.integers(
min_value=1, max_value=7), min_size=4, max_size=4))
data_format = "NCHW"
groups = draw(st.integers(min_value=1, max_value=4))
muti1 = draw(st.integers(min_value=1, max_value=4))
kernel_size[0] = groups * muti1
input_shape[1] = kernel_size[1] * groups
strides = draw(
st.lists(
st.integers(
min_value=1, max_value=5), min_size=1, max_size=2))
if len(strides) == 1:
strides = strides[0]
if strides > kernel_size[2]:
strides = kernel_size[2]
if strides > kernel_size[3]:
strides = kernel_size[3]
strides = [strides, strides]
else:
if strides[0] > kernel_size[2]:
strides[0] = kernel_size[2]
if strides[1] > kernel_size[3]:
strides[1] = kernel_size[3]
if draw(st.booleans()):
auto_pad = "NOTSET"
padding = None
if draw(st.booleans()):
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=5),
min_size=2,
max_size=2))
padding = [0, 0] + padding
else:
padding = draw(
st.lists(
st.integers(
min_value=1, max_value=5),
min_size=4,
max_size=4))
else:
auto_pad = draw(
st.sampled_from(
["SAME_LOWER", "SAME_UPPER", "VALID", "NOTSET"]))
padding = None
dilations = draw(
st.lists(
st.integers(
min_value=1, max_value=3), min_size=2, max_size=2))
config = {
"op_names": ["Conv"],
"test_data_shapes": [input_shape, kernel_size],
"test_data_types": [['float32'], ['float32']],
"inputs_shape": [[-1, input_shape[1], -1, -1], kernel_size],
"outputs_shape": [[-1, kernel_size[0], -1, -1]],
"outputs_dtype": [['float32']],
"opset_version": [7, 9],
"inputs_name": ["x", "W"],
"outputs_name": ["y"],
"delta": 1e-4,
"rtol": 1e-4
}
# Warning:
# 1、SAME_UPPER and SAME_LOWER does not yet support dynamic shapes
# 2、dilations only support 1
if "SAME" in auto_pad:
dilations = [1, 1]
config["inputs_shape"] = [input_shape, kernel_size]
config["outputs_shape"] = [[
input_shape[0], kernel_size[0],
int(input_shape[2] / strides[0]),
int(input_shape[3] / strides[1])
]]
if not isinstance(dilations, (tuple, list)):
dilations = [dilations]
if not isinstance(strides, (tuple, list)):
strides = [strides]
attrs = {
"auto_pad": auto_pad,
"dilations": dilations,
"group": groups,
"kernel_shape": kernel_size[2:],
"pads": padding,
"strides": strides,
}
return (config, attrs)
def test(self):
self.run_and_statis(max_examples=30)
if __name__ == "__main__":
unittest.main()
...@@ -92,12 +92,15 @@ def _is_static_shape(shape): ...@@ -92,12 +92,15 @@ def _is_static_shape(shape):
return True return True
def _get_same_padding(in_size, kernel_size, stride): def _get_same_padding(in_size, kernel_size, stride, autopad):
new_size = int(math.ceil(in_size * 1.0 / stride)) new_size = int(math.ceil(in_size * 1.0 / stride))
pad_size = (new_size - 1) * stride + kernel_size - in_size pad_size = (new_size - 1) * stride + kernel_size - in_size
pad0 = int(pad_size / 2) pad0 = int(pad_size / 2)
pad1 = pad_size - pad0 pad1 = pad_size - pad0
return [pad0, pad1] if autopad == "SAME_UPPER":
return [pad0, pad1]
if autopad == "SAME_LOWER":
return [pad1, pad0]
def print_mapping_info(func): def print_mapping_info(func):
...@@ -1540,9 +1543,9 @@ class OpSet9(): ...@@ -1540,9 +1543,9 @@ class OpSet9():
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
input_shape = val_x.out_shapes[0] input_shape = val_x.out_shapes[0]
pad_h = _get_same_padding(input_shape[2], kernel_shape[0], pad_h = _get_same_padding(input_shape[2], kernel_shape[0],
strides[0]) strides[0], auto_pad)
pad_w = _get_same_padding(input_shape[3], kernel_shape[1], pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
strides[1]) strides[1], auto_pad)
paddings = pad_h + pad_w paddings = pad_h + pad_w
op_name = name_generator("pool", self.nn_name2id) op_name = name_generator("pool", self.nn_name2id)
...@@ -1967,9 +1970,9 @@ class OpSet9(): ...@@ -1967,9 +1970,9 @@ class OpSet9():
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
input_shape = val_x.out_shapes[0] input_shape = val_x.out_shapes[0]
pad_h = _get_same_padding(input_shape[2], kernel_shape[0], pad_h = _get_same_padding(input_shape[2], kernel_shape[0],
strides[0]) strides[0], auto_pad)
pad_w = _get_same_padding(input_shape[3], kernel_shape[1], pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
strides[1]) strides[1], auto_pad)
paddings = pad_h + pad_w paddings = pad_h + pad_w
layer_attrs = { layer_attrs = {
...@@ -2196,13 +2199,15 @@ class OpSet9(): ...@@ -2196,13 +2199,15 @@ class OpSet9():
pads = node.get_attr('pads', [0] * (convnd * 2)) pads = node.get_attr('pads', [0] * (convnd * 2))
input_shape = val_x.out_shapes[0] input_shape = val_x.out_shapes[0]
paddings, val_x = self._pad_if_asymmetric(node, pads, val_x) paddings = np.array(pads).reshape((2, -1)).transpose().astype("int32")
paddings = paddings.flatten().tolist()
if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER": if auto_pad == "SAME_UPPER" or auto_pad == "SAME_LOWER":
assert -1 not in input_shape, 'SAME_UPPER and SAME_LOWER does not yet support dynamic shapes'
pad_h = _get_same_padding(input_shape[2], kernel_shape[0], pad_h = _get_same_padding(input_shape[2], kernel_shape[0],
strides[0]) strides[0], auto_pad)
pad_w = _get_same_padding(input_shape[3], kernel_shape[1], pad_w = _get_same_padding(input_shape[3], kernel_shape[1],
strides[1]) strides[1], auto_pad)
paddings = pad_h + pad_w paddings = pad_h + pad_w
layer_inputs = {'x': val_x if isinstance(val_x, str) else val_x.name} layer_inputs = {'x': val_x if isinstance(val_x, str) else val_x.name}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册