diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index a7caa3e369f80a954f36226c070ff1f7bd822a2b..6388cfc4b2daea2b9d744ca3fbf30a997c938132 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1744,6 +1744,7 @@ USE_TRT_CONVERTER(yolo_box); USE_TRT_CONVERTER(roi_align); USE_TRT_CONVERTER(affine_channel); USE_TRT_CONVERTER(multiclass_nms); +USE_TRT_CONVERTER(multiclass_nms3); USE_TRT_CONVERTER(nearest_interp); USE_TRT_CONVERTER(nearest_interp_v2); USE_TRT_CONVERTER(reshape); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index e91faedb06872a5abe38c1de77b54477e0da8ef4..4f8aa4c14cd7b3770f5a264e94d715330ff7f6fe 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -11,6 +11,7 @@ nv_library(tensorrt_converter roi_align_op.cc affine_channel_op.cc multiclass_nms_op.cc + multiclass_nms3_op.cc nearest_interp_op.cc reshape_op.cc reduce_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/multiclass_nms3_op.cc b/paddle/fluid/inference/tensorrt/convert/multiclass_nms3_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..00f1419f082d1b57b7c8bbd0a75c6bea9ac24f7c --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/multiclass_nms3_op.cc @@ -0,0 +1,136 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class MultiClassNMS3OpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid multiclassNMS3 op to tensorrt plugin"; + + // for now, only work for static shape and regular tensor + framework::OpDesc op_desc(op, nullptr); + + std::string bboxes = op_desc.Input("BBoxes").front(); + std::string scores = op_desc.Input("Scores").front(); + std::string output_name = op_desc.Output("Out").front(); + std::string rois_num_name = op_desc.Output("NmsRoisNum").front(); + + auto* bboxes_tensor = engine_->GetITensor(bboxes); + auto* scores_tensor = engine_->GetITensor(scores); + + int background_label = + BOOST_GET_CONST(int, op_desc.GetAttr("background_label")); + float score_threshold = + BOOST_GET_CONST(float, op_desc.GetAttr("score_threshold")); + int nms_top_k = BOOST_GET_CONST(int, op_desc.GetAttr("nms_top_k")); + float nms_threshold = + BOOST_GET_CONST(float, op_desc.GetAttr("nms_threshold")); + int keep_top_k = BOOST_GET_CONST(int, op_desc.GetAttr("keep_top_k")); + bool normalized = BOOST_GET_CONST(bool, op_desc.GetAttr("normalized")); + int num_classes = scores_tensor->getDimensions().d[0]; + + auto bboxes_dims = bboxes_tensor->getDimensions(); + nvinfer1::Dims3 bboxes_expand_dims(bboxes_dims.d[0], 1, bboxes_dims.d[1]); + auto* bboxes_expand_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *bboxes_tensor); + bboxes_expand_layer->setReshapeDimensions(bboxes_expand_dims); + + nvinfer1::Permutation permutation{1, 0}; + auto* scores_transpose_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *scores_tensor); + scores_transpose_layer->setFirstTranspose(permutation); + + std::vector batch_nms_inputs; + batch_nms_inputs.push_back(bboxes_expand_layer->getOutput(0)); + batch_nms_inputs.push_back(scores_transpose_layer->getOutput(0)); + + constexpr bool shareLocation = true; + constexpr bool clip_boxes = false; + + const std::vector fields{ + {"shareLocation", &shareLocation, nvinfer1::PluginFieldType::kINT32, 1}, + {"backgroundLabelId", &background_label, + nvinfer1::PluginFieldType::kINT32, 1}, + {"numClasses", &num_classes, nvinfer1::PluginFieldType::kINT32, 1}, + {"topK", &nms_top_k, nvinfer1::PluginFieldType::kINT32, 1}, + {"keepTopK", &keep_top_k, nvinfer1::PluginFieldType::kINT32, 1}, + {"scoreThreshold", &score_threshold, + nvinfer1::PluginFieldType::kFLOAT32, 1}, + {"iouThreshold", &nms_threshold, nvinfer1::PluginFieldType::kFLOAT32, + 1}, + {"isNormalized", &normalized, nvinfer1::PluginFieldType::kINT32, 1}, + {"clipBoxes", &clip_boxes, nvinfer1::PluginFieldType::kINT32, 1}, + }; + + nvinfer1::PluginFieldCollection* plugin_collections = + static_cast( + malloc(sizeof(*plugin_collections) + + fields.size() * sizeof(nvinfer1::PluginField))); + plugin_collections->nbFields = static_cast(fields.size()); + plugin_collections->fields = fields.data(); + + auto creator = GetPluginRegistry()->getPluginCreator("BatchedNMS_TRT", "1"); + auto batch_nms_plugin = + creator->createPlugin("BatchNMSPlugin", plugin_collections); + free(plugin_collections); + + auto batch_nms_layer = engine_->network()->addPluginV2( + batch_nms_inputs.data(), batch_nms_inputs.size(), *batch_nms_plugin); + auto nmsed_boxes = batch_nms_layer->getOutput(1); + auto nmsed_scores = batch_nms_layer->getOutput(2); + auto nmsed_classes = batch_nms_layer->getOutput(3); + + auto nmsed_scores_transpose_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *nmsed_scores); + nmsed_scores_transpose_layer->setReshapeDimensions( + nvinfer1::Dims2(keep_top_k, 1)); + auto nmsed_classes_reshape_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *nmsed_classes); + nmsed_classes_reshape_layer->setReshapeDimensions( + nvinfer1::Dims2(keep_top_k, 1)); + + std::vector concat_inputs; + concat_inputs.push_back(nmsed_classes_reshape_layer->getOutput(0)); + concat_inputs.push_back(nmsed_scores_transpose_layer->getOutput(0)); + concat_inputs.push_back(nmsed_boxes); + + auto nms_concat_layer = TRT_ENGINE_ADD_LAYER( + engine_, Concatenation, concat_inputs.data(), concat_inputs.size()); + nms_concat_layer->setAxis(1); + + RreplenishLayerAndOutput(batch_nms_layer, "multiclass_nms3", + {rois_num_name}, test_mode); + RreplenishLayerAndOutput(nms_concat_layer, "multiclass_nms3", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(multiclass_nms3, MultiClassNMS3OpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index e73237ab13def6bf78bb991ab6df5ef4e5e80080..fe0332025ed4aaa538f2541496dba3069b7e1b32 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -179,7 +179,8 @@ struct SimpleOpTypeSetTeller : public Teller { "skip_layernorm", "slice", "fused_preln_embedding_eltwise_layernorm", - "preln_skip_layernorm"}; + "preln_skip_layernorm", + "multiclass_nms3"}; }; bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, @@ -646,7 +647,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } } - if (op_type == "multiclass_nms") { + if (op_type == "multiclass_nms" || op_type == "multiclass_nms3") { if (with_dynamic_shape) return false; auto* block = desc.Block(); if (block == nullptr) { @@ -655,7 +656,14 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, "the pass."; return false; } - for (auto& param_name : desc.Inputs()) { + auto multiclass_nms_inputs = desc.Inputs(); + if (multiclass_nms_inputs.find("RoisNum") != + multiclass_nms_inputs.end()) { + if (desc.Input("RoisNum").size() >= 1) { + return false; + } + } + for (auto& param_name : multiclass_nms_inputs) { for (auto& var_name : param_name.second) { auto* var_desc = block->FindVar(var_name); const auto shape = var_desc->GetShape(); @@ -673,6 +681,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, desc.HasAttr("keep_top_k") && desc.HasAttr("normalized")); if (has_attrs == false) return false; + // TODO(wangxinxin08): tricky solution because the outputs of batchedNMS + // plugin are not constient with those of multiclass_nms3 + if (desc.HasAttr("nms_eta") == false) return false; + auto nms_eta = BOOST_GET_CONST(float, desc.GetAttr("nms_eta")); + if (nms_eta <= 1.0) return false; + auto nms_top_k = BOOST_GET_CONST(int, desc.GetAttr("nms_top_k")); if (nms_top_k < 0) return false; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index b1ed0165edb56fae2d206e810e14b6114b5884bd..5be531258edacbabf98bf6c0dd9405cbdc68561d 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -91,6 +91,7 @@ set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30) +set_tests_properties(test_trt_multiclass_nms3_op PROPERTIES TIMEOUT 60) if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) set_tests_properties(test_emb_eltwise_layernorm_fuse_pass PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms3_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms3_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ed993ffce7da727b743813638e200223a5cef8a7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_multiclass_nms3_op.py @@ -0,0 +1,300 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import itertools +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.core import PassVersionChecker +from paddle.fluid.core import AnalysisConfig + + +def multiclass_nms(bboxes, + scores, + score_threshold, + nms_top_k, + keep_top_k, + nms_threshold=0.3, + normalized=True, + nms_eta=1., + background_label=-1, + return_index=False, + return_rois_num=True, + rois_num=None, + name=None): + """ + This operator is to do multi-class non maximum suppression (NMS) on + boxes and scores. + In the NMS step, this operator greedily selects a subset of detection bounding + boxes that have high scores larger than score_threshold, if providing this + threshold, then selects the largest nms_top_k confidences scores if nms_top_k + is larger than -1. Then this operator pruns away boxes that have high IOU + (intersection over union) overlap with already selected boxes by adaptive + threshold NMS based on parameters of nms_threshold and nms_eta. + Aftern NMS step, at most keep_top_k number of total bboxes are to be kept + per image if keep_top_k is larger than -1. + Args: + bboxes (Tensor): Two types of bboxes are supported: + 1. (Tensor) A 3-D Tensor with shape + [N, M, 4 or 8 16 24 32] represents the + predicted locations of M bounding bboxes, + N is the batch size. Each bounding box has four + coordinate values and the layout is + [xmin, ymin, xmax, ymax], when box size equals to 4. + 2. (LoDTensor) A 3-D Tensor with shape [M, C, 4] + M is the number of bounding boxes, C is the + class number + scores (Tensor): Two types of scores are supported: + 1. (Tensor) A 3-D Tensor with shape [N, C, M] + represents the predicted confidence predictions. + N is the batch size, C is the class number, M is + number of bounding boxes. For each category there + are total M scores which corresponding M bounding + boxes. Please note, M is equal to the 2nd dimension + of BBoxes. + 2. (LoDTensor) A 2-D LoDTensor with shape [M, C]. + M is the number of bbox, C is the class number. + In this case, input BBoxes should be the second + case with shape [M, C, 4]. + background_label (int): The index of background label, the background + label will be ignored. If set to -1, then all + categories will be considered. Default: 0 + score_threshold (float): Threshold to filter out bounding boxes with + low confidence score. If not provided, + consider all boxes. + nms_top_k (int): Maximum number of detections to be kept according to + the confidences after the filtering detections based + on score_threshold. + nms_threshold (float): The threshold to be used in NMS. Default: 0.3 + nms_eta (float): The threshold to be used in NMS. Default: 1.0 + keep_top_k (int): Number of total bboxes to be kept per image after NMS + step. -1 means keeping all bboxes after NMS step. + normalized (bool): Whether detections are normalized. Default: True + return_index(bool): Whether return selected index. Default: False + rois_num(Tensor): 1-D Tensor contains the number of RoIs in each image. + The shape is [B] and data type is int32. B is the number of images. + If it is not None then return a list of 1-D Tensor. Each element + is the output RoIs' number of each image on the corresponding level + and the shape is [B]. None by default. + name(str): Name of the multiclass nms op. Default: None. + Returns: + A tuple with two Variables: (Out, Index) if return_index is True, + otherwise, a tuple with one Variable(Out) is returned. + Out: A 2-D LoDTensor with shape [No, 6] represents the detections. + Each row has 6 values: [label, confidence, xmin, ymin, xmax, ymax] + or A 2-D LoDTensor with shape [No, 10] represents the detections. + Each row has 10 values: [label, confidence, x1, y1, x2, y2, x3, y3, + x4, y4]. No is the total number of detections. + If all images have not detected results, all elements in LoD will be + 0, and output tensor is empty (None). + Index: Only return when return_index is True. A 2-D LoDTensor with + shape [No, 1] represents the selected index which type is Integer. + The index is the absolute value cross batches. No is the same number + as Out. If the index is used to gather other attribute such as age, + one needs to reshape the input(N, M, 1) to (N * M, 1) as first, where + N is the batch size and M is the number of boxes. + Examples: + .. code-block:: python + import paddle + from ppdet.modeling import ops + boxes = paddle.static.data(name='bboxes', shape=[81, 4], + dtype='float32', lod_level=1) + scores = paddle.static.data(name='scores', shape=[81], + dtype='float32', lod_level=1) + out, index = ops.multiclass_nms(bboxes=boxes, + scores=scores, + background_label=0, + score_threshold=0.5, + nms_top_k=400, + nms_threshold=0.3, + keep_top_k=200, + normalized=False, + return_index=True) + """ + if in_dygraph_mode(): + attrs = ('background_label', background_label, 'score_threshold', + score_threshold, 'nms_top_k', nms_top_k, 'nms_threshold', + nms_threshold, 'keep_top_k', keep_top_k, 'nms_eta', nms_eta, + 'normalized', normalized) + output, index, nms_rois_num = core.ops.multiclass_nms3(bboxes, scores, + rois_num, *attrs) + if not return_index: + index = None + return output, nms_rois_num, index + + else: + helper = LayerHelper('multiclass_nms3', **locals()) + output = helper.create_variable_for_type_inference(dtype=bboxes.dtype) + index = helper.create_variable_for_type_inference(dtype='int32') + + inputs = {'BBoxes': bboxes, 'Scores': scores} + outputs = {'Out': output, 'Index': index} + + if rois_num is not None: + inputs['RoisNum'] = rois_num + + if return_rois_num: + nms_rois_num = helper.create_variable_for_type_inference( + dtype='int32') + outputs['NmsRoisNum'] = nms_rois_num + + helper.append_op( + type="multiclass_nms3", + inputs=inputs, + attrs={ + 'background_label': background_label, + 'score_threshold': score_threshold, + 'nms_top_k': nms_top_k, + 'nms_threshold': nms_threshold, + 'keep_top_k': keep_top_k, + 'nms_eta': nms_eta, + 'normalized': normalized + }, + outputs=outputs) + output.stop_gradient = True + index.stop_gradient = True + if not return_index: + index = None + if not return_rois_num: + nms_rois_num = None + + return output, nms_rois_num, index + + +class TensorRTMultiClassNMS3Test(InferencePassTest): + def setUp(self): + self.enable_trt = True + self.enable_tensorrt_oss = True + self.precision = AnalysisConfig.Precision.Float32 + self.serialize = False + self.bs = 1 + self.background_label = -1 + self.score_threshold = .5 + self.nms_top_k = 8 + self.nms_threshold = .3 + self.keep_top_k = 8 + self.normalized = False + self.num_classes = 8 + self.num_boxes = 8 + self.nms_eta = 1.1 + self.trt_parameters = InferencePassTest.TensorRTParam( + 1 << 30, self.bs, 2, self.precision, self.serialize, False) + + def build(self): + with fluid.program_guard(self.main_program, self.startup_program): + boxes = fluid.data( + name='bboxes', shape=[-1, self.num_boxes, 4], dtype='float32') + scores = fluid.data( + name='scores', + shape=[-1, self.num_classes, self.num_boxes], + dtype='float32') + multiclass_nms_out, _, _ = multiclass_nms( + bboxes=boxes, + scores=scores, + background_label=self.background_label, + score_threshold=self.score_threshold, + nms_top_k=self.nms_top_k, + nms_threshold=self.nms_threshold, + keep_top_k=self.keep_top_k, + normalized=self.normalized, + nms_eta=self.nms_eta) + mutliclass_nms_out = multiclass_nms_out + 1. + multiclass_nms_out = fluid.layers.reshape( + multiclass_nms_out, [self.bs, 1, self.keep_top_k, 6], + name='reshape') + out = fluid.layers.batch_norm(multiclass_nms_out, is_test=True) + + boxes_data = np.arange(self.num_boxes * 4).reshape( + [self.bs, self.num_boxes, 4]).astype('float32') + scores_data = np.arange(1 * self.num_classes * self.num_boxes).reshape( + [self.bs, self.num_classes, self.num_boxes]).astype('float32') + self.feeds = { + 'bboxes': boxes_data, + 'scores': scores_data, + } + self.fetch_list = [out] + + def run_test(self): + self.build() + self.check_output() + + def run_test_all(self): + precision_opt = [ + AnalysisConfig.Precision.Float32, AnalysisConfig.Precision.Half + ] + serialize_opt = [False, True] + max_shape = { + 'bboxes': [self.bs, self.num_boxes, 4], + 'scores': [self.bs, self.num_classes, self.num_boxes], + } + opt_shape = max_shape + dynamic_shape_opt = [ + None, InferencePassTest.DynamicShapeParam({ + 'bboxes': [1, 1, 4], + 'scores': [1, 1, 1] + }, max_shape, opt_shape, False) + ] + for precision, serialize, dynamic_shape in itertools.product( + precision_opt, serialize_opt, dynamic_shape_opt): + self.precision = precision + self.serialize = serialize + self.dynamic_shape_params = dynamic_shape + self.build() + self.check_output() + + def check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + def test_base(self): + self.run_test() + + def test_fp16(self): + self.precision = AnalysisConfig.Precision.Half + self.run_test() + + def test_serialize(self): + self.serialize = True + self.run_test() + + def test_dynamic(self): + max_shape = { + 'bboxes': [self.bs, self.num_boxes, 4], + 'scores': [self.bs, self.num_classes, self.num_boxes], + } + opt_shape = max_shape + self.dynamic_shape_params = InferencePassTest.DynamicShapeParam({ + 'bboxes': [1, 1, 4], + 'scores': [1, 1, 1] + }, max_shape, opt_shape, False) + self.run_test() + + def test_background(self): + self.background = 7 + self.run_test() + + def test_disable_oss(self): + self.diable_tensorrt_oss = False + self.run_test() + + +if __name__ == "__main__": + unittest.main()