Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
bf7b674c
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bf7b674c
编写于
6月 02, 2022
作者:
S
shangliang Xu
提交者:
GitHub
6月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[TIPC] add onnx infer (#6119)
上级
1e70fffb
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
595 addition
and
19 deletion
+595
-19
deploy/serving/python/preprocess_ops.py
deploy/serving/python/preprocess_ops.py
+10
-12
deploy/serving/python/web_service.py
deploy/serving/python/web_service.py
+4
-4
deploy/third_engine/demo_onnxruntime/infer_demo.py
deploy/third_engine/demo_onnxruntime/infer_demo.py
+3
-2
deploy/third_engine/onnx/infer.py
deploy/third_engine/onnx/infer.py
+161
-0
deploy/third_engine/onnx/preprocess.py
deploy/third_engine/onnx/preprocess.py
+416
-0
ppdet/modeling/assigners/task_aligned_assigner.py
ppdet/modeling/assigners/task_aligned_assigner.py
+1
-1
未找到文件。
deploy/serving/python/preprocess_ops.py
浏览文件 @
bf7b674c
...
@@ -3,10 +3,14 @@ import cv2
...
@@ -3,10 +3,14 @@ import cv2
import
copy
import
copy
def
decode_image
(
im
,
img_info
):
def
decode_image
(
im
):
im
=
np
.
array
(
im
)
im
=
np
.
array
(
im
)
img_info
[
'im_shape'
]
=
np
.
array
(
im
.
shape
[:
2
],
dtype
=
np
.
float32
)
img_info
=
{
img_info
[
'scale_factor'
]
=
np
.
array
([
1.
,
1.
],
dtype
=
np
.
float32
)
"im_shape"
:
np
.
array
(
im
.
shape
[:
2
],
dtype
=
np
.
float32
),
"scale_factor"
:
np
.
array
(
[
1.
,
1.
],
dtype
=
np
.
float32
)
}
return
im
,
img_info
return
im
,
img_info
...
@@ -399,16 +403,10 @@ class Compose:
...
@@ -399,16 +403,10 @@ class Compose:
op_type
=
new_op_info
.
pop
(
'type'
)
op_type
=
new_op_info
.
pop
(
'type'
)
self
.
transforms
.
append
(
eval
(
op_type
)(
**
new_op_info
))
self
.
transforms
.
append
(
eval
(
op_type
)(
**
new_op_info
))
self
.
im_info
=
{
'scale_factor'
:
np
.
array
(
[
1.
,
1.
],
dtype
=
np
.
float32
),
'im_shape'
:
None
}
def
__call__
(
self
,
img
):
def
__call__
(
self
,
img
):
img
,
self
.
im_info
=
decode_image
(
img
,
self
.
im_info
)
img
,
im_info
=
decode_image
(
img
)
for
t
in
self
.
transforms
:
for
t
in
self
.
transforms
:
img
,
self
.
im_info
=
t
(
img
,
self
.
im_info
)
img
,
im_info
=
t
(
img
,
im_info
)
inputs
=
copy
.
deepcopy
(
self
.
im_info
)
inputs
=
copy
.
deepcopy
(
im_info
)
inputs
[
'image'
]
=
img
inputs
[
'image'
]
=
img
return
inputs
return
inputs
deploy/serving/python/web_service.py
浏览文件 @
bf7b674c
...
@@ -132,7 +132,7 @@ class PredictConfig(object):
...
@@ -132,7 +132,7 @@ class PredictConfig(object):
self
.
arch
=
yml_conf
[
'arch'
]
self
.
arch
=
yml_conf
[
'arch'
]
self
.
preprocess_infos
=
yml_conf
[
'Preprocess'
]
self
.
preprocess_infos
=
yml_conf
[
'Preprocess'
]
self
.
min_subgraph_size
=
yml_conf
[
'min_subgraph_size'
]
self
.
min_subgraph_size
=
yml_conf
[
'min_subgraph_size'
]
self
.
label
s
=
yml_conf
[
'label_list'
]
self
.
label
_list
=
yml_conf
[
'label_list'
]
self
.
use_dynamic_shape
=
yml_conf
[
'use_dynamic_shape'
]
self
.
use_dynamic_shape
=
yml_conf
[
'use_dynamic_shape'
]
self
.
draw_threshold
=
yml_conf
.
get
(
"draw_threshold"
,
0.5
)
self
.
draw_threshold
=
yml_conf
.
get
(
"draw_threshold"
,
0.5
)
self
.
mask
=
yml_conf
.
get
(
"mask"
,
False
)
self
.
mask
=
yml_conf
.
get
(
"mask"
,
False
)
...
@@ -189,8 +189,8 @@ class DetectorOp(Op):
...
@@ -189,8 +189,8 @@ class DetectorOp(Op):
result
=
{}
result
=
{}
for
k
,
num
in
zip
(
input_dict
.
keys
(),
bboxes_num
):
for
k
,
num
in
zip
(
input_dict
.
keys
(),
bboxes_num
):
bbox
=
bboxes
[
idx
:
idx
+
num
]
bbox
=
bboxes
[
idx
:
idx
+
num
]
result
[
k
]
=
self
.
parse_det_result
(
bbox
,
draw_threshold
,
result
[
k
]
=
self
.
parse_det_result
(
GLOBAL_VAR
[
'model_config'
].
labels
)
bbox
,
draw_threshold
,
GLOBAL_VAR
[
'model_config'
].
label_list
)
return
result
,
None
,
""
return
result
,
None
,
""
def
collate_inputs
(
self
,
inputs
):
def
collate_inputs
(
self
,
inputs
):
...
@@ -206,7 +206,7 @@ class DetectorOp(Op):
...
@@ -206,7 +206,7 @@ class DetectorOp(Op):
def
parse_det_result
(
self
,
bbox
,
draw_threshold
,
label_list
):
def
parse_det_result
(
self
,
bbox
,
draw_threshold
,
label_list
):
result
=
[]
result
=
[]
for
line
in
bbox
:
for
line
in
bbox
:
if
line
[
1
]
>
draw_threshold
:
if
line
[
0
]
>
-
1
and
line
[
1
]
>
draw_threshold
:
result
.
append
(
f
"
{
label_list
[
int
(
line
[
0
])]
}
{
line
[
1
]
}
"
result
.
append
(
f
"
{
label_list
[
int
(
line
[
0
])]
}
{
line
[
1
]
}
"
f
"
{
line
[
2
]
}
{
line
[
3
]
}
{
line
[
4
]
}
{
line
[
5
]
}
"
)
f
"
{
line
[
2
]
}
{
line
[
3
]
}
{
line
[
4
]
}
{
line
[
5
]
}
"
)
return
result
return
result
...
...
deploy/third_engine/demo_onnxruntime/infer_demo.py
浏览文件 @
bf7b674c
...
@@ -55,8 +55,9 @@ class PicoDet():
...
@@ -55,8 +55,9 @@ class PicoDet():
origin_shape
=
srcimg
.
shape
[:
2
]
origin_shape
=
srcimg
.
shape
[:
2
]
im_scale_y
=
newh
/
float
(
origin_shape
[
0
])
im_scale_y
=
newh
/
float
(
origin_shape
[
0
])
im_scale_x
=
neww
/
float
(
origin_shape
[
1
])
im_scale_x
=
neww
/
float
(
origin_shape
[
1
])
img_shape
=
np
.
array
([[
float
(
origin_shape
[
0
]),
float
(
origin_shape
[
1
])]
img_shape
=
np
.
array
([
]).
astype
(
'float32'
)
[
float
(
self
.
input_shape
[
0
]),
float
(
self
.
input_shape
[
1
])]
]).
astype
(
'float32'
)
scale_factor
=
np
.
array
([[
im_scale_y
,
im_scale_x
]]).
astype
(
'float32'
)
scale_factor
=
np
.
array
([[
im_scale_y
,
im_scale_x
]]).
astype
(
'float32'
)
if
keep_ratio
and
srcimg
.
shape
[
0
]
!=
srcimg
.
shape
[
1
]:
if
keep_ratio
and
srcimg
.
shape
[
0
]
!=
srcimg
.
shape
[
1
]:
...
...
deploy/third_engine/onnx/infer.py
0 → 100644
浏览文件 @
bf7b674c
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
yaml
import
argparse
import
numpy
as
np
import
glob
from
onnxruntime
import
InferenceSession
from
preprocess
import
Compose
# Global dictionary
SUPPORT_MODELS
=
{
'YOLO'
,
'RCNN'
,
'SSD'
,
'Face'
,
'FCOS'
,
'SOLOv2'
,
'TTFNet'
,
'S2ANet'
,
'JDE'
,
'FairMOT'
,
'DeepSORT'
,
'GFL'
,
'PicoDet'
,
'CenterNet'
,
'TOOD'
,
'RetinaNet'
,
'StrongBaseline'
,
'STGCN'
,
'YOLOX'
,
}
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
parser
.
add_argument
(
"-c"
,
"--config"
,
type
=
str
,
help
=
"infer_cfg.yml"
)
parser
.
add_argument
(
'--onnx_file'
,
type
=
str
,
default
=
"model.onnx"
,
help
=
"onnx model file path"
)
parser
.
add_argument
(
"--image_dir"
,
type
=
str
)
parser
.
add_argument
(
"--image_file"
,
type
=
str
)
def
get_test_images
(
infer_dir
,
infer_img
):
"""
Get image path list in TEST mode
"""
assert
infer_img
is
not
None
or
infer_dir
is
not
None
,
\
"--image_file or --image_dir should be set"
assert
infer_img
is
None
or
os
.
path
.
isfile
(
infer_img
),
\
"{} is not a file"
.
format
(
infer_img
)
assert
infer_dir
is
None
or
os
.
path
.
isdir
(
infer_dir
),
\
"{} is not a directory"
.
format
(
infer_dir
)
# infer_img has a higher priority
if
infer_img
and
os
.
path
.
isfile
(
infer_img
):
return
[
infer_img
]
images
=
set
()
infer_dir
=
os
.
path
.
abspath
(
infer_dir
)
assert
os
.
path
.
isdir
(
infer_dir
),
\
"infer_dir {} is not a directory"
.
format
(
infer_dir
)
exts
=
[
'jpg'
,
'jpeg'
,
'png'
,
'bmp'
]
exts
+=
[
ext
.
upper
()
for
ext
in
exts
]
for
ext
in
exts
:
images
.
update
(
glob
.
glob
(
'{}/*.{}'
.
format
(
infer_dir
,
ext
)))
images
=
list
(
images
)
assert
len
(
images
)
>
0
,
"no image found in {}"
.
format
(
infer_dir
)
print
(
"Found {} inference images in total."
.
format
(
len
(
images
)))
return
images
class
PredictConfig
(
object
):
"""set config of preprocess, postprocess and visualize
Args:
model_dir (str): root path of infer_cfg.yml
"""
def
__init__
(
self
,
infer_config
):
# parsing Yaml config for Preprocess
with
open
(
infer_config
)
as
f
:
yml_conf
=
yaml
.
safe_load
(
f
)
self
.
check_model
(
yml_conf
)
self
.
arch
=
yml_conf
[
'arch'
]
self
.
preprocess_infos
=
yml_conf
[
'Preprocess'
]
self
.
min_subgraph_size
=
yml_conf
[
'min_subgraph_size'
]
self
.
label_list
=
yml_conf
[
'label_list'
]
self
.
use_dynamic_shape
=
yml_conf
[
'use_dynamic_shape'
]
self
.
draw_threshold
=
yml_conf
.
get
(
"draw_threshold"
,
0.5
)
self
.
mask
=
yml_conf
.
get
(
"mask"
,
False
)
self
.
tracker
=
yml_conf
.
get
(
"tracker"
,
None
)
self
.
nms
=
yml_conf
.
get
(
"NMS"
,
None
)
self
.
fpn_stride
=
yml_conf
.
get
(
"fpn_stride"
,
None
)
if
self
.
arch
==
'RCNN'
and
yml_conf
.
get
(
'export_onnx'
,
False
):
print
(
'The RCNN export model is used for ONNX and it only supports batch_size = 1'
)
self
.
print_config
()
def
check_model
(
self
,
yml_conf
):
"""
Raises:
ValueError: loaded model not in supported model type
"""
for
support_model
in
SUPPORT_MODELS
:
if
support_model
in
yml_conf
[
'arch'
]:
return
True
raise
ValueError
(
"Unsupported arch: {}, expect {}"
.
format
(
yml_conf
[
'arch'
],
SUPPORT_MODELS
))
def
print_config
(
self
):
print
(
'----------- Model Configuration -----------'
)
print
(
'%s: %s'
%
(
'Model Arch'
,
self
.
arch
))
print
(
'%s: '
%
(
'Transform Order'
))
for
op_info
in
self
.
preprocess_infos
:
print
(
'--%s: %s'
%
(
'transform op'
,
op_info
[
'type'
]))
print
(
'--------------------------------------------'
)
def
predict_image
(
infer_config
,
predictor
,
img_list
):
# load preprocess transforms
transforms
=
Compose
(
infer_config
.
preprocess_infos
)
# predict image
for
img_path
in
img_list
:
inputs
=
transforms
(
img_path
)
inputs_name
=
[
var
.
name
for
var
in
predictor
.
get_inputs
()]
inputs
=
{
k
:
inputs
[
k
][
None
,
]
for
k
in
inputs_name
}
outputs
=
predictor
.
run
(
output_names
=
None
,
input_feed
=
inputs
)
print
(
"ONNXRuntime predict: "
)
bboxes
=
np
.
array
(
outputs
[
0
])
for
bbox
in
bboxes
:
if
bbox
[
0
]
>
-
1
and
bbox
[
1
]
>
infer_config
.
draw_threshold
:
print
(
f
"
{
infer_config
.
label_list
[
int
(
bbox
[
0
])]
}
{
bbox
[
1
]
}
"
f
"
{
bbox
[
2
]
}
{
bbox
[
3
]
}
{
bbox
[
4
]
}
{
bbox
[
5
]
}
"
)
if
__name__
==
'__main__'
:
FLAGS
=
parser
.
parse_args
()
# load image list
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
# load predictor
predictor
=
InferenceSession
(
FLAGS
.
onnx_file
)
# load infer config
infer_config
=
PredictConfig
(
FLAGS
.
config
)
predict_image
(
infer_config
,
predictor
,
img_list
)
deploy/third_engine/onnx/preprocess.py
0 → 100644
浏览文件 @
bf7b674c
import
numpy
as
np
import
cv2
import
copy
def
decode_image
(
img_path
):
with
open
(
img_path
,
'rb'
)
as
f
:
im_read
=
f
.
read
()
data
=
np
.
frombuffer
(
im_read
,
dtype
=
'uint8'
)
im
=
cv2
.
imdecode
(
data
,
1
)
# BGR mode, but need RGB mode
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
img_info
=
{
"im_shape"
:
np
.
array
(
im
.
shape
[:
2
],
dtype
=
np
.
float32
),
"scale_factor"
:
np
.
array
(
[
1.
,
1.
],
dtype
=
np
.
float32
)
}
return
im
,
img_info
class
Resize
(
object
):
"""resize image by target_size and max_size
Args:
target_size (int): the target size of image
keep_ratio (bool): whether keep_ratio or not, default true
interp (int): method of resize
"""
def
__init__
(
self
,
target_size
,
keep_ratio
=
True
,
interp
=
cv2
.
INTER_LINEAR
):
if
isinstance
(
target_size
,
int
):
target_size
=
[
target_size
,
target_size
]
self
.
target_size
=
target_size
self
.
keep_ratio
=
keep_ratio
self
.
interp
=
interp
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert
len
(
self
.
target_size
)
==
2
assert
self
.
target_size
[
0
]
>
0
and
self
.
target_size
[
1
]
>
0
im_channel
=
im
.
shape
[
2
]
im_scale_y
,
im_scale_x
=
self
.
generate_scale
(
im
)
im
=
cv2
.
resize
(
im
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
self
.
interp
)
im_info
[
'im_shape'
]
=
np
.
array
(
im
.
shape
[:
2
]).
astype
(
'float32'
)
im_info
[
'scale_factor'
]
=
np
.
array
(
[
im_scale_y
,
im_scale_x
]).
astype
(
'float32'
)
return
im
,
im_info
def
generate_scale
(
self
,
im
):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape
=
im
.
shape
[:
2
]
im_c
=
im
.
shape
[
2
]
if
self
.
keep_ratio
:
im_size_min
=
np
.
min
(
origin_shape
)
im_size_max
=
np
.
max
(
origin_shape
)
target_size_min
=
np
.
min
(
self
.
target_size
)
target_size_max
=
np
.
max
(
self
.
target_size
)
im_scale
=
float
(
target_size_min
)
/
float
(
im_size_min
)
if
np
.
round
(
im_scale
*
im_size_max
)
>
target_size_max
:
im_scale
=
float
(
target_size_max
)
/
float
(
im_size_max
)
im_scale_x
=
im_scale
im_scale_y
=
im_scale
else
:
resize_h
,
resize_w
=
self
.
target_size
im_scale_y
=
resize_h
/
float
(
origin_shape
[
0
])
im_scale_x
=
resize_w
/
float
(
origin_shape
[
1
])
return
im_scale_y
,
im_scale_x
class
NormalizeImage
(
object
):
"""normalize image
Args:
mean (list): im - mean
std (list): im / std
is_scale (bool): whether need im / 255
is_channel_first (bool): if True: image shape is CHW, else: HWC
"""
def
__init__
(
self
,
mean
,
std
,
is_scale
=
True
):
self
.
mean
=
mean
self
.
std
=
std
self
.
is_scale
=
is_scale
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im
=
im
.
astype
(
np
.
float32
,
copy
=
False
)
mean
=
np
.
array
(
self
.
mean
)[
np
.
newaxis
,
np
.
newaxis
,
:]
std
=
np
.
array
(
self
.
std
)[
np
.
newaxis
,
np
.
newaxis
,
:]
if
self
.
is_scale
:
im
=
im
/
255.0
im
-=
mean
im
/=
std
return
im
,
im_info
class
Permute
(
object
):
"""permute image
Args:
to_bgr (bool): whether convert RGB to BGR
channel_first (bool): whether convert HWC to CHW
"""
def
__init__
(
self
,
):
super
(
Permute
,
self
).
__init__
()
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
im
=
im
.
transpose
((
2
,
0
,
1
)).
copy
()
return
im
,
im_info
class
PadStride
(
object
):
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
Args:
stride (bool): model with FPN need image shape % stride == 0
"""
def
__init__
(
self
,
stride
=
0
):
self
.
coarsest_stride
=
stride
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
coarsest_stride
=
self
.
coarsest_stride
if
coarsest_stride
<=
0
:
return
im
,
im_info
im_c
,
im_h
,
im_w
=
im
.
shape
pad_h
=
int
(
np
.
ceil
(
float
(
im_h
)
/
coarsest_stride
)
*
coarsest_stride
)
pad_w
=
int
(
np
.
ceil
(
float
(
im_w
)
/
coarsest_stride
)
*
coarsest_stride
)
padding_im
=
np
.
zeros
((
im_c
,
pad_h
,
pad_w
),
dtype
=
np
.
float32
)
padding_im
[:,
:
im_h
,
:
im_w
]
=
im
return
padding_im
,
im_info
class
LetterBoxResize
(
object
):
def
__init__
(
self
,
target_size
):
"""
Resize image to target size, convert normalized xywh to pixel xyxy
format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
Args:
target_size (int|list): image target size.
"""
super
(
LetterBoxResize
,
self
).
__init__
()
if
isinstance
(
target_size
,
int
):
target_size
=
[
target_size
,
target_size
]
self
.
target_size
=
target_size
def
letterbox
(
self
,
img
,
height
,
width
,
color
=
(
127.5
,
127.5
,
127.5
)):
# letterbox: resize a rectangular image to a padded rectangular
shape
=
img
.
shape
[:
2
]
# [height, width]
ratio_h
=
float
(
height
)
/
shape
[
0
]
ratio_w
=
float
(
width
)
/
shape
[
1
]
ratio
=
min
(
ratio_h
,
ratio_w
)
new_shape
=
(
round
(
shape
[
1
]
*
ratio
),
round
(
shape
[
0
]
*
ratio
))
# [width, height]
padw
=
(
width
-
new_shape
[
0
])
/
2
padh
=
(
height
-
new_shape
[
1
])
/
2
top
,
bottom
=
round
(
padh
-
0.1
),
round
(
padh
+
0.1
)
left
,
right
=
round
(
padw
-
0.1
),
round
(
padw
+
0.1
)
img
=
cv2
.
resize
(
img
,
new_shape
,
interpolation
=
cv2
.
INTER_AREA
)
# resized, no border
img
=
cv2
.
copyMakeBorder
(
img
,
top
,
bottom
,
left
,
right
,
cv2
.
BORDER_CONSTANT
,
value
=
color
)
# padded rectangular
return
img
,
ratio
,
padw
,
padh
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
assert
len
(
self
.
target_size
)
==
2
assert
self
.
target_size
[
0
]
>
0
and
self
.
target_size
[
1
]
>
0
height
,
width
=
self
.
target_size
h
,
w
=
im
.
shape
[:
2
]
im
,
ratio
,
padw
,
padh
=
self
.
letterbox
(
im
,
height
=
height
,
width
=
width
)
new_shape
=
[
round
(
h
*
ratio
),
round
(
w
*
ratio
)]
im_info
[
'im_shape'
]
=
np
.
array
(
new_shape
,
dtype
=
np
.
float32
)
im_info
[
'scale_factor'
]
=
np
.
array
([
ratio
,
ratio
],
dtype
=
np
.
float32
)
return
im
,
im_info
class
Pad
(
object
):
def
__init__
(
self
,
size
,
fill_value
=
[
114.0
,
114.0
,
114.0
]):
"""
Pad image to a specified size.
Args:
size (list[int]): image target size
fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
"""
super
(
Pad
,
self
).
__init__
()
if
isinstance
(
size
,
int
):
size
=
[
size
,
size
]
self
.
size
=
size
self
.
fill_value
=
fill_value
def
__call__
(
self
,
im
,
im_info
):
im_h
,
im_w
=
im
.
shape
[:
2
]
h
,
w
=
self
.
size
if
h
==
im_h
and
w
==
im_w
:
im
=
im
.
astype
(
np
.
float32
)
return
im
,
im_info
canvas
=
np
.
ones
((
h
,
w
,
3
),
dtype
=
np
.
float32
)
canvas
*=
np
.
array
(
self
.
fill_value
,
dtype
=
np
.
float32
)
canvas
[
0
:
im_h
,
0
:
im_w
,
:]
=
im
.
astype
(
np
.
float32
)
im
=
canvas
return
im
,
im_info
def
rotate_point
(
pt
,
angle_rad
):
"""Rotate a point by an angle.
Args:
pt (list[float]): 2 dimensional point to be rotated
angle_rad (float): rotation angle by radian
Returns:
list[float]: Rotated point.
"""
assert
len
(
pt
)
==
2
sn
,
cs
=
np
.
sin
(
angle_rad
),
np
.
cos
(
angle_rad
)
new_x
=
pt
[
0
]
*
cs
-
pt
[
1
]
*
sn
new_y
=
pt
[
0
]
*
sn
+
pt
[
1
]
*
cs
rotated_pt
=
[
new_x
,
new_y
]
return
rotated_pt
def
_get_3rd_point
(
a
,
b
):
"""To calculate the affine matrix, three pairs of points are required. This
function is used to get the 3rd point, given 2D points a & b.
The 3rd point is defined by rotating vector `a - b` by 90 degrees
anticlockwise, using b as the rotation center.
Args:
a (np.ndarray): point(x,y)
b (np.ndarray): point(x,y)
Returns:
np.ndarray: The 3rd point.
"""
assert
len
(
a
)
==
2
assert
len
(
b
)
==
2
direction
=
a
-
b
third_pt
=
b
+
np
.
array
([
-
direction
[
1
],
direction
[
0
]],
dtype
=
np
.
float32
)
return
third_pt
def
get_affine_transform
(
center
,
input_size
,
rot
,
output_size
,
shift
=
(
0.
,
0.
),
inv
=
False
):
"""Get the affine transform matrix, given the center/scale/rot/output_size.
Args:
center (np.ndarray[2, ]): Center of the bounding box (x, y).
scale (np.ndarray[2, ]): Scale of the bounding box
wrt [width, height].
rot (float): Rotation angle (degree).
output_size (np.ndarray[2, ]): Size of the destination heatmaps.
shift (0-100%): Shift translation ratio wrt the width/height.
Default (0., 0.).
inv (bool): Option to inverse the affine transform direction.
(inv=False: src->dst or inv=True: dst->src)
Returns:
np.ndarray: The transform matrix.
"""
assert
len
(
center
)
==
2
assert
len
(
output_size
)
==
2
assert
len
(
shift
)
==
2
if
not
isinstance
(
input_size
,
(
np
.
ndarray
,
list
)):
input_size
=
np
.
array
([
input_size
,
input_size
],
dtype
=
np
.
float32
)
scale_tmp
=
input_size
shift
=
np
.
array
(
shift
)
src_w
=
scale_tmp
[
0
]
dst_w
=
output_size
[
0
]
dst_h
=
output_size
[
1
]
rot_rad
=
np
.
pi
*
rot
/
180
src_dir
=
rotate_point
([
0.
,
src_w
*
-
0.5
],
rot_rad
)
dst_dir
=
np
.
array
([
0.
,
dst_w
*
-
0.5
])
src
=
np
.
zeros
((
3
,
2
),
dtype
=
np
.
float32
)
src
[
0
,
:]
=
center
+
scale_tmp
*
shift
src
[
1
,
:]
=
center
+
src_dir
+
scale_tmp
*
shift
src
[
2
,
:]
=
_get_3rd_point
(
src
[
0
,
:],
src
[
1
,
:])
dst
=
np
.
zeros
((
3
,
2
),
dtype
=
np
.
float32
)
dst
[
0
,
:]
=
[
dst_w
*
0.5
,
dst_h
*
0.5
]
dst
[
1
,
:]
=
np
.
array
([
dst_w
*
0.5
,
dst_h
*
0.5
])
+
dst_dir
dst
[
2
,
:]
=
_get_3rd_point
(
dst
[
0
,
:],
dst
[
1
,
:])
if
inv
:
trans
=
cv2
.
getAffineTransform
(
np
.
float32
(
dst
),
np
.
float32
(
src
))
else
:
trans
=
cv2
.
getAffineTransform
(
np
.
float32
(
src
),
np
.
float32
(
dst
))
return
trans
class
WarpAffine
(
object
):
"""Warp affine the image
"""
def
__init__
(
self
,
keep_res
=
False
,
pad
=
31
,
input_h
=
512
,
input_w
=
512
,
scale
=
0.4
,
shift
=
0.1
):
self
.
keep_res
=
keep_res
self
.
pad
=
pad
self
.
input_h
=
input_h
self
.
input_w
=
input_w
self
.
scale
=
scale
self
.
shift
=
shift
def
__call__
(
self
,
im
,
im_info
):
"""
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
Returns:
im (np.ndarray): processed image (np.ndarray)
im_info (dict): info of processed image
"""
img
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_RGB2BGR
)
h
,
w
=
img
.
shape
[:
2
]
if
self
.
keep_res
:
input_h
=
(
h
|
self
.
pad
)
+
1
input_w
=
(
w
|
self
.
pad
)
+
1
s
=
np
.
array
([
input_w
,
input_h
],
dtype
=
np
.
float32
)
c
=
np
.
array
([
w
//
2
,
h
//
2
],
dtype
=
np
.
float32
)
else
:
s
=
max
(
h
,
w
)
*
1.0
input_h
,
input_w
=
self
.
input_h
,
self
.
input_w
c
=
np
.
array
([
w
/
2.
,
h
/
2.
],
dtype
=
np
.
float32
)
trans_input
=
get_affine_transform
(
c
,
s
,
0
,
[
input_w
,
input_h
])
img
=
cv2
.
resize
(
img
,
(
w
,
h
))
inp
=
cv2
.
warpAffine
(
img
,
trans_input
,
(
input_w
,
input_h
),
flags
=
cv2
.
INTER_LINEAR
)
return
inp
,
im_info
class
Compose
:
def
__init__
(
self
,
transforms
):
self
.
transforms
=
[]
for
op_info
in
transforms
:
new_op_info
=
op_info
.
copy
()
op_type
=
new_op_info
.
pop
(
'type'
)
self
.
transforms
.
append
(
eval
(
op_type
)(
**
new_op_info
))
def
__call__
(
self
,
img_path
):
img
,
im_info
=
decode_image
(
img_path
)
for
t
in
self
.
transforms
:
img
,
im_info
=
t
(
img
,
im_info
)
inputs
=
copy
.
deepcopy
(
im_info
)
inputs
[
'image'
]
=
img
return
inputs
ppdet/modeling/assigners/task_aligned_assigner.py
浏览文件 @
bf7b674c
...
@@ -93,7 +93,7 @@ class TaskAlignedAssigner(nn.Layer):
...
@@ -93,7 +93,7 @@ class TaskAlignedAssigner(nn.Layer):
return
assigned_labels
,
assigned_bboxes
,
assigned_scores
return
assigned_labels
,
assigned_bboxes
,
assigned_scores
# compute iou between gt and pred bbox, [B, n, L]
# compute iou between gt and pred bbox, [B, n, L]
ious
=
iou_similarity
(
gt_bboxes
,
pred_bboxes
)
ious
=
batch_
iou_similarity
(
gt_bboxes
,
pred_bboxes
)
# gather pred bboxes class score
# gather pred bboxes class score
pred_scores
=
pred_scores
.
transpose
([
0
,
2
,
1
])
pred_scores
=
pred_scores
.
transpose
([
0
,
2
,
1
])
batch_ind
=
paddle
.
arange
(
batch_ind
=
paddle
.
arange
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录