diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index c41b667e18a833111fdfc70c37b267430753dffb..bb495860c90e5fe3a749a0e89fc0d6f4546fb5a2 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1962,6 +1962,8 @@ USE_TRT_CONVERTER(recover_padding) USE_TRT_CONVERTER(remove_padding) USE_TRT_CONVERTER(top_k) USE_TRT_CONVERTER(top_k_v2) +USE_TRT_CONVERTER(squeeze2) +USE_TRT_CONVERTER(unsqueeze2) #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) USE_TRT_CONVERTER(sparse_fc) USE_TRT_CONVERTER(sparse_multihead_matmul) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 52a3c1df9a92550dd1edd8c4da41d84cbddf8b6e..4c52d91fa1259a4b8fc66cb299140ec6fe6b5775 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -61,7 +61,9 @@ list( transformer_input_convert_op.cc remove_padding_op.cc recover_padding_op.cc - top_k_op.cc) + top_k_op.cc + squeeze2_op.cc + unsqueeze2_op.cc) if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8) list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc) diff --git a/paddle/fluid/inference/tensorrt/convert/squeeze2_op.cc b/paddle/fluid/inference/tensorrt/convert/squeeze2_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..22bacfcc8e7d96137c6e8e9bccf7ce681da0ea40 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/squeeze2_op.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class Squeeze2OpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert a fluid squeeze2 op to tensorrt shuffle layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto input_dims = input->getDimensions(); + auto output_name = op_desc.Output("Out")[0]; + + // Get Attrs + std::vector axes = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("axes")); + PADDLE_ENFORCE_GT( + axes.size(), 0, + platform::errors::InvalidArgument( + "Attr(axes).size should be > 0 in squeeze2 op in TensorRT," + "but received axes.size() = %d.", + axes.size())); + + std::vector should_squeeze(input_dims.nbDims, false); + for (size_t i = 0; i < axes.size(); i++) { + if (engine_->with_dynamic_shape()) { + axes[i] += (axes[i] < 0) ? input_dims.nbDims : 0; + } else { + axes[i] += (axes[i] < 0) ? input_dims.nbDims : -1; + } + should_squeeze[axes[i]] = true; + } + + nvinfer1::Dims trt_out_dims; + trt_out_dims.nbDims = 0; + std::vector gather_indices; + for (size_t i = 0; i < should_squeeze.size(); i++) { + if (should_squeeze[i]) continue; + gather_indices.push_back(i); + // for static shape + trt_out_dims.d[trt_out_dims.nbDims] = input_dims.d[i]; + trt_out_dims.nbDims++; + } + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + if (engine_->with_dynamic_shape()) { + auto* shape_tensor = Shape(input); + auto* real_shape_tensor = Gather(shape_tensor, gather_indices); + layer->setInput(1, *real_shape_tensor); + } else { + layer->setReshapeDimensions(trt_out_dims); + } + RreplenishLayerAndOutput(layer, "squeeze2", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(squeeze2, Squeeze2OpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/unsqueeze2_op.cc b/paddle/fluid/inference/tensorrt/convert/unsqueeze2_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e65fbdbd7652b6ec8966bba4ed91bded8deef70 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/unsqueeze2_op.cc @@ -0,0 +1,99 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class Unsqueeze2OpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert a fluid unsqueeze2 op to tensorrt shuffle layer"; + + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto input_dims = input->getDimensions(); + auto output_name = op_desc.Output("Out")[0]; + + // Get Attrs + std::vector axes = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("axes")); + PADDLE_ENFORCE_GT( + axes.size(), 0, + platform::errors::InvalidArgument( + "Attr(axes).size should be > 0 in unsqueeze2 op in TensorRT," + "but received axes.size() = %d.", + axes.size())); + + std::vector should_unsqueeze(input_dims.nbDims + axes.size(), false); + int cur_out_rank = input_dims.nbDims; + for (size_t i = 0; i < axes.size(); i++) { + cur_out_rank++; + if (engine_->with_dynamic_shape()) { + axes[i] += (axes[i] < 0) ? cur_out_rank : 0; + } else { + axes[i] += (axes[i] < 0) ? cur_out_rank : -1; + } + // axes[i] is relative to cur_out_rank + // we make [axes[i], cur_out_rank - 2] shift right + // and make (axes[i]) to true! + for (int j = cur_out_rank - 1; j > axes[i]; j--) { + should_unsqueeze[j] = should_unsqueeze[j - 1]; + } + if (axes[i] >= cur_out_rank) + should_unsqueeze[cur_out_rank - 1] = true; + else + should_unsqueeze[axes[i]] = true; + } + + nvinfer1::Dims trt_out_dims; + trt_out_dims.nbDims = should_unsqueeze.size(); + std::vector gather_indices; + int in_rank_i = 0; + for (size_t i = 0; i < should_unsqueeze.size(); i++) { + if (should_unsqueeze[i]) { + trt_out_dims.d[i] = 1; + gather_indices.push_back(input_dims.nbDims); + continue; + } + trt_out_dims.d[i] = input_dims.d[in_rank_i]; + gather_indices.push_back(in_rank_i); + in_rank_i++; + } + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + if (engine_->with_dynamic_shape()) { + auto* shape_tensor = Shape(input); + std::vector all_one(axes.size(), 1); + auto* all_one_tensor = Add1DConstantLayer(all_one); + std::vector concat_inputs = {shape_tensor, + all_one_tensor}; + auto* real_shape_tensor = Gather(Concat(concat_inputs), gather_indices); + layer->setInput(1, *real_shape_tensor); + } else { + layer->setReshapeDimensions(trt_out_dims); + } + RreplenishLayerAndOutput(layer, "unsqueeze2", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(unsqueeze2, Unsqueeze2OpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 05e8d196a8285e7f7bb00214613c570cb7f64a78..6ce9b9c0bf85a77e47e70e590add2f23dbc78aaa 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -153,7 +153,9 @@ struct SimpleOpTypeSetTeller : public Teller { "preln_skip_layernorm", "transformer_input_convert", "recover_padding", - "remove_padding"}; + "remove_padding", + "squeeze2", + "unsqueeze2"}; std::unordered_set teller_set{ "mul", "matmul", @@ -242,7 +244,9 @@ struct SimpleOpTypeSetTeller : public Teller { "multiclass_nms3", "transformer_input_convert", "recover_padding", - "remove_padding"}; + "remove_padding", + "squeeze2", + "unsqueeze2"}; }; bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, @@ -887,6 +891,44 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } } + if (op_type == "squeeze2") { + std::vector axes; + if (desc.HasAttr("axes")) { + axes = BOOST_GET_CONST(std::vector, desc.GetAttr("axes")); + } + if (axes.size() == 0) { + VLOG(3) << "The necessary attributes of the squeeze2 operator axes is " + "missing."; + return false; + } + if (!with_dynamic_shape) { + if (std::find(axes.begin(), axes.end(), 0) != axes.end()) { + VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not " + "supported in static shape"; + return false; + } + } + } + + if (op_type == "unsqueeze2") { + std::vector axes; + if (desc.HasAttr("axes")) { + axes = BOOST_GET_CONST(std::vector, desc.GetAttr("axes")); + } + if (axes.size() == 0) { + VLOG(3) << "The necessary attributes of the squeeze2 operator axes is " + "missing."; + return false; + } + if (!with_dynamic_shape) { + if (std::find(axes.begin(), axes.end(), 0) != axes.end()) { + VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not " + "supported in static shape"; + return false; + } + } + } + if (op_type == "batch_norm") { const std::vector bn_inputs = {"X", "Bias", "Mean", "Scale", "Variance"}; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_squeeze2.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_squeeze2.py new file mode 100644 index 0000000000000000000000000000000000000000..f82791a59123356c98a17d1287d1a2c1cc1d8352 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_squeeze2.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import unittest +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertSplitTest(TrtLayerAutoScanTest): + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + inputs = program_config.inputs + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + if len(inputs['in_data'].shape) <= max(attrs[0]['axes']): + return False + return True + + def sample_program_configs(self): + for dims in [2, 3, 4]: + for batch in [3, 4]: + for axes in [[2], [2, 3], [-1]]: + self.batch = batch + self.dims = dims + self.axes = axes + dics = [{"axes": axes}] + ops_config = [{ + "op_type": "squeeze2", + "op_inputs": { + "X": ["in_data"] + }, + "op_outputs": { + "Out": ["out_data"], + "XShape": ["XShape_data"] + }, + "op_attrs": dics[0] + }] + # new_axes is the update of axes + new_axes = list(axes) + for i in range(len(new_axes)): + if (new_axes[i] < 0): + new_axes[i] += dims + if (max(new_axes) >= dims): + continue + # generate input data + self.input_shape = [1] * dims + for i in range(dims): + self.input_shape[i] = np.random.randint(1, 20) + + def generate_input1(attrs: List[Dict[str, Any]], batch): + self.input_shape[0] = batch + for i in new_axes: + self.input_shape[i] = 1 + return np.random.random(self.input_shape).astype( + np.float32) + + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "in_data": + TensorConfig( + data_gen=partial(generate_input1, dics, batch)) + }, + outputs=["out_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + + def generate_dynamic_shape(attrs): + max_shape = list(self.input_shape) + min_shape = list(self.input_shape) + opt_shape = list(self.input_shape) + for i in range(len(self.input_shape)): + max_shape[i] = max_shape[i] + 1 + self.dynamic_shape.min_input_shape = {"in_data": min_shape} + self.dynamic_shape.max_input_shape = {"in_data": max_shape} + self.dynamic_shape.opt_input_shape = {"in_data": opt_shape} + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + self.trt_param.max_batch_size = 9 + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + + def add_skip_trt_case(self): + pass + + def test(self): + self.add_skip_trt_case() + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unsqueeze2.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unsqueeze2.py new file mode 100644 index 0000000000000000000000000000000000000000..fc99da714f6846c27f4cd8f516d8b639f1eb59dc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_unsqueeze2.py @@ -0,0 +1,124 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import unittest +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertSplitTest(TrtLayerAutoScanTest): + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + for dims in [2, 3, 4]: + for batch in [3, 4]: + for axes in [[-2, 3], [1], [2], [2, 3]]: + self.batch = batch + self.dims = dims + self.axes = axes + dics = [{"axes": axes}] + ops_config = [{ + "op_type": "unsqueeze2", + "op_inputs": { + "X": ["in_data"] + }, + "op_outputs": { + "Out": ["out_data"], + "XShape": ["XShape_data"] + }, + "op_attrs": dics[0] + }] + + # generate input data + self.input_shape = [1] * dims + for i in range(dims): + self.input_shape[i] = np.random.randint(1, 20) + + def generate_input1(attrs: List[Dict[str, Any]], batch): + self.input_shape[0] = batch + return np.random.random(self.input_shape).astype( + np.float32) + + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "in_data": + TensorConfig( + data_gen=partial(generate_input1, dics, batch)) + }, + outputs=["out_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + + def generate_dynamic_shape(attrs): + max_shape = list(self.input_shape) + min_shape = list(self.input_shape) + opt_shape = list(self.input_shape) + for i in range(len(self.input_shape)): + max_shape[i] = max_shape[i] + 1 + self.dynamic_shape.min_input_shape = {"in_data": min_shape} + self.dynamic_shape.max_input_shape = {"in_data": max_shape} + self.dynamic_shape.opt_input_shape = {"in_data": opt_shape} + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + self.trt_param.max_batch_size = 9 + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + + def add_skip_trt_case(self): + pass + + def test(self): + self.add_skip_trt_case() + self.run_test() + + +if __name__ == "__main__": + unittest.main()