diff --git a/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc index 6c73965a80943da9ee9ea346b5e6c22fa2e0f21c..3caaf08dc9cb54dad541c0c563e656dfb04f4bab 100644 --- a/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc @@ -498,7 +498,8 @@ void TrtSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { BOOST_GET_CONST(std::vector, squeeze2_op->Op()->GetAttr("axes")); flag = flag && squeeze2_in_x_rank == 4 && squeeze2_op_axes == std::vector{2, 3} && - (matmul_in_x->outputs).size() == 1; + (matmul_in_x->outputs).size() == 1 && + matmul_in_y->Var()->Persistable(); bool transpose_X = BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X")); @@ -654,7 +655,7 @@ void TrtReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size(); flag = flag && !transpose_X && !transpose_Y && std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 && - matmul_in_y_rank == 2; + matmul_in_y_rank == 2 && matmul_in_y->Var()->Persistable(); std::vector& next_ops = matmul_out->outputs; flag = flag && next_ops.size() == 1 && @@ -740,7 +741,7 @@ void TrtFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size(); pattern_found = pattern_found && !transpose_X && !transpose_Y && std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 && - matmul_in_y_rank == 2; + matmul_in_y_rank == 2 && matmul_in_y->Var()->Persistable(); std::vector& next_ops = matmul_out->outputs; // we further require the matmul op is followed by one elementwise diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 287c896e49bf254d70a5c79c818a39f913472f2f..d6eb39e767825def9f3ba74adec759d45f56e38a 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -43,6 +43,7 @@ IRPassManager::IRPassManager(Argument *argument) { "The scope ptr should not be nullptr.")); graph_->SetNotOwned(framework::ir::kParamScopeAttr, scope_ptr); } + disable_logs_ = argument->disable_logs(); ARGUMENT_CHECK_FIELD(argument, ir_analysis_passes); CreatePasses(argument, argument->ir_analysis_passes()); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 8f7b73fc0e03630f0d1c8a64ed9118f1322c5b65..b1ed0165edb56fae2d206e810e14b6114b5884bd 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -103,6 +103,9 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) set_tests_properties(test_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_reshape2_matmul_fuse_pass PROPERTIES TIMEOUT 240) + set_tests_properties(test_trt_flatten2_matmul_fuse_pass PROPERTIES TIMEOUT 240) + set_tests_properties(test_trt_squeeze2_matmul_fuse_pass PROPERTIES TIMEOUT 240) + set_tests_properties(test_trt_reshape2_matmul_fuse_pass PROPERTIES TIMEOUT 240) set_tests_properties(test_shuffle_channel_detect_pass PROPERTIES TIMEOUT 120) if (WIN32) set_tests_properties(test_matmul_scale_fuse_pass PROPERTIES TIMEOUT 300) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_flatten2_matmul_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_flatten2_matmul_fuse_pass.py index ec3bc0287323d7993c2cecf94bb43cd5874fdf06..ba99ac306c700f8180f0670ce9944fbe861530c5 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_flatten2_matmul_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_flatten2_matmul_fuse_pass.py @@ -39,17 +39,6 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest): """ def sample_predictor_configs(self, program_config): - # TRT - # config = self.create_trt_inference_config() - # config.enable_tensorrt_engine( - # max_batch_size=10, - # workspace_size=102400, - # min_subgraph_size=0, - # precision_mode=paddle_infer.PrecisionType.Float32, - # use_static=False, - # use_calib_mode=False) - # yield config, ['mul', 'elementwise_add'], (1e-5, 1e-5) - # cpu config = self.create_inference_config(use_gpu=False) yield config, ["mul", "elementwise_add"], (1e-5, 1e-5) @@ -58,33 +47,6 @@ class TestFlatten2MatmulFusePass(PassAutoScanTest): config = self.create_inference_config(use_gpu=True) yield config, ["mul", "elementwise_add"], (1e-5, 1e-5) - def add_ignore_pass_case(self): - # Here we put some skip rules to avoid known bugs - def teller1(program_config, predictor_config): - if predictor_config.tensorrt_engine_enabled(): - # On 3080, the results of MatMul and Mul are different - # When the input Y is weight - return True - - # On TRT when the input Y is weight, Mul is converted to FC - if "matmul_y" not in program_config.weights \ - or "bias" not in program_config.weights: - return True - - y_shape = list(program_config.weights["matmul_y"].shape) - bias_shape = program_config.weights["bias"].shape - axis = program_config.ops[2].attrs["axis"] - # bias should be [mul_y_shape[-1]] - if axis == 0 or bias_shape[0] != y_shape[1] or len( - bias_shape) != 1: - return True - return False - - self.add_ignore_check_case( - teller1, - IgnoreReasons.PASS_ACCURACY_ERROR, - "The pass error on TRT while shape of bias is not [out_size].", ) - def sample_program_config(self, draw): # 1. Generate shape and attr of flatten2 x_shape = draw( diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_reshape2_matmul_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_reshape2_matmul_fuse_pass.py index 6f311ab11fefd1d0dfc780cdb23871e7f449131e..9bec34df5b6e1a5df949e7375a3ba3b5aa680b1a 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_reshape2_matmul_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_reshape2_matmul_fuse_pass.py @@ -39,17 +39,6 @@ class TestReshape2MatmulFusePass(PassAutoScanTest): """ def sample_predictor_configs(self, program_config): - # TRT - # config = self.create_trt_inference_config() - # config.enable_tensorrt_engine( - # max_batch_size=10, - # workspace_size=102400, - # min_subgraph_size=0, - # precision_mode=paddle_infer.PrecisionType.Float32, - # use_static=False, - # use_calib_mode=False) - # yield config, ['mul', 'elementwise_add'], (1e-5, 1e-5) - # cpu config = self.create_inference_config(use_gpu=False) yield config, ["mul", "elementwise_add"], (1e-5, 1e-5) @@ -58,33 +47,6 @@ class TestReshape2MatmulFusePass(PassAutoScanTest): config = self.create_inference_config(use_gpu=True) yield config, ["mul", "elementwise_add"], (1e-5, 1e-5) - def add_ignore_pass_case(self): - # Here we put some skip rules to avoid known bugs - def teller1(program_config, predictor_config): - if predictor_config.tensorrt_engine_enabled(): - # On 3080, the results of MatMul and Mul are different - # When the input Y is weight - return True - - # On TRT when the input Y is weight, Mul is converted to FC - if "matmul_y" not in program_config.weights \ - or "bias" not in program_config.weights: - return True - - y_shape = list(program_config.weights["matmul_y"].shape) - bias_shape = program_config.weights["bias"].shape - axis = program_config.ops[2].attrs["axis"] - # bias should be [mul_y_shape[-1]] - if axis == 0 or bias_shape[0] != y_shape[1] or len( - bias_shape) != 1: - return True - return False - - self.add_ignore_check_case( - teller1, - IgnoreReasons.PASS_ACCURACY_ERROR, - "The pass error on TRT while shape of bias is not [out_size].", ) - def sample_program_config(self, draw): # 1. Generate shape and attr of reshape2 reshape = draw( @@ -107,14 +69,21 @@ class TestReshape2MatmulFusePass(PassAutoScanTest): # 4. Generate legal attr:axis of elementwise_add axis = draw(st.integers(min_value=-1, max_value=1)) - if axis == 0: - bias_shape = [x_shape[0]] - elif axis == 1: - bias_shape = [y_shape[1]] - else: - bias_shape = [x_shape[0], y_shape[1]] + if axis == 0 or axis == -1: if draw(st.booleans()): - bias_shape[1] = 1 + if axis == 0: + bias_shape = [x_shape[0], ] + else: + bias_shape = [y_shape[1], ] + else: + bias_shape = [x_shape[0], y_shape[1]] + elif axis == 1: + bias_shape = [y_shape[1], ] + + if draw(st.integers(min_value=1, max_value=10)) <= 1: + bias_shape[-1] = 1 + if len(bias_shape) == 2 and draw(st.booleans()): + bias_shape[0] = 1 reshape2_op = OpConfig( "reshape2", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_squeeze2_matmul_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_squeeze2_matmul_fuse_pass.py index 9600ef7e0d1091e688d114b1cd17038a09d3f367..6d9457f35750b3393eb764a383ff24d13d7bf412 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_squeeze2_matmul_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_squeeze2_matmul_fuse_pass.py @@ -47,45 +47,6 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest): config = self.create_inference_config(use_gpu=True) yield config, ["mul", "elementwise_add"], (1e-5, 1e-5) - # TRT - # config = self.create_trt_inference_config() - # config.enable_tensorrt_engine( - # max_batch_size=10, - # workspace_size=10240, - # min_subgraph_size=0, - # precision_mode=paddle_infer.PrecisionType.Float32, - # use_static=False, - # use_calib_mode=False) - # yield config, ['mul', 'elementwise_add'], (1e-5, 1e-5) - - def add_ignore_pass_case(self): - # Here we put some skip rules to avoid known bugs - def teller1(program_config, predictor_config): - if predictor_config.tensorrt_engine_enabled(): - # On 3080, the results of MatMul and Mul are different - # When the input Y is weight - return True - - # On TRT when the input Y is weight, Mul is converted to FC - predictor_config.exp_disable_tensorrt_ops(["elementwise_add"]) - if "matmul_y" not in program_config.weights \ - or "bias" not in program_config.weights: - return True - - y_shape = list(program_config.weights["matmul_y"].shape) - bias_shape = program_config.weights["bias"].shape - axis = program_config.ops[2].attrs["axis"] - # bias should be [mul_y_shape[-1]] - if axis == 0 or bias_shape[0] != y_shape[1] or len( - bias_shape) != 1: - return True - return False - - self.add_ignore_check_case( - teller1, - IgnoreReasons.PASS_ACCURACY_ERROR, - "The pass error on TRT while shape of bias is not [out_size].", ) - def sample_program_config(self, draw): # 1. Generate shape of input:X of squeeze2 x_shape = draw( @@ -111,19 +72,21 @@ class TestSqueeze2MatmulFusePass(PassAutoScanTest): # 4. Generate legal attr:axis of elementwise_add axis = draw(st.integers(min_value=-1, max_value=1)) if axis == 0 or axis == -1: - bias_shape = [x_shape[0], y_shape[1]] - else: + if draw(st.booleans()): + if axis == 0: + bias_shape = [x_shape[0], ] + else: + bias_shape = [y_shape[1], ] + else: + bias_shape = [x_shape[0], y_shape[1]] + elif axis == 1: bias_shape = [y_shape[1], ] - if draw(st.booleans()): + + if draw(st.integers(min_value=1, max_value=10)) <= 1: bias_shape[-1] = 1 if len(bias_shape) == 2 and draw(st.booleans()): bias_shape[0] = 1 - axis = 0 - bias_shape = [2, ] - x_shape = [2, 1, 1, 1] - y_shape = [1, 2] - squeeze2_op = OpConfig( "squeeze2", inputs={"X": ["squeeze2_x"], }, diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_flatten2_matmul_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_flatten2_matmul_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0f8857e92b48220dee2b18c4b108d98a675711 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_flatten2_matmul_fuse_pass.py @@ -0,0 +1,146 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestFlatten2MatmulFusePass(PassAutoScanTest): + """ + x_var + | + flatten2 + \ + flatten2_out_var y_var + \ / + matmul bias_var + \ / + elementwise_add + """ + + def sample_predictor_configs(self, program_config): + # TRT + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=10, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + yield config, ['mul', 'elementwise_add'], (1e-4, 1e-1) + + def add_ignore_pass_case(self): + # Here we put some skip rules to avoid known bugs + def teller1(program_config, predictor_config): + y_shape = list(program_config.weights["matmul_y"].shape) + bias_shape = program_config.weights["bias"].shape + axis = program_config.ops[2].attrs["axis"] + # bias should be [mul_y_shape[-1]] + if axis == 0 or bias_shape[0] != y_shape[1] or len(bias_shape) != 1: + return True + return False + + self.add_ignore_check_case( + teller1, + IgnoreReasons.PASS_ACCURACY_ERROR, + "The pass error on TRT while shape of bias is not [out_size].", ) + + def sample_program_config(self, draw): + # 1. Generate shape and attr of flatten2 + x_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=10), min_size=4, max_size=4)) + # [a, b, c, d] => [a, b*c*d] + flatten_axis = 1 + flatten_shape = [x_shape[0], x_shape[1] * x_shape[2] * x_shape[3]] + + # 2. Generate attr:transpose_X/transpose_Y/alpha of matmul + alpha = 1.0 + transpose_X = False + transpose_Y = False + + # 3. Generate legal shape of input:Y of matmul + y_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=8), min_size=2, max_size=2)) + y_shape[0] = flatten_shape[1] + + # 4. Generate legal attr:axis of elementwise_add + axis = draw(st.integers(min_value=-1, max_value=1)) + if axis == 0: + axis = -1 + bias_shape = [y_shape[1], ] + + flatten2_op = OpConfig( + "flatten2", + inputs={"X": ["flatten2_x"], }, + axis=flatten_axis, + outputs={"Out": ["flatten2_out"], + "XShape": ["xshape"]}, ) + matmul_op = OpConfig( + "matmul", + inputs={"X": ["flatten2_out"], + "Y": ["matmul_y"]}, + outputs={"Out": ["matmul_out"]}, + alpha=alpha, + transpose_X=transpose_X, + transpose_Y=transpose_Y, + fused_reshape_X=[], + fused_reshape_Y=[], + fused_transpose_X=[], + fused_transpose_Y=[], + fused_reshape_Out=[], + fused_transpose_Out=[], ) + + add_op = OpConfig( + "elementwise_add", + inputs={"X": ["matmul_out"], + "Y": ["bias"]}, + outputs={"Out": ["add_out"]}, + axis=axis, ) + + ops = [flatten2_op, matmul_op, add_op] + + program_config = ProgramConfig( + ops=ops, + weights={ + "matmul_y": TensorConfig(shape=y_shape), + "bias": TensorConfig(shape=bias_shape), + }, + inputs={"flatten2_x": TensorConfig(shape=x_shape), }, + outputs=ops[-1].outputs["Out"], ) + + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=50, + passes=["trt_flatten2_matmul_fuse_pass"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape2_matmul_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape2_matmul_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..ecfc5c9dac064638152d5d0d14f56151382b6fb3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_reshape2_matmul_fuse_pass.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestReshape2MatmulFusePass(PassAutoScanTest): + """ + x_var + | + reshape2 + \ + reshape2_out_var y_var + \ / + matmul bias_var + \ / + elementwise_add + """ + + def sample_predictor_configs(self, program_config): + # TRT + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=10, + workspace_size=102400, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + yield config, ['mul', 'elementwise_add'], (1e-4, 1e-1) + + def add_ignore_pass_case(self): + # Here we put some skip rules to avoid known bugs + def teller1(program_config, predictor_config): + y_shape = list(program_config.weights["matmul_y"].shape) + bias_shape = program_config.weights["bias"].shape + axis = program_config.ops[2].attrs["axis"] + # bias should be [mul_y_shape[-1]] + if axis == 0 or bias_shape[0] != y_shape[1] or len(bias_shape) != 1: + return True + return False + + self.add_ignore_check_case( + teller1, + IgnoreReasons.PASS_ACCURACY_ERROR, + "The pass error on TRT while shape of bias is not [out_size].", ) + + def sample_program_config(self, draw): + # 1. Generate shape and attr of reshape2 + reshape = draw( + st.lists( + st.integers( + min_value=1, max_value=10), min_size=2, max_size=2)) + x_shape = reshape + [1, 1] + + # 2. Generate attr:transpose_X/transpose_Y/alpha of matmul + alpha = 1.0 + transpose_X = False + transpose_Y = False + + # 3. Generate legal shape of input:Y of matmul + y_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=8), min_size=2, max_size=2)) + y_shape[0] = x_shape[1] + + # 4. Generate legal attr:axis of elementwise_add + axis = draw(st.integers(min_value=-1, max_value=1)) + if axis == 0: + axis = -1 + bias_shape = [y_shape[1], ] + # if axis == -1: + # if draw(st.booleans()): + # bias_shape = [y_shape[1], ] + # else: + # bias_shape = [x_shape[0], y_shape[1]] + + reshape2_op = OpConfig( + "reshape2", + inputs={"X": ["reshape2_x"], }, + shape=reshape, + outputs={"Out": ["reshape2_out"], + "XShape": ["xshape"]}, ) + matmul_op = OpConfig( + "matmul", + inputs={"X": ["reshape2_out"], + "Y": ["matmul_y"]}, + outputs={"Out": ["matmul_out"]}, + alpha=alpha, + transpose_X=transpose_X, + transpose_Y=transpose_Y, + fused_reshape_X=[], + fused_reshape_Y=[], + fused_transpose_X=[], + fused_transpose_Y=[], + fused_reshape_Out=[], + fused_transpose_Out=[], ) + + add_op = OpConfig( + "elementwise_add", + inputs={"X": ["matmul_out"], + "Y": ["bias"]}, + outputs={"Out": ["add_out"]}, + axis=axis, ) + + ops = [reshape2_op, matmul_op, add_op] + + program_config = ProgramConfig( + ops=ops, + weights={ + "matmul_y": TensorConfig(shape=y_shape), + "bias": TensorConfig(shape=bias_shape), + }, + inputs={"reshape2_x": TensorConfig(shape=x_shape), }, + outputs=ops[-1].outputs["Out"], ) + + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=50, + passes=["trt_reshape2_matmul_fuse_pass"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_squeeze2_matmul_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_squeeze2_matmul_fuse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..d2791737a1cbf138ab78e51c0a9146932d4c34f8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_squeeze2_matmul_fuse_pass.py @@ -0,0 +1,150 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import PassAutoScanTest, IgnoreReasons +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + +import hypothesis +from hypothesis import given, settings, seed, example, assume, reproduce_failure +import hypothesis.strategies as st + + +class TestSqueeze2MatmulFusePass(PassAutoScanTest): + """ + x_var + | + squeeze2 + \ + squeeze2_out_var y_var + \ / + matmul bias_var + \ / + elementwise_add + """ + + def sample_predictor_configs(self, program_config): + # TRT + config = self.create_trt_inference_config() + config.enable_tensorrt_engine( + max_batch_size=10, + workspace_size=10240, + min_subgraph_size=0, + precision_mode=paddle_infer.PrecisionType.Float32, + use_static=False, + use_calib_mode=False) + yield config, ['mul', 'elementwise_add'], (1e-4, 1e-1) + + def add_ignore_pass_case(self): + # Here we put some skip rules to avoid known bugs + def teller1(program_config, predictor_config): + y_shape = list(program_config.weights["matmul_y"].shape) + bias_shape = program_config.weights["bias"].shape + axis = program_config.ops[2].attrs["axis"] + # bias should be [mul_y_shape[-1]] + if axis == 0 or bias_shape[0] != y_shape[1] or len(bias_shape) != 1: + return True + return False + + self.add_ignore_check_case( + teller1, + IgnoreReasons.PASS_ACCURACY_ERROR, + "The pass error on TRT while shape of bias is not [out_size].", ) + + def sample_program_config(self, draw): + # 1. Generate shape of input:X of squeeze2 + x_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=8), min_size=2, max_size=2)) + # axes of squeeze2 == [2, 3] + x_shape += [1, 1] + axes = [2, 3] + + # 2. Generate attr:transpose_X/transpose_Y/alpha of matmul + alpha = 1.0 + transpose_X = False + transpose_Y = False + + # 3. Generate legal shape of input:Y of matmul + y_shape = draw( + st.lists( + st.integers( + min_value=1, max_value=8), min_size=2, max_size=2)) + y_shape[0] = x_shape[1] + + # 4. Generate legal attr:axis of elementwise_add + axis = draw(st.integers(min_value=-1, max_value=1)) + if axis == 0: + axis = -1 + bias_shape = [y_shape[1], ] + # if axis == -1: + # if draw(st.booleans()): + # bias_shape = [y_shape[1], ] + # else: + # bias_shape = [x_shape[0], y_shape[1]] + + squeeze2_op = OpConfig( + "squeeze2", + inputs={"X": ["squeeze2_x"], }, + axes=axes, + outputs={"Out": ["squeeze2_out"], + "XShape": ["xshape"]}, ) + matmul_op = OpConfig( + "matmul", + inputs={"X": ["squeeze2_out"], + "Y": ["matmul_y"]}, + outputs={"Out": ["matmul_out"]}, + alpha=alpha, + transpose_X=transpose_X, + transpose_Y=transpose_Y, + fused_reshape_X=[], + fused_reshape_Y=[], + fused_transpose_X=[], + fused_transpose_Y=[], + fused_reshape_Out=[], + fused_transpose_Out=[], ) + + add_op = OpConfig( + "elementwise_add", + inputs={"X": ["matmul_out"], + "Y": ["bias"]}, + outputs={"Out": ["add_out"]}, + axis=axis, ) + + ops = [squeeze2_op, matmul_op, add_op] + program_config = ProgramConfig( + ops=ops, + weights={ + "matmul_y": TensorConfig(shape=y_shape), + "bias": TensorConfig(shape=bias_shape), + }, + inputs={"squeeze2_x": TensorConfig(shape=x_shape), }, + outputs=ops[-1].outputs["Out"], ) + + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=50, + passes=["trt_squeeze2_matmul_fuse_pass"]) + + +if __name__ == "__main__": + unittest.main()