From aee2db012751f27f7447f0fdc6a13f19d092fa8b Mon Sep 17 00:00:00 2001 From: xiaoxiaohehe001 <49090790+xiaoxiaohehe001@users.noreply.github.com> Date: Mon, 5 Dec 2022 10:47:38 +0800 Subject: [PATCH] [Paddle Inference] Support range trt converter and add scalar interface. (#48697) * add_range * add_range --- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../inference/tensorrt/convert/range_op.cc | 65 +++++ paddle/fluid/inference/tensorrt/engine.cc | 16 +- paddle/fluid/inference/tensorrt/engine.h | 5 +- paddle/fluid/inference/tensorrt/op_teller.cc | 8 + .../ir/inference/test_trt_convert_range.py | 230 ++++++++++++++++++ 7 files changed, 321 insertions(+), 5 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/range_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_range.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 293de6bcd3..67e6478bff 100755 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2329,6 +2329,7 @@ USE_TRT_CONVERTER(remove_padding) USE_TRT_CONVERTER(equal); USE_TRT_CONVERTER(top_k) USE_TRT_CONVERTER(top_k_v2) +USE_TRT_CONVERTER(range) USE_TRT_CONVERTER(squeeze2) USE_TRT_CONVERTER(unsqueeze2) USE_TRT_CONVERTER(sum) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index b796cf1c2a..cec617c2f5 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -71,6 +71,7 @@ list( preln_residual_bias.cc c_allreduce_op.cc top_k_op.cc + range_op.cc squeeze2_op.cc unsqueeze2_op.cc rnn_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/range_op.cc b/paddle/fluid/inference/tensorrt/convert/range_op.cc new file mode 100644 index 0000000000..7288f4877b --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/range_op.cc @@ -0,0 +1,65 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class RangeOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(3) << "convert a range op to tensorrt layer"; + framework::OpDesc op_desc(op, nullptr); + nvinfer1::ILayer* layer = nullptr; + nvinfer1::ITensor* quotient_tensor; + + // Declare inputs + auto* start = engine_->GetITensor(op_desc.Input("Start")[0]); + auto* end = engine_->GetITensor(op_desc.Input("End")[0]); + auto* step = engine_->GetITensor(op_desc.Input("Step")[0]); + auto output_name = op_desc.Output("Out")[0]; + + auto zero_tensor = Add1DConstantLayer(0, output_name + "_zero_tensor_"); + auto fquotient_tensor = FloorDiv(Sub(start, end), step); + if (start->getType() == nvinfer1::DataType::kFLOAT) { + auto* cast_int32_layer = + TRT_ENGINE_ADD_LAYER(engine_, Identity, *fquotient_tensor); + cast_int32_layer->setOutputType(0, nvinfer1::DataType::kINT32); + cast_int32_layer->getOutput(0)->setType(nvinfer1::DataType::kINT32); + quotient_tensor = cast_int32_layer->getOutput(0); + } else { + quotient_tensor = fquotient_tensor; + } + auto number_tensor = Max(Sub(zero_tensor, quotient_tensor), zero_tensor); + auto* start1 = engine_->GetITensor(op_desc.Input("Start")[0], true); + + layer = TRT_ENGINE_ADD_LAYER( + engine_, Fill, nvinfer1::Dims{}, nvinfer1::FillOperation::kLINSPACE); + layer->setInput(0, *number_tensor); + layer->setInput(1, *start1); + layer->setInput(2, *step); + + RreplenishLayerAndOutput(layer, "range", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(range, RangeOpConverter); diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 0b3c099934..255ef5d6d6 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -451,7 +451,11 @@ void TensorRTEngine::SetITensor(const std::string &name, itensor_map_[name] = tensor; } -nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) { +nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name, + bool scalar) { + if (scalar) { + return ConvertWeight2ITensor(name, true); + } if (itensor_map_.count(name)) { return itensor_map_[name]; } else { @@ -463,7 +467,7 @@ nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) { // For cases when input is not middle-tensor , but persistable tensor // you should call this. nvinfer1::ITensor *TensorRTEngine::ConvertWeight2ITensor( - const std::string &name) { + const std::string &name, bool scalar) { auto *var_v = scope_->FindVar(name); PADDLE_ENFORCE_NOT_NULL( var_v, @@ -489,9 +493,15 @@ nvinfer1::ITensor *TensorRTEngine::ConvertWeight2ITensor( trt_in_shape.d[i] = trt_in_shape.d[i + 1]; } } + if (scalar) { + trt_in_shape.nbDims = 0; + trt_in_shape.d[0] = var_dims[0]; + } nvinfer1::ILayer *layer = TRT_ENGINE_ADD_LAYER(this, Constant, trt_in_shape, weight.get()); - this->SetITensor(name, layer->getOutput(0)); + if (!scalar) { + this->SetITensor(name, layer->getOutput(0)); + } return layer->getOutput(0); } diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index b0e300dca6..91876ab154 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -295,8 +295,9 @@ class TensorRTEngine { void DeleteITensor(const std::string& name, nvinfer1::ITensor* tensor); void SetITensor(const std::string& name, nvinfer1::ITensor* tensor); // Get an ITensor called name. - nvinfer1::ITensor* GetITensor(const std::string& name); - nvinfer1::ITensor* ConvertWeight2ITensor(const std::string& name); + nvinfer1::ITensor* GetITensor(const std::string& name, bool scalar = false); + nvinfer1::ITensor* ConvertWeight2ITensor(const std::string& name, + bool scalar = false); std::unordered_map* GetITensorMap(); nvinfer1::ICudaEngine* engine() { return infer_engine_.get(); } diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index d8801bd8f5..7344755790 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -337,6 +337,12 @@ struct SimpleOpTypeSetTeller : public Teller { } } + if (op_type == "range") { + if (!with_dynamic_shape) { + return false; + } + } + if (op_type == "sign") { #if IS_TRT_VERSION_GE(8200) if (!with_dynamic_shape) { @@ -2369,6 +2375,7 @@ struct SimpleOpTypeSetTeller : public Teller { "matmul", "matmul_v2", "bmm", + "range", "conv2d", "conv2d_fusion", "pool2d", @@ -2507,6 +2514,7 @@ struct SimpleOpTypeSetTeller : public Teller { "matmul", "matmul_v2", "bmm", + "range", "conv2d", "conv2d_fusion", "pool2d", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_range.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_range.py new file mode 100644 index 0000000000..42c00181f2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_range.py @@ -0,0 +1,230 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertRangeDynamicTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input(): + return np.array([1]).astype(np.int32) + + for in_dtype in [2]: + self.in_dtype = in_dtype + dics = [{}] + ops_config = [ + { + "op_type": "fill_constant", + "op_inputs": {}, + "op_outputs": {"Out": ["start_data"]}, + "op_attrs": { + "dtype": self.in_dtype, + "str_value": "7", + "shape": [1], + }, + }, + { + "op_type": "fill_constant", + "op_inputs": {}, + "op_outputs": {"Out": ["end_data"]}, + "op_attrs": { + "dtype": self.in_dtype, + "str_value": "256", + "shape": [1], + }, + }, + { + "op_type": "fill_constant", + "op_inputs": {}, + "op_outputs": {"Out": ["step_data"]}, + "op_attrs": { + "dtype": self.in_dtype, + "str_value": "1", + "shape": [1], + }, + }, + { + "op_type": "range", + "op_inputs": { + "Start": ["start_data"], + "End": ["end_data"], + "Step": ["step_data"], + }, + "op_outputs": {"Out": ["range_output_data1"]}, + "op_attrs": dics[0], + }, + { + "op_type": "cast", + "op_inputs": {"X": ["range_output_data1"]}, + "op_outputs": {"Out": ["range_output_data"]}, + "op_attrs": {"in_dtype": self.in_dtype, "out_dtype": 5}, + }, + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "step_data": TensorConfig(data_gen=partial(generate_input)), + }, + outputs=["range_output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = { + "start_data": [1], + "end_data": [1], + "step_data": [1], + } + self.dynamic_shape.max_input_shape = { + "start_data": [1], + "end_data": [1], + "step_data": [1], + } + self.dynamic_shape.opt_input_shape = { + "start_data": [1], + "end_data": [1], + "step_data": [1], + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-2 + + def test(self): + self.run_test() + + +class TrtConvertRangeStaticTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input(): + return np.array([0]).astype(np.int32) + + def generate_input1(): + return np.array([128]).astype(np.int32) + + def generate_input2(): + return np.array([1]).astype(np.int32) + + for in_dtype in [2, 5]: + self.in_dtype = in_dtype + dics = [{}] + ops_config = [ + { + "op_type": "range", + "op_inputs": { + "Start": ["start_data"], + "End": ["end_data"], + "Step": ["step_data"], + }, + "op_outputs": {"Out": ["range_output_data1"]}, + "op_attrs": dics[0], + }, + { + "op_type": "cast", + "op_inputs": {"X": ["range_output_data1"]}, + "op_outputs": {"Out": ["range_output_data"]}, + "op_attrs": {"in_dtype": self.in_dtype, "out_dtype": 5}, + }, + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "start_data": TensorConfig( + data_gen=partial(generate_input) + ), + "end_data": TensorConfig(data_gen=partial(generate_input1)), + "step_data": TensorConfig( + data_gen=partial(generate_input2) + ), + }, + outputs=["range_output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 0, 6 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-2 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab