diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 8f2b217a2fde0a05ccb5e09d867d0dc9a892511b..0007582e2c73d2320ebd860207a7fe9890c40429 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1192,7 +1192,7 @@ USE_TRT_CONVERTER(scale); USE_TRT_CONVERTER(stack); USE_TRT_CONVERTER(clip); USE_TRT_CONVERTER(gather); - +USE_TRT_CONVERTER(multiclass_nms); USE_TRT_CONVERTER(nearest_interp); #endif diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index b0d0229ec0531f9b907fc449b62b71752b4b17a5..be7fa0548d9f34c0272f4107abd93cc20ea659b9 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -6,7 +6,7 @@ nv_library(tensorrt_converter shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc gather_op.cc - + multiclass_nms_op.cc nearest_interp_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) diff --git a/paddle/fluid/inference/tensorrt/convert/multiclass_nms_op.cc b/paddle/fluid/inference/tensorrt/convert/multiclass_nms_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b0d67a5bf90ca9fcad742367a4c1a3c2c3eb0ee2 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/multiclass_nms_op.cc @@ -0,0 +1,133 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class MultiClassNMSOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid multiclassNMS op to tensorrt plugin"; + + // for now, only work for static shape and regular tensor + framework::OpDesc op_desc(op, nullptr); + + std::string bboxes = op_desc.Input("BBoxes").front(); + std::string scores = op_desc.Input("Scores").front(); + std::string output_name = op_desc.Output("Out").front(); + + auto* bboxes_tensor = engine_->GetITensor(bboxes); + auto* scores_tensor = engine_->GetITensor(scores); + + int background_label = + BOOST_GET_CONST(int, op_desc.GetAttr("background_label")); + float score_threshold = + BOOST_GET_CONST(float, op_desc.GetAttr("score_threshold")); + int nms_top_k = BOOST_GET_CONST(int, op_desc.GetAttr("nms_top_k")); + float nms_threshold = + BOOST_GET_CONST(float, op_desc.GetAttr("nms_threshold")); + int keep_top_k = BOOST_GET_CONST(int, op_desc.GetAttr("keep_top_k")); + bool normalized = BOOST_GET_CONST(bool, op_desc.GetAttr("normalized")); + int num_classes = scores_tensor->getDimensions().d[0]; + + auto bboxes_dims = bboxes_tensor->getDimensions(); + nvinfer1::Dims3 bboxes_expand_dims(bboxes_dims.d[0], 1, bboxes_dims.d[1]); + auto* bboxes_expand_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *bboxes_tensor); + bboxes_expand_layer->setReshapeDimensions(bboxes_expand_dims); + + nvinfer1::Permutation permutation{1, 0}; + auto* scores_transpose_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *scores_tensor); + scores_transpose_layer->setFirstTranspose(permutation); + + std::vector batch_nms_inputs; + batch_nms_inputs.push_back(bboxes_expand_layer->getOutput(0)); + batch_nms_inputs.push_back(scores_transpose_layer->getOutput(0)); + + constexpr bool shareLocation = true; + constexpr bool clip_boxes = false; + + const std::vector fields{ + {"shareLocation", &shareLocation, nvinfer1::PluginFieldType::kINT32, 1}, + {"backgroundLabelId", &background_label, + nvinfer1::PluginFieldType::kINT32, 1}, + {"numClasses", &num_classes, nvinfer1::PluginFieldType::kINT32, 1}, + {"topK", &nms_top_k, nvinfer1::PluginFieldType::kINT32, 1}, + {"keepTopK", &keep_top_k, nvinfer1::PluginFieldType::kINT32, 1}, + {"scoreThreshold", &score_threshold, + nvinfer1::PluginFieldType::kFLOAT32, 1}, + {"iouThreshold", &nms_threshold, nvinfer1::PluginFieldType::kFLOAT32, + 1}, + {"isNormalized", &normalized, nvinfer1::PluginFieldType::kINT32, 1}, + {"clipBoxes", &clip_boxes, nvinfer1::PluginFieldType::kINT32, 1}, + }; + + nvinfer1::PluginFieldCollection* plugin_collections = + static_cast( + malloc(sizeof(*plugin_collections) + + fields.size() * sizeof(nvinfer1::PluginField))); + plugin_collections->nbFields = static_cast(fields.size()); + plugin_collections->fields = fields.data(); + + auto creator = GetPluginRegistry()->getPluginCreator("BatchedNMS_TRT", "1"); + auto batch_nms_plugin = + creator->createPlugin("BatchNMSPlugin", plugin_collections); + free(plugin_collections); + + auto batch_nms_layer = engine_->network()->addPluginV2( + batch_nms_inputs.data(), batch_nms_inputs.size(), *batch_nms_plugin); + auto nmsed_boxes = batch_nms_layer->getOutput(1); + auto nmsed_scores = batch_nms_layer->getOutput(2); + auto nmsed_classes = batch_nms_layer->getOutput(3); + + auto nmsed_scores_transpose_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *nmsed_scores); + nmsed_scores_transpose_layer->setReshapeDimensions( + nvinfer1::Dims2(keep_top_k, 1)); + auto nmsed_classes_reshape_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *nmsed_classes); + nmsed_classes_reshape_layer->setReshapeDimensions( + nvinfer1::Dims2(keep_top_k, 1)); + + std::vector concat_inputs; + concat_inputs.push_back(nmsed_classes_reshape_layer->getOutput(0)); + concat_inputs.push_back(nmsed_scores_transpose_layer->getOutput(0)); + concat_inputs.push_back(nmsed_boxes); + + auto nms_concat_layer = TRT_ENGINE_ADD_LAYER( + engine_, Concatenation, concat_inputs.data(), concat_inputs.size()); + nms_concat_layer->setAxis(1); + + RreplenishLayerAndOutput(nms_concat_layer, "multiclass_nms", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(multiclass_nms, MultiClassNMSOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 11752d71a45e1b8545b727eb63d48fdca6d157a1..82f58254fe8e0dc4cb462ba392ec55b9f033017e 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -111,7 +111,7 @@ struct SimpleOpTypeSetTeller : public Teller { "flatten2", "flatten", "gather", - + "multiclass_nms", "nearest_interp", }; }; @@ -195,6 +195,38 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, // current not support axis from input, use default 0 if (!with_dynamic_shape || desc.Input("Axis").size() > 0) return false; } + + if (op_type == "multiclass_nms") { + if (with_dynamic_shape) return false; + auto* block = desc.Block(); + for (auto& param_name : desc.Inputs()) { + for (auto& var_name : param_name.second) { + auto* var_desc = block->FindVar(var_name); + const auto shape = var_desc->GetShape(); + if (shape.size() != 3) { + VLOG(1) << "multiclass_nms op dims != 3 not supported in tensorrt, " + "but got dims " + << shape.size() << ", so jump it."; + return false; + } + } + } + bool has_attrs = + (desc.HasAttr("background_label") && + desc.HasAttr("score_threshold") && desc.HasAttr("nms_top_k") && + desc.HasAttr("keep_top_k") && desc.HasAttr("normalized")); + if (has_attrs == false) return false; + + auto nms_top_k = BOOST_GET_CONST(int, desc.GetAttr("nms_top_k")); + if (nms_top_k < 0) return false; + + auto keep_top_k = BOOST_GET_CONST(int, desc.GetAttr("keep_top_k")); + if (keep_top_k < 0) return false; + + auto registry = GetPluginRegistry(); + if (registry == nullptr) return false; + } + if (op_type == "fc" || op_type == "mul") { const int x_num_col_dims = desc.HasAttr("x_num_col_dims") diff --git a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py index 993493a3ccf2b6fd28448b0059e5f648836deec3..010086bfbbc47ffe65b6379b65b05900235e83d3 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py @@ -46,6 +46,7 @@ class InferencePassTest(unittest.TestCase): self.enable_mkldnn = False self.enable_mkldnn_bfloat16 = False self.enable_trt = False + self.enable_tensorrt_oss = True self.trt_parameters = None self.dynamic_shape_params = None self.enable_lite = False @@ -133,6 +134,8 @@ class InferencePassTest(unittest.TestCase): self.dynamic_shape_params.max_input_shape, self.dynamic_shape_params.optim_input_shape, self.dynamic_shape_params.disable_trt_plugin_fp16) + if self.enable_tensorrt_oss: + config.enable_tensorrt_oss() elif use_mkldnn: config.enable_mkldnn() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3ca6985985985e2a60ef8f6ff5a8ef8c2a129ec2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms_op.py @@ -0,0 +1,144 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import itertools +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker +from paddle.fluid.core import AnalysisConfig + + +class TensorRTMultiClassNMSTest(InferencePassTest): + def setUp(self): + self.enable_trt = True + self.enable_tensorrt_oss = True + self.precision = AnalysisConfig.Precision.Float32 + self.serialize = False + self.bs = 1 + self.background_label = -1 + self.score_threshold = .5 + self.nms_top_k = 8 + self.nms_threshold = .3 + self.keep_top_k = 8 + self.normalized = False + self.num_classes = 8 + self.num_boxes = 8 + self.trt_parameters = InferencePassTest.TensorRTParam( + 1 << 30, self.bs, 2, self.precision, self.serialize, False) + + def build(self): + with fluid.program_guard(self.main_program, self.startup_program): + boxes = fluid.data( + name='bboxes', shape=[-1, self.num_boxes, 4], dtype='float32') + scores = fluid.data( + name='scores', + shape=[-1, self.num_classes, self.num_boxes], + dtype='float32') + multiclass_nms_out = fluid.layers.multiclass_nms( + bboxes=boxes, + scores=scores, + background_label=self.background_label, + score_threshold=self.score_threshold, + nms_top_k=self.nms_top_k, + nms_threshold=self.nms_threshold, + keep_top_k=self.keep_top_k, + normalized=self.normalized) + mutliclass_nms_out = multiclass_nms_out + 1. + multiclass_nms_out = fluid.layers.reshape( + multiclass_nms_out, [self.bs, 1, self.keep_top_k, 6], + name='reshape') + out = fluid.layers.batch_norm(multiclass_nms_out, is_test=True) + + boxes_data = np.arange(self.num_boxes * 4).reshape( + [self.bs, self.num_boxes, 4]).astype('float32') + scores_data = np.arange(1 * self.num_classes * self.num_boxes).reshape( + [self.bs, self.num_classes, self.num_boxes]).astype('float32') + self.feeds = { + 'bboxes': boxes_data, + 'scores': scores_data, + } + self.fetch_list = [out] + + def run_test(self): + self.build() + self.check_output() + + def run_test_all(self): + precision_opt = [ + AnalysisConfig.Precision.Float32, AnalysisConfig.Precision.Half + ] + serialize_opt = [False, True] + max_shape = { + 'bboxes': [self.bs, self.num_boxes, 4], + 'scores': [self.bs, self.num_classes, self.num_boxes], + } + opt_shape = max_shape + dynamic_shape_opt = [ + None, InferencePassTest.DynamicShapeParam({ + 'bboxes': [1, 1, 4], + 'scores': [1, 1, 1] + }, max_shape, opt_shape, False) + ] + for precision, serialize, dynamic_shape in itertools.product( + precision_opt, serialize_opt, dynamic_shape_opt): + self.precision = precision + self.serialize = serialize + self.dynamic_shape_params = dynamic_shape + self.build() + self.check_output() + + def check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + def test_base(self): + self.run_test() + + def test_fp16(self): + self.precision = AnalysisConfig.Precision.Half + self.run_test() + + def test_serialize(self): + self.serialize = True + self.run_test() + + def test_dynamic(self): + max_shape = { + 'bboxes': [self.bs, self.num_boxes, 4], + 'scores': [self.bs, self.num_classes, self.num_boxes], + } + opt_shape = max_shape + self.dynamic_shape_params = InferencePassTest.DynamicShapeParam({ + 'bboxes': [1, 1, 4], + 'scores': [1, 1, 1] + }, max_shape, opt_shape, False) + self.run_test() + + def test_background(self): + self.background = 7 + self.run_test() + + def test_disable_oss(self): + self.diable_tensorrt_oss = False + self.run_test() + + +if __name__ == "__main__": + unittest.main()