Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
d207622b
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 1 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d207622b
编写于
6月 02, 2022
作者:
G
Guanghua Yu
提交者:
GitHub
6月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add YOLOv5s ACT demo (#1143)
* add YOLOv5s ACT demo * fix comment * fix docs
上级
567a97d9
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
637 addition
and
23 deletion
+637
-23
demo/auto_compression/detection/README.md
demo/auto_compression/detection/README.md
+22
-1
demo/auto_compression/detection/configs/yolov5_reader.yml
demo/auto_compression/detection/configs/yolov5_reader.yml
+27
-0
demo/auto_compression/detection/configs/yolov5s_qat_dis.yaml
demo/auto_compression/detection/configs/yolov5s_qat_dis.yaml
+46
-0
demo/auto_compression/detection/infer.py
demo/auto_compression/detection/infer.py
+322
-0
demo/auto_compression/detection/post_process.py
demo/auto_compression/detection/post_process.py
+173
-0
demo/auto_compression/detection/run.py
demo/auto_compression/detection/run.py
+46
-20
paddleslim/auto_compression/compressor.py
paddleslim/auto_compression/compressor.py
+1
-2
未找到文件。
demo/auto_compression/detection/README.md
浏览文件 @
d207622b
...
...
@@ -18,7 +18,7 @@
## 2.Benchmark
-
PP-YOLOE模型
### PP-YOLOE
| 模型 | 策略 | 输入尺寸 | mAP
<sup>
val
<br>
0.5:0.95 | 预测时延
<sup><small>
FP32
</small><sup><br><sup>
(ms) |预测时延
<sup><small>
FP32
</small><sup><br><sup>
(ms) | 预测时延
<sup><small>
INT8
</small><sup><br><sup>
(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
...
...
@@ -28,12 +28,33 @@
-
mAP的指标均在COCO val2017数据集中评测得到。
-
PP-YOLOE模型在Tesla V100的GPU环境下测试,测试脚本是
[
benchmark demo
](
https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.4/deploy/python
)
### YOLOv5
| 模型 | 策略 | 输入尺寸 | mAP
<sup>
val
<br>
0.5:0.95 | 预测时延
<sup><small>
FP32
</small><sup><br><sup>
(ms) |预测时延
<sup><small>
FP32
</small><sup><br><sup>
(ms) | 预测时延
<sup><small>
INT8
</small><sup><br><sup>
(ms) | 配置文件 | Inference模型 |
| :-------- |:-------- |:--------: | :---------------------: | :----------------: | :----------------: | :---------------: | :-----------------------------: | :-----------------------------: |
| YOLOv5s | Base模型 | 640
*
640 | 37.4 | 6.0 | 4.9ms | - | - |
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/detection/yolov5s_infer.tar
)
|
| YOLOv5s | 量化+蒸馏 | 640
*
640 | 36.5 | - | - | 4.5ms |
[
config
](
https://github.com/PaddlePaddle/PaddleSlim/tree/develop/demo/auto_compression/detection/configs/yolov5s_qat_dis.yaml
)
|
[
Model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/yolov5s_quant.tar
)
|
说明:
-
mAP的指标均在COCO val2017数据集中评测得到。
-
YOLOv5s模型在Tesla V100的GPU环境下测试,测试脚本是
[
benchmark demo
](
./infer.py
)
-
YOLOv5模型源自
[
ultralytics/yolov5
](
https://github.com/ultralytics/yolov5
)
,通过
[
X2Paddle
](
https://github.com/PaddlePaddle/X2Paddle
)
工具转换YOLOv5预测模型步骤:
(1) 安装X2Paddle的1.3.6以上版本;(pip install x2paddle)
(2) 转换模型:
```
x2paddle --framework=onnx --model=yolov5s.onnx --save_dir=pd_model
cp -r pd_model/inference_model/ yolov5_inference_model
```
即可得到YOLOv5s模型的预测模型(
`model.pdmodel`
和
`model.pdiparams`
)。如想快速体验,可直接下载上方表格中YOLOv5s的Base预测模型。
## 3. 自动压缩流程
#### 3.1 准备环境
-
PaddlePaddle >= 2.3 (可从
[
Paddle官网
](
https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/linux-pip.html
)
下载安装)
-
PaddleSlim develop版本
-
PaddleDet >= 2.4
-
opencv-python
安装paddlepaddle:
```
shell
...
...
demo/auto_compression/detection/configs/yolov5_reader.yml
0 → 100644
浏览文件 @
d207622b
metric
:
COCO
num_classes
:
80
# Datset configuration
TrainDataset
:
!COCODataSet
image_dir
:
train2017
anno_path
:
annotations/instances_train2017.json
dataset_dir
:
dataset/coco/
EvalDataset
:
!COCODataSet
image_dir
:
val2017
anno_path
:
annotations/instances_val2017.json
dataset_dir
:
dataset/coco/
worker_num
:
4
# preprocess reader in test
EvalReader
:
sample_transforms
:
-
Decode
:
{}
-
Resize
:
{
target_size
:
[
640
,
640
],
keep_ratio
:
True
}
-
Pad
:
{
size
:
[
640
,
640
],
fill_value
:
[
114.
,
114.
,
114.
]}
-
NormalizeImage
:
{
mean
:
[
0
,
0
,
0
],
std
:
[
1
,
1
,
1
],
is_scale
:
True
}
-
Permute
:
{}
batch_size
:
1
demo/auto_compression/detection/configs/yolov5s_qat_dis.yaml
0 → 100644
浏览文件 @
d207622b
Global
:
reader_config
:
configs/yolov5_reader.yml
input_list
:
{
'
image'
:
'
x2paddle_images'
}
Evaluation
:
True
arch
:
'
YOLOv5'
model_dir
:
./yolov5s_infer/
model_filename
:
model.pdmodel
params_filename
:
model.pdiparams
Distillation
:
distill_lambda
:
1.0
distill_loss
:
l2_loss
distill_node_pair
:
-
teacher_conv2d_106.tmp_1
-
conv2d_106.tmp_1
-
teacher_conv2d_113.tmp_1
-
conv2d_113.tmp_1
-
teacher_conv2d_119.tmp_1
-
conv2d_119.tmp_1
merge_feed
:
true
teacher_model_dir
:
./yolov5_inference_model/
teacher_model_filename
:
model.pdmodel
teacher_params_filename
:
model.pdiparams
Quantization
:
use_pact
:
true
activation_bits
:
8
weight_bits
:
8
activation_quantize_type
:
'
range_abs_max'
weight_quantize_type
:
'
channel_wise_abs_max'
is_full_quantize
:
false
not_quant_pattern
:
-
skip_quant
quantize_op_types
:
-
conv2d
-
depthwise_conv2d
TrainConfig
:
epochs
:
1
eval_iter
:
1000
learning_rate
:
0.00001
optimizer
:
SGD
optim_args
:
weight_decay
:
4.0e-05
target_metric
:
0.365
demo/auto_compression/detection/infer.py
0 → 100644
浏览文件 @
d207622b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
numpy
as
np
import
argparse
import
time
from
paddle.inference
import
Config
from
paddle.inference
import
create_predictor
from
post_process
import
YOLOv5PostProcess
CLASS_LABEL
=
[
'person'
,
'bicycle'
,
'car'
,
'motorcycle'
,
'airplane'
,
'bus'
,
'train'
,
'truck'
,
'boat'
,
'traffic light'
,
'fire hydrant'
,
'stop sign'
,
'parking meter'
,
'bench'
,
'bird'
,
'cat'
,
'dog'
,
'horse'
,
'sheep'
,
'cow'
,
'elephant'
,
'bear'
,
'zebra'
,
'giraffe'
,
'backpack'
,
'umbrella'
,
'handbag'
,
'tie'
,
'suitcase'
,
'frisbee'
,
'skis'
,
'snowboard'
,
'sports ball'
,
'kite'
,
'baseball bat'
,
'baseball glove'
,
'skateboard'
,
'surfboard'
,
'tennis racket'
,
'bottle'
,
'wine glass'
,
'cup'
,
'fork'
,
'knife'
,
'spoon'
,
'bowl'
,
'banana'
,
'apple'
,
'sandwich'
,
'orange'
,
'broccoli'
,
'carrot'
,
'hot dog'
,
'pizza'
,
'donut'
,
'cake'
,
'chair'
,
'couch'
,
'potted plant'
,
'bed'
,
'dining table'
,
'toilet'
,
'tv'
,
'laptop'
,
'mouse'
,
'remote'
,
'keyboard'
,
'cell phone'
,
'microwave'
,
'oven'
,
'toaster'
,
'sink'
,
'refrigerator'
,
'book'
,
'clock'
,
'vase'
,
'scissors'
,
'teddy bear'
,
'hair drier'
,
'toothbrush'
]
def
generate_scale
(
im
,
target_shape
,
keep_ratio
=
True
):
"""
Args:
im (np.ndarray): image (np.ndarray)
Returns:
im_scale_x: the resize ratio of X
im_scale_y: the resize ratio of Y
"""
origin_shape
=
im
.
shape
[:
2
]
if
keep_ratio
:
im_size_min
=
np
.
min
(
origin_shape
)
im_size_max
=
np
.
max
(
origin_shape
)
target_size_min
=
np
.
min
(
target_shape
)
target_size_max
=
np
.
max
(
target_shape
)
im_scale
=
float
(
target_size_min
)
/
float
(
im_size_min
)
if
np
.
round
(
im_scale
*
im_size_max
)
>
target_size_max
:
im_scale
=
float
(
target_size_max
)
/
float
(
im_size_max
)
im_scale_x
=
im_scale
im_scale_y
=
im_scale
else
:
resize_h
,
resize_w
=
target_shape
im_scale_y
=
resize_h
/
float
(
origin_shape
[
0
])
im_scale_x
=
resize_w
/
float
(
origin_shape
[
1
])
return
im_scale_y
,
im_scale_x
def
image_preprocess
(
img_path
,
target_shape
):
img
=
cv2
.
imread
(
img_path
)
# Resize
im_scale_y
,
im_scale_x
=
generate_scale
(
img
,
target_shape
)
img
=
cv2
.
resize
(
img
,
None
,
None
,
fx
=
im_scale_x
,
fy
=
im_scale_y
,
interpolation
=
cv2
.
INTER_LINEAR
)
# Pad
im_h
,
im_w
=
img
.
shape
[:
2
]
h
,
w
=
target_shape
[:]
if
h
!=
im_h
or
w
!=
im_w
:
canvas
=
np
.
ones
((
h
,
w
,
3
),
dtype
=
np
.
float32
)
canvas
*=
np
.
array
([
114.0
,
114.0
,
114.0
],
dtype
=
np
.
float32
)
canvas
[
0
:
im_h
,
0
:
im_w
,
:]
=
img
.
astype
(
np
.
float32
)
img
=
canvas
img
=
cv2
.
cvtColor
(
img
,
cv2
.
COLOR_BGR2RGB
)
img
=
np
.
transpose
(
img
,
[
2
,
0
,
1
])
/
255
img
=
np
.
expand_dims
(
img
,
0
)
scale_factor
=
np
.
array
([[
im_scale_y
,
im_scale_x
]])
return
img
.
astype
(
np
.
float32
),
scale_factor
def
get_color_map_list
(
num_classes
):
color_map
=
num_classes
*
[
0
,
0
,
0
]
for
i
in
range
(
0
,
num_classes
):
j
=
0
lab
=
i
while
lab
:
color_map
[
i
*
3
]
|=
(((
lab
>>
0
)
&
1
)
<<
(
7
-
j
))
color_map
[
i
*
3
+
1
]
|=
(((
lab
>>
1
)
&
1
)
<<
(
7
-
j
))
color_map
[
i
*
3
+
2
]
|=
(((
lab
>>
2
)
&
1
)
<<
(
7
-
j
))
j
+=
1
lab
>>=
3
color_map
=
[
color_map
[
i
:
i
+
3
]
for
i
in
range
(
0
,
len
(
color_map
),
3
)]
return
color_map
def
draw_box
(
image_file
,
results
,
class_label
,
threshold
=
0.5
):
srcimg
=
cv2
.
imread
(
image_file
,
1
)
for
i
in
range
(
len
(
results
)):
color_list
=
get_color_map_list
(
len
(
class_label
))
clsid2color
=
{}
classid
,
conf
=
int
(
results
[
i
,
0
]),
results
[
i
,
1
]
if
conf
<
threshold
:
continue
xmin
,
ymin
,
xmax
,
ymax
=
int
(
results
[
i
,
2
]),
int
(
results
[
i
,
3
]),
int
(
results
[
i
,
4
]),
int
(
results
[
i
,
5
])
if
classid
not
in
clsid2color
:
clsid2color
[
classid
]
=
color_list
[
classid
]
color
=
tuple
(
clsid2color
[
classid
])
cv2
.
rectangle
(
srcimg
,
(
xmin
,
ymin
),
(
xmax
,
ymax
),
color
,
thickness
=
2
)
print
(
class_label
[
classid
]
+
': '
+
str
(
round
(
conf
,
3
)))
cv2
.
putText
(
srcimg
,
class_label
[
classid
]
+
':'
+
str
(
round
(
conf
,
3
)),
(
xmin
,
ymin
-
10
),
cv2
.
FONT_HERSHEY_SIMPLEX
,
0.8
,
(
0
,
255
,
0
),
thickness
=
2
)
return
srcimg
def
load_predictor
(
model_dir
,
run_mode
=
'paddle'
,
batch_size
=
1
,
device
=
'CPU'
,
min_subgraph_size
=
3
,
use_dynamic_shape
=
False
,
trt_min_shape
=
1
,
trt_max_shape
=
1280
,
trt_opt_shape
=
640
,
trt_calib_mode
=
False
,
cpu_threads
=
1
,
enable_mkldnn
=
False
,
enable_mkldnn_bfloat16
=
False
,
delete_shuffle_pass
=
False
):
"""set AnalysisConfig, generate AnalysisPredictor
Args:
model_dir (str): root path of __model__ and __params__
device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
use_dynamic_shape (bool): use dynamic shape or not
trt_min_shape (int): min shape for dynamic shape in trt
trt_max_shape (int): max shape for dynamic shape in trt
trt_opt_shape (int): opt shape for dynamic shape in trt
trt_calib_mode (bool): If the model is produced by TRT offline quantitative
calibration, trt_calib_mode need to set True
delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
Used by action model.
Returns:
predictor (PaddlePredictor): AnalysisPredictor
Raises:
ValueError: predict by TensorRT need device == 'GPU'.
"""
if
device
!=
'GPU'
and
run_mode
!=
'paddle'
:
raise
ValueError
(
"Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
.
format
(
run_mode
,
device
))
config
=
Config
(
os
.
path
.
join
(
model_dir
,
'model.pdmodel'
),
os
.
path
.
join
(
model_dir
,
'model.pdiparams'
))
if
device
==
'GPU'
:
# initial GPU memory(M), device ID
config
.
enable_use_gpu
(
200
,
0
)
# optimize graph and fuse op
config
.
switch_ir_optim
(
True
)
elif
device
==
'XPU'
:
config
.
enable_lite_engine
()
config
.
enable_xpu
(
10
*
1024
*
1024
)
else
:
config
.
disable_gpu
()
config
.
set_cpu_math_library_num_threads
(
cpu_threads
)
if
enable_mkldnn
:
try
:
# cache 10 different shapes for mkldnn to avoid memory leak
config
.
set_mkldnn_cache_capacity
(
10
)
config
.
enable_mkldnn
()
if
enable_mkldnn_bfloat16
:
config
.
enable_mkldnn_bfloat16
()
except
Exception
as
e
:
print
(
"The current environment does not support `mkldnn`, so disable mkldnn."
)
pass
precision_map
=
{
'trt_int8'
:
Config
.
Precision
.
Int8
,
'trt_fp32'
:
Config
.
Precision
.
Float32
,
'trt_fp16'
:
Config
.
Precision
.
Half
}
if
run_mode
in
precision_map
.
keys
():
config
.
enable_tensorrt_engine
(
workspace_size
=
(
1
<<
25
)
*
batch_size
,
max_batch_size
=
batch_size
,
min_subgraph_size
=
min_subgraph_size
,
precision_mode
=
precision_map
[
run_mode
],
use_static
=
False
,
use_calib_mode
=
trt_calib_mode
)
if
use_dynamic_shape
:
min_input_shape
=
{
'image'
:
[
batch_size
,
3
,
trt_min_shape
,
trt_min_shape
]
}
max_input_shape
=
{
'image'
:
[
batch_size
,
3
,
trt_max_shape
,
trt_max_shape
]
}
opt_input_shape
=
{
'image'
:
[
batch_size
,
3
,
trt_opt_shape
,
trt_opt_shape
]
}
config
.
set_trt_dynamic_shape_info
(
min_input_shape
,
max_input_shape
,
opt_input_shape
)
print
(
'trt set dynamic shape done!'
)
# disable print log when predict
config
.
disable_glog_info
()
# enable shared memory
config
.
enable_memory_optim
()
# disable feed, fetch OP, needed by zero_copy_run
config
.
switch_use_feed_fetch_ops
(
False
)
if
delete_shuffle_pass
:
config
.
delete_pass
(
"shuffle_channel_detect_pass"
)
predictor
=
create_predictor
(
config
)
return
predictor
def
predict_image
(
predictor
,
image_file
,
image_shape
=
[
640
,
640
],
warmup
=
1
,
repeats
=
1
,
threshold
=
0.5
,
arch
=
'YOLOv5'
):
img
,
scale_factor
=
image_preprocess
(
image_file
,
image_shape
)
inputs
=
{}
if
arch
==
'YOLOv5'
:
inputs
[
'x2paddle_images'
]
=
img
input_names
=
predictor
.
get_input_names
()
for
i
in
range
(
len
(
input_names
)):
input_tensor
=
predictor
.
get_input_handle
(
input_names
[
i
])
input_tensor
.
copy_from_cpu
(
inputs
[
input_names
[
i
]])
for
i
in
range
(
warmup
):
predictor
.
run
()
np_boxes
=
None
predict_time
=
0.
time_min
=
float
(
"inf"
)
time_max
=
float
(
'-inf'
)
for
i
in
range
(
repeats
):
start_time
=
time
.
time
()
predictor
.
run
()
output_names
=
predictor
.
get_output_names
()
boxes_tensor
=
predictor
.
get_output_handle
(
output_names
[
0
])
np_boxes
=
boxes_tensor
.
copy_to_cpu
()
end_time
=
time
.
time
()
timed
=
end_time
-
start_time
time_min
=
min
(
time_min
,
timed
)
time_max
=
max
(
time_max
,
timed
)
predict_time
+=
timed
time_avg
=
predict_time
/
repeats
print
(
'Inference time(ms): min={}, max={}, avg={}'
.
format
(
round
(
time_min
*
1000
,
2
),
round
(
time_max
*
1000
,
1
),
round
(
time_avg
*
1000
,
1
)))
postprocess
=
YOLOv5PostProcess
(
score_threshold
=
0.001
,
nms_threshold
=
0.6
,
multi_label
=
True
)
res
=
postprocess
(
np_boxes
,
scale_factor
)
res_img
=
draw_box
(
image_file
,
res
[
'bbox'
],
CLASS_LABEL
,
threshold
=
threshold
)
cv2
.
imwrite
(
'result.jpg'
,
res_img
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--image_file'
,
type
=
str
,
default
=
None
,
help
=
"image path"
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
"inference model filepath"
)
parser
.
add_argument
(
'--benchmark'
,
type
=
bool
,
default
=
False
,
help
=
"Whether run benchmark or not."
)
parser
.
add_argument
(
'--run_mode'
,
type
=
str
,
default
=
'paddle'
,
help
=
"mode of running(paddle/trt_fp32/trt_fp16/trt_int8)"
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'CPU'
,
help
=
"Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU"
)
parser
.
add_argument
(
'--img_shape'
,
type
=
int
,
default
=
640
,
help
=
"input_size"
)
args
=
parser
.
parse_args
()
predictor
=
load_predictor
(
args
.
model_path
,
run_mode
=
args
.
run_mode
,
device
=
args
.
device
)
warmup
,
repeats
=
1
,
1
if
args
.
benchmark
:
warmup
,
repeats
=
50
,
100
predict_image
(
predictor
,
args
.
image_file
,
image_shape
=
[
args
.
img_shape
,
args
.
img_shape
],
warmup
=
warmup
,
repeats
=
repeats
)
demo/auto_compression/detection/post_process.py
0 → 100644
浏览文件 @
d207622b
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
cv2
def
box_area
(
boxes
):
"""
Args:
boxes(np.ndarray): [N, 4]
return: [N]
"""
return
(
boxes
[:,
2
]
-
boxes
[:,
0
])
*
(
boxes
[:,
3
]
-
boxes
[:,
1
])
def
box_iou
(
box1
,
box2
):
"""
Args:
box1(np.ndarray): [N, 4]
box2(np.ndarray): [M, 4]
return: [N, M]
"""
area1
=
box_area
(
box1
)
area2
=
box_area
(
box2
)
lt
=
np
.
maximum
(
box1
[:,
np
.
newaxis
,
:
2
],
box2
[:,
:
2
])
rb
=
np
.
minimum
(
box1
[:,
np
.
newaxis
,
2
:],
box2
[:,
2
:])
wh
=
rb
-
lt
wh
=
np
.
maximum
(
0
,
wh
)
inter
=
wh
[:,
:,
0
]
*
wh
[:,
:,
1
]
iou
=
inter
/
(
area1
[:,
np
.
newaxis
]
+
area2
-
inter
)
return
iou
def
nms
(
boxes
,
scores
,
iou_threshold
):
"""
Non Max Suppression numpy implementation.
args:
boxes(np.ndarray): [N, 4]
scores(np.ndarray): [N, 1]
iou_threshold(float): Threshold of IoU.
"""
idxs
=
scores
.
argsort
()
keep
=
[]
while
idxs
.
size
>
0
:
max_score_index
=
idxs
[
-
1
]
max_score_box
=
boxes
[
max_score_index
][
None
,
:]
keep
.
append
(
max_score_index
)
if
idxs
.
size
==
1
:
break
idxs
=
idxs
[:
-
1
]
other_boxes
=
boxes
[
idxs
]
ious
=
box_iou
(
max_score_box
,
other_boxes
)
idxs
=
idxs
[
ious
[
0
]
<=
iou_threshold
]
keep
=
np
.
array
(
keep
)
return
keep
class
YOLOv5PostProcess
(
object
):
"""
Post process of YOLOv5 network.
args:
score_threshold(float): Threshold to filter out bounding boxes with low
confidence score. If not provided, consider all boxes.
nms_threshold(float): The threshold to be used in NMS.
multi_label(bool): Whether keep multi label in boxes.
keep_top_k(int): Number of total bboxes to be kept per image after NMS
step. -1 means keeping all bboxes after NMS step.
"""
def
__init__
(
self
,
score_threshold
=
0.25
,
nms_threshold
=
0.5
,
multi_label
=
False
,
keep_top_k
=
300
):
self
.
score_threshold
=
score_threshold
self
.
nms_threshold
=
nms_threshold
self
.
multi_label
=
multi_label
self
.
keep_top_k
=
keep_top_k
def
_xywh2xyxy
(
self
,
x
):
# Convert from [x, y, w, h] to [x1, y1, x2, y2]
y
=
np
.
copy
(
x
)
y
[:,
0
]
=
x
[:,
0
]
-
x
[:,
2
]
/
2
# top left x
y
[:,
1
]
=
x
[:,
1
]
-
x
[:,
3
]
/
2
# top left y
y
[:,
2
]
=
x
[:,
0
]
+
x
[:,
2
]
/
2
# bottom right x
y
[:,
3
]
=
x
[:,
1
]
+
x
[:,
3
]
/
2
# bottom right y
return
y
def
_non_max_suppression
(
self
,
prediction
):
max_wh
=
4096
# (pixels) minimum and maximum box width and height
nms_top_k
=
30000
cand_boxes
=
prediction
[...,
4
]
>
self
.
score_threshold
# candidates
output
=
[
np
.
zeros
((
0
,
6
))]
*
prediction
.
shape
[
0
]
for
batch_id
,
boxes
in
enumerate
(
prediction
):
# Apply constraints
boxes
=
boxes
[
cand_boxes
[
batch_id
]]
if
not
boxes
.
shape
[
0
]:
continue
# Compute conf (conf = obj_conf * cls_conf)
boxes
[:,
5
:]
*=
boxes
[:,
4
:
5
]
# Box (center x, center y, width, height) to (x1, y1, x2, y2)
convert_box
=
self
.
_xywh2xyxy
(
boxes
[:,
:
4
])
# Detections matrix nx6 (xyxy, conf, cls)
if
self
.
multi_label
:
i
,
j
=
(
boxes
[:,
5
:]
>
self
.
score_threshold
).
nonzero
()
boxes
=
np
.
concatenate
(
(
convert_box
[
i
],
boxes
[
i
,
j
+
5
,
None
],
j
[:,
None
].
astype
(
np
.
float32
)),
axis
=
1
)
else
:
conf
=
np
.
max
(
boxes
[:,
5
:],
axis
=
1
)
j
=
np
.
argmax
(
boxes
[:,
5
:],
axis
=
1
)
re
=
np
.
array
(
conf
.
reshape
(
-
1
)
>
self
.
score_threshold
)
conf
=
conf
.
reshape
(
-
1
,
1
)
j
=
j
.
reshape
(
-
1
,
1
)
boxes
=
np
.
concatenate
((
convert_box
,
conf
,
j
),
axis
=
1
)[
re
]
num_box
=
boxes
.
shape
[
0
]
if
not
num_box
:
continue
elif
num_box
>
nms_top_k
:
boxes
=
boxes
[
boxes
[:,
4
].
argsort
()[::
-
1
][:
nms_top_k
]]
# Batched NMS
c
=
boxes
[:,
5
:
6
]
*
max_wh
clean_boxes
,
scores
=
boxes
[:,
:
4
]
+
c
,
boxes
[:,
4
]
keep
=
nms
(
clean_boxes
,
scores
,
self
.
nms_threshold
)
# limit detection box num
if
keep
.
shape
[
0
]
>
self
.
keep_top_k
:
keep
=
keep
[:
self
.
keep_top_k
]
output
[
batch_id
]
=
boxes
[
keep
]
return
output
def
__call__
(
self
,
outs
,
scale_factor
):
preds
=
self
.
_non_max_suppression
(
outs
)
bboxs
,
box_nums
=
[],
[]
for
i
,
pred
in
enumerate
(
preds
):
if
len
(
pred
.
shape
)
>
2
:
pred
=
np
.
squeeze
(
pred
)
if
len
(
pred
.
shape
)
==
1
:
pred
=
pred
[
np
.
newaxis
,
:]
pred_bboxes
=
pred
[:,
:
4
]
scale_factor
=
np
.
tile
(
scale_factor
[
i
][::
-
1
],
(
1
,
2
))
pred_bboxes
/=
scale_factor
bbox
=
np
.
concatenate
(
[
pred
[:,
-
1
][:,
np
.
newaxis
],
pred
[:,
-
2
][:,
np
.
newaxis
],
pred_bboxes
],
axis
=-
1
)
bboxs
.
append
(
bbox
)
box_num
=
bbox
.
shape
[
0
]
box_nums
.
append
(
box_num
)
bboxs
=
np
.
concatenate
(
bboxs
,
axis
=
0
)
box_nums
=
np
.
array
(
box_nums
)
return
{
'bbox'
:
bboxs
,
'bbox_num'
:
box_nums
}
demo/auto_compression/detection/run.py
浏览文件 @
d207622b
...
...
@@ -23,6 +23,8 @@ from ppdet.metrics import COCOMetric
from
paddleslim.auto_compression.config_helpers
import
load_config
as
load_slim_config
from
paddleslim.auto_compression
import
AutoCompression
from
post_process
import
YOLOv5PostProcess
def
argsparser
():
parser
=
argparse
.
ArgumentParser
(
description
=
__doc__
)
...
...
@@ -59,8 +61,12 @@ def reader_wrapper(reader, input_list):
def
gen
():
for
data
in
reader
:
in_dict
=
{}
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
if
isinstance
(
input_list
,
list
):
for
input_name
in
input_list
:
in_dict
[
input_name
]
=
data
[
input_name
]
elif
isinstance
(
input_list
,
dict
):
for
input_name
in
input_list
.
keys
():
in_dict
[
input_list
[
input_name
]]
=
data
[
input_name
]
yield
in_dict
return
gen
...
...
@@ -80,24 +86,34 @@ def eval(config):
anno_file
=
dataset
.
get_anno
()
metric
=
COCOMetric
(
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
bias
=
0
,
IouType
=
'bbox'
)
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
for
batch_id
,
data
in
enumerate
(
val_loader
):
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
data_input
=
{}
for
k
,
v
in
data
.
items
():
if
k
in
config
[
'input_list'
]:
data_input
[
k
]
=
np
.
array
(
v
)
if
isinstance
(
config
[
'input_list'
],
list
):
if
k
in
config
[
'input_list'
]:
data_input
[
k
]
=
np
.
array
(
v
)
elif
isinstance
(
config
[
'input_list'
],
dict
):
if
k
in
config
[
'input_list'
].
keys
():
data_input
[
config
[
'input_list'
][
k
]]
=
np
.
array
(
v
)
outs
=
exe
.
run
(
val_program
,
feed
=
data_input
,
fetch_list
=
fetch_targets
,
return_numpy
=
False
)
res
=
{}
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
if
'arch'
in
config
and
config
[
'arch'
]
==
'YOLOv5'
:
postprocess
=
YOLOv5PostProcess
(
score_threshold
=
0.001
,
nms_threshold
=
0.6
,
multi_label
=
True
)
res
=
postprocess
(
np
.
array
(
outs
[
0
]),
data_all
[
'scale_factor'
])
else
:
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
metric
.
update
(
data_all
,
res
)
if
batch_id
%
100
==
0
:
...
...
@@ -112,24 +128,33 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
anno_file
=
dataset
.
get_anno
()
metric
=
COCOMetric
(
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
bias
=
1
,
IouType
=
'bbox'
)
anno_file
=
anno_file
,
clsid2catid
=
clsid2catid
,
IouType
=
'bbox'
)
for
batch_id
,
data
in
enumerate
(
val_loader
):
data_all
=
{
k
:
np
.
array
(
v
)
for
k
,
v
in
data
.
items
()}
data_input
=
{}
for
k
,
v
in
data
.
items
():
if
k
in
test_feed_names
:
data_input
[
k
]
=
np
.
array
(
v
)
if
isinstance
(
global_config
[
'input_list'
],
list
):
if
k
in
test_feed_names
:
data_input
[
k
]
=
np
.
array
(
v
)
elif
isinstance
(
global_config
[
'input_list'
],
dict
):
if
k
in
global_config
[
'input_list'
].
keys
():
data_input
[
global_config
[
'input_list'
][
k
]]
=
np
.
array
(
v
)
outs
=
exe
.
run
(
compiled_test_program
,
feed
=
data_input
,
fetch_list
=
test_fetch_list
,
return_numpy
=
False
)
res
=
{}
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
if
'arch'
in
global_config
and
global_config
[
'arch'
]
==
'YOLOv5'
:
postprocess
=
YOLOv5PostProcess
(
score_threshold
=
0.001
,
nms_threshold
=
0.6
,
multi_label
=
True
)
res
=
postprocess
(
np
.
array
(
outs
[
0
]),
data_all
[
'scale_factor'
])
else
:
for
out
in
outs
:
v
=
np
.
array
(
out
)
if
len
(
v
.
shape
)
>
1
:
res
[
'bbox'
]
=
v
else
:
res
[
'bbox_num'
]
=
v
metric
.
update
(
data_all
,
res
)
if
batch_id
%
100
==
0
:
...
...
@@ -142,6 +167,7 @@ def eval_function(exe, compiled_test_program, test_feed_names, test_fetch_list):
def
main
():
global
global_config
compress_config
,
train_config
,
global_config
=
load_slim_config
(
FLAGS
.
config_path
)
reader_cfg
=
load_config
(
global_config
[
'reader_config'
])
...
...
paddleslim/auto_compression/compressor.py
浏览文件 @
d207622b
...
...
@@ -571,10 +571,9 @@ class AutoCompression:
os
.
remove
(
os
.
path
.
join
(
self
.
tmp_dir
,
'best_model.pdparams'
))
if
'qat'
in
strategy
:
float_program
,
int8_program
=
convert
(
test_program_info
.
program
.
_program
,
self
.
_places
,
self
.
_quant_config
,
\
test_program
,
int8_program
=
convert
(
test
_program
,
self
.
_places
,
self
.
_quant_config
,
\
scope
=
paddle
.
static
.
global_scope
(),
\
save_int8
=
True
)
test_program_info
.
program
=
float_program
model_dir
=
os
.
path
.
join
(
self
.
tmp_dir
,
'strategy_{}'
.
format
(
str
(
strategy_idx
+
1
)))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录