未验证 提交 61a3e30b 编写于 作者: Z Zhang Jun 提交者: GitHub

fix trt multiclass_nms3 (#45166) (#46034)

* Support dynamic shape in multiclass_nms3 Plugin for Paddle-TensorRT.
上级 5130b0a1
......@@ -54,18 +54,34 @@ class MultiClassNMS3OpConverter : public OpConverter {
PADDLE_GET_CONST(float, op_desc.GetAttr("nms_threshold"));
int keep_top_k = PADDLE_GET_CONST(int, op_desc.GetAttr("keep_top_k"));
bool normalized = PADDLE_GET_CONST(bool, op_desc.GetAttr("normalized"));
int num_classes = scores_tensor->getDimensions().d[0];
int class_index = engine_->with_dynamic_shape() ? 1 : 0;
int num_classes = scores_tensor->getDimensions().d[class_index];
auto bboxes_dims = bboxes_tensor->getDimensions();
nvinfer1::Dims3 bboxes_expand_dims(bboxes_dims.d[0], 1, bboxes_dims.d[1]);
auto* bboxes_expand_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *bboxes_tensor);
bboxes_expand_layer->setReshapeDimensions(bboxes_expand_dims);
nvinfer1::Permutation permutation{1, 0};
auto* scores_transpose_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *scores_tensor);
scores_transpose_layer->setFirstTranspose(permutation);
nvinfer1::IShuffleLayer* bboxes_expand_layer = nullptr;
nvinfer1::IShuffleLayer* scores_transpose_layer = nullptr;
if (engine_->with_dynamic_shape()) {
nvinfer1::Dims4 bboxes_expand_dims(
bboxes_dims.d[0], bboxes_dims.d[1], 1, bboxes_dims.d[2]);
bboxes_expand_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *bboxes_tensor);
bboxes_expand_layer->setReshapeDimensions(bboxes_expand_dims);
nvinfer1::Permutation permutation{0, 2, 1};
scores_transpose_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *scores_tensor);
scores_transpose_layer->setFirstTranspose(permutation);
} else {
nvinfer1::Dims3 bboxes_expand_dims(bboxes_dims.d[0], 1, bboxes_dims.d[1]);
bboxes_expand_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *bboxes_tensor);
bboxes_expand_layer->setReshapeDimensions(bboxes_expand_dims);
nvinfer1::Permutation permutation{1, 0};
scores_transpose_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *scores_tensor);
scores_transpose_layer->setFirstTranspose(permutation);
}
std::vector<nvinfer1::ITensor*> batch_nms_inputs;
batch_nms_inputs.push_back(bboxes_expand_layer->getOutput(0));
......@@ -101,27 +117,41 @@ class MultiClassNMS3OpConverter : public OpConverter {
fields.size() * sizeof(nvinfer1::PluginField)));
plugin_collections->nbFields = static_cast<int>(fields.size());
plugin_collections->fields = fields.data();
auto creator = GetPluginRegistry()->getPluginCreator("BatchedNMS_TRT", "1");
std::string nms_plugin_name = "BatchedNMS_TRT";
if (engine_->with_dynamic_shape()) {
nms_plugin_name = "BatchedNMSDynamic_TRT";
}
auto creator =
GetPluginRegistry()->getPluginCreator(nms_plugin_name.c_str(), "1");
auto batch_nms_plugin =
creator->createPlugin("BatchNMSPlugin", plugin_collections);
creator->createPlugin(nms_plugin_name.c_str(), plugin_collections);
free(plugin_collections);
auto batch_nms_layer = engine_->network()->addPluginV2(
batch_nms_inputs.data(), batch_nms_inputs.size(), *batch_nms_plugin);
// static shape: [keep_topk, 4], [keep_topk], [keep_topk]
// dynamic shape: [bs, keep_topk, 4], [bs, keep_topk], [bs, keep_topk]
auto nmsed_boxes = batch_nms_layer->getOutput(1);
auto nmsed_scores = batch_nms_layer->getOutput(2);
auto nmsed_classes = batch_nms_layer->getOutput(3);
auto nmsed_scores_transpose_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *nmsed_scores);
nmsed_scores_transpose_layer->setReshapeDimensions(
nvinfer1::Dims2(keep_top_k, 1));
auto nmsed_classes_reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *nmsed_classes);
nmsed_classes_reshape_layer->setReshapeDimensions(
nvinfer1::Dims2(keep_top_k, 1));
if (engine_->with_dynamic_shape()) {
nmsed_scores_transpose_layer->setReshapeDimensions(
nvinfer1::Dims3(bboxes_dims.d[0], keep_top_k, 1));
nmsed_classes_reshape_layer->setReshapeDimensions(
nvinfer1::Dims3(bboxes_dims.d[0], keep_top_k, 1));
} else {
nmsed_scores_transpose_layer->setReshapeDimensions(
nvinfer1::Dims2(keep_top_k, 1));
nmsed_classes_reshape_layer->setReshapeDimensions(
nvinfer1::Dims2(keep_top_k, 1));
}
std::vector<nvinfer1::ITensor*> concat_inputs;
concat_inputs.push_back(nmsed_classes_reshape_layer->getOutput(0));
concat_inputs.push_back(nmsed_scores_transpose_layer->getOutput(0));
......@@ -129,7 +159,8 @@ class MultiClassNMS3OpConverter : public OpConverter {
auto nms_concat_layer = TRT_ENGINE_ADD_LAYER(
engine_, Concatenation, concat_inputs.data(), concat_inputs.size());
nms_concat_layer->setAxis(1);
int axis_index = engine_->with_dynamic_shape() ? 1 : 0;
nms_concat_layer->setAxis(axis_index + 1);
// add fake index as output to be consistent with the outputs of
// multiclass_nms3
......
......@@ -52,18 +52,34 @@ class MultiClassNMSOpConverter : public OpConverter {
PADDLE_GET_CONST(float, op_desc.GetAttr("nms_threshold"));
int keep_top_k = PADDLE_GET_CONST(int, op_desc.GetAttr("keep_top_k"));
bool normalized = PADDLE_GET_CONST(bool, op_desc.GetAttr("normalized"));
int num_classes = scores_tensor->getDimensions().d[0];
int class_index = engine_->with_dynamic_shape() ? 1 : 0;
int num_classes = scores_tensor->getDimensions().d[class_index];
auto bboxes_dims = bboxes_tensor->getDimensions();
nvinfer1::Dims3 bboxes_expand_dims(bboxes_dims.d[0], 1, bboxes_dims.d[1]);
auto* bboxes_expand_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *bboxes_tensor);
bboxes_expand_layer->setReshapeDimensions(bboxes_expand_dims);
nvinfer1::Permutation permutation{1, 0};
auto* scores_transpose_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *scores_tensor);
scores_transpose_layer->setFirstTranspose(permutation);
nvinfer1::IShuffleLayer* bboxes_expand_layer = nullptr;
nvinfer1::IShuffleLayer* scores_transpose_layer = nullptr;
if (engine_->with_dynamic_shape()) {
nvinfer1::Dims4 bboxes_expand_dims(
bboxes_dims.d[0], bboxes_dims.d[1], 1, bboxes_dims.d[2]);
bboxes_expand_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *bboxes_tensor);
bboxes_expand_layer->setReshapeDimensions(bboxes_expand_dims);
nvinfer1::Permutation permutation{0, 2, 1};
scores_transpose_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *scores_tensor);
scores_transpose_layer->setFirstTranspose(permutation);
} else {
nvinfer1::Dims3 bboxes_expand_dims(bboxes_dims.d[0], 1, bboxes_dims.d[1]);
bboxes_expand_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *bboxes_tensor);
bboxes_expand_layer->setReshapeDimensions(bboxes_expand_dims);
nvinfer1::Permutation permutation{1, 0};
scores_transpose_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *scores_tensor);
scores_transpose_layer->setFirstTranspose(permutation);
}
std::vector<nvinfer1::ITensor*> batch_nms_inputs;
batch_nms_inputs.push_back(bboxes_expand_layer->getOutput(0));
......@@ -100,9 +116,14 @@ class MultiClassNMSOpConverter : public OpConverter {
plugin_collections->nbFields = static_cast<int>(fields.size());
plugin_collections->fields = fields.data();
auto creator = GetPluginRegistry()->getPluginCreator("BatchedNMS_TRT", "1");
std::string nms_plugin_name = "BatchedNMS_TRT";
if (engine_->with_dynamic_shape()) {
nms_plugin_name = "BatchedNMSDynamic_TRT";
}
auto creator =
GetPluginRegistry()->getPluginCreator(nms_plugin_name.c_str(), "1");
auto batch_nms_plugin =
creator->createPlugin("BatchNMSPlugin", plugin_collections);
creator->createPlugin(nms_plugin_name.c_str(), plugin_collections);
free(plugin_collections);
auto batch_nms_layer = engine_->network()->addPluginV2(
......@@ -113,12 +134,21 @@ class MultiClassNMSOpConverter : public OpConverter {
auto nmsed_scores_transpose_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *nmsed_scores);
nmsed_scores_transpose_layer->setReshapeDimensions(
nvinfer1::Dims2(keep_top_k, 1));
auto nmsed_classes_reshape_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *nmsed_classes);
nmsed_classes_reshape_layer->setReshapeDimensions(
nvinfer1::Dims2(keep_top_k, 1));
if (engine_->with_dynamic_shape()) {
nmsed_scores_transpose_layer->setReshapeDimensions(
nvinfer1::Dims3(bboxes_dims.d[0], keep_top_k, 1));
nmsed_classes_reshape_layer->setReshapeDimensions(
nvinfer1::Dims3(bboxes_dims.d[0], keep_top_k, 1));
} else {
nmsed_scores_transpose_layer->setReshapeDimensions(
nvinfer1::Dims2(keep_top_k, 1));
nmsed_classes_reshape_layer->setReshapeDimensions(
nvinfer1::Dims2(keep_top_k, 1));
}
std::vector<nvinfer1::ITensor*> concat_inputs;
concat_inputs.push_back(nmsed_classes_reshape_layer->getOutput(0));
......@@ -127,7 +157,8 @@ class MultiClassNMSOpConverter : public OpConverter {
auto nms_concat_layer = TRT_ENGINE_ADD_LAYER(
engine_, Concatenation, concat_inputs.data(), concat_inputs.size());
nms_concat_layer->setAxis(1);
int axis_index = engine_->with_dynamic_shape() ? 1 : 0;
nms_concat_layer->setAxis(axis_index + 1);
RreplenishLayerAndOutput(
nms_concat_layer, "multiclass_nms", {output_name}, test_mode);
......
......@@ -33,7 +33,10 @@ namespace tensorrt {
struct SimpleOpTypeSetTeller : public Teller {
SimpleOpTypeSetTeller() {
#if IS_TRT_VERSION_GE(7130)
// use TensorRT plugin
teller_set.insert("group_norm");
teller_set.insert("multiclass_nms3");
teller_set.insert("multiclass_nms");
#endif
#if IS_TRT_VERSION_GE(7000)
teller_set.insert("tile");
......@@ -278,7 +281,6 @@ struct SimpleOpTypeSetTeller : public Teller {
"c_allreduce_prod",
"roll",
"cast",
"multiclass_nms3",
"transformer_input_convert",
"recover_padding",
"remove_padding",
......@@ -853,7 +855,6 @@ bool OpTeller::Tell(const framework::ir::Node* node,
}
if (op_type == "multiclass_nms" || op_type == "multiclass_nms3") {
if (with_dynamic_shape) return false;
auto* block = desc.Block();
if (block == nullptr) {
VLOG(3) << "The block desc is nullptr, we can't continue to analyze. "
......
......@@ -73,7 +73,7 @@ TEST(tensorrt_tester_ppyolo_mbv3, multi_thread4_trt_fp32_bz2) {
FLAGS_modeldir + "/model.pdiparams");
config.EnableUseGpu(100, 0);
config.EnableTensorRtEngine(
1 << 20, 2, 3, paddle_infer::PrecisionType::kFloat32, false, false);
1 << 25, 2, 3, paddle_infer::PrecisionType::kFloat32, false, false);
LOG(INFO) << config.Summary();
// get groudtruth by disbale ir
paddle_infer::services::PredictorPool pred_pool_no_ir(config_no_ir, 1);
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
class TrtConvertMulticlassNMSTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def create_inference_config(self, use_trt=True) -> paddle_infer.Config:
if use_trt:
config = paddle_infer.Config()
config.disable_glog_info()
config.enable_use_gpu(100, 0)
config.set_optim_cache_dir(self.cache_dir)
config.switch_ir_debug()
config.enable_tensorrt_engine(
max_batch_size=self.trt_param.max_batch_size,
workspace_size=self.trt_param.workspace_size,
min_subgraph_size=self.trt_param.min_subgraph_size,
precision_mode=self.trt_param.precision,
use_static=self.trt_param.use_static,
use_calib_mode=self.trt_param.use_calib_mode)
if len(self.dynamic_shape.min_input_shape
) != 0 and self.dynamic_shape.min_input_shape.keys(
) == self.dynamic_shape.max_input_shape.keys(
) and self.dynamic_shape.min_input_shape.keys(
) == self.dynamic_shape.opt_input_shape.keys():
config.set_trt_dynamic_shape_info(
self.dynamic_shape.min_input_shape,
self.dynamic_shape.max_input_shape,
self.dynamic_shape.opt_input_shape,
self.dynamic_shape.disable_trt_plugin_fp16)
return config
else:
config = paddle_infer.Config()
config.switch_ir_debug(True)
config.set_optim_cache_dir(self.cache_dir)
config.disable_glog_info()
return config
def sample_program_configs(self):
def generate_boxes(batch, num_boxes):
return np.arange(batch * num_boxes * 4,
dtype=np.float32).reshape([batch, num_boxes, 4])
def generate_scores(batch, num_boxes, num_classes):
return np.arange(batch * num_classes * num_boxes,
dtype=np.float32).reshape(
[batch, num_classes, num_boxes])
# return np.random.rand(batch, num_classes, num_boxes).astype(np.float32)
for batch in [1, 2]:
self.batch = batch
for nms_eta in [0.8, 1.1]:
for num_boxes, num_classes in [[80, 100], [40, 200], [20, 400]]:
self.num_boxes, self.num_classes = num_boxes, num_classes
for score_threshold in [
0.01,
]:
ops_config = [{
"op_type": "multiclass_nms",
"op_inputs": {
"BBoxes": ["input_bboxes"],
"Scores": ["input_scores"],
},
"op_outputs": {
"Out": ["nms_output_boxes"],
},
"op_attrs": {
"background_label": -1,
"score_threshold": score_threshold,
"nms_top_k": num_boxes,
"keep_top_k": num_boxes,
"nms_threshold": 0.3,
"normalized": False,
"nms_eta": nms_eta
}
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_bboxes":
TensorConfig(data_gen=partial(
generate_boxes, batch, num_boxes)),
"input_scores":
TensorConfig(
data_gen=partial(generate_scores, batch,
num_boxes, num_classes))
},
outputs=["nms_output_boxes"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
# The last dim of input_bboxes should be static.
self.dynamic_shape.min_input_shape = {
"input_bboxes": [1, self.num_boxes, 4],
"input_scores": [1, self.num_classes, self.num_boxes],
}
self.dynamic_shape.max_input_shape = {
"input_bboxes": [8, self.num_boxes, 4],
"input_scores": [8, self.num_classes, self.num_boxes],
}
self.dynamic_shape.opt_input_shape = {
"input_bboxes": [self.batch, self.num_boxes, 4],
"input_scores": [self.batch, self.num_classes, self.num_boxes],
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-2
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-5
# self.trt_param.precision = paddle_infer.PrecisionType.Half
# yield self.create_inference_config(), generate_trt_nodes_num(
# attrs, True), (1e-2, 1e-2)
def assert_tensors_near(self, atol: float, rtol: float,
tensor: Dict[str, np.array],
baseline: Dict[str, np.array]):
# the order of tensorrt outputs are not consistent with paddle
for key, arr in tensor.items():
if key == "nms_output_boxes":
basline_arr = np.array(
sorted(baseline[key].reshape((-1, 6)),
key=lambda i: [i[0], i[1]]))
arr = np.array(
sorted(arr.reshape((-1, 6)), key=lambda i: [i[0], i[1]]))
else:
basline_arr = np.array(baseline[key].reshape((-1, 1)))
arr = np.array(arr.reshape((-1, 1)))
self.assertTrue(
basline_arr.shape == arr.shape,
"The output shapes are not equal, the baseline shape is " +
str(basline_arr.shape) + ', but got ' + str(arr.shape))
diff = abs(basline_arr - arr)
np.testing.assert_allclose(
basline_arr,
arr,
rtol=rtol,
atol=atol,
err_msg='Output has diff, Maximum absolute error: {}'.format(
np.amax(diff)))
def assert_op_size(self, trt_engine_num, paddle_op_num):
# tensorrt op num is not consistent with paddle
return True
def test(self):
self.trt_param.workspace_size = 1 << 25
self.run_test()
if __name__ == "__main__":
unittest.main()
......@@ -71,8 +71,10 @@ class TrtConvertMulticlassNMS3Test(TrtLayerAutoScanTest):
# return np.random.rand(batch, num_classes, num_boxes).astype(np.float32)
for batch in [1, 2]:
for num_boxes in [4, 12]:
for num_classes in [2, 6]:
self.batch = batch
for nms_eta in [0.8, 1.1]:
for num_boxes, num_classes in [[80, 100], [40, 200], [20, 400]]:
self.num_boxes, self.num_classes = num_boxes, num_classes
for score_threshold in [
0.01,
]:
......@@ -94,7 +96,7 @@ class TrtConvertMulticlassNMS3Test(TrtLayerAutoScanTest):
"keep_top_k": num_boxes,
"nms_threshold": 0.3,
"normalized": False,
"nms_eta": 1.1
"nms_eta": nms_eta
}
}]
ops = self.generate_op_config(ops_config)
......@@ -114,12 +116,26 @@ class TrtConvertMulticlassNMS3Test(TrtLayerAutoScanTest):
"nms_output_boxes", "nms_output_num",
"nms_output_index"
])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
# The last dim of input_bboxes should be static.
self.dynamic_shape.min_input_shape = {
"input_bboxes": [1, self.num_boxes, 4],
"input_scores": [1, self.num_classes, self.num_boxes],
}
self.dynamic_shape.max_input_shape = {
"input_bboxes": [8, self.num_boxes, 4],
"input_scores": [8, self.num_classes, self.num_boxes],
}
self.dynamic_shape.opt_input_shape = {
"input_bboxes": [self.batch, self.num_boxes, 4],
"input_scores": [self.batch, self.num_classes, self.num_boxes],
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
......@@ -141,6 +157,15 @@ class TrtConvertMulticlassNMS3Test(TrtLayerAutoScanTest):
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-2
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-5
# self.trt_param.precision = paddle_infer.PrecisionType.Half
# yield self.create_inference_config(), generate_trt_nodes_num(
# attrs, True), (1e-2, 1e-2)
def assert_tensors_near(self, atol: float, rtol: float,
tensor: Dict[str, np.array],
baseline: Dict[str, np.array]):
......@@ -176,7 +201,7 @@ class TrtConvertMulticlassNMS3Test(TrtLayerAutoScanTest):
return True
def test(self):
self.trt_param.workspace_size = 1 << 20
self.trt_param.workspace_size = 1 << 25
self.run_test()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册