未验证 提交 57069f8b 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-TRT]remove matrix_multiply unitest (#52606)

* remove matrix_multiply unitest
上级 39278731
......@@ -238,8 +238,6 @@ if(WITH_GPU AND TENSORRT_FOUND)
240)
set_tests_properties(test_trt_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT
240)
set_tests_properties(test_trt_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT
240)
set_tests_properties(test_shuffle_channel_detect_pass PROPERTIES TIMEOUT
120)
if(WIN32)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from functools import partial
from typing import Any, Dict, List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertFcTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
# The output has diff between gpu and trt in CI windows
if os.name == 'nt':
return False
return True
def sample_program_configs(self):
self.trt_param.workspace_size = 1073741824
def generate_input1(batch, attrs: List[Dict[str, Any]]):
return np.random.random(
[batch, 3, 64, (int)(attrs[0]["m"] / 2), 2]
).astype(np.float32)
def generate_w(batch, attrs: List[Dict[str, Any]]):
return np.random.random([attrs[0]["m"], attrs[0]["n"]]).astype(
np.float32
)
def generate_bias(batch, attrs: List[Dict[str, Any]]):
return np.random.random([attrs[0]["n"]]).astype(np.float32)
for batch in [1, 4]:
for [m, n] in [[32, 23]]:
dics = [
{
"in_num_col_dims": 3,
# for my conveinence
"m": m,
"n": n,
},
{},
]
ops_config = [
{
"op_type": "fc",
"op_inputs": {
"Input": ["input_data"],
"W": ["w_data"],
"Bias": ["bias_data"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"w_data": TensorConfig(
data_gen=partial(generate_w, batch, dics)
),
"bias_data": TensorConfig(
data_gen=partial(generate_bias, batch, dics)
),
},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input1, batch, dics)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data": [1, 3, 32, 16, 2],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 3, 64, 16, 2],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 3, 64, 16, 2],
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# # for static_shape
# clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (1e-3, 1e-3)
def test(self):
self.run_test()
def test_quant(self):
self.run_test(quant=True)
class TrtConvertFcTest2(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
# The output has diff between gpu and trt in CI windows
if os.name == 'nt':
return False
return True
def sample_program_configs(self):
self.trt_param.workspace_size = 1073741824
def generate_input1(batch, attrs: List[Dict[str, Any]]):
return np.random.random([batch, 3, 64, 14]).astype(np.float32)
def generate_w(batch, attrs: List[Dict[str, Any]]):
return np.random.random([attrs[0]["m"], attrs[0]["n"]]).astype(
np.float32
)
def generate_bias(batch, attrs: List[Dict[str, Any]]):
return np.random.random([attrs[0]["n"]]).astype(np.float32)
for batch in [1, 4]:
for [m, n] in [[14, 43]]:
dics = [
{
"in_num_col_dims": 3,
# for my conveinence
"m": m,
"n": n,
},
{},
]
ops_config = [
{
"op_type": "fc",
"op_inputs": {
"Input": ["input_data"],
"W": ["w_data"],
"Bias": ["bias_data"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"w_data": TensorConfig(
data_gen=partial(generate_w, batch, dics)
),
"bias_data": TensorConfig(
data_gen=partial(generate_bias, batch, dics)
),
},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input1, batch, dics)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape():
self.dynamic_shape.min_input_shape = {
"input_data": [1, 3, 32, 14],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 3, 64, 14],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 3, 64, 14],
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
# # for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 2), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 2), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), (1e-3, 1e-3)
def test(self):
self.run_test()
# this is the special case when x_dim.nbDims == 4 && x_num_col_dims == 1
class TrtConvertFcTest3(TrtLayerAutoScanTest):
# this case will invoke a bug in fc_op.cc, so return False
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return False
def sample_program_configs(self):
self.trt_param.workspace_size = 1073741824
def generate_input1(batch, attrs: List[Dict[str, Any]]):
return np.ones([batch, 14, 1, 2]).astype(np.float32)
def generate_w(batch, attrs: List[Dict[str, Any]]):
return np.ones([attrs[0]["m"], attrs[0]["n"]]).astype(np.float32)
def generate_bias(batch, attrs: List[Dict[str, Any]]):
return np.ones([attrs[0]["n"]]).astype(np.float32)
for batch in [1, 4]:
for [m, n] in [[28, 43]]:
dics = [
{
"in_num_col_dims": 1,
"Input_scale": 0.1,
"out_threshold": 0.1,
"enable_int8": True,
# for my conveinence
"m": m,
"n": n,
},
{},
]
ops_config = [
{
"op_type": "fc",
"op_inputs": {
"Input": ["input_data"],
"W": ["w_data"],
"Bias": ["bias_data"],
},
"op_outputs": {"Out": ["output_data"]},
"op_attrs": dics[0],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"w_data": TensorConfig(
data_gen=partial(generate_w, batch, dics)
),
"bias_data": TensorConfig(
data_gen=partial(generate_bias, batch, dics)
),
},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input1, batch, dics)
),
},
outputs=["output_data"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape():
self.dynamic_shape.min_input_shape = {
"input_data": [1, 14, 1, 2],
}
self.dynamic_shape.max_input_shape = {
"input_data": [4, 14, 1, 2],
}
self.dynamic_shape.opt_input_shape = {
"input_data": [1, 14, 1, 2],
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 2), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 2), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 2), (1e-3, 1e-3)
self.trt_param.precision = paddle_infer.PrecisionType.Int8
yield self.create_inference_config(), (1, 2), (1e-3, 1e-3)
def test(self):
self.run_test()
def test_quant(self):
self.run_test(quant=True)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
from auto_scan_test import IgnoreReasons, PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
import paddle.inference as paddle_infer
class TestSqueeze2MatmulFusePass(PassAutoScanTest):
r"""
x_var
|
squeeze2
\
squeeze2_out_var y_var
\ /
matmul bias_var
\ /
elementwise_add
"""
def sample_predictor_configs(self, program_config):
# TRT
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=10,
workspace_size=10240,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False,
)
yield config, ['mul', 'elementwise_add'], (1e-4, 1e-1)
def add_ignore_pass_case(self):
# Here we put some skip rules to avoid known bugs
def teller1(program_config, predictor_config):
y_shape = list(program_config.weights["matmul_y"].shape)
bias_shape = program_config.weights["bias"].shape
axis = program_config.ops[2].attrs["axis"]
# bias should be [mul_y_shape[-1]]
if axis == 0 or bias_shape[0] != y_shape[1] or len(bias_shape) != 1:
return True
return False
self.add_ignore_check_case(
teller1,
IgnoreReasons.PASS_ACCURACY_ERROR,
"The pass error on TRT while shape of bias is not [out_size].",
)
def sample_program_config(self, draw):
# 1. Generate shape of input:X of squeeze2
x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=8), min_size=2, max_size=2
)
)
# axes of squeeze2 == [2, 3]
x_shape += [1, 1]
axes = [2, 3]
# 2. Generate attr:transpose_X/transpose_Y/alpha of matmul
alpha = 1.0
transpose_X = False
transpose_Y = False
# 3. Generate legal shape of input:Y of matmul
y_shape = draw(
st.lists(
st.integers(min_value=1, max_value=8), min_size=2, max_size=2
)
)
y_shape[0] = x_shape[1]
# 4. Generate legal attr:axis of elementwise_add
axis = draw(st.integers(min_value=-1, max_value=1))
if axis == 0:
axis = -1
bias_shape = [
y_shape[1],
]
# if axis == -1:
# if draw(st.booleans()):
# bias_shape = [y_shape[1], ]
# else:
# bias_shape = [x_shape[0], y_shape[1]]
squeeze2_op = OpConfig(
"squeeze2",
inputs={
"X": ["squeeze2_x"],
},
axes=axes,
outputs={"Out": ["squeeze2_out"], "XShape": ["xshape"]},
)
matmul_op = OpConfig(
"matmul",
inputs={"X": ["squeeze2_out"], "Y": ["matmul_y"]},
outputs={"Out": ["matmul_out"]},
alpha=alpha,
transpose_X=transpose_X,
transpose_Y=transpose_Y,
)
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["matmul_out"], "Y": ["bias"]},
outputs={"Out": ["add_out"]},
axis=axis,
)
ops = [squeeze2_op, matmul_op, add_op]
program_config = ProgramConfig(
ops=ops,
weights={
"matmul_y": TensorConfig(shape=y_shape),
"bias": TensorConfig(shape=bias_shape),
},
inputs={
"squeeze2_x": TensorConfig(shape=x_shape),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["trt_squeeze2_matmul_fuse_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册