From d17e39c20721d9fdb54ceab708524d4f5c559db0 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Wed, 20 Apr 2022 09:57:29 +0800 Subject: [PATCH] strided_slice (#41573) (#41914) --- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../tensorrt/convert/strided_slice_op.cc | 131 ++++++++++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 2 + .../test_trt_convert_strided_slice.py | 120 ++++++++++++++++ 5 files changed, 255 insertions(+) create mode 100644 paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 19de09ab152..f5ad3df909c 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1761,6 +1761,7 @@ USE_TRT_CONVERTER(deformable_conv); USE_TRT_CONVERTER(pool3d) USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm) USE_TRT_CONVERTER(preln_skip_layernorm) +USE_TRT_CONVERTER(strided_slice) #endif namespace paddle_infer { diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 4f8aa4c14cd..f1800afcb1d 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -23,6 +23,7 @@ nv_library(tensorrt_converter pool3d_op.cc deformable_conv_op.cc preln_emb_eltwise_layernorm.cc + strided_slice_op.cc preln_skip_layernorm.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) diff --git a/paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc b/paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc new file mode 100644 index 00000000000..26046d38bcb --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/strided_slice_op.cc @@ -0,0 +1,131 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Stack converter from fluid to tensorRT. + */ +class StridedSliceOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(4) << "convert fluid StridedSlice op to tensorrt Slice layer"; + + framework::OpDesc op_desc(op, nullptr); + auto* input = engine_->GetITensor(op_desc.Input("Input")[0]); + nvinfer1::Dims input_dims = input->getDimensions(); + + std::vector axes = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("axes")); + std::vector starts = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("starts")); + std::vector ends = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("ends")); + std::vector strides = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("strides")); + + nvinfer1::Dims start; + start.nbDims = input_dims.nbDims; + int axes_size = axes.size(); + for (int i = 0; i < start.nbDims; i++) { + start.d[i] = 0; + } + for (int i = 0; i < axes_size; i++) { + start.d[axes[i]] = starts[i]; + } + + nvinfer1::Dims stride; + stride.nbDims = input_dims.nbDims; + for (int i = 0; i < stride.nbDims; i++) { + stride.d[i] = 1; + } + for (int i = 0; i < axes_size; i++) { + stride.d[axes[i]] = strides[i]; + } + + nvinfer1::Dims size; + size.nbDims = input_dims.nbDims; + for (int i = 0; i < size.nbDims; i++) { + size.d[i] = 1; + } + + auto output_name = op_desc.Output("Out")[0]; + + auto create_weights = [&](const std::vector& data, + const std::string& type) -> int* { + std::unique_ptr tmp_tensor(new framework::Tensor()); + int data_size = data.size(); + tmp_tensor->Resize({data_size}); + auto* tmp_data = tmp_tensor->mutable_data(platform::CPUPlace()); + for (int i = 0; i < data_size; i++) { + tmp_data[i] = data[i]; + } + + engine_->SetWeights(output_name + "_add_slice_op_" + type, + std::move(tmp_tensor)); + return tmp_data; + }; + + std::vector const_weight(input_dims.nbDims, 1); + for (int i = 0; i < axes_size; i++) { + const_weight[axes[i]] = strides[i]; + } + + int* weight_data = create_weights(const_weight, "size"); + + TensorRTEngine::Weight weight{nvinfer1::DataType::kINT32, + static_cast(weight_data), + static_cast(input_dims.nbDims)}; + + int input_dim_size = input_dims.nbDims; + nvinfer1::Dims input_shape; + input_shape.nbDims = 1; + input_shape.d[0] = input_dim_size; + + auto const_layer = + TRT_ENGINE_ADD_LAYER(engine_, Constant, input_shape, weight.get()); + + auto shape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shape, *input); + + auto size_layer = TRT_ENGINE_ADD_LAYER( + engine_, ElementWise, *shape_layer->getOutput(0), + *const_layer->getOutput(0), nvinfer1::ElementWiseOperation::kDIV); + + auto* layer = + TRT_ENGINE_ADD_LAYER(engine_, Slice, *input, start, size, stride); + layer->setInput(2, *size_layer->getOutput(0)); + + RreplenishLayerAndOutput(layer, "strided_slice", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(strided_slice, StridedSliceOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 6ccaf80c9f0..d9a874dd2b6 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -117,6 +117,7 @@ struct SimpleOpTypeSetTeller : public Teller { "multihead_matmul", "skip_layernorm", "slice", + "strided_slice", "fused_preln_embedding_eltwise_layernorm", "preln_skip_layernorm"}; std::unordered_set teller_set{ @@ -178,6 +179,7 @@ struct SimpleOpTypeSetTeller : public Teller { "multihead_matmul", "skip_layernorm", "slice", + "strided_slice", "fused_preln_embedding_eltwise_layernorm", "preln_skip_layernorm", "multiclass_nms3"}; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py new file mode 100644 index 00000000000..04eb3ab10ba --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_strided_slice.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest + + +class TrtConvertStridedSliceTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + inputs = program_config.inputs + weights = program_config.weights + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + return True + + def sample_program_configs(self): + def generate_input1(attrs: List[Dict[str, Any]]): + return np.ones([1, 56, 56, 192]).astype(np.float32) + + for axes in [[1, 2]]: + for starts in [[1, 1]]: + for ends in [[10000000, 10000000]]: + for decrease_axis in [[]]: + for infer_flags in [[1, 1]]: + for strides in [[2, 2]]: + dics = [{ + "axes": axes, + "starts": starts, + "ends": ends, + "decrease_axis": decrease_axis, + "infer_flags": infer_flags, + "strides": strides + }] + + ops_config = [{ + "op_type": "strided_slice", + "op_inputs": { + "Input": ["input_data"] + }, + "op_outputs": { + "Out": ["slice_output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1, + dics)) + }, + outputs=["slice_output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = { + "input_data": [1, 56, 56, 192] + } + self.dynamic_shape.max_input_shape = { + "input_data": [8, 56, 56, 192] + } + self.dynamic_shape.opt_input_shape = { + "input_data": [4, 56, 56, 192] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + inputs = program_config.inputs + + if dynamic_shape: + for i in range(len(attrs[0]["starts"])): + if attrs[0]["starts"][i] < 0 or attrs[0]["ends"][i] < 0: + return 0, 3 + if not dynamic_shape: + for x in attrs[0]["axes"]: + if x == 0: + return 0, 3 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + + def test(self): + self.run_test() -- GitLab