未验证 提交 7ef69202 编写于 作者: H heliqi 提交者: GitHub

add flatten2,reshape2,squueze2_trt_fuse_pass test cast (#41031)

* add flatten2,reshape2,squueze2_trt_fuse_pass  test cast

* add flatten2,reshape2,squueze2_trt_fuse_pass  test cast

* add flatten2,reshape2,squueze2_trt_fuse_pass  test cast
上级 e7928a06
......@@ -498,7 +498,8 @@ void TrtSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
BOOST_GET_CONST(std::vector<int>, squeeze2_op->Op()->GetAttr("axes"));
flag = flag && squeeze2_in_x_rank == 4 &&
squeeze2_op_axes == std::vector<int>{2, 3} &&
(matmul_in_x->outputs).size() == 1;
(matmul_in_x->outputs).size() == 1 &&
matmul_in_y->Var()->Persistable();
bool transpose_X =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X"));
......@@ -654,7 +655,7 @@ void TrtReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size();
flag = flag && !transpose_X && !transpose_Y &&
std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 &&
matmul_in_y_rank == 2;
matmul_in_y_rank == 2 && matmul_in_y->Var()->Persistable();
std::vector<Node*>& next_ops = matmul_out->outputs;
flag = flag && next_ops.size() == 1 &&
......@@ -740,7 +741,7 @@ void TrtFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size();
pattern_found = pattern_found && !transpose_X && !transpose_Y &&
std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 &&
matmul_in_y_rank == 2;
matmul_in_y_rank == 2 && matmul_in_y->Var()->Persistable();
std::vector<Node*>& next_ops = matmul_out->outputs;
// we further require the matmul op is followed by one elementwise
......
......@@ -43,6 +43,7 @@ IRPassManager::IRPassManager(Argument *argument) {
"The scope ptr should not be nullptr."));
graph_->SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr);
}
disable_logs_ = argument->disable_logs();
ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes);
CreatePasses(argument, argument->ir_analysis_passes());
......
......@@ -103,6 +103,9 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU)
set_tests_properties(test_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_reshape2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_trt_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_trt_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_trt_reshape2_matmul_fuse_pass PROPERTIES TIMEOUT 240)
set_tests_properties(test_shuffle_channel_detect_pass PROPERTIES TIMEOUT 120)
if (WIN32)
set_tests_properties(test_matmul_scale_fuse_pass PROPERTIES TIMEOUT 300)
......
......@@ -39,17 +39,6 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest):
"""
def sample_predictor_configs(self, program_config):
# TRT
# config = self.create_trt_inference_config()
# config.enable_tensorrt_engine(
# max_batch_size=10,
# workspace_size=102400,
# min_subgraph_size=0,
# precision_mode=paddle_infer.PrecisionType.Float32,
# use_static=False,
# use_calib_mode=False)
# yield config, ['mul', 'elementwise_add'], (1e-5, 1e-5)
# cpu
config = self.create_inference_config(use_gpu=False)
yield config, ["mul", "elementwise_add"], (1e-5, 1e-5)
......@@ -58,33 +47,6 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest):
config = self.create_inference_config(use_gpu=True)
yield config, ["mul", "elementwise_add"], (1e-5, 1e-5)
def add_ignore_pass_case(self):
# Here we put some skip rules to avoid known bugs
def teller1(program_config, predictor_config):
if predictor_config.tensorrt_engine_enabled():
# On 3080, the results of MatMul and Mul are different
# When the input Y is weight
return True
# On TRT when the input Y is weight, Mul is converted to FC
if "matmul_y" not in program_config.weights \
or "bias" not in program_config.weights:
return True
y_shape = list(program_config.weights["matmul_y"].shape)
bias_shape = program_config.weights["bias"].shape
axis = program_config.ops[2].attrs["axis"]
# bias should be [mul_y_shape[-1]]
if axis == 0 or bias_shape[0] != y_shape[1] or len(
bias_shape) != 1:
return True
return False
self.add_ignore_check_case(
teller1,
IgnoreReasons.PASS_ACCURACY_ERROR,
"The pass error on TRT while shape of bias is not [out_size].", )
def sample_program_config(self, draw):
# 1. Generate shape and attr of flatten2
x_shape = draw(
......
......@@ -39,17 +39,6 @@ class TestReshape2MatmulFusePass(PassAutoScanTest):
"""
def sample_predictor_configs(self, program_config):
# TRT
# config = self.create_trt_inference_config()
# config.enable_tensorrt_engine(
# max_batch_size=10,
# workspace_size=102400,
# min_subgraph_size=0,
# precision_mode=paddle_infer.PrecisionType.Float32,
# use_static=False,
# use_calib_mode=False)
# yield config, ['mul', 'elementwise_add'], (1e-5, 1e-5)
# cpu
config = self.create_inference_config(use_gpu=False)
yield config, ["mul", "elementwise_add"], (1e-5, 1e-5)
......@@ -58,33 +47,6 @@ class TestReshape2MatmulFusePass(PassAutoScanTest):
config = self.create_inference_config(use_gpu=True)
yield config, ["mul", "elementwise_add"], (1e-5, 1e-5)
def add_ignore_pass_case(self):
# Here we put some skip rules to avoid known bugs
def teller1(program_config, predictor_config):
if predictor_config.tensorrt_engine_enabled():
# On 3080, the results of MatMul and Mul are different
# When the input Y is weight
return True
# On TRT when the input Y is weight, Mul is converted to FC
if "matmul_y" not in program_config.weights \
or "bias" not in program_config.weights:
return True
y_shape = list(program_config.weights["matmul_y"].shape)
bias_shape = program_config.weights["bias"].shape
axis = program_config.ops[2].attrs["axis"]
# bias should be [mul_y_shape[-1]]
if axis == 0 or bias_shape[0] != y_shape[1] or len(
bias_shape) != 1:
return True
return False
self.add_ignore_check_case(
teller1,
IgnoreReasons.PASS_ACCURACY_ERROR,
"The pass error on TRT while shape of bias is not [out_size].", )
def sample_program_config(self, draw):
# 1. Generate shape and attr of reshape2
reshape = draw(
......@@ -107,14 +69,21 @@ class TestReshape2MatmulFusePass(PassAutoScanTest):
# 4. Generate legal attr:axis of elementwise_add
axis = draw(st.integers(min_value=-1, max_value=1))
if axis == 0:
bias_shape = [x_shape[0]]
elif axis == 1:
bias_shape = [y_shape[1]]
else:
bias_shape = [x_shape[0], y_shape[1]]
if axis == 0 or axis == -1:
if draw(st.booleans()):
bias_shape[1] = 1
if axis == 0:
bias_shape = [x_shape[0], ]
else:
bias_shape = [y_shape[1], ]
else:
bias_shape = [x_shape[0], y_shape[1]]
elif axis == 1:
bias_shape = [y_shape[1], ]
if draw(st.integers(min_value=1, max_value=10)) <= 1:
bias_shape[-1] = 1
if len(bias_shape) == 2 and draw(st.booleans()):
bias_shape[0] = 1
reshape2_op = OpConfig(
"reshape2",
......
......@@ -47,45 +47,6 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest):
config = self.create_inference_config(use_gpu=True)
yield config, ["mul", "elementwise_add"], (1e-5, 1e-5)
# TRT
# config = self.create_trt_inference_config()
# config.enable_tensorrt_engine(
# max_batch_size=10,
# workspace_size=10240,
# min_subgraph_size=0,
# precision_mode=paddle_infer.PrecisionType.Float32,
# use_static=False,
# use_calib_mode=False)
# yield config, ['mul', 'elementwise_add'], (1e-5, 1e-5)
def add_ignore_pass_case(self):
# Here we put some skip rules to avoid known bugs
def teller1(program_config, predictor_config):
if predictor_config.tensorrt_engine_enabled():
# On 3080, the results of MatMul and Mul are different
# When the input Y is weight
return True
# On TRT when the input Y is weight, Mul is converted to FC
predictor_config.exp_disable_tensorrt_ops(["elementwise_add"])
if "matmul_y" not in program_config.weights \
or "bias" not in program_config.weights:
return True
y_shape = list(program_config.weights["matmul_y"].shape)
bias_shape = program_config.weights["bias"].shape
axis = program_config.ops[2].attrs["axis"]
# bias should be [mul_y_shape[-1]]
if axis == 0 or bias_shape[0] != y_shape[1] or len(
bias_shape) != 1:
return True
return False
self.add_ignore_check_case(
teller1,
IgnoreReasons.PASS_ACCURACY_ERROR,
"The pass error on TRT while shape of bias is not [out_size].", )
def sample_program_config(self, draw):
# 1. Generate shape of input:X of squeeze2
x_shape = draw(
......@@ -111,19 +72,21 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest):
# 4. Generate legal attr:axis of elementwise_add
axis = draw(st.integers(min_value=-1, max_value=1))
if axis == 0 or axis == -1:
bias_shape = [x_shape[0], y_shape[1]]
else:
if draw(st.booleans()):
if axis == 0:
bias_shape = [x_shape[0], ]
else:
bias_shape = [y_shape[1], ]
else:
bias_shape = [x_shape[0], y_shape[1]]
elif axis == 1:
bias_shape = [y_shape[1], ]
if draw(st.booleans()):
if draw(st.integers(min_value=1, max_value=10)) <= 1:
bias_shape[-1] = 1
if len(bias_shape) == 2 and draw(st.booleans()):
bias_shape[0] = 1
axis = 0
bias_shape = [2, ]
x_shape = [2, 1, 1, 1]
y_shape = [1, 2]
squeeze2_op = OpConfig(
"squeeze2",
inputs={"X": ["squeeze2_x"], },
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestFlatten2MatmulFusePass(PassAutoScanTest):
"""
x_var
|
flatten2
\
flatten2_out_var y_var
\ /
matmul bias_var
\ /
elementwise_add
"""
def sample_predictor_configs(self, program_config):
# TRT
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=10,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
yield config, ['mul', 'elementwise_add'], (1e-4, 1e-1)
def add_ignore_pass_case(self):
# Here we put some skip rules to avoid known bugs
def teller1(program_config, predictor_config):
y_shape = list(program_config.weights["matmul_y"].shape)
bias_shape = program_config.weights["bias"].shape
axis = program_config.ops[2].attrs["axis"]
# bias should be [mul_y_shape[-1]]
if axis == 0 or bias_shape[0] != y_shape[1] or len(bias_shape) != 1:
return True
return False
self.add_ignore_check_case(
teller1,
IgnoreReasons.PASS_ACCURACY_ERROR,
"The pass error on TRT while shape of bias is not [out_size].", )
def sample_program_config(self, draw):
# 1. Generate shape and attr of flatten2
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=10), min_size=4, max_size=4))
# [a, b, c, d] => [a, b*c*d]
flatten_axis = 1
flatten_shape = [x_shape[0], x_shape[1] * x_shape[2] * x_shape[3]]
# 2. Generate attr:transpose_X/transpose_Y/alpha of matmul
alpha = 1.0
transpose_X = False
transpose_Y = False
# 3. Generate legal shape of input:Y of matmul
y_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=2, max_size=2))
y_shape[0] = flatten_shape[1]
# 4. Generate legal attr:axis of elementwise_add
axis = draw(st.integers(min_value=-1, max_value=1))
if axis == 0:
axis = -1
bias_shape = [y_shape[1], ]
flatten2_op = OpConfig(
"flatten2",
inputs={"X": ["flatten2_x"], },
axis=flatten_axis,
outputs={"Out": ["flatten2_out"],
"XShape": ["xshape"]}, )
matmul_op = OpConfig(
"matmul",
inputs={"X": ["flatten2_out"],
"Y": ["matmul_y"]},
outputs={"Out": ["matmul_out"]},
alpha=alpha,
transpose_X=transpose_X,
transpose_Y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[], )
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["matmul_out"],
"Y": ["bias"]},
outputs={"Out": ["add_out"]},
axis=axis, )
ops = [flatten2_op, matmul_op, add_op]
program_config = ProgramConfig(
ops=ops,
weights={
"matmul_y": TensorConfig(shape=y_shape),
"bias": TensorConfig(shape=bias_shape),
},
inputs={"flatten2_x": TensorConfig(shape=x_shape), },
outputs=ops[-1].outputs["Out"], )
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["trt_flatten2_matmul_fuse_pass"])
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestReshape2MatmulFusePass(PassAutoScanTest):
"""
x_var
|
reshape2
\
reshape2_out_var y_var
\ /
matmul bias_var
\ /
elementwise_add
"""
def sample_predictor_configs(self, program_config):
# TRT
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=10,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
yield config, ['mul', 'elementwise_add'], (1e-4, 1e-1)
def add_ignore_pass_case(self):
# Here we put some skip rules to avoid known bugs
def teller1(program_config, predictor_config):
y_shape = list(program_config.weights["matmul_y"].shape)
bias_shape = program_config.weights["bias"].shape
axis = program_config.ops[2].attrs["axis"]
# bias should be [mul_y_shape[-1]]
if axis == 0 or bias_shape[0] != y_shape[1] or len(bias_shape) != 1:
return True
return False
self.add_ignore_check_case(
teller1,
IgnoreReasons.PASS_ACCURACY_ERROR,
"The pass error on TRT while shape of bias is not [out_size].", )
def sample_program_config(self, draw):
# 1. Generate shape and attr of reshape2
reshape = draw(
st.lists(
st.integers(
min_value=1, max_value=10), min_size=2, max_size=2))
x_shape = reshape + [1, 1]
# 2. Generate attr:transpose_X/transpose_Y/alpha of matmul
alpha = 1.0
transpose_X = False
transpose_Y = False
# 3. Generate legal shape of input:Y of matmul
y_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=2, max_size=2))
y_shape[0] = x_shape[1]
# 4. Generate legal attr:axis of elementwise_add
axis = draw(st.integers(min_value=-1, max_value=1))
if axis == 0:
axis = -1
bias_shape = [y_shape[1], ]
# if axis == -1:
# if draw(st.booleans()):
# bias_shape = [y_shape[1], ]
# else:
# bias_shape = [x_shape[0], y_shape[1]]
reshape2_op = OpConfig(
"reshape2",
inputs={"X": ["reshape2_x"], },
shape=reshape,
outputs={"Out": ["reshape2_out"],
"XShape": ["xshape"]}, )
matmul_op = OpConfig(
"matmul",
inputs={"X": ["reshape2_out"],
"Y": ["matmul_y"]},
outputs={"Out": ["matmul_out"]},
alpha=alpha,
transpose_X=transpose_X,
transpose_Y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[], )
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["matmul_out"],
"Y": ["bias"]},
outputs={"Out": ["add_out"]},
axis=axis, )
ops = [reshape2_op, matmul_op, add_op]
program_config = ProgramConfig(
ops=ops,
weights={
"matmul_y": TensorConfig(shape=y_shape),
"bias": TensorConfig(shape=bias_shape),
},
inputs={"reshape2_x": TensorConfig(shape=x_shape), },
outputs=ops[-1].outputs["Out"], )
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["trt_reshape2_matmul_fuse_pass"])
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestSqueeze2MatmulFusePass(PassAutoScanTest):
"""
x_var
|
squeeze2
\
squeeze2_out_var y_var
\ /
matmul bias_var
\ /
elementwise_add
"""
def sample_predictor_configs(self, program_config):
# TRT
config = self.create_trt_inference_config()
config.enable_tensorrt_engine(
max_batch_size=10,
workspace_size=10240,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False)
yield config, ['mul', 'elementwise_add'], (1e-4, 1e-1)
def add_ignore_pass_case(self):
# Here we put some skip rules to avoid known bugs
def teller1(program_config, predictor_config):
y_shape = list(program_config.weights["matmul_y"].shape)
bias_shape = program_config.weights["bias"].shape
axis = program_config.ops[2].attrs["axis"]
# bias should be [mul_y_shape[-1]]
if axis == 0 or bias_shape[0] != y_shape[1] or len(bias_shape) != 1:
return True
return False
self.add_ignore_check_case(
teller1,
IgnoreReasons.PASS_ACCURACY_ERROR,
"The pass error on TRT while shape of bias is not [out_size].", )
def sample_program_config(self, draw):
# 1. Generate shape of input:X of squeeze2
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=2, max_size=2))
# axes of squeeze2 == [2, 3]
x_shape += [1, 1]
axes = [2, 3]
# 2. Generate attr:transpose_X/transpose_Y/alpha of matmul
alpha = 1.0
transpose_X = False
transpose_Y = False
# 3. Generate legal shape of input:Y of matmul
y_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=2, max_size=2))
y_shape[0] = x_shape[1]
# 4. Generate legal attr:axis of elementwise_add
axis = draw(st.integers(min_value=-1, max_value=1))
if axis == 0:
axis = -1
bias_shape = [y_shape[1], ]
# if axis == -1:
# if draw(st.booleans()):
# bias_shape = [y_shape[1], ]
# else:
# bias_shape = [x_shape[0], y_shape[1]]
squeeze2_op = OpConfig(
"squeeze2",
inputs={"X": ["squeeze2_x"], },
axes=axes,
outputs={"Out": ["squeeze2_out"],
"XShape": ["xshape"]}, )
matmul_op = OpConfig(
"matmul",
inputs={"X": ["squeeze2_out"],
"Y": ["matmul_y"]},
outputs={"Out": ["matmul_out"]},
alpha=alpha,
transpose_X=transpose_X,
transpose_Y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[], )
add_op = OpConfig(
"elementwise_add",
inputs={"X": ["matmul_out"],
"Y": ["bias"]},
outputs={"Out": ["add_out"]},
axis=axis, )
ops = [squeeze2_op, matmul_op, add_op]
program_config = ProgramConfig(
ops=ops,
weights={
"matmul_y": TensorConfig(shape=y_shape),
"bias": TensorConfig(shape=bias_shape),
},
inputs={"squeeze2_x": TensorConfig(shape=x_shape), },
outputs=ops[-1].outputs["Out"], )
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=50,
passes=["trt_squeeze2_matmul_fuse_pass"])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册