Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
e77549b5
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e77549b5
编写于
7月 08, 2022
作者:
W
whs
提交者:
GitHub
7月 08, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add inference and evaluation scripts for demo of segmentation (#1275)
上级
d3d78272
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
445 addition
and
7 deletion
+445
-7
example/auto_compression/semantic_segmentation/README.md
example/auto_compression/semantic_segmentation/README.md
+134
-7
example/auto_compression/semantic_segmentation/data/cityscape_demo.jpg
...compression/semantic_segmentation/data/cityscape_demo.jpg
+0
-0
example/auto_compression/semantic_segmentation/data/human_demo.jpg
...uto_compression/semantic_segmentation/data/human_demo.jpg
+0
-0
example/auto_compression/semantic_segmentation/eval.py
example/auto_compression/semantic_segmentation/eval.py
+129
-0
example/auto_compression/semantic_segmentation/infer.py
example/auto_compression/semantic_segmentation/infer.py
+182
-0
未找到文件。
example/auto_compression/semantic_segmentation/README.md
浏览文件 @
e77549b5
...
...
@@ -8,7 +8,8 @@
-
[
3.2 准备数据集
](
#32-准备数据集
)
-
[
3.3 准备预测模型
](
#33-准备预测模型
)
-
[
3.4 自动压缩并产出模型
](
#34-自动压缩并产出模型
)
-
[
4.预测部署
](
#4预测部署
)
-
[
4.评估精度
](
#4评估精度
)
-
[
5.预测部署
](
#5预测部署
)
-
[
5.FAQ
](
5FAQ
)
## 1.简介
...
...
@@ -23,13 +24,13 @@
|:-----:|:-----:|:----------:|:---------:| :------:|:------:|:------:|
| PP-HumanSeg-Lite | Baseline | 92.87 | 56.363 |-| - |
[
model
](
https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz
)
|
| PP-HumanSeg-Lite | 非结构化稀疏+蒸馏 | 92.35 | 37.712 |-|
[
config
](
./configs/pp_human/pp_human_sparse.yaml
)
| - |
| PP-HumanSeg-Lite | 量化+蒸馏 | 92.84 | 49.656 |-|
[
config
](
./configs/pp_human/pp_human_qat.yaml
)
|
-
|
| PP-HumanSeg-Lite | 量化+蒸馏 | 92.84 | 49.656 |-|
[
config
](
./configs/pp_human/pp_human_qat.yaml
)
|
[
model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip
)
(
非最佳
)
|
| PP-Liteseg | Baseline | 77.04| - | 1.425| - |
[
model
](
https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-PPLIteSegSTDC1.zip
)
|
| PP-Liteseg | 量化训练 | 76.93 | - | 1.158|
[
config
](
./configs/pp_liteseg/pp_liteseg_qat.yaml
)
|
-
|
| PP-Liteseg | 量化训练 | 76.93 | - | 1.158|
[
config
](
./configs/pp_liteseg/pp_liteseg_qat.yaml
)
|
[
model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp-liteseg.zip
)
|
| HRNet | Baseline | 78.97 | - |8.188|-|
[
model
](
https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-HRNetW18-Seg.zip
)
|
| HRNet | 量化训练 | 78.90 | - |5.812|
[
config
](
./configs/hrnet/hrnet_qat.yaml
)
|
-
|
| HRNet | 量化训练 | 78.90 | - |5.812|
[
config
](
./configs/hrnet/hrnet_qat.yaml
)
|
[
model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/hrnet.zip
)
|
| UNet | Baseline | 65.00 | - |15.291|-|
[
model
](
https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-UNet.zip
)
|
| UNet | 量化训练 | 64.93 | - |10.228|
[
config
](
./configs/unet/unet_qat.yaml
)
|
-
|
| UNet | 量化训练 | 64.93 | - |10.228|
[
config
](
./configs/unet/unet_qat.yaml
)
|
[
model
](
https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/unet.zip
)
|
| Deeplabv3-ResNet50 | Baseline | 79.90 | -|12.766| -|
[
model
](
https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-Deeplabv3-ResNet50.zip
)
|
| Deeplabv3-ResNet50 | 量化训练 | 78.89 | - |8.839|
[
config
](
./configs/deeplabv3/deeplabv3_qat.yaml
)
| - |
...
...
@@ -187,10 +188,136 @@ python -m paddle.distributed.launch run.py \
压缩完成后会在
`save_dir`
中产出压缩好的预测模型,可直接预测部署。
## 4.预测部署
## 4.评估精度
本小节以人像分割模型和小数据集为例, 介绍如何在测试集上评估压缩后的模型.
下载经过量化训练压缩后的推理模型:
```
wget https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip
unzip pp_humanseg_qat.zip
```
通过以下命令下载人像分割示例数据:
```
shell
cd
./data
python download_data.py mini_humanseg
cd
-
```
执行以下命令评估模型在测试集上的精度:
```
python eval.py \
--model_dir ./pp_humanseg_qat \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--dataset_config configs/dataset/humanseg_dataset.yaml
```
## 5.预测部署
本小节以人像分割为例, 介绍如何使用Paddle Inference推理库执行压缩后的模型.
### 5.1 安装推理库
请参考该链接安装Python版本的PaddleInference推理库:
[
推理库安装教程
](
https://www.paddlepaddle.org.cn/inference/user_guides/download_lib.html#python
)
### 5.2 准备模型和数据
从
[
2.Benchmark
](
#2Benchmark
)
的表格中获得压缩前后的推理模型的下载链接,执行以下命令下载并解压推理模型:
下载Float32数值类型的模型:
```
wget https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz
tar -xzf ppseg_lite_portrait_398x224_with_softmax.tar.gz
mv ppseg_lite_portrait_398x224_with_softmax pp_humanseg_fp32
```
下载经过量化训练压缩后的推理模型:
```
wget https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip
unzip pp_humanseg_qat.zip
```
准备好需要处理的图片,这里直接使用人像示例图片
`./data/human_demo.jpg`
。
### 5.3 执行推理
执行以下命令,直接使用飞桨框架的原生推理(仅支持Float32, 无需依赖TensorRT):
```
export CUDA_VISIBLE_DEVICES=0
python infer.py \
--image_file "./data/human_demo.jpg" \
--model_path "./pp_humanseg_fp32/model.pdmodel" \
--params_path "./pp_humanseg_fp32/model.pdiparams" \
--save_file "./humanseg_result_fp32.png" \
--dataset "human" \
--benchmark True \
--precision "fp32"
```
执行以下命令,使用Int8推理:
```
export CUDA_VISIBLE_DEVICES=0
python infer.py \
--image_file "./data/human_demo.jpg" \
--model_path "./pp_humanseg_qat/model.pdmodel" \
--params_path "./pp_humanseg_qat/model.pdiparams" \
--save_file "./humanseg_result_qat.png" \
--dataset "human" \
--benchmark True \
--use_trt True \
--precision "int8"
```
<table><tbody>
<tr>
<td>
原始图片
</td>
<td>
<img
src=
"https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/images/humanseg_demo.jpeg"
width=
"340"
height=
"200"
>
</td>
</tr>
<tr>
<td>
FP32推理结果
</td>
<td>
<img
src=
"https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/images/humanseg_result_fp32_demo.png"
width=
"340"
height=
"200"
>
</td>
</tr>
<tr>
<td>
Int8推理结果
</td>
<td>
<img
src=
"https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/images/humanseg_result_qat_demo.png"
width=
"340"
height=
"200"
>
</td>
</tr>
</tbody></table>
执行以下命令查看更多关于
`infer.py`
使用说明:
```
python infer.py --help
```
### 5.4 更多部署教程
-
[
Paddle Inference Python部署
](
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/python_inference.md
)
-
[
Paddle Inference C++部署
](
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/cpp_inference.md
)
-
[
Paddle Lite部署
](
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/lite/lite.md
)
##
5
.FAQ
##
6
.FAQ
example/auto_compression/semantic_segmentation/data/cityscape_demo.jpg
0 → 100644
浏览文件 @
e77549b5
1.9 MB
example/auto_compression/semantic_segmentation/data/human_demo.jpg
0 → 100644
浏览文件 @
e77549b5
207.0 KB
example/auto_compression/semantic_segmentation/eval.py
0 → 100644
浏览文件 @
e77549b5
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
argparse
import
random
import
paddle
import
numpy
as
np
from
tqdm
import
tqdm
from
paddleseg.cvlibs
import
Config
as
PaddleSegDataConfig
from
paddleseg.utils
import
worker_init_fn
from
paddleseg.core.infer
import
reverse_transform
from
paddleseg.utils
import
metrics
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
'Model evaluation'
)
parser
.
add_argument
(
'--model_dir'
,
type
=
str
,
default
=
None
,
help
=
"inference model directory."
)
parser
.
add_argument
(
'--model_filename'
,
type
=
str
,
default
=
None
,
help
=
"inference model filename."
)
parser
.
add_argument
(
'--params_filename'
,
type
=
str
,
default
=
None
,
help
=
"inference params filename."
)
parser
.
add_argument
(
'--dataset_config'
,
type
=
str
,
default
=
None
,
help
=
"path of dataset config."
)
return
parser
.
parse_args
()
def
eval
(
args
):
exe
=
paddle
.
static
.
Executor
(
paddle
.
CUDAPlace
(
0
))
inference_program
,
feed_target_names
,
fetch_targets
=
paddle
.
static
.
load_inference_model
(
args
.
model_dir
,
exe
,
model_filename
=
args
.
model_filename
,
params_filename
=
args
.
params_filename
)
data_cfg
=
PaddleSegDataConfig
(
args
.
dataset_config
)
eval_dataset
=
data_cfg
.
val_dataset
batch_sampler
=
paddle
.
io
.
BatchSampler
(
eval_dataset
,
batch_size
=
1
,
shuffle
=
False
,
drop_last
=
False
)
loader
=
paddle
.
io
.
DataLoader
(
eval_dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
1
,
return_list
=
True
,
)
total_iters
=
len
(
loader
)
intersect_area_all
=
0
pred_area_all
=
0
label_area_all
=
0
print
(
"Start evaluating (total_samples: {}, total_iters: {})..."
.
format
(
len
(
eval_dataset
),
total_iters
))
for
(
image
,
label
)
in
tqdm
(
loader
):
label
=
np
.
array
(
label
).
astype
(
'int64'
)
ori_shape
=
np
.
array
(
label
).
shape
[
-
2
:]
image
=
np
.
array
(
image
)
logits
=
exe
.
run
(
inference_program
,
feed
=
{
feed_target_names
[
0
]:
image
},
fetch_list
=
fetch_targets
,
return_numpy
=
True
)
paddle
.
disable_static
()
logit
=
logits
[
0
]
logit
=
reverse_transform
(
paddle
.
to_tensor
(
logit
),
ori_shape
,
eval_dataset
.
transforms
.
transforms
,
mode
=
'bilinear'
)
pred
=
paddle
.
to_tensor
(
logit
)
if
len
(
pred
.
shape
)
==
4
:
# for humanseg model whose prediction is distribution but not class id
pred
=
paddle
.
argmax
(
pred
,
axis
=
1
,
keepdim
=
True
,
dtype
=
'int32'
)
intersect_area
,
pred_area
,
label_area
=
metrics
.
calculate_area
(
pred
,
paddle
.
to_tensor
(
label
),
eval_dataset
.
num_classes
,
ignore_index
=
eval_dataset
.
ignore_index
)
intersect_area_all
=
intersect_area_all
+
intersect_area
pred_area_all
=
pred_area_all
+
pred_area
label_area_all
=
label_area_all
+
label_area
class_iou
,
miou
=
metrics
.
mean_iou
(
intersect_area_all
,
pred_area_all
,
label_area_all
)
class_acc
,
acc
=
metrics
.
accuracy
(
intersect_area_all
,
pred_area_all
)
kappa
=
metrics
.
kappa
(
intersect_area_all
,
pred_area_all
,
label_area_all
)
class_dice
,
mdice
=
metrics
.
dice
(
intersect_area_all
,
pred_area_all
,
label_area_all
)
infor
=
"[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}"
.
format
(
len
(
eval_dataset
),
miou
,
acc
,
kappa
,
mdice
)
print
(
infor
)
if
__name__
==
'__main__'
:
rank_id
=
paddle
.
distributed
.
get_rank
()
place
=
paddle
.
CUDAPlace
(
rank_id
)
args
=
parse_args
()
paddle
.
enable_static
()
eval
(
args
)
example/auto_compression/semantic_segmentation/infer.py
0 → 100644
浏览文件 @
e77549b5
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
cv2
import
numpy
as
np
import
argparse
import
time
import
PIL
from
PIL
import
Image
import
paddle
import
paddleseg.transforms
as
T
from
paddleseg.cvlibs
import
Config
as
PaddleSegDataConfig
from
paddleseg.core.infer
import
reverse_transform
from
paddleseg.utils
import
get_image_list
from
paddleseg.utils.visualize
import
get_pseudo_color_map
from
paddle.inference
import
create_predictor
,
PrecisionType
from
paddle.inference
import
Config
as
PredictConfig
def
_transforms
(
dataset
):
transforms
=
[]
if
dataset
==
"human"
:
transforms
.
append
(
T
.
PaddingByAspectRatio
(
aspect_ratio
=
1.77777778
))
transforms
.
append
(
T
.
Resize
(
target_size
=
[
398
,
224
]))
transforms
.
append
(
T
.
Normalize
())
elif
dataset
==
"cityscape"
:
transforms
.
append
(
T
.
Normalize
())
return
transforms
return
T
.
Compose
(
transforms
)
def
auto_tune_trt
(
args
):
auto_tuned_shape_file
=
"./auto_tuning_shape"
pred_cfg
=
PredictConfig
(
args
.
model_path
,
args
.
params_path
)
pred_cfg
.
enable_use_gpu
(
100
,
0
)
pred_cfg
.
collect_shape_range_info
(
"./auto_tuning_shape"
)
predictor
=
create_predictor
(
pred_cfg
)
input_names
=
predictor
.
get_input_names
()
input_handle
=
predictor
.
get_input_handle
(
input_names
[
0
])
transforms
=
_transforms
(
args
.
dataset
)
transform
=
T
.
Compose
(
transforms
)
img
=
cv2
.
imread
(
args
.
image_file
).
astype
(
'float32'
)
data
,
_
=
transform
(
img
)
data
=
np
.
array
(
data
)[
np
.
newaxis
,
:]
input_handle
.
reshape
(
data
.
shape
)
input_handle
.
copy_from_cpu
(
data
)
predictor
.
run
()
return
auto_tuned_shape_file
def
load_predictor
(
args
):
pred_cfg
=
PredictConfig
(
args
.
model_path
,
args
.
params_path
)
pred_cfg
.
disable_glog_info
()
pred_cfg
.
enable_memory_optim
()
pred_cfg
.
switch_ir_optim
(
True
)
if
args
.
device
==
"GPU"
:
pred_cfg
.
enable_use_gpu
(
100
,
0
)
if
args
.
use_trt
:
# To collect the dynamic shapes of inputs for TensorRT engine
auto_tuned_shape_file
=
auto_tune_trt
(
args
)
precision_map
=
{
"fp16"
:
PrecisionType
.
Half
,
"fp32"
:
PrecisionType
.
Float32
,
"int8"
:
PrecisionType
.
Int8
}
pred_cfg
.
enable_tensorrt_engine
(
workspace_size
=
1
<<
30
,
max_batch_size
=
1
,
min_subgraph_size
=
4
,
precision_mode
=
precision_map
[
args
.
precision
],
use_static
=
False
,
use_calib_mode
=
False
)
allow_build_at_runtime
=
True
pred_cfg
.
enable_tuned_tensorrt_dynamic_shape
(
auto_tuned_shape_file
,
allow_build_at_runtime
)
predictor
=
create_predictor
(
pred_cfg
)
return
predictor
def
predict_image
(
args
,
predictor
):
transforms
=
_transforms
(
args
.
dataset
)
transform
=
T
.
Compose
(
transforms
)
# Step1: Load image and preprocess
im
=
cv2
.
imread
(
args
.
image_file
).
astype
(
'float32'
)
data
,
_
=
transform
(
im
)
data
=
np
.
array
(
data
)[
np
.
newaxis
,
:]
# Step2: Inference
input_names
=
predictor
.
get_input_names
()
input_handle
=
predictor
.
get_input_handle
(
input_names
[
0
])
output_names
=
predictor
.
get_output_names
()
output_handle
=
predictor
.
get_output_handle
(
output_names
[
0
])
input_handle
.
reshape
(
data
.
shape
)
input_handle
.
copy_from_cpu
(
data
)
warmup
,
repeats
=
0
,
1
if
args
.
benchmark
:
warmup
,
repeats
=
20
,
100
for
i
in
range
(
warmup
):
predictor
.
run
()
start_time
=
time
.
time
()
for
i
in
range
(
repeats
):
predictor
.
run
()
results
=
output_handle
.
copy_to_cpu
()
total_time
=
time
.
time
()
-
start_time
avg_time
=
float
(
total_time
)
/
repeats
print
(
f
"Average inference time:
\033
[91m
{
round
(
avg_time
*
1000
,
2
)
}
ms
\033
[0m"
)
# Step3: Post process
if
args
.
dataset
==
"human"
:
results
=
reverse_transform
(
paddle
.
to_tensor
(
results
),
im
.
shape
,
transforms
,
mode
=
'bilinear'
)
results
=
np
.
argmax
(
results
,
axis
=
1
)
result
=
get_pseudo_color_map
(
results
[
0
])
# Step4: Save result to file
if
args
.
save_file
is
not
None
:
result
.
save
(
args
.
save_file
)
print
(
f
"Saved result to
\033
[91m
{
args
.
save_file
}
\033
[0m"
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--image_file'
,
type
=
str
,
help
=
"Image path to be processed."
)
parser
.
add_argument
(
'--save_file'
,
type
=
str
,
help
=
"The path to save the processed image."
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
"Inference model filepath."
)
parser
.
add_argument
(
'--params_path'
,
type
=
str
,
help
=
"Inference parameters filepath."
)
parser
.
add_argument
(
'--dataset'
,
type
=
str
,
default
=
"human"
,
choices
=
[
"human"
,
"cityscape"
],
help
=
"The type of given image which can be 'human' or 'cityscape'."
)
parser
.
add_argument
(
'--benchmark'
,
type
=
bool
,
default
=
False
,
help
=
"Whether to run benchmark or not."
)
parser
.
add_argument
(
'--use_trt'
,
type
=
bool
,
default
=
False
,
help
=
"Whether to use tensorrt engine or not."
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'GPU'
,
choices
=
[
"CPU"
,
"GPU"
],
help
=
"Choose the device you want to run, it can be: CPU/GPU, default is GPU"
)
parser
.
add_argument
(
'--precision'
,
type
=
str
,
default
=
'fp32'
,
choices
=
[
"fp32"
,
"fp16"
,
"int8"
],
help
=
"The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'."
)
args
=
parser
.
parse_args
()
predictor
=
load_predictor
(
args
)
predict_image
(
args
,
predictor
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录