diff --git a/example/auto_compression/semantic_segmentation/README.md b/example/auto_compression/semantic_segmentation/README.md index adddcc1a511e96608813b001bf9ea2af2b19b0dd..aa3842ebb99094185b854705122dfd5f7f255728 100644 --- a/example/auto_compression/semantic_segmentation/README.md +++ b/example/auto_compression/semantic_segmentation/README.md @@ -8,7 +8,8 @@ - [3.2 准备数据集](#32-准备数据集) - [3.3 准备预测模型](#33-准备预测模型) - [3.4 自动压缩并产出模型](#34-自动压缩并产出模型) -- [4.预测部署](#4预测部署) +- [4.评估精度](#4评估精度) +- [5.预测部署](#5预测部署) - [5.FAQ](5FAQ) ## 1.简介 @@ -23,13 +24,13 @@ |:-----:|:-----:|:----------:|:---------:| :------:|:------:|:------:| | PP-HumanSeg-Lite | Baseline | 92.87 | 56.363 |-| - | [model](https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz) | | PP-HumanSeg-Lite | 非结构化稀疏+蒸馏 | 92.35 | 37.712 |-| [config](./configs/pp_human/pp_human_sparse.yaml)| - | -| PP-HumanSeg-Lite | 量化+蒸馏 | 92.84 | 49.656 |-| [config](./configs/pp_human/pp_human_qat.yaml) | - | +| PP-HumanSeg-Lite | 量化+蒸馏 | 92.84 | 49.656 |-| [config](./configs/pp_human/pp_human_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip) (非最佳) | | PP-Liteseg | Baseline | 77.04| - | 1.425| - |[model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-PPLIteSegSTDC1.zip)| -| PP-Liteseg | 量化训练 | 76.93 | - | 1.158|[config](./configs/pp_liteseg/pp_liteseg_qat.yaml) | - | +| PP-Liteseg | 量化训练 | 76.93 | - | 1.158|[config](./configs/pp_liteseg/pp_liteseg_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp-liteseg.zip) | | HRNet | Baseline | 78.97 | - |8.188|-| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-HRNetW18-Seg.zip)| -| HRNet | 量化训练 | 78.90 | - |5.812| [config](./configs/hrnet/hrnet_qat.yaml) | - | +| HRNet | 量化训练 | 78.90 | - |5.812| [config](./configs/hrnet/hrnet_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/hrnet.zip) | | UNet | Baseline | 65.00 | - |15.291|-| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-UNet.zip) | -| UNet | 量化训练 | 64.93 | - |10.228| [config](./configs/unet/unet_qat.yaml) | - | +| UNet | 量化训练 | 64.93 | - |10.228| [config](./configs/unet/unet_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/unet.zip) | | Deeplabv3-ResNet50 | Baseline | 79.90 | -|12.766| -| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-Deeplabv3-ResNet50.zip)| | Deeplabv3-ResNet50 | 量化训练 | 78.89 | - |8.839|[config](./configs/deeplabv3/deeplabv3_qat.yaml) | - | @@ -151,10 +152,136 @@ python -m paddle.distributed.launch run.py --config_path='./configs/pp_humanseg/ 压缩完成后会在`save_dir`中产出压缩好的预测模型,可直接预测部署。 -## 4.预测部署 +## 4.评估精度 + +本小节以人像分割模型和小数据集为例, 介绍如何在测试集上评估压缩后的模型. + +下载经过量化训练压缩后的推理模型: +``` +wget https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip +unzip pp_humanseg_qat.zip +``` + +通过以下命令下载人像分割示例数据: + +```shell +cd ./data +python download_data.py mini_humanseg +cd - + +``` + +执行以下命令评估模型在测试集上的精度: + +``` +python eval.py \ +--model_dir ./pp_humanseg_qat \ +--model_filename model.pdmodel \ +--params_filename model.pdiparams \ +--dataset_config configs/dataset/humanseg_dataset.yaml + +``` + +## 5.预测部署 + +本小节以人像分割为例, 介绍如何使用Paddle Inference推理库执行压缩后的模型. + +### 5.1 安装推理库 + +请参考该链接安装Python版本的PaddleInference推理库: [推理库安装教程](https://www.paddlepaddle.org.cn/inference/user_guides/download_lib.html#python) + +### 5.2 准备模型和数据 + +从 [2.Benchmark](#2Benchmark) 的表格中获得压缩前后的推理模型的下载链接,执行以下命令下载并解压推理模型: + +下载Float32数值类型的模型: +``` +wget https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz +tar -xzf ppseg_lite_portrait_398x224_with_softmax.tar.gz +mv ppseg_lite_portrait_398x224_with_softmax pp_humanseg_fp32 +``` + +下载经过量化训练压缩后的推理模型: +``` +wget https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip +unzip pp_humanseg_qat.zip +``` + +准备好需要处理的图片,这里直接使用人像示例图片 `./data/human_demo.jpg`。 + +### 5.3 执行推理 + +执行以下命令,直接使用飞桨框架的原生推理(仅支持Float32, 无需依赖TensorRT): + +``` +export CUDA_VISIBLE_DEVICES=0 +python infer.py \ +--image_file "./data/human_demo.jpg" \ +--model_path "./pp_humanseg_fp32/model.pdmodel" \ +--params_path "./pp_humanseg_fp32/model.pdiparams" \ +--save_file "./humanseg_result_fp32.png" \ +--dataset "human" \ +--benchmark True \ +--precision "fp32" +``` + +执行以下命令,使用Int8推理: + +``` +export CUDA_VISIBLE_DEVICES=0 +python infer.py \ +--image_file "./data/human_demo.jpg" \ +--model_path "./pp_humanseg_qat/model.pdmodel" \ +--params_path "./pp_humanseg_qat/model.pdiparams" \ +--save_file "./humanseg_result_qat.png" \ +--dataset "human" \ +--benchmark True \ +--use_trt True \ +--precision "int8" +``` + + + + + + + + + + + + + + + + + + +
+原始图片 + + +
+FP32推理结果 + + +
+Int8推理结果 + + +
+ +执行以下命令查看更多关于 `infer.py` 使用说明: + +``` +python infer.py --help +``` + + +### 5.4 更多部署教程 - [Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/python_inference.md) - [Paddle Inference C++部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/cpp_inference.md) - [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/lite/lite.md) -## 5.FAQ +## 6.FAQ diff --git a/example/auto_compression/semantic_segmentation/data/cityscape_demo.jpg b/example/auto_compression/semantic_segmentation/data/cityscape_demo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..de0b21d160edc2dd4c7c8d553a5a5f090b4bfd5b Binary files /dev/null and b/example/auto_compression/semantic_segmentation/data/cityscape_demo.jpg differ diff --git a/example/auto_compression/semantic_segmentation/data/human_demo.jpg b/example/auto_compression/semantic_segmentation/data/human_demo.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9b499ab55085fa6da1d776bca16da94091a14d81 Binary files /dev/null and b/example/auto_compression/semantic_segmentation/data/human_demo.jpg differ diff --git a/example/auto_compression/semantic_segmentation/eval.py b/example/auto_compression/semantic_segmentation/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..5206200568e5d0780bd23c68beb6951b49e4e1a2 --- /dev/null +++ b/example/auto_compression/semantic_segmentation/eval.py @@ -0,0 +1,129 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import random +import paddle +import numpy as np +from tqdm import tqdm +from paddleseg.cvlibs import Config as PaddleSegDataConfig +from paddleseg.utils import worker_init_fn + +from paddleseg.core.infer import reverse_transform +from paddleseg.utils import metrics + + +def parse_args(): + parser = argparse.ArgumentParser(description='Model evaluation') + parser.add_argument( + '--model_dir', + type=str, + default=None, + help="inference model directory.") + parser.add_argument( + '--model_filename', + type=str, + default=None, + help="inference model filename.") + parser.add_argument( + '--params_filename', + type=str, + default=None, + help="inference params filename.") + parser.add_argument( + '--dataset_config', + type=str, + default=None, + help="path of dataset config.") + return parser.parse_args() + + +def eval(args): + exe = paddle.static.Executor(paddle.CUDAPlace(0)) + inference_program, feed_target_names, fetch_targets = paddle.static.load_inference_model( + args.model_dir, + exe, + model_filename=args.model_filename, + params_filename=args.params_filename) + + data_cfg = PaddleSegDataConfig(args.dataset_config) + eval_dataset = data_cfg.val_dataset + + batch_sampler = paddle.io.BatchSampler( + eval_dataset, batch_size=1, shuffle=False, drop_last=False) + loader = paddle.io.DataLoader( + eval_dataset, + batch_sampler=batch_sampler, + num_workers=1, + return_list=True, ) + + total_iters = len(loader) + intersect_area_all = 0 + pred_area_all = 0 + label_area_all = 0 + + print("Start evaluating (total_samples: {}, total_iters: {})...".format( + len(eval_dataset), total_iters)) + + for (image, label) in tqdm(loader): + label = np.array(label).astype('int64') + ori_shape = np.array(label).shape[-2:] + image = np.array(image) + logits = exe.run(inference_program, + feed={feed_target_names[0]: image}, + fetch_list=fetch_targets, + return_numpy=True) + + paddle.disable_static() + logit = logits[0] + + logit = reverse_transform( + paddle.to_tensor(logit), + ori_shape, + eval_dataset.transforms.transforms, + mode='bilinear') + pred = paddle.to_tensor(logit) + if len( + pred.shape + ) == 4: # for humanseg model whose prediction is distribution but not class id + pred = paddle.argmax(pred, axis=1, keepdim=True, dtype='int32') + + intersect_area, pred_area, label_area = metrics.calculate_area( + pred, + paddle.to_tensor(label), + eval_dataset.num_classes, + ignore_index=eval_dataset.ignore_index) + intersect_area_all = intersect_area_all + intersect_area + pred_area_all = pred_area_all + pred_area + label_area_all = label_area_all + label_area + + class_iou, miou = metrics.mean_iou(intersect_area_all, pred_area_all, + label_area_all) + class_acc, acc = metrics.accuracy(intersect_area_all, pred_area_all) + kappa = metrics.kappa(intersect_area_all, pred_area_all, label_area_all) + class_dice, mdice = metrics.dice(intersect_area_all, pred_area_all, + label_area_all) + + infor = "[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format( + len(eval_dataset), miou, acc, kappa, mdice) + print(infor) + + +if __name__ == '__main__': + rank_id = paddle.distributed.get_rank() + place = paddle.CUDAPlace(rank_id) + args = parse_args() + paddle.enable_static() + eval(args) diff --git a/example/auto_compression/semantic_segmentation/infer.py b/example/auto_compression/semantic_segmentation/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..5e556256b5c714c630fbcb3813efd6f2ccc6af75 --- /dev/null +++ b/example/auto_compression/semantic_segmentation/infer.py @@ -0,0 +1,182 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import cv2 +import numpy as np +import argparse +import time +import PIL +from PIL import Image +import paddle +import paddleseg.transforms as T +from paddleseg.cvlibs import Config as PaddleSegDataConfig +from paddleseg.core.infer import reverse_transform +from paddleseg.utils import get_image_list +from paddleseg.utils.visualize import get_pseudo_color_map +from paddle.inference import create_predictor, PrecisionType +from paddle.inference import Config as PredictConfig + + +def _transforms(dataset): + transforms = [] + if dataset == "human": + transforms.append(T.PaddingByAspectRatio(aspect_ratio=1.77777778)) + transforms.append(T.Resize(target_size=[398, 224])) + transforms.append(T.Normalize()) + elif dataset == "cityscape": + transforms.append(T.Normalize()) + return transforms + return T.Compose(transforms) + + +def auto_tune_trt(args): + auto_tuned_shape_file = "./auto_tuning_shape" + pred_cfg = PredictConfig(args.model_path, args.params_path) + pred_cfg.enable_use_gpu(100, 0) + pred_cfg.collect_shape_range_info("./auto_tuning_shape") + predictor = create_predictor(pred_cfg) + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + transforms = _transforms(args.dataset) + transform = T.Compose(transforms) + img = cv2.imread(args.image_file).astype('float32') + data, _ = transform(img) + data = np.array(data)[np.newaxis, :] + input_handle.reshape(data.shape) + input_handle.copy_from_cpu(data) + predictor.run() + return auto_tuned_shape_file + + +def load_predictor(args): + pred_cfg = PredictConfig(args.model_path, args.params_path) + pred_cfg.disable_glog_info() + pred_cfg.enable_memory_optim() + pred_cfg.switch_ir_optim(True) + if args.device == "GPU": + pred_cfg.enable_use_gpu(100, 0) + + if args.use_trt: + # To collect the dynamic shapes of inputs for TensorRT engine + auto_tuned_shape_file = auto_tune_trt(args) + precision_map = { + "fp16": PrecisionType.Half, + "fp32": PrecisionType.Float32, + "int8": PrecisionType.Int8 + } + pred_cfg.enable_tensorrt_engine( + workspace_size=1 << 30, + max_batch_size=1, + min_subgraph_size=4, + precision_mode=precision_map[args.precision], + use_static=False, + use_calib_mode=False) + allow_build_at_runtime = True + pred_cfg.enable_tuned_tensorrt_dynamic_shape(auto_tuned_shape_file, + allow_build_at_runtime) + predictor = create_predictor(pred_cfg) + return predictor + + +def predict_image(args, predictor): + + transforms = _transforms(args.dataset) + transform = T.Compose(transforms) + + # Step1: Load image and preprocess + im = cv2.imread(args.image_file).astype('float32') + data, _ = transform(im) + data = np.array(data)[np.newaxis, :] + + # Step2: Inference + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + input_handle.reshape(data.shape) + input_handle.copy_from_cpu(data) + + warmup, repeats = 0, 1 + if args.benchmark: + warmup, repeats = 20, 100 + + for i in range(warmup): + predictor.run() + + start_time = time.time() + for i in range(repeats): + predictor.run() + results = output_handle.copy_to_cpu() + total_time = time.time() - start_time + avg_time = float(total_time) / repeats + print(f"Average inference time: \033[91m{round(avg_time*1000, 2)}ms\033[0m") + + # Step3: Post process + if args.dataset == "human": + results = reverse_transform( + paddle.to_tensor(results), im.shape, transforms, mode='bilinear') + results = np.argmax(results, axis=1) + result = get_pseudo_color_map(results[0]) + + # Step4: Save result to file + if args.save_file is not None: + result.save(args.save_file) + print(f"Saved result to \033[91m{args.save_file}\033[0m") + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument( + '--image_file', type=str, help="Image path to be processed.") + parser.add_argument( + '--save_file', type=str, help="The path to save the processed image.") + parser.add_argument( + '--model_path', type=str, help="Inference model filepath.") + parser.add_argument( + '--params_path', type=str, help="Inference parameters filepath.") + parser.add_argument( + '--dataset', + type=str, + default="human", + choices=["human", "cityscape"], + help="The type of given image which can be 'human' or 'cityscape'.") + parser.add_argument( + '--benchmark', + type=bool, + default=False, + help="Whether to run benchmark or not.") + parser.add_argument( + '--use_trt', + type=bool, + default=False, + help="Whether to use tensorrt engine or not.") + parser.add_argument( + '--device', + type=str, + default='GPU', + choices=["CPU", "GPU"], + help="Choose the device you want to run, it can be: CPU/GPU, default is GPU" + ) + parser.add_argument( + '--precision', + type=str, + default='fp32', + choices=["fp32", "fp16", "int8"], + help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'." + ) + args = parser.parse_args() + predictor = load_predictor(args) + predict_image(args, predictor)