未验证 提交 edff1d7a 编写于 作者: M MissPenguin 提交者: GitHub

Merge pull request #3519 from MissPenguin/release/2.2

update 2.2
English | [简体中文](README_ch.md)
------------------------------------------------------------------------------------------
<p align="left">
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
<a href="https://github.com/PaddlePaddle/PaddleOCR/releases"><img src="https://img.shields.io/github/v/release/PaddlePaddle/PaddleOCR?color=ffa"></a>
<a href=""><img src="https://img.shields.io/badge/python-3.7+-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-pink.svg"></a>
<a href=""><img src="https://img.shields.io/pypi/format/PaddleOCR?color=c77"></a>
<a href="https://github.com/PaddlePaddle/PaddleOCR/graphs/contributors"><img src="https://img.shields.io/github/contributors/PaddlePaddle/PaddleOCR?color=9ea"></a>
<a href="https://pypi.org/project/PaddleOCR/"><img src="https://img.shields.io/pypi/dm/PaddleOCR?color=9cf"></a>
<a href="https://github.com/PaddlePaddle/PaddleOCR/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/PaddleOCR?color=ccf"></a>
</p>
## Introduction
PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools that help users train better models and apply them into practice.
......@@ -80,7 +93,8 @@ For a new language request, please refer to [Guideline for new language_requests
## Tutorials
- [Installation](./doc/doc_en/installation_en.md)
- [Quick Start](./doc/doc_en/quickstart_en.md)
- [Quick Start(Chinese)](./doc/doc_en/quickstart_en.md)
- [Quick Start(English&Multi-languages)](./doc/doc_en/multi_languages_en.md)
- [Code Structure](./doc/doc_en/tree_en.md)
- Algorithm Introduction
- [Text Detection Algorithm](./doc/doc_en/algorithm_overview_en.md)
......
[English](README.md) | 简体中文
------------------------------------------------------------------------------------------
<p align="left">
<a href="./LICENSE"><img src="https://img.shields.io/badge/license-Apache%202-dfd.svg"></a>
<a href="https://github.com/PaddlePaddle/PaddleOCR/releases"><img src="https://img.shields.io/github/v/release/PaddlePaddle/PaddleOCR?color=ffa"></a>
<a href=""><img src="https://img.shields.io/badge/python-3.7+-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/os-linux%2C%20win%2C%20mac-pink.svg"></a>
<a href=""><img src="https://img.shields.io/pypi/format/PaddleOCR?color=c77"></a>
<a href="https://github.com/PaddlePaddle/PaddleOCR/graphs/contributors"><img src="https://img.shields.io/github/contributors/PaddlePaddle/PaddleOCR?color=9ea"></a>
<a href="https://pypi.org/project/PaddleOCR/"><img src="https://img.shields.io/pypi/dm/PaddleOCR?color=9cf"></a>
<a href="https://github.com/PaddlePaddle/PaddleOCR/stargazers"><img src="https://img.shields.io/github/stars/PaddlePaddle/PaddleOCR?color=ccf"></a>
</p>
## 简介
PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。
## 注意
PaddleOCR同时支持动态图与静态图两种编程范式
- 动态图版本:dygraph分支(默认),需将paddle版本升级至2.0.0[快速安装](./doc/doc_ch/installation.md)
- 动态图版本:release/2.1(默认分支,开发分支为dygraph分支),需将paddle版本升级至2.0.0或以上版本[快速安装](./doc/doc_ch/installation.md)
- 静态图版本:develop分支
**近期更新**
- 2021.6.29 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数248个,每周一都会更新,欢迎大家持续关注。
- PaddleOCR研发团队对最新发版内容技术深入解读,4月13日晚上19:00,[直播地址](https://live.bilibili.com/21689802)
- 2021.4.8 release 2.1版本,新增AAAI 2021论文[端到端识别算法PGNet](./doc/doc_ch/pgnet.md)开源,[多语言模型](./doc/doc_ch/multi_languages.md)支持种类增加到80+。
- 2021.2.1 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数162个,每周一都会更新,欢迎大家持续关注。
- 2021.1.21 更新多语言识别模型,目前支持语种超过27种,包括中文简体、中文繁体、英文、法文、德文、韩文、日文、意大利文、西班牙文、葡萄牙文、俄罗斯文、阿拉伯文等,后续计划可以参考[多语言研发计划](https://github.com/PaddlePaddle/PaddleOCR/issues/1048)
- 2020.12.15 更新数据合成工具[Style-Text](./StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。
- 2020.11.25 更新半自动标注工具[PPOCRLabel](./PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。
- 2020.9.22 更新PP-OCR技术文章,https://arxiv.org/abs/2009.09941
- 2021.2.8 正式发布PaddleOCRv2.0(branch release/2.0)并设置为推荐用户使用的默认分支. 发布的详细内容,请参考: https://github.com/PaddlePaddle/PaddleOCR/releases/tag/v2.0.0
- 2021.1.26,28,29 PaddleOCR官方研发团队带来技术深入解读三日直播课,1月26日、28日、29日晚上19:30,[直播地址](https://live.bilibili.com/21689802)
- [More](./doc/doc_ch/update.md)
......@@ -24,7 +36,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式
- 超轻量ppocr_mobile移动端系列:检测(3.0M)+方向分类器(1.4M)+ 识别(5.0M)= 9.4M
- 通用ppocr_server系列:检测(47.1M)+方向分类器(1.4M)+ 识别(94.9M)= 143.4M
- 支持中英文数字组合识别、竖排文本识别、长文本识别
- 支持多语言识别:韩语、日语、德语、法语
- 支持80+多语言识别,详见[多语言模型](./doc/doc_ch/multi_languages.md)
- 丰富易用的OCR相关工具组件
- 半自动数据标注工具PPOCRLabel:支持快速高效的数据标注
- 数据合成工具Style-Text:批量合成大量与目标场景类似的图像
......@@ -90,7 +102,7 @@ PaddleOCR同时支持动态图与静态图两种编程范式
- [基于pip安装whl包快速推理](./doc/doc_ch/whl.md)
- [基于Python脚本预测引擎推理](./doc/doc_ch/inference.md)
- [基于C++预测引擎推理](./deploy/cpp_infer/readme.md)
- [服务化部署](./deploy/pdserving/README_CN.md)
- [服务化部署](./deploy/hubserving/readme.md)
- [端侧部署](./deploy/lite/readme.md)
- [Benchmark](./doc/doc_ch/benchmark.md)
- 数据集
......@@ -105,8 +117,8 @@ PaddleOCR同时支持动态图与静态图两种编程范式
- [效果展示](#效果展示)
- FAQ
- [【精选】OCR精选10个问题](./doc/doc_ch/FAQ.md)
- [【理论篇】OCR通用32个问题](./doc/doc_ch/FAQ.md)
- [【实战篇】PaddleOCR实战110个问题](./doc/doc_ch/FAQ.md)
- [【理论篇】OCR通用50个问题](./doc/doc_ch/FAQ.md)
- [【实战篇】PaddleOCR实战183个问题](./doc/doc_ch/FAQ.md)
- [技术交流群](#欢迎加入PaddleOCR技术交流群)
- [参考文献](./doc/doc_ch/reference.md)
- [许可证书](#许可证书)
......
Global:
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/ch_db_mv3/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Architecture:
name: DistillationModel
algorithm: Distillation
Models:
Student:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Student2:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Teacher:
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDilaDBLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
- ["Student2", "Teacher"]
key: maps
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
# act: None
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
# name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student", "Student2", "Teacher"]
# key: maps
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
Global:
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/ch_db_mv3/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Architecture:
name: DistillationModel
algorithm: Distillation
Models:
Student:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Teacher:
pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params: true
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: ResNet
layers: 18
Neck:
name: DBFPN
out_channels: 256
Head:
name: DBHead
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDilaDBLoss:
weight: 1.0
model_name_pairs:
- ["Student", "Teacher"]
key: maps
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
# key: maps
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student", "Student2"]
key: head_out
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
Global:
use_gpu: true
epoch_num: 1200
log_smooth_window: 20
print_batch_step: 2
save_model_dir: ./output/ch_db_mv3/
save_epoch_step: 1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step: [3000, 2000]
cal_metric_during_train: False
pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_en/img_10.jpg
save_res_path: ./output/det_db/predicts_db.txt
Architecture:
name: DistillationModel
algorithm: Distillation
Models:
Student:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Student2:
pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params: false
return_all_feats: false
model_type: det
algorithm: DB
Transform:
Backbone:
name: MobileNetV3
scale: 0.5
model_name: large
disable_se: True
Neck:
name: DBFPN
out_channels: 96
Head:
name: DBHead
k: 50
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDMLLoss:
model_name_pairs:
- ["Student", "Student2"]
maps_name: "thrink_maps"
weight: 1.0
act: "softmax"
model_name_pairs: ["Student", "Student2"]
key: maps
- DistillationDBLoss:
weight: 1.0
model_name_list: ["Student", "Student2"]
# key: maps
name: DBLoss
balance_loss: true
main_loss_type: DiceLoss
alpha: 5
beta: 10
ohem_ratio: 3
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 2
regularizer:
name: 'L2'
factor: 0
PostProcess:
name: DistillationDBPostProcess
model_name: ["Student", "Student2"]
key: head_out
thresh: 0.3
box_thresh: 0.6
max_candidates: 1000
unclip_ratio: 1.5
Metric:
name: DistillationMetric
base_metric_name: DetMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list: [1.0]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- IaaAugment:
augmenter_args:
- { 'type': Fliplr, 'args': { 'p': 0.5 } }
- { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
- { 'type': Resize, 'args': { 'size': [0.5, 3] } }
- EastRandomCropData:
size: [960, 960]
max_tries: 50
keep_ratio: true
- MakeBorderMap:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MakeShrinkMap:
shrink_ratio: 0.4
min_text_size: 8
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/icdar2015/text_localization/
label_file_list:
- ./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- DetLabelEncode: # Class handling label
- DetResizeForTest:
# image_shape: [736, 1280]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 1 # must be 1
num_workers: 2
Global:
debug: false
use_gpu: true
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_chinese_lite_distillation_v2.1
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
character_type: ch
max_text_length: 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: Piecewise
decay_epochs : [700, 800]
values : [0.001, 0.0001]
warmup_epoch: 5
regularizer:
name: L2
factor: 2.0e-05
Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Loss:
name: CombinedLoss
loss_config_list:
- DistillationCTCLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: head_out
- DistillationDMLLoss:
weight: 1.0
act: "softmax"
model_name_pairs:
- ["Student", "Teacher"]
key: head_out
- DistillationDistanceLoss:
weight: 1.0
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
PostProcess:
name: DistillationCTCLabelDecode
model_name: ["Student", "Teacher"]
key: head_out
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: true
batch_size_per_card: 128
drop_last: true
num_sections: 1
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
keep_keys:
- image
- label
- length
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 8
此差异已折叠。
# 知识蒸馏
## 1. 简介
### 1.1 知识蒸馏介绍
近年来,深度神经网络在计算机视觉、自然语言处理等领域被验证是一种极其有效的解决问题的方法。通过构建合适的神经网络,加以训练,最终网络模型的性能指标基本上都会超过传统算法。
在数据量足够大的情况下,通过合理构建网络模型的方式增加其参数量,可以显著改善模型性能,但是这又带来了模型复杂度急剧提升的问题。大模型在实际场景中使用的成本较高。
深度神经网络一般有较多的参数冗余,目前有几种主要的方法对模型进行压缩,减小其参数量。如裁剪、量化、知识蒸馏等,其中知识蒸馏是指使用教师模型(teacher model)去指导学生模型(student model)学习特定任务,保证小模型在参数量不变的情况下,得到比较大的性能提升。
此外,在知识蒸馏任务中,也衍生出了互学习的模型训练方法,论文[Deep Mutual Learning](https://arxiv.org/abs/1706.00384)中指出,使用两个完全相同的模型在训练的过程中互相监督,可以达到比单个模型训练更好的效果。
### 1.2 PaddleOCR知识蒸馏简介
无论是大模型蒸馏小模型,还是小模型之间互相学习,更新参数,他们本质上是都是不同模型之间输出或者特征图(feature map)之间的相互监督,区别仅在于 (1) 模型是否需要固定参数。(2) 模型是否需要加载预训练模型。
对于大模型蒸馏小模型的情况,大模型一般需要加载预训练模型并固定参数;对于小模型之间互相蒸馏的情况,小模型一般都不加载预训练模型,参数也都是可学习的状态。
在知识蒸馏任务中,不只有2个模型之间进行蒸馏的情况,多个模型之间互相学习的情况也非常普遍。因此在知识蒸馏代码框架中,也有必要支持该种类别的蒸馏方法。
PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要的特点:
- 支持任意网络的互相学习,不要求子网络结构完全一致或者具有预训练模型;同时子网络数量也没有任何限制,只需要在配置文件中添加即可。
- 支持loss函数通过配置文件任意配置,不仅可以使用某种loss,也可以使用多种loss的组合
- 支持知识蒸馏训练、预测、评估与导出等所有模型相关的环境,方便使用与部署。
通过知识蒸馏,在中英文通用文字识别任务中,不增加任何预测耗时的情况下,可以给模型带来3%以上的精度提升,结合学习率调整策略以及模型结构微调策略,最终提升提升超过5%。
## 2. 配置文件解析
在知识蒸馏训练的过程中,数据预处理、优化器、学习率、全局的一些属性没有任何变化。模型结构、损失函数、后处理、指标计算等模块的配置文件需要进行微调。
下面以识别与检测的知识蒸馏配置文件为例,对知识蒸馏的训练与配置进行解析。
### 2.1 识别配置文件解析
配置文件在[rec_chinese_lite_train_distillation_v2.1.yml](../../configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml)
#### 2.1.1 模型结构
知识蒸馏任务中,模型结构配置如下所示。
```yaml
Architecture:
model_type: &model_type "rec" # 模型类别,rec、det等,每个子网络的的模型类别都与
name: DistillationModel # 结构名称,蒸馏任务中,为DistillationModel,用于构建对应的结构
algorithm: Distillation # 算法名称
Models: # 模型,包含子网络的配置信息
Teacher: # 子网络名称,至少需要包含`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
pretrained: # 该子网络是否需要加载预训练模型
freeze_params: false # 是否需要固定参数
return_all_feats: true # 子网络的参数,表示是否需要返回所有的features,如果为False,则只返回最后的输出
model_type: *model_type # 模型类别
algorithm: CRNN # 子网络的算法名称,该子网络剩余参与均为构造参数,与普通的模型训练配置一致
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student: # 另外一个子网络,这里给的是DML的蒸馏示例,两个子网络结构相同,均需要学习参数
pretrained: # 下面的组网参数同上
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
```
当然,这里如果希望添加更多的子网络进行训练,也可以按照`Student``Teacher`的添加方式,在配置文件中添加相应的字段。比如说如果希望有3个模型互相监督,共同训练,那么`Architecture`可以写为如下格式。
```yaml
Architecture:
model_type: &model_type "rec"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
Student2: # 知识蒸馏任务中引入的新的子网络,其他部分与上述配置相同
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: CRNN
Transform:
Backbone:
name: MobileNetV1Enhance
scale: 0.5
Neck:
name: SequenceEncoder
encoder_type: rnn
hidden_size: 64
Head:
name: CTCHead
mid_channels: 96
fc_decay: 0.00002
```
最终该模型训练时,包含3个子网络:`Teacher`, `Student`, `Student2`
蒸馏模型`DistillationModel`类的具体实现代码可以参考[distillation_model.py](../../ppocr/modeling/architectures/distillation_model.py)
最终模型`forward`输出为一个字典,key为所有的子网络名称,例如这里为`Student``Teacher`,value为对应子网络的输出,可以为`Tensor`(只返回该网络的最后一层)和`dict`(也返回了中间的特征信息)。
在识别任务中,为了添加更多损失函数,保证蒸馏方法的可扩展性,将每个子网络的输出保存为`dict`,其中包含子模块输出。以该识别模型为例,每个子网络的输出结果均为`dict`,key包含`backbone_out`,`neck_out`, `head_out``value`为对应模块的tensor,最终对于上述配置文件,`DistillationModel`的输出格式如下。
```json
{
"Teacher": {
"backbone_out": tensor,
"neck_out": tensor,
"head_out": tensor,
},
"Student": {
"backbone_out": tensor,
"neck_out": tensor,
"head_out": tensor,
}
}
```
#### 2.1.2 损失函数
知识蒸馏任务中,损失函数配置如下所示。
```yaml
Loss:
name: CombinedLoss # 损失函数名称,基于改名称,构建用于损失函数的类
loss_config_list: # 损失函数配置文件列表,为CombinedLoss的必备函数
- DistillationCTCLoss: # 基于蒸馏的CTC损失函数,继承自标准的CTC loss
weight: 1.0 # 损失函数的权重,loss_config_list中,每个损失函数的配置都必须包含该字段
model_name_list: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,与gt计算CTC loss
key: head_out # 取子网络输出dict中,该key对应的tensor
- DistillationDMLLoss: # 蒸馏的DML损失函数,继承自标准的DMLLoss
weight: 1.0 # 权重
act: "softmax" # 激活函数,对输入使用激活函数处理,可以为softmax, sigmoid或者为None,默认为None
model_name_pairs: # 用于计算DML loss的子网络名称对,如果希望计算其他子网络的DML loss,可以在列表下面继续填充
- ["Student", "Teacher"]
key: head_out # 取子网络输出dict中,该key对应的tensor
- DistillationDistanceLoss: # 蒸馏的距离损失函数
weight: 1.0 # 权重
mode: "l2" # 距离计算方法,目前支持l1, l2, smooth_l1
model_name_pairs: # 用于计算distance loss的子网络名称对
- ["Student", "Teacher"]
key: backbone_out # 取子网络输出dict中,该key对应的tensor
```
上述损失函数中,所有的蒸馏损失函数均继承自标准的损失函数类,主要功能为: 对蒸馏模型的输出进行解析,找到用于计算损失的中间节点(tensor),再使用标准的损失函数类去计算。
以上述配置为例,最终蒸馏训练的损失函数包含下面3个部分。
- `Student``Teacher`的最终输出(`head_out`)与gt的CTC loss,权重为1。在这里因为2个子网络都需要更新参数,因此2者都需要计算与g的loss。
- `Student``Teacher`的最终输出(`head_out`)之间的DML loss,权重为1。
- `Student``Teacher`的骨干网络输出(`backbone_out`)之间的l2 loss,权重为1。
关于`CombinedLoss`更加具体的实现可以参考: [combined_loss.py](../../ppocr/losses/combined_loss.py#L23)。关于`DistillationCTCLoss`等蒸馏损失函数更加具体的实现可以参考[distillation_loss.py](../../ppocr/losses/distillation_loss.py)
#### 2.1.3 后处理
知识蒸馏任务中,后处理配置如下所示。
```yaml
PostProcess:
name: DistillationCTCLabelDecode # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
model_name: ["Student", "Teacher"] # 对于蒸馏模型的预测结果,提取这两个子网络的输出,进行解码
key: head_out # 取子网络输出dict中,该key对应的tensor
```
以上述配置为例,最终会同时计算`Student``Teahcer` 2个子网络的CTC解码输出,返回一个`dict``key`为用于处理的子网络名称,`value`为用于处理的子网络列表。
关于`DistillationCTCLabelDecode`更加具体的实现可以参考: [rec_postprocess.py](../../ppocr/postprocess/rec_postprocess.py#L128)
#### 2.1.4 指标计算
知识蒸馏任务中,指标计算配置如下所示。
```yaml
Metric:
name: DistillationMetric # 蒸馏任务的CTC解码后处理,继承自标准的CTCLabelDecode类
base_metric_name: RecMetric # 指标计算的基类,对于模型的输出,会基于该类,计算指标
main_indicator: acc # 指标的名称
key: "Student" # 选取该子网络的 main_indicator 作为作为保存保存best model的判断标准
```
以上述配置为例,最终会使用`Student`子网络的acc指标作为保存best model的判断指标,同时,日志中也会打印出所有子网络的acc指标。
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)
### 2.2 检测配置文件解析
* coming soon!
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册