未验证 提交 1bd8125f 编写于 作者: W wangxinxin08 提交者: GitHub

add fake index and unittest for multiclass_nms3 trt (#41344)

* add fake index and unittest for multiclass_nms3 trt

* modify unittest
上级 d8a10977
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
...@@ -38,6 +38,7 @@ class MultiClassNMS3OpConverter : public OpConverter { ...@@ -38,6 +38,7 @@ class MultiClassNMS3OpConverter : public OpConverter {
std::string scores = op_desc.Input("Scores").front(); std::string scores = op_desc.Input("Scores").front();
std::string output_name = op_desc.Output("Out").front(); std::string output_name = op_desc.Output("Out").front();
std::string rois_num_name = op_desc.Output("NmsRoisNum").front(); std::string rois_num_name = op_desc.Output("NmsRoisNum").front();
std::string index_name = op_desc.Output("Index").front();
auto* bboxes_tensor = engine_->GetITensor(bboxes); auto* bboxes_tensor = engine_->GetITensor(bboxes);
auto* scores_tensor = engine_->GetITensor(scores); auto* scores_tensor = engine_->GetITensor(scores);
...@@ -122,10 +123,20 @@ class MultiClassNMS3OpConverter : public OpConverter { ...@@ -122,10 +123,20 @@ class MultiClassNMS3OpConverter : public OpConverter {
engine_, Concatenation, concat_inputs.data(), concat_inputs.size()); engine_, Concatenation, concat_inputs.data(), concat_inputs.size());
nms_concat_layer->setAxis(1); nms_concat_layer->setAxis(1);
// add fake index as output to be consistent with the outputs of
// multiclass_nms3
std::vector<uint32_t> index(1, 0);
auto constant_layer = TRT_ENGINE_ADD_LAYER(
engine_, Constant, nvinfer1::Dims2(1, 1),
nvinfer1::Weights{nvinfer1::DataType::kINT32,
static_cast<void*>(index.data()), 1});
RreplenishLayerAndOutput(batch_nms_layer, "multiclass_nms3", RreplenishLayerAndOutput(batch_nms_layer, "multiclass_nms3",
{rois_num_name}, test_mode); {rois_num_name}, test_mode);
RreplenishLayerAndOutput(nms_concat_layer, "multiclass_nms3", {output_name}, RreplenishLayerAndOutput(nms_concat_layer, "multiclass_nms3", {output_name},
test_mode); test_mode);
RreplenishLayerAndOutput(constant_layer, "multiclass_nms3", {index_name},
test_mode);
} }
}; };
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
class TrtConvertMulticlassNMS3Test(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def create_inference_config(self, use_trt=True) -> paddle_infer.Config:
if use_trt:
config = paddle_infer.Config()
config.disable_glog_info()
config.enable_use_gpu(100, 0)
config.set_optim_cache_dir(self.cache_dir)
config.switch_ir_debug()
config.enable_tensorrt_engine(
max_batch_size=self.trt_param.max_batch_size,
workspace_size=self.trt_param.workspace_size,
min_subgraph_size=self.trt_param.min_subgraph_size,
precision_mode=self.trt_param.precision,
use_static=self.trt_param.use_static,
use_calib_mode=self.trt_param.use_calib_mode)
if len(self.dynamic_shape.min_input_shape
) != 0 and self.dynamic_shape.min_input_shape.keys(
) == self.dynamic_shape.max_input_shape.keys(
) and self.dynamic_shape.min_input_shape.keys(
) == self.dynamic_shape.opt_input_shape.keys():
config.set_trt_dynamic_shape_info(
self.dynamic_shape.min_input_shape,
self.dynamic_shape.max_input_shape,
self.dynamic_shape.opt_input_shape,
self.dynamic_shape.disable_trt_plugin_fp16)
return config
else:
config = paddle_infer.Config()
config.switch_ir_debug(True)
config.set_optim_cache_dir(self.cache_dir)
config.disable_glog_info()
return config
def sample_program_configs(self):
def generate_boxes(batch, num_boxes):
return np.arange(
batch * num_boxes * 4,
dtype=np.float32).reshape([batch, num_boxes, 4])
def generate_scores(batch, num_boxes, num_classes):
return np.arange(
batch * num_classes * num_boxes,
dtype=np.float32).reshape([batch, num_classes, num_boxes])
# return np.random.rand(batch, num_classes, num_boxes).astype(np.float32)
for batch in [1, 2]:
for num_boxes in [4, 12]:
for num_classes in [2, 6]:
for score_threshold in [0.01, ]:
ops_config = [{
"op_type": "multiclass_nms3",
"op_inputs": {
"BBoxes": ["input_bboxes"],
"Scores": ["input_scores"],
},
"op_outputs": {
"Out": ["nms_output_boxes"],
"Index": ["nms_output_index"],
"NmsRoisNum": ["nms_output_num"]
},
"op_attrs": {
"background_label": -1,
"score_threshold": score_threshold,
"nms_top_k": num_boxes,
"keep_top_k": num_boxes,
"nms_threshold": 0.3,
"normalized": False,
"nms_eta": 1.1
}
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_bboxes": TensorConfig(data_gen=partial(
generate_boxes, batch, num_boxes)),
"input_scores": TensorConfig(
data_gen=partial(generate_scores, batch,
num_boxes, num_classes))
},
outputs=[
"nms_output_boxes", "nms_output_num",
"nms_output_index"
])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-2
def assert_tensors_near(self,
atol: float,
rtol: float,
tensor: Dict[str, np.array],
baseline: Dict[str, np.array]):
# the order of tensorrt outputs are not consistent with paddle
for key, arr in tensor.items():
if key == "nms_output_index":
continue
if key == "nms_output_boxes":
basline_arr = np.array(
sorted(
baseline[key].reshape((-1, 6)),
key=lambda i: [i[0], i[1]]))
arr = np.array(
sorted(
arr.reshape((-1, 6)), key=lambda i: [i[0], i[1]]))
else:
basline_arr = np.array(baseline[key].reshape((-1, 1)))
arr = np.array(arr.reshape((-1, 1)))
self.assertTrue(
basline_arr.shape == arr.shape,
"The output shapes are not equal, the baseline shape is " +
str(basline_arr.shape) + ', but got ' + str(arr.shape))
diff = abs(basline_arr - arr)
self.assertTrue(
np.allclose(
basline_arr, arr, atol=atol, rtol=rtol),
"Output has diff, Maximum absolute error: {}".format(
np.amax(diff)))
def assert_op_size(self, trt_engine_num, paddle_op_num):
# tensorrt op num is not consistent with paddle
return True
def test(self):
self.trt_param.workspace_size = 1 << 20
self.run_test()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册