提交 3d5ba524 编写于 作者: G gaotingquan 提交者: cuicheng01

support image_orientation

上级 dec008d6
......@@ -39,6 +39,29 @@ def build_postprocess(config):
return PostProcesser(func_list, main_indicator)
def parse_class_id_map(class_id_map_file, delimiter):
if class_id_map_file is None:
return None
if not os.path.exists(class_id_map_file):
print(
"Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!"
)
return None
try:
class_id_map = {}
with open(class_id_map_file, "r") as fin:
lines = fin.readlines()
for line in lines:
partition = line.split("\n")[0].partition(delimiter)
class_id_map[int(partition[0])] = str(partition[-1])
except Exception as ex:
print(ex)
class_id_map = None
return class_id_map
class PostProcesser(object):
def __init__(self, func_list, main_indicator="Topk"):
self.func_list = func_list
......@@ -54,27 +77,45 @@ class PostProcesser(object):
class ThreshOutput(object):
def __init__(self, threshold, label_0="0", label_1="1"):
def __init__(self,
threshold=0,
default_label_index=0,
class_id_map_file=None,
delimiter=None,
label_0=None,
label_1=None):
self.threshold = threshold
self.label_0 = label_0
self.label_1 = label_1
self.default_label_index = default_label_index
delimiter = delimiter if delimiter is not None else " "
self.class_id_map = parse_class_id_map(class_id_map_file, delimiter)
if label_0 is not None or label_1 is not None:
print(
"[WARNING] The arguments \"label_0\" and \"label_1\" have been deprecated. Please use \"default_label_index\" instead."
)
def __call__(self, x, file_names=None):
if file_names is not None:
assert x.shape[0] == len(file_names)
y = []
for idx, probs in enumerate(x):
score = probs[1]
if score < self.threshold:
result = {
"class_ids": [0],
"scores": [1 - score],
"label_names": [self.label_0]
}
index = probs.argsort(axis=0)[::-1].astype("int32")
top1_id = index[0]
top1_score = probs[top1_id]
if top1_score > self.threshold:
rtn_id = top1_id
else:
result = {
"class_ids": [1],
"scores": [score],
"label_names": [self.label_1]
}
rtn_id = self.default_label_index
label_name = self.class_id_map[
rtn_id] if self.class_id_map is not None else ""
result = {
"class_ids": [rtn_id],
"scores": [probs[rtn_id]],
"label_names": [label_name]
}
if file_names is not None:
result["file_name"] = file_names[idx]
y.append(result)
......@@ -85,30 +126,8 @@ class Topk(object):
def __init__(self, topk=1, class_id_map_file=None, delimiter=None):
assert isinstance(topk, (int, ))
self.topk = topk
self.delimiter = delimiter if delimiter is not None else " "
self.class_id_map = self.parse_class_id_map(class_id_map_file)
def parse_class_id_map(self, class_id_map_file):
if class_id_map_file is None:
return None
if not os.path.exists(class_id_map_file):
print(
"Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!"
)
return None
try:
class_id_map = {}
with open(class_id_map_file, "r") as fin:
lines = fin.readlines()
for line in lines:
partition = line.split("\n")[0].partition(self.delimiter)
class_id_map[int(partition[0])] = str(partition[-1])
except Exception as ex:
print(ex)
class_id_map = None
return class_id_map
delimiter = delimiter if delimiter is not None else " "
self.class_id_map = parse_class_id_map(class_id_map_file, delimiter)
def __call__(self, x, file_names=None):
if file_names is not None:
......@@ -140,30 +159,8 @@ class Topk(object):
class MultiLabelThreshOutput(object):
def __init__(self, threshold=0.5, class_id_map_file=None, delimiter=None):
self.threshold = threshold
self.delimiter = delimiter if delimiter is not None else " "
self.class_id_map = self.parse_class_id_map(class_id_map_file)
def parse_class_id_map(self, class_id_map_file):
if class_id_map_file is None:
return None
if not os.path.exists(class_id_map_file):
print(
"Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!"
)
return None
try:
class_id_map = {}
with open(class_id_map_file, "r") as fin:
lines = fin.readlines()
for line in lines:
partition = line.split("\n")[0].partition(self.delimiter)
class_id_map[int(partition[0])] = str(partition[-1])
except Exception as ex:
print(ex)
class_id_map = None
return class_id_map
delimiter = delimiter if delimiter is not None else " "
self.class_id_map = parse_class_id_map(class_id_map_file, delimiter)
def __call__(self, x, file_names=None):
y = []
......
# PULC 图像方向分类模型
## 目录
- [1. 模型和应用场景介绍](#1)
- [2. 模型快速体验](#2)
- [2.1 安装 paddlepaddle](#2.1)
- [2.2 安装 paddleclas](#2.2)
- [2.3 预测](#2.3)
- [3. 模型训练、评估和预测](#3)
- [3.1 环境配置](#3.1)
- [3.2 数据准备](#3.2)
- [3.2.1 数据集来源](#3.2.1)
- [3.2.2 数据集获取](#3.2.2)
- [3.3 模型训练](#3.3)
- [3.4 模型评估](#3.4)
- [3.5 模型预测](#3.5)
- [4. 模型压缩](#4)
- [4.1 SKL-UGI 知识蒸馏](#4.1)
- [4.1.1 教师模型训练](#4.1.1)
- [4.1.2 蒸馏训练](#4.1.2)
- [5. 超参搜索](#5)
- [6. 模型推理部署](#6)
- [6.1 推理模型准备](#6.1)
- [6.1.1 基于训练得到的权重导出 inference 模型](#6.1.1)
- [6.1.2 直接下载 inference 模型](#6.1.2)
- [6.2 基于 Python 预测引擎推理](#6.2)
- [6.2.1 预测单张图片](#6.2.1)
- [6.2.2 基于文件夹的批量预测](#6.2.2)
- [6.3 基于 C++ 预测引擎推理](#6.3)
- [6.4 服务化部署](#6.4)
- [6.5 端侧部署](#6.5)
- [6.6 Paddle2ONNX 模型转换与预测](#6.6)
<a name="1"></a>
## 1. 模型和应用场景介绍
该案例提供了用户使用 PaddleClas 的超轻量图像分类方案(PULC,Practical Ultra Lightweight image Classification)快速构建轻量级、高精度、可落地的图像方向分类模型(),该模型能够广泛应用于多种视觉任务中。下表列出了图像方向分类模型的相关指标。
<!-- 前两行展现了使用 SwinTranformer_tiny 和 MobileNetV3_small_x0_35 作为 backbone 训练得到的模型的相关指标,第三行至第五行依次展现了替换 backbone 为 PPLCNet_x1_0、使用 SSLD 预训练模型、使用EDA策略训练得到的模型的相关指标。 -->
| 模型 | 精度(%) | 延时(ms) | 存储(M) | 策略 |
| ----------------------- | --------- | ---------- | --------- | -------------------------- |
<!-- | SwinTranformer_tiny | 99.12 | 89.65 | 111 | 使用ImageNet预训练模型 | -->
<!-- | MobileNetV3_small_x0_35 | 83.61 | 2.95 | 2.6 | 使用ImageNet预训练模型 | -->
<!-- | PPLCNet_x1_0 | 97.85 | 2.16 | 7.1 | 使用ImageNet预训练模型 | -->
| PPLCNet_x1_0 | 89.99 | 2.16 | 7.1 | 使用SSLD预训练模型 |
<!-- | **PPLCNet_x1_0** | **99.06** | **2.16** | **7.1** | 使用SSLD预训练模型+EDA策略 | -->
<!-- 从表中可以看出,backbone 为 SwinTranformer_tiny 时精度比较高,但是推理速度较慢。将 backbone 替换为轻量级模型 MobileNetV3_small_x0_35 后,速度提升明显,但精度有了大幅下降。将 backbone 替换为 PPLCNet_x1_0 时,速度略为提升,同时精度较 MobileNetV3_small_x0_35 高了 14.24 个百分点。在此基础上,使用 SSLD 预训练模型后,在不改变推理速度的前提下,精度可以提升 1.17 个百分点,进一步地使用 EDA 策略后,精度可以再提升 0.04 个百分点。此时,PPLCNet_x1_0 与 SwinTranformer_tiny 的精度差别不大,但是速度明显变快。关于 PULC 的训练方法和推理部署方法将在下面详细介绍。 -->
**备注:**
* 关于PP-LCNet的介绍可以参考[PP-LCNet介绍](../ImageNet1k/PP-LCNet.md),相关论文可以查阅[PP-LCNet paper](https://arxiv.org/abs/2109.15099)
<a name="2"></a>
## 2. 模型快速体验
<a name="2.1"></a>
### 2.1 安装 paddlepaddle
- 您的机器安装的是 CUDA9 或 CUDA10,请运行以下命令安装
```bash
python3 -m pip install paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple
```
- 您的机器是CPU,请运行以下命令安装
```bash
python3 -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple
```
更多的版本需求,请参照[飞桨官网安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
<a name="2.2"></a>
### 2.2 安装 paddleclas
使用如下命令快速安装 paddleclas
```
pip3 install paddleclas
```
<a name="2.3"></a>
### 2.3 预测
点击[这里](https://paddleclas.bj.bcebos.com/data/PULC/pulc_demo_imgs.zip)下载 demo 数据并解压,然后在终端中切换到相应目录。
* 使用命令行快速预测
```bash
paddleclas --model_name=image_orientation --infer_imgs=pulc_demo_imgs/image_orientation/1.jpg
```
结果如下:
```
>>> result
class_ids: [1], scores: [0.9346007], label_names: ['90°'], filename: pulc_demo_imgs/image_orientation/1.jpg
Predict complete!
```
**备注**: 更换其他预测的数据时,只需要改变 `--infer_imgs=xx` 中的字段即可,支持传入整个文件夹。
* 在 Python 代码中预测
```python
import paddleclas
model = paddleclas.PaddleClas(model_name="image_orientation")
result = model.predict(input_data="pulc_demo_imgs/image_orientation/1.jpg")
print(next(result))
```
**备注**`model.predict()` 为可迭代对象(`generator`),因此需要使用 `next()` 函数或 `for` 循环对其迭代调用。每次调用将以 `batch_size` 为单位进行一次预测,并返回预测结果, 默认 `batch_size` 为 1,如果需要更改 `batch_size`,实例化模型时,需要指定 `batch_size`,如 `model = paddleclas.PaddleClas(model_name="image_orientation", batch_size=2)`, 使用默认的代码返回结果示例如下:
```
>>> result
[{'class_ids': [1], 'scores': [0.9346007], 'label_names': ['90°'], 'filename': 'pulc_demo_imgs/image_orientation/1.jpg'}]
```
<a name="3"></a>
## 3. 模型训练、评估和预测
敬请期待。
<a name="4"></a>
## 4. 模型压缩
敬请期待。
<a name="6"></a>
## 6. 模型推理部署
<a name="6.1"></a>
### 6.1 推理模型准备
Paddle Inference 是飞桨的原生推理库, 作用于服务器端和云端,提供高性能的推理能力。相比于直接基于预训练模型进行预测,Paddle Inference可使用MKLDNN、CUDNN、TensorRT 进行预测加速,从而实现更优的推理性能。更多关于Paddle Inference推理引擎的介绍,可以参考[Paddle Inference官网教程](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/infer/inference/inference_cn.html)
当使用 Paddle Inference 推理时,加载的模型类型为 inference 模型。本案例提供了两种获得 inference 模型的方法,如果希望得到和文档相同的结果,请选择[直接下载 inference 模型](#6.1.2)的方式。
<a name="6.1.1"></a>
#### 6.1.1 基于训练得到的权重导出 inference 模型
此处,我们提供了将权重和模型转换的脚本,执行该脚本可以得到对应的 inference 模型:
```bash
python3 tools/export_model.py \
-c ./ppcls/configs/PULC/image_orientation/PPLCNet_x1_0.yaml \
-o Global.pretrained_model=output/DistillationModel/best_model_student \
-o Global.save_inference_dir=deploy/models/PPLCNet_x1_0_image_orientation_infer
```
执行完该脚本后会在`deploy/models/`下生成`PPLCNet_x1_0_image_orientation_infer`文件夹,`models` 文件夹下应有如下文件结构:
```
├── PPLCNet_x1_0_image_orientation_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
```
**备注:** 此处的最佳权重是经过知识蒸馏后的权重路径,如果没有执行知识蒸馏的步骤,最佳模型保存在`output/PPLCNet_x1_0/best_model.pdparams`中。
<a name="6.1.2"></a>
#### 6.1.2 直接下载 inference 模型
[6.1.1 小节](#6.1.1)提供了导出 inference 模型的方法,此处也提供了该场景可以下载的 inference 模型,可以直接下载体验。
```
cd deploy/models
# 下载inference 模型并解压
wget https://paddleclas.bj.bcebos.com/models/PULC/inference/image_orientation_infer.tar && tar -xf image_orientation_infer.tar
```
解压完毕后,`models` 文件夹下应有如下文件结构:
```
├── image_orientation_infer
│ ├── inference.pdiparams
│ ├── inference.pdiparams.info
│ └── inference.pdmodel
```
<a name="6.2"></a>
### 6.2 基于 Python 预测引擎推理
<a name="6.2.1"></a>
#### 6.2.1 预测单张图像
返回 `deploy` 目录:
```
cd ../
```
运行下面的命令,对图像 `./images/PULC/image_orientation/1.jpg` 进行含文字图像方向分类。
```shell
# 使用下面的命令使用 GPU 进行预测
python3.7 python/predict_cls.py -c configs/PULC/image_orientation/inference_image_orientation.yaml
# 使用下面的命令使用 CPU 进行预测
python3.7 python/predict_cls.py -c configs/PULC/image_orientation/inference_image_orientation.yaml -o Global.use_gpu=False
```
输出结果如下。
```
1.jpg: class id(s): [1], score(s): [0.93], label_name(s): ['90°']
```
其中,输出为top1的预测结果,`0°` 表示该图未旋转,`90°` 表示该图方向为逆时针90度,`180°` 表示该图文本方向为逆时针180度,`270°` 表示该图文本方向为逆时针270度。
<a name="6.2.2"></a>
#### 6.2.2 基于文件夹的批量预测
如果希望预测文件夹内的图像,可以直接修改配置文件中的 `Global.infer_imgs` 字段,也可以通过下面的 `-o` 参数修改对应的配置。
```shell
# 使用下面的命令使用 GPU 进行预测,如果希望使用 CPU 预测,可以在命令后面添加 -o Global.use_gpu=False
python3.7 python/predict_cls.py -c configs/PULC/image_orientation/inference_image_orientation.yaml -o Global.infer_imgs="./images/PULC/image_orientation/"
```
终端中会输出该文件夹内所有图像的分类结果,如下所示。
```
0.jpg: class id(s): [0], score(s): [0.93], label_name(s): ['0°']
1.jpg: class id(s): [1], score(s): [0.93], label_name(s): ['90°']
2.jpg: class id(s): [2], score(s): [0.92], label_name(s): ['180°']
3.jpg: class id(s): [3], score(s): [0.92], label_name(s): ['270°']
```
<a name="6.3"></a>
### 6.3 基于 C++ 预测引擎推理
PaddleClas 提供了基于 C++ 预测引擎推理的示例,您可以参考[服务器端 C++ 预测](../../deployment/image_classification/cpp/linux.md)来完成相应的推理部署。如果您使用的是 Windows 平台,可以参考[基于 Visual Studio 2019 Community CMake 编译指南](../../deployment/image_classification/cpp/windows.md)完成相应的预测库编译和模型预测工作。
<a name="6.4"></a>
### 6.4 服务化部署
Paddle Serving 提供高性能、灵活易用的工业级在线推理服务。Paddle Serving 支持 RESTful、gRPC、bRPC 等多种协议,提供多种异构硬件和多种操作系统环境下推理解决方案。更多关于Paddle Serving 的介绍,可以参考[Paddle Serving 代码仓库](https://github.com/PaddlePaddle/Serving)
PaddleClas 提供了基于 Paddle Serving 来完成模型服务化部署的示例,您可以参考[模型服务化部署](../../deployment/image_classification/paddle_serving.md)来完成相应的部署工作。
<a name="6.5"></a>
### 6.5 端侧部署
Paddle Lite 是一个高性能、轻量级、灵活性强且易于扩展的深度学习推理框架,定位于支持包括移动端、嵌入式以及服务器端在内的多硬件平台。更多关于 Paddle Lite 的介绍,可以参考[Paddle Lite 代码仓库](https://github.com/PaddlePaddle/Paddle-Lite)
PaddleClas 提供了基于 Paddle Lite 来完成模型端侧部署的示例,您可以参考[端侧部署](../../deployment/image_classification/paddle_lite.md)来完成相应的部署工作。
<a name="6.6"></a>
### 6.6 Paddle2ONNX 模型转换与预测
Paddle2ONNX 支持将 PaddlePaddle 模型格式转化到 ONNX 模型格式。通过 ONNX 可以完成将 Paddle 模型到多种推理引擎的部署,包括TensorRT/OpenVINO/MNN/TNN/NCNN,以及其它对 ONNX 开源格式进行支持的推理引擎或硬件。更多关于 Paddle2ONNX 的介绍,可以参考[Paddle2ONNX 代码仓库](https://github.com/PaddlePaddle/Paddle2ONNX)
PaddleClas 提供了基于 Paddle2ONNX 来完成 inference 模型转换 ONNX 模型并作推理预测的示例,您可以参考[Paddle2ONNX 模型转换与预测](../../deployment/image_classification/paddle2onnx.md)来完成相应的部署工作。
......@@ -19,6 +19,7 @@
| table_attribute |[PULC表格属性识别](PULC_table_attribute.md)|表格属性识别,可以识别表格是否为拍照、表格数量、表格颜色、表格清晰度、表格有无干扰、表格角度6个属性| 88.1 |7.1M|2.58ms|[推理模型](https://paddleclas.bj.bcebos.com/models/PULC/inference/table_attribute_infer.tar) / [预训练模型](https://paddleclas.bj.bcebos.com/models/PULC/pretrained/table_attribute_pretrained.pdparams)|
| code_exists |[PULC有无广告码](PULC_code_exists.md)|判断图片中有无广告码,其中,这里广告码包含二维码、条形码、小程序码| 94.9 |7.0M|2.13ms|[推理模型](https://paddleclas.bj.bcebos.com/models/PULC/inference/code_exists_infer.tar) / [预训练模型](https://paddleclas.bj.bcebos.com/models/PULC/pretrained/code_exists_pretrained.pdparams)|
| clarity_assessment |[PULC清晰度评估](PULC_clarity_assessment.md)|判断图片的清晰度| 95.3 |7.0M|2.13ms|[推理模型](https://paddleclas.bj.bcebos.com/models/PULC/inference/clarity_assessment_infer.tar) / [预训练模型](https://paddleclas.bj.bcebos.com/models/PULC/pretrained/clarity_assessment_pretrained.pdparams)|
| image_orientation |[PULC图像方向分类](PULC_image_orientation.md)|判断图片的清晰度| 89.99 |7.1M|2.16ms|[推理模型](https://paddleclas.bj.bcebos.com/models/PULC/inference/image_orientation_infer.tar) / [预训练模型](https://paddleclas.bj.bcebos.com/models/PULC/pretrained/image_orientation_pretrained.pdparams)|
......@@ -26,4 +27,4 @@
* 以上所有的模型的 backbone 均为 PPLCNet_x1_0,部分模型大小不同是由于分类的输出大小不同导致的,推理耗时是基于Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz 测试得到,其中测试过程开启 MKLDNN 加速策略,线程数为10。速度测试过程会有轻微波动。
* person_exists、safety_helmet、car_exists 的评测指标为 TprAtFpr,person_attribute、vehicle_attribute、table_attribute 的评测指标为mA、traffic_sign、text_image_orientation、textline_orientation、language_classification、code_exists、clarity_assessment 的评测指标为Top-1 Acc。
* person_exists、safety_helmet、car_exists 的评测指标为 TprAtFpr;person_attribute、vehicle_attribute、table_attribute 的评测指标为mA;traffic_sign、text_image_orientation、textline_orientation、language_classification、code_exists、clarity_assessment、image_orientation 的评测指标为Top-1 Acc。
......@@ -197,8 +197,8 @@ PULC_MODEL_BASE_DOWNLOAD_URL = "https://paddleclas.bj.bcebos.com/models/PULC/inf
PULC_MODELS = [
"car_exists", "language_classification", "person_attribute",
"person_exists", "safety_helmet", "text_image_orientation",
"textline_orientation", "traffic_sign", "vehicle_attribute",
"table_attribute"
"image_orientation", "textline_orientation", "traffic_sign",
"vehicle_attribute", "table_attribute"
]
SHITU_MODEL_BASE_DOWNLOAD_URL = "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/rec/models/inference/{}_infer.tar"
......@@ -284,10 +284,6 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
"crop_size"]
# TODO(gaotingquan): not robust
if "thresh" in kwargs and kwargs[
"thresh"] and "ThreshOutput" in cfg.PostProcess:
cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
if cfg.get("PostProcess"):
if "Topk" in cfg.PostProcess:
if "topk" in kwargs and kwargs["topk"]:
......@@ -300,6 +296,17 @@ def init_config(model_type, model_name, inference_model_dir, **kwargs):
cfg.PostProcess.Topk.class_id_map_file, "../")
cfg.PostProcess.Topk.class_id_map_file = os.path.join(
__dir__, class_id_map_file_path)
if "ThreshOutput" in cfg.PostProcess:
if "thresh" in kwargs and kwargs["thresh"]:
cfg.PostProcess.ThreshOutput.thresh = kwargs["thresh"]
if "class_id_map_file" in kwargs and kwargs["class_id_map_file"]:
cfg.PostProcess.ThreshOutput["class_id_map_file"] = kwargs[
"class_id_map_file"]
elif "class_id_map_file" in cfg.PostProcess.ThreshOutput:
class_id_map_file_path = os.path.relpath(
cfg.PostProcess.ThreshOutput.class_id_map_file, "../")
cfg.PostProcess.ThreshOutput.class_id_map_file = os.path.join(
__dir__, class_id_map_file_path)
if "VehicleAttribute" in cfg.PostProcess:
if "color_threshold" in kwargs and kwargs["color_threshold"]:
cfg.PostProcess.VehicleAttribute.color_threshold = kwargs[
......
0 0°
1 90°
2 180°
3 270°
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册