diff --git a/example/auto_compression/semantic_segmentation/README.md b/example/auto_compression/semantic_segmentation/README.md
index adddcc1a511e96608813b001bf9ea2af2b19b0dd..aa3842ebb99094185b854705122dfd5f7f255728 100644
--- a/example/auto_compression/semantic_segmentation/README.md
+++ b/example/auto_compression/semantic_segmentation/README.md
@@ -8,7 +8,8 @@
- [3.2 准备数据集](#32-准备数据集)
- [3.3 准备预测模型](#33-准备预测模型)
- [3.4 自动压缩并产出模型](#34-自动压缩并产出模型)
-- [4.预测部署](#4预测部署)
+- [4.评估精度](#4评估精度)
+- [5.预测部署](#5预测部署)
- [5.FAQ](5FAQ)
## 1.简介
@@ -23,13 +24,13 @@
|:-----:|:-----:|:----------:|:---------:| :------:|:------:|:------:|
| PP-HumanSeg-Lite | Baseline | 92.87 | 56.363 |-| - | [model](https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz) |
| PP-HumanSeg-Lite | 非结构化稀疏+蒸馏 | 92.35 | 37.712 |-| [config](./configs/pp_human/pp_human_sparse.yaml)| - |
-| PP-HumanSeg-Lite | 量化+蒸馏 | 92.84 | 49.656 |-| [config](./configs/pp_human/pp_human_qat.yaml) | - |
+| PP-HumanSeg-Lite | 量化+蒸馏 | 92.84 | 49.656 |-| [config](./configs/pp_human/pp_human_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip) (非最佳) |
| PP-Liteseg | Baseline | 77.04| - | 1.425| - |[model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-PPLIteSegSTDC1.zip)|
-| PP-Liteseg | 量化训练 | 76.93 | - | 1.158|[config](./configs/pp_liteseg/pp_liteseg_qat.yaml) | - |
+| PP-Liteseg | 量化训练 | 76.93 | - | 1.158|[config](./configs/pp_liteseg/pp_liteseg_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp-liteseg.zip) |
| HRNet | Baseline | 78.97 | - |8.188|-| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-HRNetW18-Seg.zip)|
-| HRNet | 量化训练 | 78.90 | - |5.812| [config](./configs/hrnet/hrnet_qat.yaml) | - |
+| HRNet | 量化训练 | 78.90 | - |5.812| [config](./configs/hrnet/hrnet_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/hrnet.zip) |
| UNet | Baseline | 65.00 | - |15.291|-| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-UNet.zip) |
-| UNet | 量化训练 | 64.93 | - |10.228| [config](./configs/unet/unet_qat.yaml) | - |
+| UNet | 量化训练 | 64.93 | - |10.228| [config](./configs/unet/unet_qat.yaml) | [model](https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/unet.zip) |
| Deeplabv3-ResNet50 | Baseline | 79.90 | -|12.766| -| [model](https://paddleseg.bj.bcebos.com/tipc/easyedge/RES-paddle2-Deeplabv3-ResNet50.zip)|
| Deeplabv3-ResNet50 | 量化训练 | 78.89 | - |8.839|[config](./configs/deeplabv3/deeplabv3_qat.yaml) | - |
@@ -151,10 +152,136 @@ python -m paddle.distributed.launch run.py --config_path='./configs/pp_humanseg/
压缩完成后会在`save_dir`中产出压缩好的预测模型,可直接预测部署。
-## 4.预测部署
+## 4.评估精度
+
+本小节以人像分割模型和小数据集为例, 介绍如何在测试集上评估压缩后的模型.
+
+下载经过量化训练压缩后的推理模型:
+```
+wget https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip
+unzip pp_humanseg_qat.zip
+```
+
+通过以下命令下载人像分割示例数据:
+
+```shell
+cd ./data
+python download_data.py mini_humanseg
+cd -
+
+```
+
+执行以下命令评估模型在测试集上的精度:
+
+```
+python eval.py \
+--model_dir ./pp_humanseg_qat \
+--model_filename model.pdmodel \
+--params_filename model.pdiparams \
+--dataset_config configs/dataset/humanseg_dataset.yaml
+
+```
+
+## 5.预测部署
+
+本小节以人像分割为例, 介绍如何使用Paddle Inference推理库执行压缩后的模型.
+
+### 5.1 安装推理库
+
+请参考该链接安装Python版本的PaddleInference推理库: [推理库安装教程](https://www.paddlepaddle.org.cn/inference/user_guides/download_lib.html#python)
+
+### 5.2 准备模型和数据
+
+从 [2.Benchmark](#2Benchmark) 的表格中获得压缩前后的推理模型的下载链接,执行以下命令下载并解压推理模型:
+
+下载Float32数值类型的模型:
+```
+wget https://paddleseg.bj.bcebos.com/dygraph/ppseg/ppseg_lite_portrait_398x224_with_softmax.tar.gz
+tar -xzf ppseg_lite_portrait_398x224_with_softmax.tar.gz
+mv ppseg_lite_portrait_398x224_with_softmax pp_humanseg_fp32
+```
+
+下载经过量化训练压缩后的推理模型:
+```
+wget https://bj.bcebos.com/v1/paddle-slim-models/act/PaddleSeg/qat/pp_humanseg_qat.zip
+unzip pp_humanseg_qat.zip
+```
+
+准备好需要处理的图片,这里直接使用人像示例图片 `./data/human_demo.jpg`。
+
+### 5.3 执行推理
+
+执行以下命令,直接使用飞桨框架的原生推理(仅支持Float32, 无需依赖TensorRT):
+
+```
+export CUDA_VISIBLE_DEVICES=0
+python infer.py \
+--image_file "./data/human_demo.jpg" \
+--model_path "./pp_humanseg_fp32/model.pdmodel" \
+--params_path "./pp_humanseg_fp32/model.pdiparams" \
+--save_file "./humanseg_result_fp32.png" \
+--dataset "human" \
+--benchmark True \
+--precision "fp32"
+```
+
+执行以下命令,使用Int8推理:
+
+```
+export CUDA_VISIBLE_DEVICES=0
+python infer.py \
+--image_file "./data/human_demo.jpg" \
+--model_path "./pp_humanseg_qat/model.pdmodel" \
+--params_path "./pp_humanseg_qat/model.pdiparams" \
+--save_file "./humanseg_result_qat.png" \
+--dataset "human" \
+--benchmark True \
+--use_trt True \
+--precision "int8"
+```
+
+
+
+
+
+原始图片
+ |
+
+
+ |
+
+
+
+
+FP32推理结果
+ |
+
+
+ |
+
+
+
+
+Int8推理结果
+ |
+
+
+ |
+
+
+
+
+执行以下命令查看更多关于 `infer.py` 使用说明:
+
+```
+python infer.py --help
+```
+
+
+### 5.4 更多部署教程
- [Paddle Inference Python部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/python_inference.md)
- [Paddle Inference C++部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/inference/cpp_inference.md)
- [Paddle Lite部署](https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/deployment/lite/lite.md)
-## 5.FAQ
+## 6.FAQ
diff --git a/example/auto_compression/semantic_segmentation/data/cityscape_demo.jpg b/example/auto_compression/semantic_segmentation/data/cityscape_demo.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..de0b21d160edc2dd4c7c8d553a5a5f090b4bfd5b
Binary files /dev/null and b/example/auto_compression/semantic_segmentation/data/cityscape_demo.jpg differ
diff --git a/example/auto_compression/semantic_segmentation/data/human_demo.jpg b/example/auto_compression/semantic_segmentation/data/human_demo.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9b499ab55085fa6da1d776bca16da94091a14d81
Binary files /dev/null and b/example/auto_compression/semantic_segmentation/data/human_demo.jpg differ
diff --git a/example/auto_compression/semantic_segmentation/eval.py b/example/auto_compression/semantic_segmentation/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..5206200568e5d0780bd23c68beb6951b49e4e1a2
--- /dev/null
+++ b/example/auto_compression/semantic_segmentation/eval.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import argparse
+import random
+import paddle
+import numpy as np
+from tqdm import tqdm
+from paddleseg.cvlibs import Config as PaddleSegDataConfig
+from paddleseg.utils import worker_init_fn
+
+from paddleseg.core.infer import reverse_transform
+from paddleseg.utils import metrics
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Model evaluation')
+ parser.add_argument(
+ '--model_dir',
+ type=str,
+ default=None,
+ help="inference model directory.")
+ parser.add_argument(
+ '--model_filename',
+ type=str,
+ default=None,
+ help="inference model filename.")
+ parser.add_argument(
+ '--params_filename',
+ type=str,
+ default=None,
+ help="inference params filename.")
+ parser.add_argument(
+ '--dataset_config',
+ type=str,
+ default=None,
+ help="path of dataset config.")
+ return parser.parse_args()
+
+
+def eval(args):
+ exe = paddle.static.Executor(paddle.CUDAPlace(0))
+ inference_program, feed_target_names, fetch_targets = paddle.static.load_inference_model(
+ args.model_dir,
+ exe,
+ model_filename=args.model_filename,
+ params_filename=args.params_filename)
+
+ data_cfg = PaddleSegDataConfig(args.dataset_config)
+ eval_dataset = data_cfg.val_dataset
+
+ batch_sampler = paddle.io.BatchSampler(
+ eval_dataset, batch_size=1, shuffle=False, drop_last=False)
+ loader = paddle.io.DataLoader(
+ eval_dataset,
+ batch_sampler=batch_sampler,
+ num_workers=1,
+ return_list=True, )
+
+ total_iters = len(loader)
+ intersect_area_all = 0
+ pred_area_all = 0
+ label_area_all = 0
+
+ print("Start evaluating (total_samples: {}, total_iters: {})...".format(
+ len(eval_dataset), total_iters))
+
+ for (image, label) in tqdm(loader):
+ label = np.array(label).astype('int64')
+ ori_shape = np.array(label).shape[-2:]
+ image = np.array(image)
+ logits = exe.run(inference_program,
+ feed={feed_target_names[0]: image},
+ fetch_list=fetch_targets,
+ return_numpy=True)
+
+ paddle.disable_static()
+ logit = logits[0]
+
+ logit = reverse_transform(
+ paddle.to_tensor(logit),
+ ori_shape,
+ eval_dataset.transforms.transforms,
+ mode='bilinear')
+ pred = paddle.to_tensor(logit)
+ if len(
+ pred.shape
+ ) == 4: # for humanseg model whose prediction is distribution but not class id
+ pred = paddle.argmax(pred, axis=1, keepdim=True, dtype='int32')
+
+ intersect_area, pred_area, label_area = metrics.calculate_area(
+ pred,
+ paddle.to_tensor(label),
+ eval_dataset.num_classes,
+ ignore_index=eval_dataset.ignore_index)
+ intersect_area_all = intersect_area_all + intersect_area
+ pred_area_all = pred_area_all + pred_area
+ label_area_all = label_area_all + label_area
+
+ class_iou, miou = metrics.mean_iou(intersect_area_all, pred_area_all,
+ label_area_all)
+ class_acc, acc = metrics.accuracy(intersect_area_all, pred_area_all)
+ kappa = metrics.kappa(intersect_area_all, pred_area_all, label_area_all)
+ class_dice, mdice = metrics.dice(intersect_area_all, pred_area_all,
+ label_area_all)
+
+ infor = "[EVAL] #Images: {} mIoU: {:.4f} Acc: {:.4f} Kappa: {:.4f} Dice: {:.4f}".format(
+ len(eval_dataset), miou, acc, kappa, mdice)
+ print(infor)
+
+
+if __name__ == '__main__':
+ rank_id = paddle.distributed.get_rank()
+ place = paddle.CUDAPlace(rank_id)
+ args = parse_args()
+ paddle.enable_static()
+ eval(args)
diff --git a/example/auto_compression/semantic_segmentation/infer.py b/example/auto_compression/semantic_segmentation/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e556256b5c714c630fbcb3813efd6f2ccc6af75
--- /dev/null
+++ b/example/auto_compression/semantic_segmentation/infer.py
@@ -0,0 +1,182 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import cv2
+import numpy as np
+import argparse
+import time
+import PIL
+from PIL import Image
+import paddle
+import paddleseg.transforms as T
+from paddleseg.cvlibs import Config as PaddleSegDataConfig
+from paddleseg.core.infer import reverse_transform
+from paddleseg.utils import get_image_list
+from paddleseg.utils.visualize import get_pseudo_color_map
+from paddle.inference import create_predictor, PrecisionType
+from paddle.inference import Config as PredictConfig
+
+
+def _transforms(dataset):
+ transforms = []
+ if dataset == "human":
+ transforms.append(T.PaddingByAspectRatio(aspect_ratio=1.77777778))
+ transforms.append(T.Resize(target_size=[398, 224]))
+ transforms.append(T.Normalize())
+ elif dataset == "cityscape":
+ transforms.append(T.Normalize())
+ return transforms
+ return T.Compose(transforms)
+
+
+def auto_tune_trt(args):
+ auto_tuned_shape_file = "./auto_tuning_shape"
+ pred_cfg = PredictConfig(args.model_path, args.params_path)
+ pred_cfg.enable_use_gpu(100, 0)
+ pred_cfg.collect_shape_range_info("./auto_tuning_shape")
+ predictor = create_predictor(pred_cfg)
+ input_names = predictor.get_input_names()
+ input_handle = predictor.get_input_handle(input_names[0])
+ transforms = _transforms(args.dataset)
+ transform = T.Compose(transforms)
+ img = cv2.imread(args.image_file).astype('float32')
+ data, _ = transform(img)
+ data = np.array(data)[np.newaxis, :]
+ input_handle.reshape(data.shape)
+ input_handle.copy_from_cpu(data)
+ predictor.run()
+ return auto_tuned_shape_file
+
+
+def load_predictor(args):
+ pred_cfg = PredictConfig(args.model_path, args.params_path)
+ pred_cfg.disable_glog_info()
+ pred_cfg.enable_memory_optim()
+ pred_cfg.switch_ir_optim(True)
+ if args.device == "GPU":
+ pred_cfg.enable_use_gpu(100, 0)
+
+ if args.use_trt:
+ # To collect the dynamic shapes of inputs for TensorRT engine
+ auto_tuned_shape_file = auto_tune_trt(args)
+ precision_map = {
+ "fp16": PrecisionType.Half,
+ "fp32": PrecisionType.Float32,
+ "int8": PrecisionType.Int8
+ }
+ pred_cfg.enable_tensorrt_engine(
+ workspace_size=1 << 30,
+ max_batch_size=1,
+ min_subgraph_size=4,
+ precision_mode=precision_map[args.precision],
+ use_static=False,
+ use_calib_mode=False)
+ allow_build_at_runtime = True
+ pred_cfg.enable_tuned_tensorrt_dynamic_shape(auto_tuned_shape_file,
+ allow_build_at_runtime)
+ predictor = create_predictor(pred_cfg)
+ return predictor
+
+
+def predict_image(args, predictor):
+
+ transforms = _transforms(args.dataset)
+ transform = T.Compose(transforms)
+
+ # Step1: Load image and preprocess
+ im = cv2.imread(args.image_file).astype('float32')
+ data, _ = transform(im)
+ data = np.array(data)[np.newaxis, :]
+
+ # Step2: Inference
+ input_names = predictor.get_input_names()
+ input_handle = predictor.get_input_handle(input_names[0])
+ output_names = predictor.get_output_names()
+ output_handle = predictor.get_output_handle(output_names[0])
+ input_handle.reshape(data.shape)
+ input_handle.copy_from_cpu(data)
+
+ warmup, repeats = 0, 1
+ if args.benchmark:
+ warmup, repeats = 20, 100
+
+ for i in range(warmup):
+ predictor.run()
+
+ start_time = time.time()
+ for i in range(repeats):
+ predictor.run()
+ results = output_handle.copy_to_cpu()
+ total_time = time.time() - start_time
+ avg_time = float(total_time) / repeats
+ print(f"Average inference time: \033[91m{round(avg_time*1000, 2)}ms\033[0m")
+
+ # Step3: Post process
+ if args.dataset == "human":
+ results = reverse_transform(
+ paddle.to_tensor(results), im.shape, transforms, mode='bilinear')
+ results = np.argmax(results, axis=1)
+ result = get_pseudo_color_map(results[0])
+
+ # Step4: Save result to file
+ if args.save_file is not None:
+ result.save(args.save_file)
+ print(f"Saved result to \033[91m{args.save_file}\033[0m")
+
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--image_file', type=str, help="Image path to be processed.")
+ parser.add_argument(
+ '--save_file', type=str, help="The path to save the processed image.")
+ parser.add_argument(
+ '--model_path', type=str, help="Inference model filepath.")
+ parser.add_argument(
+ '--params_path', type=str, help="Inference parameters filepath.")
+ parser.add_argument(
+ '--dataset',
+ type=str,
+ default="human",
+ choices=["human", "cityscape"],
+ help="The type of given image which can be 'human' or 'cityscape'.")
+ parser.add_argument(
+ '--benchmark',
+ type=bool,
+ default=False,
+ help="Whether to run benchmark or not.")
+ parser.add_argument(
+ '--use_trt',
+ type=bool,
+ default=False,
+ help="Whether to use tensorrt engine or not.")
+ parser.add_argument(
+ '--device',
+ type=str,
+ default='GPU',
+ choices=["CPU", "GPU"],
+ help="Choose the device you want to run, it can be: CPU/GPU, default is GPU"
+ )
+ parser.add_argument(
+ '--precision',
+ type=str,
+ default='fp32',
+ choices=["fp32", "fp16", "int8"],
+ help="The precision of inference. It can be 'fp32', 'fp16' or 'int8'. Default is 'fp16'."
+ )
+ args = parser.parse_args()
+ predictor = load_predictor(args)
+ predict_image(args, predictor)