Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
41ca7cad
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
41ca7cad
编写于
12月 07, 2022
作者:
W
wangxinxin08
提交者:
GitHub
12月 07, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add onnx infer for ppyoloe_r (#7457)
上级
964643ee
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
340 addition
and
2 deletion
+340
-2
configs/rotate/ppyoloe_r/README.md
configs/rotate/ppyoloe_r/README.md
+15
-0
configs/rotate/ppyoloe_r/README_en.md
configs/rotate/ppyoloe_r/README_en.md
+17
-1
configs/rotate/tools/onnx_infer.py
configs/rotate/tools/onnx_infer.py
+302
-0
ppdet/modeling/heads/ppyoloe_r_head.py
ppdet/modeling/heads/ppyoloe_r_head.py
+6
-1
未找到文件。
configs/rotate/ppyoloe_r/README.md
浏览文件 @
41ca7cad
...
...
@@ -123,6 +123,21 @@ python deploy/python/infer.py --image_file demo/P0072__1.0__0___0.png --model_di
**注意:**
-
使用Paddle-TRT使用确保
**PaddlePaddle版本为develop版本且TensorRT版本大于8.2**
.
**使用ONNX Runtime进行部署**
,执行以下命令:
```
# 导出模型
python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams export_onnx=True
# 安装paddle2onnx
pip install paddle2onnx
# 转换成onnx模型
paddle2onnx --model_dir output_inference/ppyoloe_r_crn_l_3x_dota --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 11 --save_file ppyoloe_r_crn_l_3x_dota.onnx
# 预测图片
python configs/rotate/tools/onnx_infer.py --infer_cfg output_inference/ppyoloe_r_crn_l_3x_dota/infer_cfg.yml --onnx_file ppyoloe_r_crn_l_3x_dota.onnx --image_file demo/P0072__1.0__0___0.png
```
## 附录
...
...
configs/rotate/ppyoloe_r/README_en.md
浏览文件 @
41ca7cad
...
...
@@ -114,7 +114,7 @@ python tools/export_model.py -c configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota
python deploy/python/infer.py
--image_file
demo/P0072__1.0__0___0.png
--model_dir
=
output_inference/ppyoloe_r_crn_l_3x_dota
--run_mode
=
paddle
--device
=
gpu
```
**Using Paddle-TRT**
to
for deployment, run following command
**Using Paddle-TRT**
for deployment, run following command
```
bash
# export inference model
...
...
@@ -126,6 +126,22 @@ python deploy/python/infer.py --image_file demo/P0072__1.0__0___0.png --model_di
**Notes:**
-
When using Paddle-TRT for speed testing, make sure that
**the version of TensorRT is larger than 8.2 and the version of PaddlePaddle is the develop version**
**Using ONNX Runtime**
for deployment, run following command
```
bash
# export inference model
python tools/export_model.py
-c
configs/rotate/ppyoloe_r/ppyoloe_r_crn_l_3x_dota.yml
-o
weights
=
https://paddledet.bj.bcebos.com/models/ppyoloe_r_crn_l_3x_dota.pdparams
export_onnx
=
True
# install paddle2onnx
pip
install
paddle2onnx
# convert to onnx model
paddle2onnx
--model_dir
output_inference/ppyoloe_r_crn_l_3x_dota
--model_filename
model.pdmodel
--params_filename
model.pdiparams
--opset_version
11
--save_file
ppyoloe_r_crn_l_3x_dota.onnx
# inference single image
python configs/rotate/tools/onnx_infer.py
--infer_cfg
output_inference/ppyoloe_r_crn_l_3x_dota/infer_cfg.yml
--onnx_file
ppyoloe_r_crn_l_3x_dota.onnx
--image_file
demo/P0072__1.0__0___0.png
```
## Appendix
Ablation experiments of PP-YOLOE-R
...
...
configs/rotate/tools/onnx_infer.py
0 → 100644
浏览文件 @
41ca7cad
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
os
import
sys
import
six
import
glob
import
copy
import
yaml
import
argparse
import
cv2
import
numpy
as
np
from
shapely.geometry
import
Polygon
from
onnxruntime
import
InferenceSession
# preprocess ops
def
decode_image
(
img_path
):
with
open
(
img_path
,
'rb'
)
as
f
:
im_read
=
f
.
read
()
data
=
np
.
frombuffer
(
im_read
,
dtype
=
'uint8'
)
im
=
cv2
.
imdecode
(
data
,
1
)
# BGR mode, but need RGB mode
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
img_info
=
{
"im_shape"
:
np
.
array
(
im
.
shape
[:
2
],
dtype
=
np
.
float32
),
"scale_factor"
:
np
.
array
(
[
1.
,
1.
],
dtype
=
np
.
float32
)
}
return
im
,
img_info
class
Resize
(
object
):
def
__init__
(
self
,
target_size
,
keep_ratio
=
True
,
interp
=
cv2
.
INTER_LINEAR
):
if
isinstance
(
target_size
,
int
):
target_size
=
[
target_size
,
target_size
]
self
.
target_size
=
target_size
self
.
keep_ratio
=
keep_ratio
self
.
interp
=
interp
def
__call__
(
self
,
im
,
im_info
):
assert
len
(
self
.
target_size
)
==
2
assert
self
.
target_size
[
0
]
>
0
and
self
.
target_size
[
1
]
>
0
im_channel
=
im
.
shape
[
2
]
im_scale_y
,
im_scale_x
=
self
.
generate_scale
(
im
)
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
im_info
[
'im_shape'
]
=
np
.
array
(
im
.
shape
[:
2
]).
astype
(
'float32'
)
im_info
[
'scale_factor'
]
=
np
.
array
(
[
im_scale_y
,
im_scale_x
]).
astype
(
'float32'
)
return
im
,
im_info
def
generate_scale
(
self
,
im
):
origin_shape
=
im
.
shape
[:
2
]
im_c
=
im
.
shape
[
2
]
if
self
.
keep_ratio
:
im_size_min
=
np
.
min
(
origin_shape
)
im_size_max
=
np
.
max
(
origin_shape
)
target_size_min
=
np
.
min
(
self
.
target_size
)
target_size_max
=
np
.
max
(
self
.
target_size
)
im_scale
=
float
(
target_size_min
)
/
float
(
im_size_min
)
if
np
.
round
(
im_scale
*
im_size_max
)
>
target_size_max
:
im_scale
=
float
(
target_size_max
)
/
float
(
im_size_max
)
im_scale_x
=
im_scale
im_scale_y
=
im_scale
else
:
resize_h
,
resize_w
=
self
.
target_size
im_scale_y
=
resize_h
/
float
(
origin_shape
[
0
])
im_scale_x
=
resize_w
/
float
(
origin_shape
[
1
])
return
im_scale_y
,
im_scale_x
class
Permute
(
object
):
def
__init__
(
self
,
):
super
(
Permute
,
self
).
__init__
()
def
__call__
(
self
,
im
,
im_info
):
im
=
im
.
transpose
((
2
,
0
,
1
))
return
im
,
im_info
class
NormalizeImage
(
object
):
def
__init__
(
self
,
mean
,
std
,
is_scale
=
True
,
norm_type
=
'mean_std'
):
self
.
mean
=
mean
self
.
std
=
std
self
.
is_scale
=
is_scale
self
.
norm_type
=
norm_type
def
__call__
(
self
,
im
,
im_info
):
im
=
im
.
astype
(
np
.
float32
,
copy
=
False
)
if
self
.
is_scale
:
scale
=
1.0
/
255.0
im
*=
scale
if
self
.
norm_type
==
'mean_std'
:
mean
=
np
.
array
(
self
.
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
self
.
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
im
-=
mean
im
/=
std
return
im
,
im_info
class
PadStride
(
object
):
def
__init__
(
self
,
stride
=
0
):
self
.
coarsest_stride
=
stride
def
__call__
(
self
,
im
,
im_info
):
coarsest_stride
=
self
.
coarsest_stride
if
coarsest_stride
<=
0
:
return
im
,
im_info
im_c
,
im_h
,
im_w
=
im
.
shape
pad_h
=
int
(
np
.
ceil
(
float
(
im_h
)
/
coarsest_stride
)
*
coarsest_stride
)
pad_w
=
int
(
np
.
ceil
(
float
(
im_w
)
/
coarsest_stride
)
*
coarsest_stride
)
padding_im
=
np
.
zeros
((
im_c
,
pad_h
,
pad_w
),
dtype
=
np
.
float32
)
padding_im
[:,
:
im_h
,
:
im_w
]
=
im
return
padding_im
,
im_info
class
Compose
:
def
__init__
(
self
,
transforms
):
self
.
transforms
=
[]
for
op_info
in
transforms
:
new_op_info
=
op_info
.
copy
()
op_type
=
new_op_info
.
pop
(
'type'
)
self
.
transforms
.
append
(
eval
(
op_type
)(
**
new_op_info
))
def
__call__
(
self
,
img_path
):
img
,
im_info
=
decode_image
(
img_path
)
for
t
in
self
.
transforms
:
img
,
im_info
=
t
(
img
,
im_info
)
inputs
=
copy
.
deepcopy
(
im_info
)
inputs
[
'image'
]
=
img
return
inputs
# postprocess
def
rbox_iou
(
g
,
p
):
g
=
np
.
array
(
g
)
p
=
np
.
array
(
p
)
g
=
Polygon
(
g
[:
8
].
reshape
((
4
,
2
)))
p
=
Polygon
(
p
[:
8
].
reshape
((
4
,
2
)))
g
=
g
.
buffer
(
0
)
p
=
p
.
buffer
(
0
)
if
not
g
.
is_valid
or
not
p
.
is_valid
:
return
0
inter
=
Polygon
(
g
).
intersection
(
Polygon
(
p
)).
area
union
=
g
.
area
+
p
.
area
-
inter
if
union
==
0
:
return
0
else
:
return
inter
/
union
def
multiclass_nms_rotated
(
pred_bboxes
,
pred_scores
,
iou_threshlod
=
0.1
,
score_threshold
=
0.1
):
"""
Args:
pred_bboxes (numpy.ndarray): [B, N, 8]
pred_scores (numpy.ndarray): [B, C, N]
Return:
bboxes (numpy.ndarray): [N, 10]
bbox_num (numpy.ndarray): [B]
"""
bbox_num
=
[]
bboxes
=
[]
for
bbox_per_img
,
score_per_img
in
zip
(
pred_bboxes
,
pred_scores
):
num_per_img
=
0
for
cls_id
,
score_per_cls
in
enumerate
(
score_per_img
):
keep_mask
=
score_per_cls
>
score_threshold
bbox
=
bbox_per_img
[
keep_mask
]
score
=
score_per_cls
[
keep_mask
]
idx
=
score
.
argsort
()[::
-
1
]
bbox
=
bbox
[
idx
]
score
=
score
[
idx
]
keep_idx
=
[]
for
i
,
b
in
enumerate
(
bbox
):
supressed
=
False
for
gi
in
keep_idx
:
g
=
bbox
[
gi
]
if
rbox_iou
(
b
,
g
)
>
iou_threshlod
:
supressed
=
True
break
if
supressed
:
continue
keep_idx
.
append
(
i
)
keep_box
=
bbox
[
keep_idx
]
keep_score
=
score
[
keep_idx
]
keep_cls_ids
=
np
.
ones
(
len
(
keep_idx
))
*
cls_id
bboxes
.
append
(
np
.
concatenate
(
[
keep_cls_ids
[:,
None
],
keep_score
[:,
None
],
keep_box
],
axis
=-
1
))
num_per_img
+=
len
(
keep_idx
)
bbox_num
.
append
(
num_per_img
)
return
np
.
concatenate
(
bboxes
,
axis
=
0
),
np
.
array
(
bbox_num
)
def
get_test_images
(
infer_dir
,
infer_img
):
"""
Get image path list in TEST mode
"""
assert
infer_img
is
not
None
or
infer_dir
is
not
None
,
\
"--image_file or --image_dir should be set"
assert
infer_img
is
None
or
os
.
path
.
isfile
(
infer_img
),
\
"{} is not a file"
.
format
(
infer_img
)
assert
infer_dir
is
None
or
os
.
path
.
isdir
(
infer_dir
),
\
"{} is not a directory"
.
format
(
infer_dir
)
# infer_img has a higher priority
if
infer_img
and
os
.
path
.
isfile
(
infer_img
):
return
[
infer_img
]
images
=
set
()
infer_dir
=
os
.
path
.
abspath
(
infer_dir
)
assert
os
.
path
.
isdir
(
infer_dir
),
\
"infer_dir {} is not a directory"
.
format
(
infer_dir
)
exts
=
[
'jpg'
,
'jpeg'
,
'png'
,
'bmp'
]
exts
+=
[
ext
.
upper
()
for
ext
in
exts
]
for
ext
in
exts
:
images
.
update
(
glob
.
glob
(
'{}/*.{}'
.
format
(
infer_dir
,
ext
)))
images
=
list
(
images
)
assert
len
(
images
)
>
0
,
"no image found in {}"
.
format
(
infer_dir
)
print
(
"Found {} inference images in total."
.
format
(
len
(
images
)))
return
images
def
predict_image
(
infer_config
,
predictor
,
img_list
):
# load preprocess transforms
transforms
=
Compose
(
infer_config
[
'Preprocess'
])
# predict image
for
img_path
in
img_list
:
inputs
=
transforms
(
img_path
)
inputs_name
=
[
var
.
name
for
var
in
predictor
.
get_inputs
()]
inputs
=
{
k
:
inputs
[
k
][
None
,
]
for
k
in
inputs_name
}
outputs
=
predictor
.
run
(
output_names
=
None
,
input_feed
=
inputs
)
bboxes
,
bbox_num
=
multiclass_nms_rotated
(
np
.
array
(
outputs
[
0
]),
np
.
array
(
outputs
[
1
]))
print
(
"ONNXRuntime predict: "
)
for
bbox
in
bboxes
:
if
bbox
[
0
]
>
-
1
and
bbox
[
1
]
>
infer_config
[
'draw_threshold'
]:
print
(
f
"
{
int
(
bbox
[
0
])
}
{
bbox
[
1
]
}
"
f
"
{
bbox
[
2
]
}
{
bbox
[
3
]
}
{
bbox
[
4
]
}
{
bbox
[
5
]
}
"
f
"
{
bbox
[
6
]
}
{
bbox
[
7
]
}
{
bbox
[
8
]
}
{
bbox
[
9
]
}
"
)
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"--infer_cfg"
,
type
=
str
,
help
=
"infer_cfg.yml"
)
parser
.
add_argument
(
'--onnx_file'
,
type
=
str
,
default
=
"model.onnx"
,
help
=
"onnx model file path"
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
)
parser
.
add_argument
(
"--image_file"
,
type
=
str
)
return
parser
.
parse_args
()
if
__name__
==
'__main__'
:
FLAGS
=
parse_args
()
# load image list
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
# load predictor
predictor
=
InferenceSession
(
FLAGS
.
onnx_file
)
# load infer config
with
open
(
FLAGS
.
infer_cfg
)
as
f
:
infer_config
=
yaml
.
safe_load
(
f
)
predict_image
(
infer_config
,
predictor
,
img_list
)
ppdet/modeling/heads/ppyoloe_r_head.py
浏览文件 @
41ca7cad
...
...
@@ -44,7 +44,7 @@ class ESEAttn(nn.Layer):
@
register
class
PPYOLOERHead
(
nn
.
Layer
):
__shared__
=
[
'num_classes'
,
'trt'
]
__shared__
=
[
'num_classes'
,
'trt'
,
'export_onnx'
]
__inject__
=
[
'static_assigner'
,
'assigner'
,
'nms'
]
def
__init__
(
self
,
...
...
@@ -57,6 +57,7 @@ class PPYOLOERHead(nn.Layer):
use_varifocal_loss
=
True
,
static_assigner_epoch
=
4
,
trt
=
False
,
export_onnx
=
False
,
static_assigner
=
'ATSSAssigner'
,
assigner
=
'TaskAlignedAssigner'
,
nms
=
'MultiClassNMS'
,
...
...
@@ -84,6 +85,8 @@ class PPYOLOERHead(nn.Layer):
self
.
stem_cls
=
nn
.
LayerList
()
self
.
stem_reg
=
nn
.
LayerList
()
self
.
stem_angle
=
nn
.
LayerList
()
trt
=
False
if
export_onnx
else
trt
self
.
export_onnx
=
export_onnx
act
=
get_act_fn
(
act
,
trt
=
trt
)
if
act
is
None
or
isinstance
(
act
,
(
str
,
dict
))
else
act
...
...
@@ -415,5 +418,7 @@ class PPYOLOERHead(nn.Layer):
],
axis
=-
1
).
reshape
([
-
1
,
1
,
8
])
pred_bboxes
/=
scale_factor
if
self
.
export_onnx
:
return
pred_bboxes
,
pred_scores
bbox_pred
,
bbox_num
,
_
=
self
.
nms
(
pred_bboxes
,
pred_scores
)
return
bbox_pred
,
bbox_num
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录