From 2309aa585cd9a4d5f35a8ea936b388d9a58e8645 Mon Sep 17 00:00:00 2001 From: gaoziyuan <88373061+gzy19990617@users.noreply.github.com> Date: Wed, 12 Apr 2023 12:04:17 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=2078=E3=80=91=E4=B8=BAPaddle?= =?UTF-8?q?-TRT=E5=A2=9E=E5=8A=A0cumsum=E7=AE=97=E5=AD=90=20(#52518)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../inference/tensorrt/convert/cumsum_op.cc | 157 ++++++++++++++++ .../inference/tensorrt/convert/op_converter.h | 46 +++++ paddle/fluid/inference/tensorrt/op_teller.cc | 25 ++- .../ir/inference/test_trt_convert_cumsum.py | 176 ++++++++++++++++++ 6 files changed, 404 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/cumsum_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cumsum.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 38222b797f1..6523e5cfced 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2688,6 +2688,7 @@ USE_TRT_CONVERTER(expand_v2) USE_TRT_CONVERTER(take_along_axis) USE_TRT_CONVERTER(skip_groupnorm_act) USE_TRT_CONVERTER(preln_groupnorm_act) +USE_TRT_CONVERTER(cumsum) #if IS_TRT_VERSION_GE(8522) USE_TRT_CONVERTER(flash_multihead_matmul) USE_TRT_CONVERTER(cross_multihead_matmul) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index cbe26a3d31e..1793e120777 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -106,6 +106,7 @@ list( skip_groupnorm_act_op.cc preln_groupnorm_act_op.cc expand_v2_op.cc + cumsum_op.cc temporal_shift_op.cc) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) diff --git a/paddle/fluid/inference/tensorrt/convert/cumsum_op.cc b/paddle/fluid/inference/tensorrt/convert/cumsum_op.cc new file mode 100644 index 00000000000..a46bf1efa17 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/cumsum_op.cc @@ -0,0 +1,157 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Cumsum Op + */ +class CumsumOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { +#if IS_TRT_VERSION_GE(7220) + VLOG(3) << "convert a cumsum op to tensorrt layer"; + framework::OpDesc op_desc(op, nullptr); + std::string input_x_name = op_desc.Input("X").front(); + std::string output_name = op_desc.Output("Out").front(); + auto* input_x_tensor = engine_->GetITensor(input_x_name); + auto dims = input_x_tensor->getDimensions(); + auto rank = dims.nbDims; + int axis = 0; + if (op_desc.HasAttr("axis")) { + axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis")); + if (axis < 0) { + axis += rank; + } + } + + // getAxisLength default is a scalar + auto getAxisLength = + [&](nvinfer1::ITensor* inpTensor, int axis, bool scalar = true) { + auto dims = inpTensor->getDimensions(); + int d = dims.d[axis]; + if (d >= 0) { + return Add1DConstantLayer(d, "", scalar); + } else { + nvinfer1::ITensor* inpShape = Shape(inpTensor); + return GetEleTensorOfShape(inpShape, d, scalar); + } + }; + + // Create "inputSliced" tensor that is sliced on dimension[axis] to length 1 + nvinfer1::Dims start; + start.nbDims = rank; + std::vector start_vec(rank, 0); + std::fill(start.d, start.d + rank, 0); + + nvinfer1::Dims size; + size.nbDims = rank; + nvinfer1::Dims stride; + stride.nbDims = rank; + auto axisLength = getAxisLength(input_x_tensor, axis, false); + + auto starts_tensor = + Add1DConstantLayer(start_vec, output_name + "_start_tensor_"); + auto sizes_tensor = axis == 0 ? Add1DConstantLayer(1) + : getAxisLength(input_x_tensor, 0, false); + auto strides_tensor = axis == 0 ? axisLength : Add1DConstantLayer(1); + + for (int i = 1; i < rank; i++) { + if (i == axis) { + std::vector strides_itensors = {strides_tensor, + axisLength}; + strides_tensor = Concat(strides_itensors); + std::vector sizes_itensors = { + sizes_tensor, Add1DConstantLayer(1)}; + sizes_tensor = Concat(sizes_itensors); + } else { + auto currLength = getAxisLength(input_x_tensor, i, false); + std::vector strides_itensors = { + strides_tensor, Add1DConstantLayer(1)}; + strides_tensor = Concat(strides_itensors); + std::vector sizes_itensors = {sizes_tensor, + currLength}; + sizes_tensor = Concat(sizes_itensors); + } + } + + auto inputSliced = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *input_x_tensor, start, size, stride); + inputSliced->setInput(1, *starts_tensor); + inputSliced->setInput(2, *sizes_tensor); + inputSliced->setInput(3, *strides_tensor); + auto inputSliced_output = inputSliced->getOutput(0); + + // Scan through each slice across axis and add it to the running sum + auto loop = TRT_ENGINE_ADD_LAYER(engine_, Loop); + nvinfer1::ITensor* tripLimit = getAxisLength(input_x_tensor, axis); + loop->addTripLimit(*tripLimit, nvinfer1::TripLimit::kCOUNT); + auto iterator = loop->addIterator(*input_x_tensor, axis); + auto data = iterator->getOutput(0); + + // Squeeze inputSliced down to same shape as `data` + auto sliced_dims = inputSliced_output->getDimensions(); + std::vector subscripts(sliced_dims.nbDims); + std::iota(subscripts.begin(), subscripts.end(), 0); + auto p = std::remove_if(subscripts.begin(), + subscripts.end(), + [axis](int x) { return x == axis; }); + subscripts.resize(p - subscripts.begin()); + auto newDims = Gather(Shape(inputSliced_output), subscripts); + inputSliced_output = Reshape(inputSliced_output, newDims); + + // creat ZeroTensor + std::vector zero_vec{0.f}; + auto zero = Add1DConstantLayer(zero_vec); + auto cast = TRT_ENGINE_ADD_LAYER(engine_, Identity, *zero); + cast->setOutputType(0, inputSliced_output->getType()); + + zero = TRT_ENGINE_ADD_LAYER( + engine_, + ElementWise, + *inputSliced_output, + *BroadcastTensors(cast->getOutput(0), inputSliced_output), + nvinfer1::ElementWiseOperation::kPROD) + ->getOutput(0); + + auto runningSum = loop->addRecurrence(*zero); + auto runningSumTensor = runningSum->getOutput(0); + auto curSum = TRT_ENGINE_ADD_LAYER(engine_, + ElementWise, + *data, + *runningSumTensor, + nvinfer1::ElementWiseOperation::kSUM); + runningSum->setInput(1, *curSum->getOutput(0)); + auto reverseFlag = nvinfer1::LoopOutput::kCONCATENATE; + nvinfer1::ILoopOutputLayer* loopOut = + loop->addLoopOutput(*curSum->getOutput(0), reverseFlag, axis); + loopOut->setInput(1, *tripLimit); + RreplenishLayerAndOutput(loopOut, "cumsum", {output_name}, test_mode); +#else + VLOG(3) << "Cumsum is not supported when TensorRT < 7.2.2"; +#endif + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(cumsum, CumsumOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index db19e5c45d3..e2dfe4d5ba3 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -416,6 +416,52 @@ class OpConverter { return TRT_ENGINE_ADD_LAYER(engine_, Shape, *input)->getOutput(0); } + nvinfer1::ITensor* Reshape(nvinfer1::ITensor* input, + nvinfer1::ITensor* newShape) { + nvinfer1::ITensor* oldShape = Shape(input); + if (oldShape == newShape) { + return input; + } + auto* shuffle = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + shuffle->setInput(1, *newShape); + return shuffle->getOutput(0); + } + + nvinfer1::ITensor* BroadcastTensor(nvinfer1::ITensor* input, + const int nbDims) { + auto oldShape = Shape(input); + auto oldShapeDims = oldShape->getDimensions(); + const int rank = oldShapeDims.nbDims; + if (rank > nbDims) { + PADDLE_THROW(platform::errors::InvalidArgument( + "Cannot broadcast a higher rank tensor to a lower rank tensor.")); + } + if (rank < nbDims) { + nvinfer1::ITensor* concat_shape_tensor; + auto* one_rank_tensor = + Add1DConstantLayer(std::vector(nbDims - rank, 1)); + std::vector itensors; + itensors.push_back(one_rank_tensor); + itensors.push_back(oldShape); + concat_shape_tensor = Concat(itensors); + input = Reshape(input, concat_shape_tensor); + } + return input; + } + + nvinfer1::ITensor* BroadcastTensors(nvinfer1::ITensor* a, + nvinfer1::ITensor* b) { + const int aDims = a->getDimensions().nbDims; + const int bDims = b->getDimensions().nbDims; + if (aDims == bDims) { + VLOG(3) << "Broadcast two equal rank tensors"; + } + if (aDims > bDims) { + return BroadcastTensor(b, aDims); + } + return BroadcastTensor(a, bDims); + } + // Concat not make rank changed nvinfer1::ITensor* Concat(const std::vector& inputs, int axis = 0) { diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 24dca82d3fb..85f5c003746 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2705,6 +2705,25 @@ struct SimpleOpTypeSetTeller : public Teller { #endif } + if (op_type == "cumsum") { +#if !IS_TRT_VERSION_GE(7220) + VLOG(3) << "cumsum is not supported when TensorRT < 7.2.2"; + return false; +#endif + if (!with_dynamic_shape) { + VLOG(3) << "the cumsum does not support " + "static shape yet"; + return false; + } + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } + } + if (op_type == "temporal_shift") { #if !IS_TRT_VERSION_GE(8200) VLOG(3) << "temporal_shift is not supported when TensorRT < 8.2"; @@ -2906,7 +2925,8 @@ struct SimpleOpTypeSetTeller : public Teller { "skip_groupnorm_act", "preln_groupnorm_act", "temporal_shift", - "grid_sampler"}; + "grid_sampler", + "cumsum"}; std::unordered_set teller_set{ "mul", @@ -3064,7 +3084,8 @@ struct SimpleOpTypeSetTeller : public Teller { "skip_groupnorm_act", "preln_groupnorm_act", "temporal_shift", - "grid_sampler"}; + "grid_sampler", + "cumsum"}; }; struct GenericPluginTeller : public Teller { diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cumsum.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cumsum.py new file mode 100644 index 00000000000..60dbfa37aab --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cumsum.py @@ -0,0 +1,176 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertCumsum(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 7220: + return False + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_input1(): + if self.dims == 2: + self.input_shape = [2, 3] + return np.random.random([2, 3]).astype(np.int32) + elif self.dims == 3: + self.input_shape = [2, 3, 4] + return np.random.random([2, 3, 4]).astype(np.int64) + elif self.dims == 4: + self.input_shape = [4, 3, 32, 32] + return np.random.random([4, 3, 32, 32]).astype(np.float32) - 0.5 + + for dims in [2, 3, 4]: + for axis in range(-1, dims): + for type in ["int32", "int64", "float32", "float64"]: + self.dims = dims + ops_config = [ + { + "op_type": "cumsum", + "op_inputs": { + "X": ["input_data"], + }, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": {"axis": axis, "dtype": type}, + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1) + ), + }, + outputs=["output_data"], + ) + + yield program_config + + # no op_attrs + for dims in [2, 3, 4]: + self.dims = dims + ops_config = [ + { + "op_type": "cumsum", + "op_inputs": { + "X": ["input_data"], + }, + "op_outputs": {"Out": ["output_data"]}, + "op_attrs": {}, + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1) + ), + }, + outputs=["output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(): + + if self.dims == 2: + self.dynamic_shape.min_input_shape = { + "input_data": [2, 3], + } + self.dynamic_shape.max_input_shape = { + "input_data": [2, 3], + } + self.dynamic_shape.opt_input_shape = { + "input_data": [2, 3], + } + + elif self.dims == 3: + self.dynamic_shape.min_input_shape = { + "input_data": [2, 3, 4], + } + self.dynamic_shape.max_input_shape = { + "input_data": [2, 3, 4], + } + self.dynamic_shape.opt_input_shape = { + "input_data": [2, 3, 4], + } + + elif self.dims == 4: + self.dynamic_shape.min_input_shape = { + "input_data": [4, 3, 32, 32], + } + self.dynamic_shape.max_input_shape = { + "input_data": [4, 3, 32, 32], + } + self.dynamic_shape.opt_input_shape = { + "input_data": [4, 3, 32, 32], + } + + def generate_trt_nodes_num(attrs, dynamic_shape): + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 7220: + return 0, 3 + return 1, 2 + + def clear_dynamic_shape(): + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + + # for dynamic_shape + generate_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-2 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab