Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
71cb3ff8
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
71cb3ff8
编写于
10月 11, 2021
作者:
W
wangxinxin08
提交者:
GitHub
10月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
enhance yolobox trt plugin (#34128)
* enhance yolobox plugin
上级
7850f7ce
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
111 addition
and
17 deletion
+111
-17
paddle/fluid/inference/tensorrt/convert/yolo_box_op.cc
paddle/fluid/inference/tensorrt/convert/yolo_box_op.cc
+8
-1
paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu
paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu
+49
-16
paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h
paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h
+3
-0
python/paddle/fluid/tests/unittests/ir/inference/test_trt_yolo_box_op.py
...luid/tests/unittests/ir/inference/test_trt_yolo_box_op.py
+51
-0
未找到文件。
paddle/fluid/inference/tensorrt/convert/yolo_box_op.cc
浏览文件 @
71cb3ff8
...
...
@@ -48,13 +48,20 @@ class YoloBoxOpConverter : public OpConverter {
float
conf_thresh
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"conf_thresh"
));
bool
clip_bbox
=
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"clip_bbox"
));
float
scale_x_y
=
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"scale_x_y"
));
bool
iou_aware
=
op_desc
.
HasAttr
(
"iou_aware"
)
?
BOOST_GET_CONST
(
bool
,
op_desc
.
GetAttr
(
"iou_aware"
))
:
false
;
float
iou_aware_factor
=
op_desc
.
HasAttr
(
"iou_aware_factor"
)
?
BOOST_GET_CONST
(
float
,
op_desc
.
GetAttr
(
"iou_aware_factor"
))
:
0.5
;
int
type_id
=
static_cast
<
int
>
(
engine_
->
WithFp16
());
auto
input_dim
=
X_tensor
->
getDimensions
();
auto
*
yolo_box_plugin
=
new
plugin
::
YoloBoxPlugin
(
type_id
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
anchors
,
class_num
,
conf_thresh
,
downsample_ratio
,
clip_bbox
,
scale_x_y
,
input_dim
.
d
[
1
],
input_dim
.
d
[
2
]);
i
ou_aware
,
iou_aware_factor
,
i
nput_dim
.
d
[
1
],
input_dim
.
d
[
2
]);
std
::
vector
<
nvinfer1
::
ITensor
*>
yolo_box_inputs
;
yolo_box_inputs
.
push_back
(
X_tensor
);
...
...
paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.cu
浏览文件 @
71cb3ff8
...
...
@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <algorithm>
#include <cassert>
...
...
@@ -29,7 +27,8 @@ YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type,
const
std
::
vector
<
int
>&
anchors
,
const
int
class_num
,
const
float
conf_thresh
,
const
int
downsample_ratio
,
const
bool
clip_bbox
,
const
float
scale_x_y
,
const
int
input_h
,
const
float
scale_x_y
,
const
bool
iou_aware
,
const
float
iou_aware_factor
,
const
int
input_h
,
const
int
input_w
)
:
data_type_
(
data_type
),
class_num_
(
class_num
),
...
...
@@ -37,6 +36,8 @@ YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type,
downsample_ratio_
(
downsample_ratio
),
clip_bbox_
(
clip_bbox
),
scale_x_y_
(
scale_x_y
),
iou_aware_
(
iou_aware
),
iou_aware_factor_
(
iou_aware_factor
),
input_h_
(
input_h
),
input_w_
(
input_w
)
{
anchors_
.
insert
(
anchors_
.
end
(),
anchors
.
cbegin
(),
anchors
.
cend
());
...
...
@@ -45,6 +46,7 @@ YoloBoxPlugin::YoloBoxPlugin(const nvinfer1::DataType data_type,
assert
(
class_num_
>
0
);
assert
(
input_h_
>
0
);
assert
(
input_w_
>
0
);
assert
((
iou_aware_factor_
>
0
&&
iou_aware_factor_
<
1
));
cudaMalloc
(
&
anchors_device_
,
anchors
.
size
()
*
sizeof
(
int
));
cudaMemcpy
(
anchors_device_
,
anchors
.
data
(),
anchors
.
size
()
*
sizeof
(
int
),
...
...
@@ -59,6 +61,8 @@ YoloBoxPlugin::YoloBoxPlugin(const void* data, size_t length) {
DeserializeValue
(
&
data
,
&
length
,
&
downsample_ratio_
);
DeserializeValue
(
&
data
,
&
length
,
&
clip_bbox_
);
DeserializeValue
(
&
data
,
&
length
,
&
scale_x_y_
);
DeserializeValue
(
&
data
,
&
length
,
&
iou_aware_
);
DeserializeValue
(
&
data
,
&
length
,
&
iou_aware_factor_
);
DeserializeValue
(
&
data
,
&
length
,
&
input_h_
);
DeserializeValue
(
&
data
,
&
length
,
&
input_w_
);
}
...
...
@@ -133,8 +137,19 @@ __device__ inline void GetYoloBox(float* box, const T* x, const int* anchors,
__device__
inline
int
GetEntryIndex
(
int
batch
,
int
an_idx
,
int
hw_idx
,
int
an_num
,
int
an_stride
,
int
stride
,
int
entry
)
{
return
(
batch
*
an_num
+
an_idx
)
*
an_stride
+
entry
*
stride
+
hw_idx
;
int
entry
,
bool
iou_aware
)
{
if
(
iou_aware
)
{
return
(
batch
*
an_num
+
an_idx
)
*
an_stride
+
(
batch
*
an_num
+
an_num
+
entry
)
*
stride
+
hw_idx
;
}
else
{
return
(
batch
*
an_num
+
an_idx
)
*
an_stride
+
entry
*
stride
+
hw_idx
;
}
}
__device__
inline
int
GetIoUIndex
(
int
batch
,
int
an_idx
,
int
hw_idx
,
int
an_num
,
int
an_stride
,
int
stride
)
{
return
batch
*
an_num
*
an_stride
+
(
batch
*
an_num
+
an_idx
)
*
stride
+
hw_idx
;
}
template
<
typename
T
>
...
...
@@ -178,7 +193,8 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize,
const
int
w
,
const
int
an_num
,
const
int
class_num
,
const
int
box_num
,
int
input_size_h
,
int
input_size_w
,
bool
clip_bbox
,
const
float
scale
,
const
float
bias
)
{
const
float
bias
,
bool
iou_aware
,
const
float
iou_aware_factor
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
blockDim
.
x
*
gridDim
.
x
;
float
box
[
4
];
...
...
@@ -193,11 +209,16 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize,
int
img_height
=
imgsize
[
2
*
i
];
int
img_width
=
imgsize
[
2
*
i
+
1
];
int
obj_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
4
);
int
obj_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
4
,
iou_aware
);
float
conf
=
sigmoid
(
static_cast
<
float
>
(
input
[
obj_idx
]));
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
0
);
if
(
iou_aware
)
{
int
iou_idx
=
GetIoUIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
);
float
iou
=
sigmoid
<
float
>
(
input
[
iou_idx
]);
conf
=
powf
(
conf
,
1.
-
iou_aware_factor
)
*
powf
(
iou
,
iou_aware_factor
);
}
int
box_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
0
,
iou_aware
);
if
(
conf
<
conf_thresh
)
{
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
...
...
@@ -212,8 +233,8 @@ __global__ void KeYoloBoxFw(const T* const input, const int* const imgsize,
box_idx
=
(
i
*
box_num
+
j
*
grid_num
+
k
*
w
+
l
)
*
4
;
CalcDetectionBox
<
T
>
(
boxes
,
box
,
box_idx
,
img_height
,
img_width
,
clip_bbox
);
int
label_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
5
);
int
label_idx
=
GetEntryIndex
(
i
,
j
,
k
*
w
+
l
,
an_num
,
an_stride
,
grid_num
,
5
,
iou_aware
);
int
score_idx
=
(
i
*
box_num
+
j
*
grid_num
+
k
*
w
+
l
)
*
class_num
;
CalcLabelScore
<
T
>
(
scores
,
input
,
label_idx
,
score_idx
,
class_num
,
conf
,
grid_num
);
...
...
@@ -240,7 +261,8 @@ int YoloBoxPlugin::enqueue_impl(int batch_size, const void* const* inputs,
reinterpret_cast
<
const
int
*
const
>
(
inputs
[
1
]),
reinterpret_cast
<
T
*>
(
outputs
[
0
]),
reinterpret_cast
<
T
*>
(
outputs
[
1
]),
conf_thresh_
,
anchors_device_
,
n
,
h
,
w
,
an_num
,
class_num_
,
box_num
,
input_size_h
,
input_size_w
,
clip_bbox_
,
scale_x_y_
,
bias
);
input_size_h
,
input_size_w
,
clip_bbox_
,
scale_x_y_
,
bias
,
iou_aware_
,
iou_aware_factor_
);
return
cudaGetLastError
()
!=
cudaSuccess
;
}
...
...
@@ -274,6 +296,8 @@ size_t YoloBoxPlugin::getSerializationSize() const TRT_NOEXCEPT {
serialize_size
+=
SerializedSize
(
scale_x_y_
);
serialize_size
+=
SerializedSize
(
input_h_
);
serialize_size
+=
SerializedSize
(
input_w_
);
serialize_size
+=
SerializedSize
(
iou_aware_
);
serialize_size
+=
SerializedSize
(
iou_aware_factor_
);
return
serialize_size
;
}
...
...
@@ -285,6 +309,8 @@ void YoloBoxPlugin::serialize(void* buffer) const TRT_NOEXCEPT {
SerializeValue
(
&
buffer
,
downsample_ratio_
);
SerializeValue
(
&
buffer
,
clip_bbox_
);
SerializeValue
(
&
buffer
,
scale_x_y_
);
SerializeValue
(
&
buffer
,
iou_aware_
);
SerializeValue
(
&
buffer
,
iou_aware_factor_
);
SerializeValue
(
&
buffer
,
input_h_
);
SerializeValue
(
&
buffer
,
input_w_
);
}
...
...
@@ -326,8 +352,8 @@ void YoloBoxPlugin::configurePlugin(
nvinfer1
::
IPluginV2Ext
*
YoloBoxPlugin
::
clone
()
const
TRT_NOEXCEPT
{
return
new
YoloBoxPlugin
(
data_type_
,
anchors_
,
class_num_
,
conf_thresh_
,
downsample_ratio_
,
clip_bbox_
,
scale_x_y_
,
input_h_
,
input_w_
);
downsample_ratio_
,
clip_bbox_
,
scale_x_y_
,
i
ou_aware_
,
iou_aware_factor_
,
input_h_
,
i
nput_w_
);
}
YoloBoxPluginCreator
::
YoloBoxPluginCreator
()
{}
...
...
@@ -367,6 +393,8 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin(
float
scale_x_y
=
1.
;
int
h
=
-
1
;
int
w
=
-
1
;
bool
iou_aware
=
false
;
float
iou_aware_factor
=
0.5
;
for
(
int
i
=
0
;
i
<
fc
->
nbFields
;
++
i
)
{
const
std
::
string
field_name
(
fc
->
fields
[
i
].
name
);
...
...
@@ -386,6 +414,10 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin(
clip_bbox
=
*
static_cast
<
const
bool
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"scale_x_y"
))
{
scale_x_y
=
*
static_cast
<
const
float
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"iou_aware"
))
{
iou_aware
=
*
static_cast
<
const
bool
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"iou_aware_factor"
))
{
iou_aware_factor
=
*
static_cast
<
const
float
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"h"
))
{
h
=
*
static_cast
<
const
int
*>
(
fc
->
fields
[
i
].
data
);
}
else
if
(
field_name
.
compare
(
"w"
))
{
...
...
@@ -397,7 +429,8 @@ nvinfer1::IPluginV2Ext* YoloBoxPluginCreator::createPlugin(
return
new
YoloBoxPlugin
(
type_id
?
nvinfer1
::
DataType
::
kHALF
:
nvinfer1
::
DataType
::
kFLOAT
,
anchors
,
class_num
,
conf_thresh
,
downsample_ratio
,
clip_bbox
,
scale_x_y
,
h
,
w
);
class_num
,
conf_thresh
,
downsample_ratio
,
clip_bbox
,
scale_x_y
,
iou_aware
,
iou_aware_factor
,
h
,
w
);
}
nvinfer1
::
IPluginV2Ext
*
YoloBoxPluginCreator
::
deserializePlugin
(
...
...
paddle/fluid/inference/tensorrt/plugin/yolo_box_op_plugin.h
浏览文件 @
71cb3ff8
...
...
@@ -31,6 +31,7 @@ class YoloBoxPlugin : public nvinfer1::IPluginV2Ext {
const
std
::
vector
<
int
>&
anchors
,
const
int
class_num
,
const
float
conf_thresh
,
const
int
downsample_ratio
,
const
bool
clip_bbox
,
const
float
scale_x_y
,
const
bool
iou_aware
,
const
float
iou_aware_factor
,
const
int
input_h
,
const
int
input_w
);
YoloBoxPlugin
(
const
void
*
data
,
size_t
length
);
~
YoloBoxPlugin
()
override
;
...
...
@@ -89,6 +90,8 @@ class YoloBoxPlugin : public nvinfer1::IPluginV2Ext {
float
scale_x_y_
;
int
input_h_
;
int
input_w_
;
bool
iou_aware_
;
float
iou_aware_factor_
;
std
::
string
namespace_
;
};
...
...
python/paddle/fluid/tests/unittests/ir/inference/test_trt_yolo_box_op.py
浏览文件 @
71cb3ff8
...
...
@@ -116,5 +116,56 @@ class TRTYoloBoxFP16Test(InferencePassTest):
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
class
TRTYoloBoxIoUAwareTest
(
InferencePassTest
):
def
setUp
(
self
):
self
.
set_params
()
with
fluid
.
program_guard
(
self
.
main_program
,
self
.
startup_program
):
image_shape
=
[
self
.
bs
,
self
.
channel
,
self
.
height
,
self
.
width
]
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
image_shape
,
dtype
=
'float32'
)
image_size
=
fluid
.
data
(
name
=
'image_size'
,
shape
=
[
self
.
bs
,
2
],
dtype
=
'int32'
)
boxes
,
scores
=
self
.
append_yolobox
(
image
,
image_size
)
self
.
feeds
=
{
'image'
:
np
.
random
.
random
(
image_shape
).
astype
(
'float32'
),
'image_size'
:
np
.
random
.
randint
(
32
,
64
,
size
=
(
self
.
bs
,
2
)).
astype
(
'int32'
),
}
self
.
enable_trt
=
True
self
.
trt_parameters
=
TRTYoloBoxTest
.
TensorRTParam
(
1
<<
30
,
self
.
bs
,
1
,
AnalysisConfig
.
Precision
.
Float32
,
False
,
False
)
self
.
fetch_list
=
[
scores
,
boxes
]
def
set_params
(
self
):
self
.
bs
=
4
self
.
channel
=
258
self
.
height
=
64
self
.
width
=
64
self
.
class_num
=
80
self
.
anchors
=
[
10
,
13
,
16
,
30
,
33
,
23
]
self
.
conf_thresh
=
.
1
self
.
downsample_ratio
=
32
self
.
iou_aware
=
True
self
.
iou_aware_factor
=
0.5
def
append_yolobox
(
self
,
image
,
image_size
):
return
fluid
.
layers
.
yolo_box
(
x
=
image
,
img_size
=
image_size
,
class_num
=
self
.
class_num
,
anchors
=
self
.
anchors
,
conf_thresh
=
self
.
conf_thresh
,
downsample_ratio
=
self
.
downsample_ratio
,
iou_aware
=
self
.
iou_aware
,
iou_aware_factor
=
self
.
iou_aware_factor
)
def
test_check_output
(
self
):
if
core
.
is_compiled_with_cuda
():
use_gpu
=
True
self
.
check_output_with_option
(
use_gpu
,
flatten
=
True
)
self
.
assertTrue
(
PassVersionChecker
.
IsCompatible
(
'tensorrt_subgraph_pass'
))
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录