From 050aa6fe5a524b0e7b85201c54a0da315701518d Mon Sep 17 00:00:00 2001 From: heliqi Date: Fri, 14 Jan 2022 16:50:56 +0800 Subject: [PATCH] add flatten_contiguous_range OpConvert for Paddle-TRT (#38922) * add trt_convert_flatten_contiguous_rang op * trt version >7,support trt_convert_flatten_contiguous_rang * trt version >7,support trt_convert_flatten_contiguous_rang * trt version >7,support trt_convert_flatten_contiguous_rang * test cast add trt version >=7 skip --- .../ir_passes/tensorrt_subgraph_pass.cc | 7 +- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 2 +- .../convert/flatten_contiguous_range_op.cc | 136 ++++++++++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 32 +++++ ...st_trt_convert_flatten_contiguous_range.py | 115 +++++++++++++++ 6 files changed, 290 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/flatten_contiguous_range_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_flatten_contiguous_range.py diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index ef50df3084f..55bbc554508 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -46,8 +46,11 @@ void analysis::TensorRtSubgraphPass::ApplyImpl( << " is diabled by config in TensorRT"; return false; } - return tensorrt::OpTeller::Global().Tell(node, no_calib_int8, - with_dynamic_shape); + bool is_ok = tensorrt::OpTeller::Global().Tell(node, no_calib_int8, + with_dynamic_shape); + if (!is_ok) + VLOG(3) << node->Op()->Type().c_str() << " op is not in TensorRT"; + return is_ok; }; framework::ir::SubGraphFuser fuser( diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 2799fb9e174..d4b680288e3 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1416,6 +1416,7 @@ USE_TRT_CONVERTER(elementwise_min_tensor); USE_TRT_CONVERTER(elementwise_pow_tensor); USE_TRT_CONVERTER(transpose); USE_TRT_CONVERTER(flatten); +USE_TRT_CONVERTER(flatten_contiguous_range); USE_TRT_CONVERTER(matmul); USE_TRT_CONVERTER(conv2d); USE_TRT_CONVERTER(relu); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index a885b69fa7f..017caca6adc 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -3,7 +3,7 @@ nv_library(tensorrt_converter SRCS matmul_op.cc conv2d_op.cc fc_op.cc pool2d_op.cc elementwise_op.cc batch_norm_op.cc activation_op.cc softmax_op.cc concat_op.cc dropout_op.cc group_norm_op.cc pad_op.cc split_op.cc prelu_op.cc leaky_relu_op.cc gelu_op.cc layer_norm_op.cc multihead_matmul_op.cc - shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc + shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc flatten_contiguous_range_op.cc emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc gather_op.cc anchor_generator_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/flatten_contiguous_range_op.cc b/paddle/fluid/inference/tensorrt/convert/flatten_contiguous_range_op.cc new file mode 100644 index 00000000000..706814340a0 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/flatten_contiguous_range_op.cc @@ -0,0 +1,136 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { +/* + * flatten_contiguous_range trt converter + */ +class FlattenContiguousRangeOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + int dims = input->getDimensions().nbDims; + int start_axis = BOOST_GET_CONST(int, op_desc.GetAttr("start_axis")); + int stop_axis = BOOST_GET_CONST(int, op_desc.GetAttr("stop_axis")); + + nvinfer1::IShuffleLayer* layer = nullptr; + if (!engine_->with_dynamic_shape()) { + if (start_axis < 0) start_axis += dims + 1; + if (stop_axis < 0) stop_axis += dims + 1; + int dim_prod = 1; + nvinfer1::Dims flatten_dim; + flatten_dim.nbDims = dims - (stop_axis - start_axis); + for (int i = 0, j = 0; i < dims; ++i) { + if (start_axis <= i + 1 && i + 1 <= stop_axis) { + int dim_i = input->getDimensions().d[i]; + PADDLE_ENFORCE_GT(dim_i, 0, platform::errors::InvalidArgument( + "flatten_contiguous_range input dim " + "should be > 0, but got %d.", + dim_i)); + dim_prod *= dim_i; + if (i + 1 == stop_axis) { + flatten_dim.d[j++] = dim_prod; + } + } else { + flatten_dim.d[j++] = input->getDimensions().d[i]; + } + } + layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + layer->setReshapeDimensions(flatten_dim); + } else { + if (start_axis < 0) start_axis += dims; + if (stop_axis < 0) stop_axis += dims; + auto* shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input); + auto* shape_layer_itensor = shape_layer->getOutput(0); + + nvinfer1::Dims start_dim, size_dim, stride_dim; + start_dim.nbDims = 1; + size_dim.nbDims = 1; + stride_dim.nbDims = 1; + start_dim.d[0] = start_axis; + size_dim.d[0] = stop_axis - start_axis + 1; + stride_dim.d[0] = 1; + auto* slice_layer = + TRT_ENGINE_ADD_LAYER(engine_, Slice, *shape_layer_itensor, start_dim, + size_dim, stride_dim); + uint32_t reduce_dim = 1; + auto* reduce_prod_layer = TRT_ENGINE_ADD_LAYER( + engine_, Reduce, *(slice_layer->getOutput(0)), + nvinfer1::ReduceOperation::kPROD, reduce_dim, true); + + nvinfer1::ITensor* input_shape = nullptr; + if (start_axis == 0 && stop_axis == dims - 1) { + input_shape = reduce_prod_layer->getOutput(0); + } else { + std::vector itensors; + if (start_axis > 0) { + nvinfer1::Dims left_start_dim, left_size_dim, left_stride_dim; + left_start_dim.nbDims = 1; + left_size_dim.nbDims = 1; + left_stride_dim.nbDims = 1; + left_start_dim.d[0] = 0; + left_size_dim.d[0] = start_axis; + left_stride_dim.d[0] = 1; + auto* slice_layer_left = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *shape_layer_itensor, left_start_dim, + left_size_dim, left_stride_dim); + itensors.push_back(slice_layer_left->getOutput(0)); + } + itensors.push_back(reduce_prod_layer->getOutput(0)); + if (stop_axis < dims - 1) { + nvinfer1::Dims right_start_dim, right_size_dim, right_stride_dim; + right_start_dim.nbDims = 1; + right_size_dim.nbDims = 1; + right_stride_dim.nbDims = 1; + right_start_dim.d[0] = stop_axis + 1; + right_size_dim.d[0] = dims - stop_axis - 1; + right_stride_dim.d[0] = 1; + auto* slice_layer_right = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *shape_layer_itensor, right_start_dim, + right_size_dim, right_stride_dim); + itensors.push_back(slice_layer_right->getOutput(0)); + } + auto* concat_layer = TRT_ENGINE_ADD_LAYER( + engine_, Concatenation, itensors.data(), itensors.size()); + concat_layer->setAxis(0); + input_shape = concat_layer->getOutput(0); + } + layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + layer->setInput(1, *input_shape); + } + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "flatten_contiguous_range", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(flatten_contiguous_range, + FlattenContiguousRangeOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ddee4e0d682..6663103d4ca 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -55,6 +55,7 @@ struct SimpleOpTypeSetTeller : public Teller { // #endif #if IS_TRT_VERSION_GE(7000) teller_set.insert("tile"); + teller_set.insert("flatten_contiguous_range"); #endif #if CUDA_VERSION >= 10020 teller_set.insert("reshape"); @@ -531,6 +532,37 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, if (axis != 1) return false; } } + if (op_type == "flatten_contiguous_range") { + if (!with_dynamic_shape) { + int start_axis = BOOST_GET_CONST(int, desc.GetAttr("start_axis")); + int stop_axis = BOOST_GET_CONST(int, desc.GetAttr("stop_axis")); + auto x_var_name = desc.Input("X")[0]; + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } + auto* x_var_desc = block->FindVar(x_var_name); + const auto x_shape = x_var_desc->GetShape(); + int dims = x_shape.size(); + if (start_axis < 0) start_axis += dims; + if (start_axis == 0) { + VLOG(3) << "TRT flatten_contiguous_range not support the " + "batch-dimension being changed"; + return false; + } + if (stop_axis < 0) stop_axis += dims; + for (int i = start_axis; i <= stop_axis; ++i) { + if (x_shape[i] < 0) { + VLOG(3) << "On TRT static shape,flatten_contiguous_range input dim " + "should be > 0"; + return false; + } + } + } + } if (op_type == "gather") { auto gather_inputs = desc.Inputs(); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_flatten_contiguous_range.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_flatten_contiguous_range.py new file mode 100644 index 00000000000..a4060349d4b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_flatten_contiguous_range.py @@ -0,0 +1,115 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import unittest +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertFlattenContiguousRangeTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input(batch): + return np.random.random([2, batch, 4, 8, 3]).astype(np.float32) + + for batch in [1, 2, 4]: + for start_axis in range(5): + for stop_axis in range(start_axis, 5): + type = "flatten_contiguous_range" + op_outputs = { + "Out": ["output_data"], + "XShape": ["xshape_data"] + } + ops_config = [{ + "op_type": type, + "op_inputs": { + "X": ["input_data"] + }, + "op_outputs": op_outputs, + "op_attrs": { + "start_axis": start_axis, + "stop_axis": stop_axis, + } + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input, batch)) + }, + outputs=["output_data"]) + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = {"input_data": [2, 1, 4, 8, 3]} + self.dynamic_shape.max_input_shape = {"input_data": [2, 4, 4, 8, 3]} + self.dynamic_shape.opt_input_shape = {"input_data": [2, 2, 4, 8, 3]} + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[0] * 10 >= 7000: + if dynamic_shape: + return 1, 2 + else: + if attrs[0]['start_axis'] == 0: + return 0, 3 + else: + return 1, 2 + else: + return 0, 3 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab