Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
sxaah
PaddleDetection
提交
5e19955b
P
PaddleDetection
项目概览
sxaah
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
5e19955b
编写于
5月 20, 2021
作者:
C
cnn
提交者:
GitHub
5月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[dev] inference support bs > 1 (#3003)
* bs>1 for YOLO
上级
fd494657
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
346 addition
and
190 deletion
+346
-190
deploy/README.md
deploy/README.md
+2
-0
deploy/cpp/include/object_detector.h
deploy/cpp/include/object_detector.h
+6
-3
deploy/cpp/src/main.cc
deploy/cpp/src/main.cc
+94
-49
deploy/cpp/src/object_detector.cc
deploy/cpp/src/object_detector.cc
+109
-75
deploy/cpp/src/preprocess_op.cc
deploy/cpp/src/preprocess_op.cc
+0
-1
deploy/python/infer.py
deploy/python/infer.py
+93
-32
deploy/python/utils.py
deploy/python/utils.py
+2
-0
ppdet/engine/trainer.py
ppdet/engine/trainer.py
+1
-1
ppdet/modeling/architectures/s2anet.py
ppdet/modeling/architectures/s2anet.py
+2
-0
ppdet/modeling/layers.py
ppdet/modeling/layers.py
+5
-2
ppdet/modeling/post_process.py
ppdet/modeling/post_process.py
+1
-2
ppdet/modeling/proposal_generator/rpn_head.py
ppdet/modeling/proposal_generator/rpn_head.py
+31
-25
未找到文件。
deploy/README.md
浏览文件 @
5e19955b
...
...
@@ -28,6 +28,8 @@ python tools/export_model.py -c configs/yolov3/yolov3_mobilenet_v1_roadsign.yml
*
C++部署 支持
`CPU`
、
`GPU`
和
`XPU`
环境,支持,windows、linux系统,支持NV Jetson嵌入式设备上部署。参考文档
[
C++部署
](
cpp/README.md
)
*
PaddleDetection支持TensorRT加速,相关文档请参考
[
TensorRT预测部署教程
](
TENSOR_RT.md
)
**注意:**
Paddle预测库版本需要>=2.1,batch_size>1仅支持YOLOv3和PP-YOLO。
## 2.PaddleServing部署
### 2.1 导出模型
...
...
deploy/cpp/include/object_detector.h
浏览文件 @
5e19955b
...
...
@@ -50,7 +50,7 @@ std::vector<int> GenerateColorMap(int num_class);
// Visualiztion Detection Result
cv
::
Mat
VisualizeResult
(
const
cv
::
Mat
&
img
,
const
std
::
vector
<
ObjectResult
>&
results
,
const
std
::
vector
<
std
::
string
>&
lable
_list
,
const
std
::
vector
<
std
::
string
>&
lable
s
,
const
std
::
vector
<
int
>&
colormap
,
const
bool
is_rbox
);
...
...
@@ -93,11 +93,12 @@ class ObjectDetector {
const
std
::
string
&
run_mode
=
"fluid"
);
// Run predictor
void
Predict
(
const
cv
::
Mat
&
im
,
void
Predict
(
const
std
::
vector
<
cv
::
Mat
>
imgs
,
const
double
threshold
=
0.5
,
const
int
warmup
=
0
,
const
int
repeats
=
1
,
std
::
vector
<
ObjectResult
>*
result
=
nullptr
,
std
::
vector
<
int
>*
bbox_num
=
nullptr
,
std
::
vector
<
double
>*
times
=
nullptr
);
// Get Model Label list
...
...
@@ -120,14 +121,16 @@ class ObjectDetector {
void
Preprocess
(
const
cv
::
Mat
&
image_mat
);
// Postprocess result
void
Postprocess
(
const
cv
::
Mat
&
raw_mat
,
const
std
::
vector
<
cv
::
Mat
>
mats
,
std
::
vector
<
ObjectResult
>*
result
,
std
::
vector
<
int
>
bbox_num
,
bool
is_rbox
);
std
::
shared_ptr
<
Predictor
>
predictor_
;
Preprocessor
preprocessor_
;
ImageBlob
inputs_
;
std
::
vector
<
float
>
output_data_
;
std
::
vector
<
int
>
out_bbox_num_data_
;
float
threshold_
;
ConfigPaser
config_
;
std
::
vector
<
int
>
image_shape_
;
...
...
deploy/cpp/src/main.cc
浏览文件 @
5e19955b
...
...
@@ -21,6 +21,7 @@
#include <numeric>
#include <sys/types.h>
#include <sys/stat.h>
#include <math.h>
#ifdef _WIN32
#include <direct.h>
...
...
@@ -37,6 +38,7 @@
DEFINE_string
(
model_dir
,
""
,
"Path of inference model"
);
DEFINE_string
(
image_file
,
""
,
"Path of input image"
);
DEFINE_string
(
image_dir
,
""
,
"Dir of input image, `image_file` has a higher priority."
);
DEFINE_int32
(
batch_size
,
1
,
"batch_size"
);
DEFINE_string
(
video_file
,
""
,
"Path of input video, `video_file` or `camera_id` has a highest priority."
);
DEFINE_int32
(
camera_id
,
-
1
,
"Device id of camera to predict"
);
DEFINE_bool
(
use_gpu
,
false
,
"Infering with GPU or CPU"
);
...
...
@@ -189,6 +191,7 @@ void PredictVideo(const std::string& video_path,
}
std
::
vector
<
PaddleDetection
::
ObjectResult
>
result
;
std
::
vector
<
int
>
bbox_num
;
std
::
vector
<
double
>
det_times
;
auto
labels
=
det
->
GetLabelList
();
auto
colormap
=
PaddleDetection
::
GenerateColorMap
(
labels
.
size
());
...
...
@@ -200,8 +203,9 @@ void PredictVideo(const std::string& video_path,
if
(
frame
.
empty
())
{
break
;
}
det
->
Predict
(
frame
,
0.5
,
0
,
1
,
&
result
,
&
det_times
);
std
::
vector
<
cv
::
Mat
>
imgs
;
imgs
.
push_back
(
frame
);
det
->
Predict
(
imgs
,
0.5
,
0
,
1
,
&
result
,
&
bbox_num
,
&
det_times
);
for
(
const
auto
&
item
:
result
)
{
if
(
item
.
rect
.
size
()
>
6
){
is_rbox
=
true
;
...
...
@@ -238,70 +242,107 @@ void PredictVideo(const std::string& video_path,
video_out
.
release
();
}
void
PredictImage
(
const
std
::
vector
<
std
::
string
>
all_img_list
,
void
PredictImage
(
const
std
::
vector
<
std
::
string
>
all_img_paths
,
const
int
batch_size
,
const
double
threshold
,
const
bool
run_benchmark
,
PaddleDetection
::
ObjectDetector
*
det
,
const
std
::
string
&
output_dir
=
"output"
)
{
std
::
vector
<
double
>
det_t
=
{
0
,
0
,
0
};
for
(
auto
image_file
:
all_img_list
)
{
// Open input image as an opencv cv::Mat object
cv
::
Mat
im
=
cv
::
imread
(
image_file
,
1
);
int
steps
=
ceil
(
float
(
all_img_paths
.
size
())
/
batch_size
);
printf
(
"total images = %d, batch_size = %d, total steps = %d
\n
"
,
all_img_paths
.
size
(),
batch_size
,
steps
);
for
(
int
idx
=
0
;
idx
<
steps
;
idx
++
)
{
std
::
vector
<
cv
::
Mat
>
batch_imgs
;
int
left_image_cnt
=
all_img_paths
.
size
()
-
idx
*
batch_size
;
if
(
left_image_cnt
>
batch_size
)
{
left_image_cnt
=
batch_size
;
}
for
(
int
bs
=
0
;
bs
<
left_image_cnt
;
bs
++
)
{
std
::
string
image_file_path
=
all_img_paths
.
at
(
idx
*
batch_size
+
bs
);
cv
::
Mat
im
=
cv
::
imread
(
image_file_path
,
1
);
batch_imgs
.
insert
(
batch_imgs
.
end
(),
im
);
}
// Store all detected result
std
::
vector
<
PaddleDetection
::
ObjectResult
>
result
;
std
::
vector
<
int
>
bbox_num
;
std
::
vector
<
double
>
det_times
;
bool
is_rbox
=
false
;
if
(
run_benchmark
)
{
det
->
Predict
(
im
,
threshold
,
10
,
10
,
&
result
,
&
det_times
);
det
->
Predict
(
batch_imgs
,
threshold
,
10
,
10
,
&
result
,
&
bbox_num
,
&
det_times
);
}
else
{
det
->
Predict
(
im
,
0.5
,
0
,
1
,
&
result
,
&
det_times
);
for
(
const
auto
&
item
:
result
)
{
if
(
item
.
rect
.
size
()
>
6
){
is_rbox
=
true
;
printf
(
"class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]
\n
"
,
item
.
class_id
,
item
.
confidence
,
item
.
rect
[
0
],
item
.
rect
[
1
],
item
.
rect
[
2
],
item
.
rect
[
3
],
item
.
rect
[
4
],
item
.
rect
[
5
],
item
.
rect
[
6
],
item
.
rect
[
7
]);
det
->
Predict
(
batch_imgs
,
0.5
,
0
,
1
,
&
result
,
&
bbox_num
,
&
det_times
);
// get labels and colormap
auto
labels
=
det
->
GetLabelList
();
auto
colormap
=
PaddleDetection
::
GenerateColorMap
(
labels
.
size
());
int
item_start_idx
=
0
;
for
(
int
i
=
0
;
i
<
left_image_cnt
;
i
++
)
{
std
::
cout
<<
all_img_paths
.
at
(
idx
*
batch_size
+
i
)
<<
"result"
<<
std
::
endl
;
if
(
bbox_num
[
i
]
<=
1
)
{
continue
;
}
else
{
printf
(
"class=%d confidence=%.4f rect=[%d %d %d %d]
\n
"
,
item
.
class_id
,
item
.
confidence
,
item
.
rect
[
0
],
item
.
rect
[
1
],
item
.
rect
[
2
],
item
.
rect
[
3
]);
for
(
int
j
=
0
;
j
<
bbox_num
[
i
];
j
++
)
{
PaddleDetection
::
ObjectResult
item
=
result
[
item_start_idx
+
j
];
if
(
item
.
rect
.
size
()
>
6
){
is_rbox
=
true
;
printf
(
"class=%d confidence=%.4f rect=[%d %d %d %d %d %d %d %d]
\n
"
,
item
.
class_id
,
item
.
confidence
,
item
.
rect
[
0
],
item
.
rect
[
1
],
item
.
rect
[
2
],
item
.
rect
[
3
],
item
.
rect
[
4
],
item
.
rect
[
5
],
item
.
rect
[
6
],
item
.
rect
[
7
]);
}
else
{
printf
(
"class=%d confidence=%.4f rect=[%d %d %d %d]
\n
"
,
item
.
class_id
,
item
.
confidence
,
item
.
rect
[
0
],
item
.
rect
[
1
],
item
.
rect
[
2
],
item
.
rect
[
3
]);
}
}
item_start_idx
=
item_start_idx
+
bbox_num
[
i
];
}
// Visualization result
auto
labels
=
det
->
GetLabelList
();
auto
colormap
=
PaddleDetection
::
GenerateColorMap
(
labels
.
size
());
cv
::
Mat
vis_img
=
PaddleDetection
::
VisualizeResult
(
im
,
result
,
labels
,
colormap
,
is_rbox
);
std
::
vector
<
int
>
compression_params
;
compression_params
.
push_back
(
CV_IMWRITE_JPEG_QUALITY
);
compression_params
.
push_back
(
95
);
std
::
string
output_path
(
output_dir
);
if
(
output_dir
.
rfind
(
OS_PATH_SEP
)
!=
output_dir
.
size
()
-
1
)
{
output_path
+=
OS_PATH_SEP
;
int
bbox_idx
=
0
;
for
(
int
bs
=
0
;
bs
<
batch_imgs
.
size
();
bs
++
)
{
if
(
bbox_num
[
bs
]
<=
1
)
{
continue
;
}
cv
::
Mat
im
=
batch_imgs
[
bs
];
std
::
vector
<
PaddleDetection
::
ObjectResult
>
im_result
;
for
(
int
k
=
0
;
k
<
bbox_num
[
bs
];
k
++
)
{
im_result
.
push_back
(
result
[
bbox_idx
+
k
]);
}
bbox_idx
+=
bbox_num
[
bs
];
cv
::
Mat
vis_img
=
PaddleDetection
::
VisualizeResult
(
im
,
im_result
,
labels
,
colormap
,
is_rbox
);
std
::
vector
<
int
>
compression_params
;
compression_params
.
push_back
(
CV_IMWRITE_JPEG_QUALITY
);
compression_params
.
push_back
(
95
);
std
::
string
output_path
(
output_dir
);
if
(
output_dir
.
rfind
(
OS_PATH_SEP
)
!=
output_dir
.
size
()
-
1
)
{
output_path
+=
OS_PATH_SEP
;
}
std
::
string
image_file_path
=
all_img_paths
.
at
(
idx
*
batch_size
+
bs
);
output_path
+=
image_file_path
.
substr
(
image_file_path
.
find_last_of
(
'/'
)
+
1
);
cv
::
imwrite
(
output_path
,
vis_img
,
compression_params
);
printf
(
"Visualized output saved as %s
\n
"
,
output_path
.
c_str
());
}
;
output_path
+=
image_file
.
substr
(
image_file
.
find_last_of
(
'/'
)
+
1
);
cv
::
imwrite
(
output_path
,
vis_img
,
compression_params
);
printf
(
"Visualized output saved as %s
\n
"
,
output_path
.
c_str
());
}
det_t
[
0
]
+=
det_times
[
0
];
det_t
[
1
]
+=
det_times
[
1
];
det_t
[
2
]
+=
det_times
[
2
];
}
PrintBenchmarkLog
(
det_t
,
all_img_
list
.
size
());
PrintBenchmarkLog
(
det_t
,
all_img_
paths
.
size
());
}
int
main
(
int
argc
,
char
**
argv
)
{
...
...
@@ -329,13 +370,17 @@ int main(int argc, char** argv) {
if
(
!
PathExists
(
FLAGS_output_dir
))
{
MkDirs
(
FLAGS_output_dir
);
}
std
::
vector
<
std
::
string
>
all_img
_list
;
std
::
vector
<
std
::
string
>
all_img
s
;
if
(
!
FLAGS_image_file
.
empty
())
{
all_img_list
.
push_back
(
FLAGS_image_file
);
all_imgs
.
push_back
(
FLAGS_image_file
);
if
(
FLAGS_batch_size
>
1
)
{
std
::
cout
<<
"batch_size should be 1, when image_file is not None"
<<
std
::
endl
;
FLAGS_batch_size
=
1
;
}
}
else
{
GetAllFiles
((
char
*
)
FLAGS_image_dir
.
c_str
(),
all_img
_list
);
GetAllFiles
((
char
*
)
FLAGS_image_dir
.
c_str
(),
all_img
s
);
}
PredictImage
(
all_img
_list
,
FLAGS_threshold
,
FLAGS_run_benchmark
,
&
det
,
FLAGS_output_dir
);
PredictImage
(
all_img
s
,
FLAGS_batch_size
,
FLAGS_threshold
,
FLAGS_run_benchmark
,
&
det
,
FLAGS_output_dir
);
}
return
0
;
}
deploy/cpp/src/object_detector.cc
浏览文件 @
5e19955b
...
...
@@ -93,7 +93,7 @@ void ObjectDetector::LoadModel(const std::string& model_dir,
// Visualiztion MaskDetector results
cv
::
Mat
VisualizeResult
(
const
cv
::
Mat
&
img
,
const
std
::
vector
<
ObjectResult
>&
results
,
const
std
::
vector
<
std
::
string
>&
lable
_list
,
const
std
::
vector
<
std
::
string
>&
lable
s
,
const
std
::
vector
<
int
>&
colormap
,
const
bool
is_rbox
=
false
)
{
cv
::
Mat
vis_img
=
img
.
clone
();
...
...
@@ -101,7 +101,7 @@ cv::Mat VisualizeResult(const cv::Mat& img,
// Configure color and text size
std
::
ostringstream
oss
;
oss
<<
std
::
setiosflags
(
std
::
ios
::
fixed
)
<<
std
::
setprecision
(
4
);
oss
<<
lable
_list
[
results
[
i
].
class_id
]
<<
" "
;
oss
<<
lable
s
[
results
[
i
].
class_id
]
<<
" "
;
oss
<<
results
[
i
].
confidence
;
std
::
string
text
=
oss
.
str
();
int
c1
=
colormap
[
3
*
results
[
i
].
class_id
+
0
];
...
...
@@ -121,20 +121,20 @@ cv::Mat VisualizeResult(const cv::Mat& img,
if
(
is_rbox
)
{
// Draw object, text, and background
for
(
int
k
=
0
;
k
<
4
;
k
++
)
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
cv
::
Point
pt1
=
cv
::
Point
(
results
[
i
].
rect
[(
k
*
2
)
%
8
],
results
[
i
].
rect
[(
k
*
2
+
1
)
%
8
]);
cv
::
Point
pt2
=
cv
::
Point
(
results
[
i
].
rect
[(
k
*
2
+
2
)
%
8
],
results
[
i
].
rect
[(
k
*
2
+
3
)
%
8
]);
cv
::
Point
pt1
=
cv
::
Point
(
results
[
i
].
rect
[(
k
*
2
)
%
8
],
results
[
i
].
rect
[(
k
*
2
+
1
)
%
8
]);
cv
::
Point
pt2
=
cv
::
Point
(
results
[
i
].
rect
[(
k
*
2
+
2
)
%
8
],
results
[
i
].
rect
[(
k
*
2
+
3
)
%
8
]);
cv
::
line
(
vis_img
,
pt1
,
pt2
,
roi_color
,
2
);
}
}
else
{
int
w
=
results
[
i
].
rect
[
1
]
-
results
[
i
].
rect
[
0
];
int
h
=
results
[
i
].
rect
[
3
]
-
results
[
i
].
rect
[
2
];
cv
::
Rect
roi
=
cv
::
Rect
(
results
[
i
].
rect
[
0
],
results
[
i
].
rect
[
2
],
w
,
h
);
int
w
=
results
[
i
].
rect
[
2
]
-
results
[
i
].
rect
[
0
];
int
h
=
results
[
i
].
rect
[
3
]
-
results
[
i
].
rect
[
1
];
cv
::
Rect
roi
=
cv
::
Rect
(
results
[
i
].
rect
[
0
],
results
[
i
].
rect
[
1
],
w
,
h
);
// Draw roi object, text, and background
cv
::
rectangle
(
vis_img
,
roi
,
roi_color
,
2
);
}
...
...
@@ -144,7 +144,7 @@ cv::Mat VisualizeResult(const cv::Mat& img,
// Configure text background
cv
::
Rect
text_back
=
cv
::
Rect
(
results
[
i
].
rect
[
0
],
results
[
i
].
rect
[
2
]
-
text_size
.
height
,
results
[
i
].
rect
[
1
]
-
text_size
.
height
,
text_size
.
width
,
text_size
.
height
);
// Draw text, and background
...
...
@@ -168,76 +168,100 @@ void ObjectDetector::Preprocess(const cv::Mat& ori_im) {
}
void
ObjectDetector
::
Postprocess
(
const
cv
::
Mat
&
raw_mat
,
const
std
::
vector
<
cv
::
Mat
>
mats
,
std
::
vector
<
ObjectResult
>*
result
,
std
::
vector
<
int
>
bbox_num
,
bool
is_rbox
=
false
)
{
result
->
clear
();
int
rh
=
1
;
int
rw
=
1
;
if
(
config_
.
arch_
==
"Face"
)
{
rh
=
raw_mat
.
rows
;
rw
=
raw_mat
.
cols
;
}
int
start_idx
=
0
;
for
(
int
im_id
=
0
;
im_id
<
bbox_num
.
size
();
im_id
++
)
{
cv
::
Mat
raw_mat
=
mats
[
im_id
];
for
(
int
j
=
start_idx
;
j
<
start_idx
+
bbox_num
[
im_id
];
j
++
)
{
int
rh
=
1
;
int
rw
=
1
;
if
(
config_
.
arch_
==
"Face"
)
{
rh
=
raw_mat
.
rows
;
rw
=
raw_mat
.
cols
;
}
if
(
is_rbox
)
{
int
total_size
=
output_data_
.
size
()
/
10
;
for
(
int
j
=
0
;
j
<
total_size
;
++
j
)
{
// Class id
int
class_id
=
static_cast
<
int
>
(
round
(
output_data_
[
0
+
j
*
10
]))
;
// Confidence score
float
score
=
output_data_
[
1
+
j
*
10
]
;
int
x1
=
(
output_data_
[
2
+
j
*
10
]
*
rw
);
int
y1
=
(
output_data_
[
3
+
j
*
10
]
*
rh
);
int
x2
=
(
output_data_
[
4
+
j
*
10
]
*
rw
);
int
y2
=
(
output_data_
[
5
+
j
*
10
]
*
rh
);
int
x3
=
(
output_data_
[
6
+
j
*
10
]
*
rw
);
int
y3
=
(
output_data_
[
7
+
j
*
10
]
*
rh
);
int
x4
=
(
output_data_
[
8
+
j
*
10
]
*
rw
);
int
y4
=
(
output_data_
[
9
+
j
*
10
]
*
rh
)
;
if
(
score
>
threshold_
&&
class_id
>
-
1
)
{
ObjectResult
result_item
;
result_item
.
rect
=
{
x1
,
y1
,
x2
,
y2
,
x3
,
y3
,
x4
,
y4
}
;
result_item
.
class_id
=
class_id
;
result_item
.
confidence
=
score
;
result
->
push_back
(
result_item
);
if
(
is_rbox
)
{
for
(
int
j
=
0
;
j
<
bbox_num
[
im_id
];
++
j
)
{
// Class id
int
class_id
=
static_cast
<
int
>
(
round
(
output_data_
[
0
+
j
*
10
]));
// Confidence score
float
score
=
output_data_
[
1
+
j
*
10
]
;
int
x1
=
(
output_data_
[
2
+
j
*
10
]
*
rw
);
int
y1
=
(
output_data_
[
3
+
j
*
10
]
*
rh
)
;
int
x2
=
(
output_data_
[
4
+
j
*
10
]
*
rw
);
int
y2
=
(
output_data_
[
5
+
j
*
10
]
*
rh
);
int
x3
=
(
output_data_
[
6
+
j
*
10
]
*
rw
);
int
y3
=
(
output_data_
[
7
+
j
*
10
]
*
rh
);
int
x4
=
(
output_data_
[
8
+
j
*
10
]
*
rw
);
int
y4
=
(
output_data_
[
9
+
j
*
10
]
*
rh
);
if
(
score
>
threshold_
&&
class_id
>
-
1
)
{
ObjectResult
result_item
;
result_item
.
rect
=
{
x1
,
y1
,
x2
,
y2
,
x3
,
y3
,
x4
,
y4
};
result_item
.
class_id
=
class_id
;
result_item
.
confidence
=
score
;
result
->
push_back
(
result_item
)
;
}
}
}
}
}
else
{
int
total_size
=
output_data_
.
size
()
/
6
;
for
(
int
j
=
0
;
j
<
total_size
;
++
j
)
{
// Class id
int
class_id
=
static_cast
<
int
>
(
round
(
output_data_
[
0
+
j
*
6
]));
// Confidence score
float
score
=
output_data_
[
1
+
j
*
6
];
int
xmin
=
(
output_data_
[
2
+
j
*
6
]
*
rw
);
int
ymin
=
(
output_data_
[
3
+
j
*
6
]
*
rh
);
int
xmax
=
(
output_data_
[
4
+
j
*
6
]
*
rw
);
int
ymax
=
(
output_data_
[
5
+
j
*
6
]
*
rh
);
int
wd
=
xmax
-
xmin
;
int
hd
=
ymax
-
ymin
;
if
(
score
>
threshold_
&&
class_id
>
-
1
)
{
ObjectResult
result_item
;
result_item
.
rect
=
{
xmin
,
xmax
,
ymin
,
ymax
};
result_item
.
class_id
=
class_id
;
result_item
.
confidence
=
score
;
result
->
push_back
(
result_item
);
else
{
for
(
int
j
=
0
;
j
<
bbox_num
[
im_id
];
++
j
)
{
// Class id
int
class_id
=
static_cast
<
int
>
(
round
(
output_data_
[
0
+
j
*
6
]));
// Confidence score
float
score
=
output_data_
[
1
+
j
*
6
];
int
xmin
=
(
output_data_
[
2
+
j
*
6
]
*
rw
);
int
ymin
=
(
output_data_
[
3
+
j
*
6
]
*
rh
);
int
xmax
=
(
output_data_
[
4
+
j
*
6
]
*
rw
);
int
ymax
=
(
output_data_
[
5
+
j
*
6
]
*
rh
);
int
wd
=
xmax
-
xmin
;
int
hd
=
ymax
-
ymin
;
if
(
score
>
threshold_
&&
class_id
>
-
1
)
{
ObjectResult
result_item
;
result_item
.
rect
=
{
xmin
,
ymin
,
xmax
,
ymax
};
result_item
.
class_id
=
class_id
;
result_item
.
confidence
=
score
;
result
->
push_back
(
result_item
);
}
}
}
}
start_idx
+=
bbox_num
[
im_id
];
}
}
void
ObjectDetector
::
Predict
(
const
cv
::
Mat
&
im
,
void
ObjectDetector
::
Predict
(
const
std
::
vector
<
cv
::
Mat
>
imgs
,
const
double
threshold
,
const
int
warmup
,
const
int
repeats
,
std
::
vector
<
ObjectResult
>*
result
,
std
::
vector
<
int
>*
bbox_num
,
std
::
vector
<
double
>*
times
)
{
auto
preprocess_start
=
std
::
chrono
::
steady_clock
::
now
();
int
batch_size
=
imgs
.
size
();
// in_data_batch
std
::
vector
<
float
>
in_data_all
;
std
::
vector
<
float
>
im_shape_all
(
batch_size
*
2
);
std
::
vector
<
float
>
scale_factor_all
(
batch_size
*
2
);
// Preprocess image
Preprocess
(
im
);
for
(
int
bs_idx
=
0
;
bs_idx
<
batch_size
;
bs_idx
++
)
{
cv
::
Mat
im
=
imgs
.
at
(
bs_idx
);
Preprocess
(
im
);
im_shape_all
[
bs_idx
*
2
]
=
inputs_
.
im_shape_
[
0
];
im_shape_all
[
bs_idx
*
2
+
1
]
=
inputs_
.
im_shape_
[
1
];
scale_factor_all
[
bs_idx
*
2
]
=
inputs_
.
scale_factor_
[
0
];
scale_factor_all
[
bs_idx
*
2
+
1
]
=
inputs_
.
scale_factor_
[
1
];
// TODO: reduce cost time
in_data_all
.
insert
(
in_data_all
.
end
(),
inputs_
.
im_data_
.
begin
(),
inputs_
.
im_data_
.
end
());
}
// Prepare input tensor
auto
input_names
=
predictor_
->
GetInputNames
();
for
(
const
auto
&
tensor_name
:
input_names
)
{
...
...
@@ -245,14 +269,14 @@ void ObjectDetector::Predict(const cv::Mat& im,
if
(
tensor_name
==
"image"
)
{
int
rh
=
inputs_
.
in_net_shape_
[
0
];
int
rw
=
inputs_
.
in_net_shape_
[
1
];
in_tensor
->
Reshape
({
1
,
3
,
rh
,
rw
});
in_tensor
->
CopyFromCpu
(
in
puts_
.
im_data_
.
data
());
in_tensor
->
Reshape
({
batch_size
,
3
,
rh
,
rw
});
in_tensor
->
CopyFromCpu
(
in
_data_all
.
data
());
}
else
if
(
tensor_name
==
"im_shape"
)
{
in_tensor
->
Reshape
({
1
,
2
});
in_tensor
->
CopyFromCpu
(
i
nputs_
.
im_shape_
.
data
());
in_tensor
->
Reshape
({
batch_size
,
2
});
in_tensor
->
CopyFromCpu
(
i
m_shape_all
.
data
());
}
else
if
(
tensor_name
==
"scale_factor"
)
{
in_tensor
->
Reshape
({
1
,
2
});
in_tensor
->
CopyFromCpu
(
inputs_
.
scale_factor_
.
data
());
in_tensor
->
Reshape
({
batch_size
,
2
});
in_tensor
->
CopyFromCpu
(
scale_factor_all
.
data
());
}
}
auto
preprocess_end
=
std
::
chrono
::
steady_clock
::
now
();
...
...
@@ -266,10 +290,6 @@ void ObjectDetector::Predict(const cv::Mat& im,
std
::
vector
<
int
>
output_shape
=
out_tensor
->
shape
();
// Calculate output length
int
output_size
=
1
;
for
(
int
j
=
0
;
j
<
output_shape
.
size
();
++
j
)
{
output_size
*=
output_shape
[
j
];
}
if
(
output_size
<
6
)
{
std
::
cerr
<<
"[WARNING] No object detected."
<<
std
::
endl
;
}
...
...
@@ -286,6 +306,8 @@ void ObjectDetector::Predict(const cv::Mat& im,
auto
output_names
=
predictor_
->
GetOutputNames
();
auto
out_tensor
=
predictor_
->
GetOutputHandle
(
output_names
[
0
]);
std
::
vector
<
int
>
output_shape
=
out_tensor
->
shape
();
auto
out_bbox_num
=
predictor_
->
GetOutputHandle
(
output_names
[
1
]);
std
::
vector
<
int
>
out_bbox_num_shape
=
out_bbox_num
->
shape
();
// Calculate output length
int
output_size
=
1
;
for
(
int
j
=
0
;
j
<
output_shape
.
size
();
++
j
)
{
...
...
@@ -298,11 +320,23 @@ void ObjectDetector::Predict(const cv::Mat& im,
}
output_data_
.
resize
(
output_size
);
out_tensor
->
CopyToCpu
(
output_data_
.
data
());
int
out_bbox_num_size
=
1
;
for
(
int
j
=
0
;
j
<
out_bbox_num_shape
.
size
();
++
j
)
{
out_bbox_num_size
*=
out_bbox_num_shape
[
j
];
}
out_bbox_num_data_
.
resize
(
out_bbox_num_size
);
out_bbox_num
->
CopyToCpu
(
out_bbox_num_data_
.
data
());
}
auto
inference_end
=
std
::
chrono
::
steady_clock
::
now
();
auto
postprocess_start
=
std
::
chrono
::
steady_clock
::
now
();
// Postprocessing result
Postprocess
(
im
,
result
,
is_rbox
);
Postprocess
(
imgs
,
result
,
out_bbox_num_data_
,
is_rbox
);
bbox_num
->
clear
();
for
(
int
k
=
0
;
k
<
out_bbox_num_data_
.
size
();
k
++
)
{
int
tmp
=
out_bbox_num_data_
[
k
];
bbox_num
->
push_back
(
tmp
);
}
auto
postprocess_end
=
std
::
chrono
::
steady_clock
::
now
();
std
::
chrono
::
duration
<
float
>
preprocess_diff
=
preprocess_end
-
preprocess_start
;
...
...
deploy/cpp/src/preprocess_op.cc
浏览文件 @
5e19955b
...
...
@@ -129,7 +129,6 @@ void PadStride::Run(cv::Mat* im, ImageBlob* data) {
static_cast
<
float
>
(
im
->
rows
),
static_cast
<
float
>
(
im
->
cols
),
};
}
...
...
deploy/python/infer.py
浏览文件 @
5e19955b
...
...
@@ -21,6 +21,7 @@ from functools import reduce
from
PIL
import
Image
import
cv2
import
numpy
as
np
import
math
import
paddle
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
...
...
@@ -85,18 +86,29 @@ class Detector(object):
self
.
det_times
=
Timer
()
self
.
cpu_mem
,
self
.
gpu_mem
,
self
.
gpu_util
=
0
,
0
,
0
def
preprocess
(
self
,
im
):
def
preprocess
(
self
,
im
age_list
):
preprocess_ops
=
[]
for
op_info
in
self
.
pred_config
.
preprocess_infos
:
new_op_info
=
op_info
.
copy
()
op_type
=
new_op_info
.
pop
(
'type'
)
preprocess_ops
.
append
(
eval
(
op_type
)(
**
new_op_info
))
im
,
im_info
=
preprocess
(
im
,
preprocess_ops
,
self
.
pred_config
.
input_shape
)
inputs
=
create_inputs
(
im
,
im_info
)
input_im_lst
=
[]
input_im_info_lst
=
[]
for
im_path
in
image_list
:
im
,
im_info
=
preprocess
(
im_path
,
preprocess_ops
,
self
.
pred_config
.
input_shape
)
input_im_lst
.
append
(
im
)
input_im_info_lst
.
append
(
im_info
)
inputs
=
create_inputs
(
input_im_lst
,
input_im_info_lst
)
return
inputs
def
postprocess
(
self
,
np_boxes
,
np_masks
,
inputs
,
threshold
=
0.5
):
def
postprocess
(
self
,
np_boxes
,
np_masks
,
inputs
,
np_boxes_num
,
threshold
=
0.5
):
# postprocess output of predictor
results
=
{}
if
self
.
pred_config
.
arch
in
[
'Face'
]:
...
...
@@ -108,14 +120,15 @@ class Detector(object):
np_boxes
[:,
4
]
*=
h
np_boxes
[:,
5
]
*=
w
results
[
'boxes'
]
=
np_boxes
results
[
'boxes_num'
]
=
np_boxes_num
if
np_masks
is
not
None
:
results
[
'masks'
]
=
np_masks
return
results
def
predict
(
self
,
image
,
threshold
=
0.5
,
warmup
=
0
,
repeats
=
1
):
def
predict
(
self
,
image
_list
,
threshold
=
0.5
,
warmup
=
0
,
repeats
=
1
):
'''
Args:
image
(str/np.ndarray): path of image/ np.ndarray read by cv2
image
_list (list): ,list of image
threshold (float): threshold of predicted box' score
Returns:
results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
...
...
@@ -124,7 +137,7 @@ class Detector(object):
shape: [N, im_h, im_w]
'''
self
.
det_times
.
preprocess_time_s
.
start
()
inputs
=
self
.
preprocess
(
image
)
inputs
=
self
.
preprocess
(
image
_list
)
np_boxes
,
np_masks
=
None
,
None
input_names
=
self
.
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
...
...
@@ -146,6 +159,8 @@ class Detector(object):
output_names
=
self
.
predictor
.
get_output_names
()
boxes_tensor
=
self
.
predictor
.
get_output_handle
(
output_names
[
0
])
np_boxes
=
boxes_tensor
.
copy_to_cpu
()
boxes_num
=
self
.
predictor
.
get_output_handle
(
output_names
[
1
])
np_boxes_num
=
boxes_num
.
copy_to_cpu
()
if
self
.
pred_config
.
mask
:
masks_tensor
=
self
.
predictor
.
get_output_handle
(
output_names
[
2
])
np_masks
=
masks_tensor
.
copy_to_cpu
()
...
...
@@ -155,12 +170,12 @@ class Detector(object):
results
=
[]
if
reduce
(
lambda
x
,
y
:
x
*
y
,
np_boxes
.
shape
)
<
6
:
print
(
'[WARNNING] No object detected.'
)
results
=
{
'boxes'
:
np
.
array
([])}
results
=
{
'boxes'
:
np
.
array
([])
,
'boxes_num'
:
[
0
]
}
else
:
results
=
self
.
postprocess
(
np_boxes
,
np_masks
,
inputs
,
threshold
=
threshold
)
np_boxes
,
np_masks
,
inputs
,
np_boxes_num
,
threshold
=
threshold
)
self
.
det_times
.
postprocess_time_s
.
end
()
self
.
det_times
.
img_num
+=
1
self
.
det_times
.
img_num
+=
len
(
image_list
)
return
results
...
...
@@ -249,21 +264,45 @@ class DetectorSOLOv2(Detector):
return
dict
(
segm
=
np_segms
,
label
=
np_label
,
score
=
np_score
)
def
create_inputs
(
im
,
im_info
):
def
create_inputs
(
im
gs
,
im_info
):
"""generate input for different model type
Args:
im (np.ndarray): image (np.ndarray)
im_info (dict): info of image
model_arch (str): model type
Returns:
inputs (dict): input of model
"""
inputs
=
{}
inputs
[
'image'
]
=
np
.
array
((
im
,
)).
astype
(
'float32'
)
inputs
[
'im_shape'
]
=
np
.
array
((
im_info
[
'im_shape'
],
)).
astype
(
'float32'
)
inputs
[
'scale_factor'
]
=
np
.
array
(
(
im_info
[
'scale_factor'
],
)).
astype
(
'float32'
)
im_shape
=
[]
scale_factor
=
[]
for
e
in
im_info
:
im_shape
.
append
(
np
.
array
((
e
[
'im_shape'
],
)).
astype
(
'float32'
))
scale_factor
.
append
(
np
.
array
((
e
[
'scale_factor'
],
)).
astype
(
'float32'
))
origin_scale_factor
=
np
.
concatenate
(
scale_factor
,
axis
=
0
)
imgs_shape
=
[[
e
.
shape
[
1
],
e
.
shape
[
2
]]
for
e
in
imgs
]
max_shape_h
=
max
([
e
[
0
]
for
e
in
imgs_shape
])
max_shape_w
=
max
([
e
[
1
]
for
e
in
imgs_shape
])
padding_imgs
=
[]
padding_imgs_shape
=
[]
padding_imgs_scale
=
[]
for
img
in
imgs
:
im_c
,
im_h
,
im_w
=
img
.
shape
[:]
padding_im
=
np
.
zeros
(
(
im_c
,
max_shape_h
,
max_shape_w
),
dtype
=
np
.
float32
)
padding_im
[:,
:
im_h
,
:
im_w
]
=
img
padding_imgs
.
append
(
padding_im
)
padding_imgs_shape
.
append
(
np
.
array
([
max_shape_h
,
max_shape_w
]).
astype
(
'float32'
))
rescale
=
[
float
(
max_shape_h
)
/
float
(
im_h
),
float
(
max_shape_w
)
/
float
(
im_w
)
]
padding_imgs_scale
.
append
(
np
.
array
(
rescale
).
astype
(
'float32'
))
inputs
[
'image'
]
=
np
.
stack
(
padding_imgs
,
axis
=
0
)
inputs
[
'im_shape'
]
=
np
.
stack
(
padding_imgs_shape
,
axis
=
0
)
inputs
[
'scale_factor'
]
=
origin_scale_factor
return
inputs
...
...
@@ -426,15 +465,30 @@ def get_test_images(infer_dir, infer_img):
return
images
def
visualize
(
image_
file
,
results
,
labels
,
output_dir
=
'output/'
,
threshold
=
0.5
):
def
visualize
(
image_
list
,
results
,
labels
,
output_dir
=
'output/'
,
threshold
=
0.5
):
# visualize the predict result
im
=
visualize_box_mask
(
image_file
,
results
,
labels
,
threshold
=
threshold
)
img_name
=
os
.
path
.
split
(
image_file
)[
-
1
]
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
out_path
=
os
.
path
.
join
(
output_dir
,
img_name
)
im
.
save
(
out_path
,
quality
=
95
)
print
(
"save result to: "
+
out_path
)
start_idx
=
0
for
idx
,
image_file
in
enumerate
(
image_list
):
im_bboxes_num
=
results
[
'boxes_num'
][
idx
]
im_results
=
{}
if
'boxes'
in
results
:
im_results
[
'boxes'
]
=
results
[
'boxes'
][
start_idx
:
start_idx
+
im_bboxes_num
,
:]
if
'masks'
in
results
:
im_results
[
'masks'
]
=
results
[
'masks'
][
start_idx
:
start_idx
+
im_bboxes_num
,
:]
if
'segm'
in
results
:
im_results
[
'segm'
]
=
results
[
'segm'
][
start_idx
:
start_idx
+
im_bboxes_num
,
:]
start_idx
+=
im_bboxes_num
im
=
visualize_box_mask
(
image_file
,
im_results
,
labels
,
threshold
=
threshold
)
img_name
=
os
.
path
.
split
(
image_file
)[
-
1
]
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
out_path
=
os
.
path
.
join
(
output_dir
,
img_name
)
im
.
save
(
out_path
,
quality
=
95
)
print
(
"save result to: "
+
out_path
)
def
print_arguments
(
args
):
...
...
@@ -444,19 +498,24 @@ def print_arguments(args):
print
(
'------------------------------------------'
)
def
predict_image
(
detector
,
image_list
):
for
i
,
img_file
in
enumerate
(
image_list
):
def
predict_image
(
detector
,
image_list
,
batch_size
=
1
):
batch_loop_cnt
=
math
.
ceil
(
float
(
len
(
image_list
))
/
batch_size
)
for
i
in
range
(
batch_loop_cnt
):
start_index
=
i
*
batch_size
end_index
=
min
((
i
+
1
)
*
batch_size
,
len
(
image_list
))
batch_image_list
=
image_list
[
start_index
:
end_index
]
if
FLAGS
.
run_benchmark
:
detector
.
predict
(
img_file
,
FLAGS
.
threshold
,
warmup
=
10
,
repeats
=
10
)
detector
.
predict
(
batch_image_list
,
FLAGS
.
threshold
,
warmup
=
10
,
repeats
=
10
)
cm
,
gm
,
gu
=
get_current_memory_mb
()
detector
.
cpu_mem
+=
cm
detector
.
gpu_mem
+=
gm
detector
.
gpu_util
+=
gu
print
(
'Test iter {}
, file name:{}'
.
format
(
i
,
img_file
))
print
(
'Test iter {}
'
.
format
(
i
))
else
:
results
=
detector
.
predict
(
img_file
,
FLAGS
.
threshold
)
results
=
detector
.
predict
(
batch_image_list
,
FLAGS
.
threshold
)
visualize
(
img_file
,
batch_image_list
,
results
,
detector
.
pred_config
.
labels
,
output_dir
=
FLAGS
.
output_dir
,
...
...
@@ -535,8 +594,10 @@ def main():
predict_video
(
detector
,
FLAGS
.
camera_id
)
else
:
# predict from image
if
FLAGS
.
image_dir
is
None
and
FLAGS
.
image_file
is
not
None
:
assert
FLAGS
.
batch_size
==
1
,
"batch_size should be 1, when image_file is not None"
img_list
=
get_test_images
(
FLAGS
.
image_dir
,
FLAGS
.
image_file
)
predict_image
(
detector
,
img_list
)
predict_image
(
detector
,
img_list
,
FLAGS
.
batch_size
)
if
not
FLAGS
.
run_benchmark
:
detector
.
det_times
.
info
(
average
=
True
)
else
:
...
...
deploy/python/utils.py
浏览文件 @
5e19955b
...
...
@@ -34,6 +34,8 @@ def argsparser():
type
=
str
,
default
=
None
,
help
=
"Dir of image file, `image_file` has a higher priority."
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"batch_size for infer."
)
parser
.
add_argument
(
"--video_file"
,
type
=
str
,
...
...
ppdet/engine/trainer.py
浏览文件 @
5e19955b
...
...
@@ -436,7 +436,7 @@ class Trainer(object):
image
=
visualize_results
(
image
,
bbox_res
,
mask_res
,
segm_res
,
keypoint_res
,
int
(
outs
[
'im_id'
]
),
catid2name
,
draw_threshold
)
int
(
im_id
),
catid2name
,
draw_threshold
)
self
.
status
[
'result_image'
]
=
np
.
array
(
image
.
copy
())
if
self
.
_compose_callback
:
self
.
_compose_callback
.
on_step_end
(
self
.
status
)
...
...
ppdet/modeling/architectures/s2anet.py
浏览文件 @
5e19955b
...
...
@@ -83,11 +83,13 @@ class S2ANet(BaseArch):
nms_pre
=
self
.
s2anet_bbox_post_process
.
nms_pre
pred_scores
,
pred_bboxes
=
self
.
s2anet_head
.
get_prediction
(
nms_pre
)
# post_process
pred_bboxes
,
bbox_num
=
self
.
s2anet_bbox_post_process
(
pred_scores
,
pred_bboxes
)
# rescale the prediction back to origin image
pred_bboxes
=
self
.
s2anet_bbox_post_process
.
get_pred
(
pred_bboxes
,
bbox_num
,
im_shape
,
scale_factor
)
# output
output
=
{
'bbox'
:
pred_bboxes
,
'bbox_num'
:
bbox_num
}
return
output
...
...
ppdet/modeling/layers.py
浏览文件 @
5e19955b
...
...
@@ -334,8 +334,11 @@ class RCNNBox(object):
self
.
num_classes
=
num_classes
def
__call__
(
self
,
bbox_head_out
,
rois
,
im_shape
,
scale_factor
):
bbox_pred
,
cls_prob
=
bbox_head_out
roi
,
rois_num
=
rois
bbox_pred
=
bbox_head_out
[
0
]
cls_prob
=
bbox_head_out
[
1
]
roi
=
rois
[
0
]
rois_num
=
rois
[
1
]
origin_shape
=
paddle
.
floor
(
im_shape
/
scale_factor
+
0.5
)
scale_list
=
[]
origin_shape_list
=
[]
...
...
ppdet/modeling/post_process.py
浏览文件 @
5e19955b
...
...
@@ -264,7 +264,6 @@ class S2ANetBBoxPostProcess(nn.Layer):
bbox_num
=
self
.
fake_bbox_num
pred_cls_score_bbox
=
paddle
.
reshape
(
pred_cls_score_bbox
,
[
-
1
,
10
])
assert
pred_cls_score_bbox
.
shape
[
1
]
==
10
return
pred_cls_score_bbox
,
bbox_num
def
get_pred
(
self
,
bboxes
,
bbox_num
,
im_shape
,
scale_factor
):
...
...
@@ -281,7 +280,6 @@ class S2ANetBBoxPostProcess(nn.Layer):
including labels, scores and bboxes. The size of
bboxes are corresponding to the original image.
"""
assert
bboxes
.
shape
[
1
]
==
10
origin_shape
=
paddle
.
floor
(
im_shape
/
scale_factor
+
0.5
)
origin_shape_list
=
[]
...
...
@@ -307,6 +305,7 @@ class S2ANetBBoxPostProcess(nn.Layer):
pred_bbox
=
bboxes
[:,
2
:]
# rescale bbox to original image
pred_bbox
=
pred_bbox
.
reshape
([
-
1
,
8
])
scaled_bbox
=
pred_bbox
/
scale_factor_list
origin_h
=
origin_shape_list
[:,
0
]
origin_w
=
origin_shape_list
[:,
1
]
...
...
ppdet/modeling/proposal_generator/rpn_head.py
浏览文件 @
5e19955b
...
...
@@ -62,11 +62,11 @@ class RPNHead(nn.Layer):
Args:
anchor_generator (dict): configure of anchor generation
rpn_target_assign (dict): configure of rpn targets assignment
train_proposal (dict): configure of proposals generation
train_proposal (dict): configure of proposals generation
at the stage of training
test_proposal (dict): configure of proposals generation
at the stage of prediction
in_channel (int): channel of input feature maps which can be
in_channel (int): channel of input feature maps which can be
derived by from_config
"""
...
...
@@ -156,31 +156,35 @@ class RPNHead(nn.Layer):
"""
prop_gen
=
self
.
train_proposal
if
self
.
training
else
self
.
test_proposal
im_shape
=
inputs
[
'im_shape'
]
rpn_rois_list
=
[[]
for
i
in
range
(
batch_size
)]
rpn_prob_list
=
[[]
for
i
in
range
(
batch_size
)]
rpn_rois_num_list
=
[[]
for
i
in
range
(
batch_size
)]
# Collect multi-level proposals for each batch
# Get 'topk' of them as final output
bs_rois_collect
=
[]
bs_rois_num_collect
=
[]
# Generate proposals for each level and each batch.
# Discard batch-computing to avoid sorting bbox cross different batches.
for
rpn_score
,
rpn_delta
,
anchor
in
zip
(
scores
,
bbox_deltas
,
anchors
):
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
rpn_rois_list
=
[]
rpn_prob_list
=
[]
rpn_rois_num_list
=
[]
for
rpn_score
,
rpn_delta
,
anchor
in
zip
(
scores
,
bbox_deltas
,
anchors
):
rpn_rois
,
rpn_rois_prob
,
rpn_rois_num
,
post_nms_top_n
=
prop_gen
(
scores
=
rpn_score
[
i
:
i
+
1
],
bbox_deltas
=
rpn_delta
[
i
:
i
+
1
],
anchors
=
anchor
,
im_shape
=
im_shape
[
i
:
i
+
1
])
if
rpn_rois
.
shape
[
0
]
>
0
:
rpn_rois_list
[
i
].
append
(
rpn_rois
)
rpn_prob_list
[
i
].
append
(
rpn_rois_prob
)
rpn_rois_num_list
[
i
].
append
(
rpn_rois_num
)
# Collect multi-level proposals for each batch
# Get 'topk' of them as final output
rois_collect
=
[]
rois_num_collect
=
[]
for
i
in
range
(
batch_size
):
rpn_rois_list
.
append
(
rpn_rois
)
rpn_prob_list
.
append
(
rpn_rois_prob
)
rpn_rois_num_list
.
append
(
rpn_rois_num
)
if
len
(
scores
)
>
1
:
rpn_rois
=
paddle
.
concat
(
rpn_rois_list
[
i
])
rpn_prob
=
paddle
.
concat
(
rpn_prob_list
[
i
]).
flatten
()
rpn_rois
=
paddle
.
concat
(
rpn_rois_list
)
rpn_prob
=
paddle
.
concat
(
rpn_prob_list
).
flatten
()
if
rpn_prob
.
shape
[
0
]
>
post_nms_top_n
:
topk_prob
,
topk_inds
=
paddle
.
topk
(
rpn_prob
,
post_nms_top_n
)
topk_rois
=
paddle
.
gather
(
rpn_rois
,
topk_inds
)
...
...
@@ -188,17 +192,19 @@ class RPNHead(nn.Layer):
topk_rois
=
rpn_rois
topk_prob
=
rpn_prob
else
:
topk_rois
=
rpn_rois_list
[
i
][
0
]
topk_prob
=
rpn_prob_list
[
i
][
0
].
flatten
()
rois_collect
.
append
(
topk_rois
)
rois_num_collect
.
append
(
paddle
.
shape
(
topk_rois
)[
0
])
rois_num_collect
=
paddle
.
concat
(
rois_num_collect
)
topk_rois
=
rpn_rois_list
[
0
]
topk_prob
=
rpn_prob_list
[
0
].
flatten
()
bs_rois_collect
.
append
(
topk_rois
)
bs_rois_num_collect
.
append
(
paddle
.
shape
(
topk_rois
)[
0
])
bs_rois_num_collect
=
paddle
.
concat
(
bs_rois_num_collect
)
return
rois_collect
,
rois_num_collect
return
bs_rois_collect
,
bs_
rois_num_collect
def
get_loss
(
self
,
pred_scores
,
pred_deltas
,
anchors
,
inputs
):
"""
pred_scores (list[Tensor]): Multi-level scores prediction
pred_scores (list[Tensor]): Multi-level scores prediction
pred_deltas (list[Tensor]): Multi-level deltas prediction
anchors (list[Tensor]): Multi-level anchors
inputs (dict): ground truth info, including im, gt_bbox, gt_score
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录