未验证 提交 f3f5467f 编写于 作者: W whs 提交者: GitHub

Add inference and evaluation scripts for demo of segmentation (#1271)

上级 1d435df9
...@@ -8,7 +8,8 @@ ...@@ -8,7 +8,8 @@
- [3.2 准备数据集](#32-准备数据集) - [3.2 准备数据集](#32-准备数据集)
- [3.3 准备预测模型](#33-准备预测模型) - [3.3 准备预测模型](#33-准备预测模型)
- [3.4 自动压缩并产出模型](#34-自动压缩并产出模型) - [3.4 自动压缩并产出模型](#34-自动压缩并产出模型)
- [4.预测部署](#4预测部署) - [4.评估精度](#4评估精度)
- [5.预测部署](#5预测部署)
- [5.FAQ](5FAQ) - [5.FAQ](5FAQ)
## 1.简介 ## 1.简介
...@@ -23,13 +24,13 @@ ...@@ -23,13 +24,13 @@
|:-----:|:-----:|:----------:|:---------:| :------:|:------:|:------:| |:-----:|:-----:|:----------:|:---------:| :------:|:------:|:------:|
| PP-HumanSeg-Lite | Baseline | 92.87 | 56.363 |-| - | [model](https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz) | | PP-HumanSeg-Lite | Baseline | 92.87 | 56.363 |-| - | [model](https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz) |
| PP-HumanSeg-Lite | 非结构化稀疏+蒸馏 | 92.35 | 37.712 |-| [config](./configs/pp_human/pp_human_sparse.yaml)| - | | PP-HumanSeg-Lite | 非结构化稀疏+蒸馏 | 92.35 | 37.712 |-| [config](./configs/pp_human/pp_human_sparse.yaml)| - |
| PP-HumanSeg-Lite | 量化+蒸馏 | 92.84 | 49.656 |-| [config](./configs/pp_human/pp_human_qat.yaml) | - | | PP-HumanSeg-Lite | 量化+蒸馏 | 92.84 | 49.656 |-| [config](./configs/pp_human/pp_human_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip) (非最佳) |
| PP-Liteseg | Baseline | 77.04| - | 1.425| - |[model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-PPLIteSegSTDC1.zip)| | PP-Liteseg | Baseline | 77.04| - | 1.425| - |[model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-PPLIteSegSTDC1.zip)|
| PP-Liteseg | 量化训练 | 76.93 | - | 1.158|[config](./configs/pp_liteseg/pp_liteseg_qat.yaml) | - | | PP-Liteseg | 量化训练 | 76.93 | - | 1.158|[config](./configs/pp_liteseg/pp_liteseg_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp-liteseg.zip) |
| HRNet | Baseline | 78.97 | - |8.188|-| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-HRNetW18-Seg.zip)| | HRNet | Baseline | 78.97 | - |8.188|-| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-HRNetW18-Seg.zip)|
| HRNet | 量化训练 | 78.90 | - |5.812| [config](./configs/hrnet/hrnet_qat.yaml) | - | | HRNet | 量化训练 | 78.90 | - |5.812| [config](./configs/hrnet/hrnet_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/hrnet.zip) |
| UNet | Baseline | 65.00 | - |15.291|-| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-UNet.zip) | | UNet | Baseline | 65.00 | - |15.291|-| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-UNet.zip) |
| UNet | 量化训练 | 64.93 | - |10.228| [config](./configs/unet/unet_qat.yaml) | - | | UNet | 量化训练 | 64.93 | - |10.228| [config](./configs/unet/unet_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/unet.zip) |
| Deeplabv3-ResNet50 | Baseline | 79.90 | -|12.766| -| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-Deeplabv3-ResNet50.zip)| | Deeplabv3-ResNet50 | Baseline | 79.90 | -|12.766| -| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-Deeplabv3-ResNet50.zip)|
| Deeplabv3-ResNet50 | 量化训练 | 78.89 | - |8.839|[config](./configs/deeplabv3/deeplabv3_qat.yaml) | - | | Deeplabv3-ResNet50 | 量化训练 | 78.89 | - |8.839|[config](./configs/deeplabv3/deeplabv3_qat.yaml) | - |
...@@ -151,10 +152,136 @@ python -m paddle.distributed.launch run.py --config_path='./configs/pp_humanseg/ ...@@ -151,10 +152,136 @@ python -m paddle.distributed.launch run.py --config_path='./configs/pp_humanseg/
压缩完成后会在`save_dir`中产出压缩好的预测模型,可直接预测部署。 压缩完成后会在`save_dir`中产出压缩好的预测模型,可直接预测部署。
## 4.预测部署 ## 4.评估精度
本小节以人像分割模型和小数据集为例, 介绍如何在测试集上评估压缩后的模型.
下载经过量化训练压缩后的推理模型:
```
wget https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip
unzip pp_humanseg_qat.zip
```
通过以下命令下载人像分割示例数据:
```shell
cd ./data
python download_data.py mini_humanseg
cd -
```
执行以下命令评估模型在测试集上的精度:
```
python eval.py \
--model_dir ./pp_humanseg_qat \
--model_filename model.pdmodel \
--params_filename model.pdiparams \
--dataset_config configs/dataset/humanseg_dataset.yaml
```
## 5.预测部署
本小节以人像分割为例, 介绍如何使用Paddle Inference推理库执行压缩后的模型.
### 5.1 安装推理库
请参考该链接安装Python版本的PaddleInference推理库: [推理库安装教程](https://www.paddlepaddle.org.cn/inference/user_guides/download_lib.html#python)
### 5.2 准备模型和数据
[2.Benchmark](#2Benchmark) 的表格中获得压缩前后的推理模型的下载链接,执行以下命令下载并解压推理模型:
下载Float32数值类型的模型:
```
wget https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz
tar -xzf ppseg_lite_portrait_398x224_with_softmax.tar.gz
mv ppseg_lite_portrait_398x224_with_softmax pp_humanseg_fp32
```
下载经过量化训练压缩后的推理模型:
```
wget https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip
unzip pp_humanseg_qat.zip
```
准备好需要处理的图片,这里直接使用人像示例图片 `./data/human_demo.jpg`
### 5.3 执行推理
执行以下命令,直接使用飞桨框架的原生推理(仅支持Float32, 无需依赖TensorRT):
```
export CUDA_VISIBLE_DEVICES=0
python infer.py \
--image_file "./data/human_demo.jpg" \
--model_path "./pp_humanseg_fp32/model.pdmodel" \
--params_path "./pp_humanseg_fp32/model.pdiparams" \
--save_file "./humanseg_result_fp32.png" \
--dataset "human" \
--benchmark True \
--precision "fp32"
```
执行以下命令,使用Int8推理:
```
export CUDA_VISIBLE_DEVICES=0
python infer.py \
--image_file "./data/human_demo.jpg" \
--model_path "./pp_humanseg_qat/model.pdmodel" \
--params_path "./pp_humanseg_qat/model.pdiparams" \
--save_file "./humanseg_result_qat.png" \
--dataset "human" \
--benchmark True \
--use_trt True \
--precision "int8"
```
<table><tbody>
<tr>
<td>
原始图片
</td>
<td>
<img src="https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/images/humanseg_demo.jpeg" width="340" height="200">
</td>
</tr>
<tr>
<td>
FP32推理结果
</td>
<td>
<img src="https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/images/humanseg_result_fp32_demo.png" width="340" height="200">
</td>
</tr>
<tr>
<td>
Int8推理结果
</td>
<td>
<img src="https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/images/humanseg_result_qat_demo.png" width="340" height="200">
</td>
</tr>
</tbody></table>
执行以下命令查看更多关于 `infer.py` 使用说明:
```
python infer.py --help
```
### 5.4 更多部署教程
- [Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/python_inference.md) - [Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/python_inference.md)
- [Paddle Inference C++部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/cpp_inference.md) - [Paddle Inference C++部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/cpp_inference.md)
- [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/lite/lite.md) - [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/lite/lite.md)
## 5.FAQ ## 6.FAQ
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import random
import paddle
import numpy as np
from tqdm import tqdm
from paddleseg.cvlibs import Config as PaddleSegDataConfig
from paddleseg.utils import worker_init_fn
from paddleseg.core.infer import reverse_transform
from paddleseg.utils import metrics
def parse_args():
parser = argparse.ArgumentParser(description='Model evaluation')
parser.add_argument(
'--model_dir',
type=str,
default=None,
help="inference model directory.")
parser.add_argument(
'--model_filename',
type=str,
default=None,
help="inference model filename.")
parser.add_argument(
'--params_filename',
type=str,
default=None,
help="inference params filename.")
parser.add_argument(
'--dataset_config',
type=str,
default=None,
help="path of dataset config.")
return parser.parse_args()
def eval(args):
exe = paddle.static.Executor(paddle.CUDAPlace(0))
inference_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
args.model_dir,
exe,
model_filename=args.model_filename,
params_filename=args.params_filename)
data_cfg = PaddleSegDataConfig(args.dataset_config)
eval_dataset = data_cfg.val_dataset
batch_sampler = paddle.io.BatchSampler(
eval_dataset, batch_size=1, shuffle=False, drop_last=False)
loader = paddle.io.DataLoader(
eval_dataset,
batch_sampler=batch_sampler,
num_workers=1,
return_list=True, )
total_iters = len(loader)
intersect_area_all = 0
pred_area_all = 0
label_area_all = 0
print("Start evaluating (total_samples: {}, total_iters: {})...".format(
len(eval_dataset), total_iters))
for (image, label) in tqdm(loader):
label = np.array(label).astype('int64')
ori_shape = np.array(label).shape[-2:]
image = np.array(image)
logits = exe.run(inference_program,
feed={feed_target_names[0]: image},
fetch_list=fetch_targets,
return_numpy=True)
paddle.disable_static()
logit = logits[0]
logit = reverse_transform(
paddle.to_tensor(logit),
ori_shape,
eval_dataset.transforms.transforms,
mode='bilinear')
pred = paddle.to_tensor(logit)
if len(
pred.shape
) == 4: # for humanseg model whose prediction is distribution but not class id
pred = paddle.argmax(pred, axis=1, keepdim=True, dtype='int32')
intersect_area, pred_area, label_area = metrics.calculate_area(
pred,
paddle.to_tensor(label),
eval_dataset.num_classes,
ignore_index=eval_dataset.ignore_index)
intersect_area_all = intersect_area_all + intersect_area
pred_area_all = pred_area_all + pred_area
label_area_all = label_area_all + label_area
class_iou, miou = metrics.mean_iou(intersect_area_all, pred_area_all,
label_area_all)
class_acc, acc = metrics.accuracy(intersect_area_all, pred_area_all)
kappa = metrics.kappa(intersect_area_all, pred_area_all, label_area_all)
class_dice, mdice = metrics.dice(intersect_area_all, pred_area_all,
label_area_all)
infor = "[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format(
len(eval_dataset), miou, acc, kappa, mdice)
print(infor)
if __name__ == '__main__':
rank_id = paddle.distributed.get_rank()
place = paddle.CUDAPlace(rank_id)
args = parse_args()
paddle.enable_static()
eval(args)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
import argparse
import time
import PIL
from PIL import Image
import paddle
import paddleseg.transforms as T
from paddleseg.cvlibs import Config as PaddleSegDataConfig
from paddleseg.core.infer import reverse_transform
from paddleseg.utils import get_image_list
from paddleseg.utils.visualize import get_pseudo_color_map
from paddle.inference import create_predictor, PrecisionType
from paddle.inference import Config as PredictConfig
def _transforms(dataset):
transforms = []
if dataset == "human":
transforms.append(T.PaddingByAspectRatio(aspect_ratio=1.77777778))
transforms.append(T.Resize(target_size=[398, 224]))
transforms.append(T.Normalize())
elif dataset == "cityscape":
transforms.append(T.Normalize())
return transforms
return T.Compose(transforms)
def auto_tune_trt(args):
auto_tuned_shape_file = "./auto_tuning_shape"
pred_cfg = PredictConfig(args.model_path, args.params_path)
pred_cfg.enable_use_gpu(100, 0)
pred_cfg.collect_shape_range_info("./auto_tuning_shape")
predictor = create_predictor(pred_cfg)
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
transforms = _transforms(args.dataset)
transform = T.Compose(transforms)
img = cv2.imread(args.image_file).astype('float32')
data, _ = transform(img)
data = np.array(data)[np.newaxis, :]
input_handle.reshape(data.shape)
input_handle.copy_from_cpu(data)
predictor.run()
return auto_tuned_shape_file
def load_predictor(args):
pred_cfg = PredictConfig(args.model_path, args.params_path)
pred_cfg.disable_glog_info()
pred_cfg.enable_memory_optim()
pred_cfg.switch_ir_optim(True)
if args.device == "GPU":
pred_cfg.enable_use_gpu(100, 0)
if args.use_trt:
# To collect the dynamic shapes of inputs for TensorRT engine
auto_tuned_shape_file = auto_tune_trt(args)
precision_map = {
"fp16": PrecisionType.Half,
"fp32": PrecisionType.Float32,
"int8": PrecisionType.Int8
}
pred_cfg.enable_tensorrt_engine(
workspace_size=1 << 30,
max_batch_size=1,
min_subgraph_size=4,
precision_mode=precision_map[args.precision],
use_static=False,
use_calib_mode=False)
allow_build_at_runtime = True
pred_cfg.enable_tuned_tensorrt_dynamic_shape(auto_tuned_shape_file,
allow_build_at_runtime)
predictor = create_predictor(pred_cfg)
return predictor
def predict_image(args, predictor):
transforms = _transforms(args.dataset)
transform = T.Compose(transforms)
# Step1: Load image and preprocess
im = cv2.imread(args.image_file).astype('float32')
data, _ = transform(im)
data = np.array(data)[np.newaxis, :]
# Step2: Inference
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
input_handle.reshape(data.shape)
input_handle.copy_from_cpu(data)
warmup, repeats = 0, 1
if args.benchmark:
warmup, repeats = 20, 100
for i in range(warmup):
predictor.run()
start_time = time.time()
for i in range(repeats):
predictor.run()
results = output_handle.copy_to_cpu()
total_time = time.time() - start_time
avg_time = float(total_time) / repeats
print(f"Average inference time: \033[91m{round(avg_time*1000, 2)}ms\033[0m")
# Step3: Post process
if args.dataset == "human":
results = reverse_transform(
paddle.to_tensor(results), im.shape, transforms, mode='bilinear')
results = np.argmax(results, axis=1)
result = get_pseudo_color_map(results[0])
# Step4: Save result to file
if args.save_file is not None:
result.save(args.save_file)
print(f"Saved result to \033[91m{args.save_file}\033[0m")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--image_file', type=str, help="Image path to be processed.")
parser.add_argument(
'--save_file', type=str, help="The path to save the processed image.")
parser.add_argument(
'--model_path', type=str, help="Inference model filepath.")
parser.add_argument(
'--params_path', type=str, help="Inference parameters filepath.")
parser.add_argument(
'--dataset',
type=str,
default="human",
choices=["human", "cityscape"],
help="The type of given image which can be 'human' or 'cityscape'.")
parser.add_argument(
'--benchmark',
type=bool,
default=False,
help="Whether to run benchmark or not.")
parser.add_argument(
'--use_trt',
type=bool,
default=False,
help="Whether to use tensorrt engine or not.")
parser.add_argument(
'--device',
type=str,
default='GPU',
choices=["CPU", "GPU"],
help="Choose the device you want to run, it can be: CPU/GPU, default is GPU"
)
parser.add_argument(
'--precision',
type=str,
default='fp32',
choices=["fp32", "fp16", "int8"],
help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'."
)
args = parser.parse_args()
predictor = load_predictor(args)
predict_image(args, predictor)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册