From 44973c65c3f5748776d67dd862419a36feff2d1a Mon Sep 17 00:00:00 2001 From: Ryan <44900829+DrRyanHuang@users.noreply.github.com> Date: Tue, 20 Dec 2022 10:47:48 +0800 Subject: [PATCH] [Paddle Inference] Add add arg_min trt converter (#49113) --- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../inference/tensorrt/convert/arg_min_op.cc | 77 ++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 17 +++ .../ir/inference/test_trt_convert_arg_min.py | 144 ++++++++++++++++++ 5 files changed, 240 insertions(+) create mode 100644 paddle/fluid/inference/tensorrt/convert/arg_min_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_min.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 8461224ec5..5dced520f3 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2334,6 +2334,7 @@ USE_TRT_CONVERTER(anchor_generator); USE_TRT_CONVERTER(yolo_box); USE_TRT_CONVERTER(yolo_box_head); USE_TRT_CONVERTER(arg_max); +USE_TRT_CONVERTER(arg_min); USE_TRT_CONVERTER(roi_align); USE_TRT_CONVERTER(affine_channel); USE_TRT_CONVERTER(multiclass_nms); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index e23ea7e7b1..617898a4e5 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -47,6 +47,7 @@ list( yolo_box_op.cc yolo_box_head_op.cc arg_max_op.cc + arg_min_op.cc roi_align_op.cc affine_channel_op.cc multiclass_nms_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/arg_min_op.cc b/paddle/fluid/inference/tensorrt/convert/arg_min_op.cc new file mode 100644 index 0000000000..81c9d058f8 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/arg_min_op.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class ArgMinOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(3) << "convert a fluid arg_min op to tensorrt topk layer"; + framework::OpDesc op_desc(op, nullptr); + // Declare inputs + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto input_dims = input->getDimensions(); + int rank = input_dims.nbDims; + int axis = op_desc.HasAttr("axis") + ? PADDLE_GET_CONST(int64_t, op_desc.GetAttr("axis")) + : -1; + if (axis > 0 && !engine_->with_dynamic_shape()) { + axis -= 1; + } + if (axis < 0) axis += rank; + auto* topk_layer = TRT_ENGINE_ADD_LAYER( + engine_, TopK, *input, nvinfer1::TopKOperation::kMIN, 1, 1 << axis); + + auto output_name = op_desc.Output("Out")[0]; + bool keepdims = PADDLE_GET_CONST(bool, op_desc.GetAttr("keepdims")); + if (keepdims) { + RreplenishLayerAndOutput(topk_layer, + "arg_min", + {output_name + "_value", output_name}, + test_mode); + } else { + auto squeeze_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *topk_layer->getOutput(1)); + auto dims = input_dims; + dims.nbDims -= 1; + for (int i = axis; i < dims.nbDims; i++) { + dims.d[i] = dims.d[i + 1]; + } + squeeze_layer->setReshapeDimensions(dims); + RreplenishLayerAndOutput( + squeeze_layer, "arg_min", {output_name}, test_mode); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(arg_min, ArgMinOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index fea54fea3f..c469c4fbf3 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -703,6 +703,21 @@ struct SimpleOpTypeSetTeller : public Teller { if (axis == 0 || flatten || dtype != 2) return false; } + if (op_type == "arg_min") { + if (!desc.HasAttr("axis", /*with_attr_var=*/false)) { + VLOG(3) << "Skip to convert into TRT while found Attribute('axis') is " + "Variable type in arg_min."; + return false; + } + + int axis = desc.HasAttr("axis") + ? PADDLE_GET_CONST(int64_t, desc.GetAttr("axis")) + : -1; + bool flatten = PADDLE_GET_CONST(bool, desc.GetAttr("flatten")); + int dtype = PADDLE_GET_CONST(int, desc.GetAttr("dtype")); + if (axis == 0 || flatten || dtype != 2) return false; + } + if (op_type == "affine_channel") { if (!desc.HasAttr("data_layout")) return false; auto data_layout = phi::StringToDataLayout( @@ -2524,6 +2539,7 @@ struct SimpleOpTypeSetTeller : public Teller { "yolo_box", "yolo_box_head", "arg_max", + "arg_min", "roi_align", "affine_channel", "nearest_interp", @@ -2669,6 +2685,7 @@ struct SimpleOpTypeSetTeller : public Teller { "yolo_box", "yolo_box_head", "arg_max", + "arg_min", "roi_align", "affine_channel", "nearest_interp", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_min.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_min.py new file mode 100644 index 0000000000..9867a9e615 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_arg_min.py @@ -0,0 +1,144 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List, Tuple + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertArgMinTest(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + input_shape = program_config.inputs["arg_min_input"].shape + axis = program_config.ops[0].attrs["axis"] + if axis < 0: + axis += len(input_shape) + if len(input_shape) <= axis or axis == 0: + return False + return True + + def sample_program_configs(self): + def generate_input(rank, batch): + dims = [batch] + for i in range(rank - 1): + dims.append((i + 1) * 8) + size = np.prod(dims) + return (np.arange(size) % 10 - 5).astype("float32").reshape(dims) + + for rank in [3, 4]: + for batch in [1, 4]: + for axis in [-1, 0, 1, 2, 3]: + for keepdims in [True, False]: + self.rank = rank + flatten = False + dtype = 2 + ops_config = [ + { + "op_type": "arg_min", + "op_inputs": {"X": ["arg_min_input"]}, + "op_outputs": {"Out": ["arg_min_out"]}, + "op_attrs": { + "axis": axis, + "keepdims": keepdims, + "flatten": flatten, + "dtype": dtype, + }, + } + ] + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "arg_min_input": TensorConfig( + data_gen=partial( + generate_input, rank, batch + ) + ) + }, + outputs=["arg_min_out"], + ) + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> Tuple[paddle_infer.Config, List[int], float]: + def generate_dynamic_shape(attrs): + if self.rank == 3: + self.dynamic_shape.min_input_shape = { + "arg_min_input": [1, 8, 16] + } + self.dynamic_shape.max_input_shape = { + "arg_min_input": [4, 8, 16] + } + self.dynamic_shape.opt_input_shape = { + "arg_min_input": [3, 8, 16] + } + else: + self.dynamic_shape.min_input_shape = { + "arg_min_input": [1, 8, 16, 24] + } + self.dynamic_shape.max_input_shape = { + "arg_min_input": [4, 8, 16, 24] + } + self.dynamic_shape.opt_input_shape = { + "arg_min_input": [1, 8, 16, 24] + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + self.trt_param.workspace_size = 1024000 + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-3 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-3 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab