diff --git a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
index 27ba4fd70b9a7ee7d4d905b3948f6cbf2b7e9469..38f77f7372c4e422b5601deb5119c24fd1e3f787 100644
--- a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
+++ b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
@@ -4,7 +4,7 @@ Global:
epoch_num: 800
log_smooth_window: 20
print_batch_step: 10
- save_model_dir: ./output/rec_chinese_lite_distillation_v2.1
+ save_model_dir: ./output/rec_mobile_pp-OCRv2
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
@@ -19,7 +19,7 @@ Global:
infer_mode: false
use_space_char: true
distributed: true
- save_res_path: ./output/rec/predicts_chinese_lite_distillation_v2.1.txt
+ save_res_path: ./output/rec/predicts_mobile_pp-OCRv2.txt
Optimizer:
@@ -35,79 +35,32 @@ Optimizer:
name: L2
factor: 2.0e-05
+
Architecture:
- model_type: &model_type "rec"
- name: DistillationModel
- algorithm: Distillation
- Models:
- Teacher:
- pretrained:
- freeze_params: false
- return_all_feats: true
- model_type: *model_type
- algorithm: CRNN
- Transform:
- Backbone:
- name: MobileNetV1Enhance
- scale: 0.5
- Neck:
- name: SequenceEncoder
- encoder_type: rnn
- hidden_size: 64
- Head:
- name: CTCHead
- mid_channels: 96
- fc_decay: 0.00002
- Student:
- pretrained:
- freeze_params: false
- return_all_feats: true
- model_type: *model_type
- algorithm: CRNN
- Transform:
- Backbone:
- name: MobileNetV1Enhance
- scale: 0.5
- Neck:
- name: SequenceEncoder
- encoder_type: rnn
- hidden_size: 64
- Head:
- name: CTCHead
- mid_channels: 96
- fc_decay: 0.00002
-
+ model_type: rec
+ algorithm: CRNN
+ Transform:
+ Backbone:
+ name: MobileNetV1Enhance
+ scale: 0.5
+ Neck:
+ name: SequenceEncoder
+ encoder_type: rnn
+ hidden_size: 64
+ Head:
+ name: CTCHead
+ mid_channels: 96
+ fc_decay: 0.00002
Loss:
- name: CombinedLoss
- loss_config_list:
- - DistillationCTCLoss:
- weight: 1.0
- model_name_list: ["Student", "Teacher"]
- key: head_out
- - DistillationDMLLoss:
- weight: 1.0
- act: "softmax"
- model_name_pairs:
- - ["Student", "Teacher"]
- key: head_out
- - DistillationDistanceLoss:
- weight: 1.0
- mode: "l2"
- model_name_pairs:
- - ["Student", "Teacher"]
- key: backbone_out
+ name: CTCLoss
PostProcess:
- name: DistillationCTCLabelDecode
- model_name: ["Student", "Teacher"]
- key: head_out
+ name: CTCLabelDecode
Metric:
- name: DistillationMetric
- base_metric_name: RecMetric
+ name: RecMetric
main_indicator: acc
- key: "Student"
Train:
dataset:
@@ -132,7 +85,6 @@ Train:
shuffle: true
batch_size_per_card: 128
drop_last: true
- num_sections: 1
num_workers: 8
Eval:
dataset:
diff --git a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml
new file mode 100644
index 0000000000000000000000000000000000000000..d2308fd5747f3fadf3bb1c98c5602c67d5e63eca
--- /dev/null
+++ b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml
@@ -0,0 +1,160 @@
+Global:
+ debug: false
+ use_gpu: true
+ epoch_num: 800
+ log_smooth_window: 20
+ print_batch_step: 10
+ save_model_dir: ./output/rec_pp-OCRv2_distillation
+ save_epoch_step: 3
+ eval_batch_step: [0, 2000]
+ cal_metric_during_train: true
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: false
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ character_dict_path: ppocr/utils/ppocr_keys_v1.txt
+ character_type: ch
+ max_text_length: 25
+ infer_mode: false
+ use_space_char: true
+ distributed: true
+ save_res_path: ./output/rec/predicts_pp-OCRv2_distillation.txt
+
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ lr:
+ name: Piecewise
+ decay_epochs : [700, 800]
+ values : [0.001, 0.0001]
+ warmup_epoch: 5
+ regularizer:
+ name: L2
+ factor: 2.0e-05
+
+Architecture:
+ model_type: &model_type "rec"
+ name: DistillationModel
+ algorithm: Distillation
+ Models:
+ Teacher:
+ pretrained:
+ freeze_params: false
+ return_all_feats: true
+ model_type: *model_type
+ algorithm: CRNN
+ Transform:
+ Backbone:
+ name: MobileNetV1Enhance
+ scale: 0.5
+ Neck:
+ name: SequenceEncoder
+ encoder_type: rnn
+ hidden_size: 64
+ Head:
+ name: CTCHead
+ mid_channels: 96
+ fc_decay: 0.00002
+ Student:
+ pretrained:
+ freeze_params: false
+ return_all_feats: true
+ model_type: *model_type
+ algorithm: CRNN
+ Transform:
+ Backbone:
+ name: MobileNetV1Enhance
+ scale: 0.5
+ Neck:
+ name: SequenceEncoder
+ encoder_type: rnn
+ hidden_size: 64
+ Head:
+ name: CTCHead
+ mid_channels: 96
+ fc_decay: 0.00002
+
+
+Loss:
+ name: CombinedLoss
+ loss_config_list:
+ - DistillationCTCLoss:
+ weight: 1.0
+ model_name_list: ["Student", "Teacher"]
+ key: head_out
+ - DistillationDMLLoss:
+ weight: 1.0
+ act: "softmax"
+ use_log: true
+ model_name_pairs:
+ - ["Student", "Teacher"]
+ key: head_out
+ - DistillationDistanceLoss:
+ weight: 1.0
+ mode: "l2"
+ model_name_pairs:
+ - ["Student", "Teacher"]
+ key: backbone_out
+
+PostProcess:
+ name: DistillationCTCLabelDecode
+ model_name: ["Student", "Teacher"]
+ key: head_out
+
+Metric:
+ name: DistillationMetric
+ base_metric_name: RecMetric
+ main_indicator: acc
+ key: "Student"
+
+Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data/
+ label_file_list:
+ - ./train_data/train_list.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - RecAug:
+ - CTCLabelEncode:
+ - RecResizeImg:
+ image_shape: [3, 32, 320]
+ - KeepKeys:
+ keep_keys:
+ - image
+ - label
+ - length
+ loader:
+ shuffle: true
+ batch_size_per_card: 128
+ drop_last: true
+ num_sections: 1
+ num_workers: 8
+Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: ./train_data
+ label_file_list:
+ - ./train_data/val_list.txt
+ transforms:
+ - DecodeImage:
+ img_mode: BGR
+ channel_first: false
+ - CTCLabelEncode:
+ - RecResizeImg:
+ image_shape: [3, 32, 320]
+ - KeepKeys:
+ keep_keys:
+ - image
+ - label
+ - length
+ loader:
+ shuffle: false
+ drop_last: false
+ batch_size_per_card: 128
+ num_workers: 8
diff --git a/doc/doc_ch/angle_class.md b/doc/doc_ch/angle_class.md
index ad25a6661817623419af0c0c7a139dd4bfaeb08c..321b32ba48e599fb6e72f697fa438a3a54e33337 100644
--- a/doc/doc_ch/angle_class.md
+++ b/doc/doc_ch/angle_class.md
@@ -1,14 +1,24 @@
-## 文字角度分类
-### 方法介绍
-文字角度分类主要用于图片非0度的场景下,在这种场景下需要对图片里检测到的文本行进行一个转正的操作。在PaddleOCR系统内,
+# 文本方向分类器
+
+- [方法介绍](#方法介绍)
+- [数据准备](#数据准备)
+- [启动训练](#启动训练)
+- [训练](#训练)
+- [评估](#评估)
+- [预测](#预测)
+
+
+## 方法介绍
+文本方向分类器主要用于图片非0度的场景下,在这种场景下需要对图片里检测到的文本行进行一个转正的操作。在PaddleOCR系统内,
文字检测之后得到的文本行图片经过仿射变换之后送入识别模型,此时只需要对文字进行一个0和180度的角度分类,因此PaddleOCR内置的
-文字角度分类器**只支持了0和180度的分类**。如果想支持更多角度,可以自己修改算法进行支持。
+文本方向分类器**只支持了0和180度的分类**。如果想支持更多角度,可以自己修改算法进行支持。
0和180度数据样本例子:
![](../imgs_results/angle_class_example.jpg)
-### 数据准备
+
+## 数据准备
请按如下步骤设置数据集:
@@ -59,8 +69,8 @@ train/cls/train/word_002.jpg 180
|- word_003.jpg
| ...
```
-
-### 启动训练
+
+## 启动训练
将准备好的txt文件和图片文件夹路径分别写入配置文件的 `Train/Eval.dataset.label_file_list` 和 `Train/Eval.dataset.data_dir` 字段下,`Train/Eval.dataset.data_dir`字段下的路径和文件里记载的图片名构成了图片的绝对路径。
@@ -88,7 +98,8 @@ PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入
*由于OpenCV的兼容性问题,扰动操作暂时只支持linux*
-### 训练
+
+## 训练
PaddleOCR支持训练和评估交替进行, 可以在 `configs/cls/cls_mv3.yml` 中修改 `eval_batch_step` 设置评估频率,默认每1000个iter评估一次。训练过程中将会保存如下内容:
```bash
@@ -106,7 +117,8 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/cls/cls_mv3.yml`
**注意,预测/评估时的配置文件请务必与训练一致。**
-### 评估
+
+## 评估
评估数据集可以通过修改`configs/cls/cls_mv3.yml`文件里的`Eval.dataset.label_file_list` 字段设置。
@@ -116,7 +128,8 @@ export CUDA_VISIBLE_DEVICES=0
python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy
```
-### 预测
+
+## 预测
* 训练引擎的预测
diff --git a/doc/doc_ch/detection.md b/doc/doc_ch/detection.md
index 66295b25252e3906b4d3e6ffb30b135f0c6bdf6c..089b3b8f40e6af7bc10b0c0191bf553be8812ddd 100644
--- a/doc/doc_ch/detection.md
+++ b/doc/doc_ch/detection.md
@@ -1,22 +1,24 @@
-# 目录
-- [1. 文字检测](#1-----)
- * [1.1 数据准备](#11-----)
- * [1.2 下载预训练模型](#12--------)
- * [1.3 启动训练](#13-----)
- * [1.4 断点训练](#14-----)
- * [1.5 更换Backbone 训练](#15---backbone---)
- * [1.6 指标评估](#16-----)
- * [1.7 测试检测效果](#17-------)
- * [1.8 转inference模型测试](#18--inference----)
-- [2. FAQ](#2-faq)
-
-
-
-# 1. 文字检测
+# 文字检测
本节以icdar2015数据集为例,介绍PaddleOCR中检测模型训练、评估、测试的使用方式。
+- [1. 准备数据和模型](#1--------)
+ * [1.1 数据准备](#11-----)
+ * [1.2 下载预训练模型](#12--------)
+- [2. 开始训练](#2-----)
+ * [2.1 启动训练](#21-----)
+ * [2.2 断点训练](#22-----)
+ * [2.3 更换Backbone 训练](#23---backbone---)
+- [3. 模型评估与预测](#3--------)
+ * [3.1 指标评估](#31-----)
+ * [3.2 测试检测效果](#32-------)
+- [4. 模型导出与预测](#4--------)
+- [5. FAQ](#5-faq)
+
+
+# 1. 准备数据和模型
+
## 1.1 数据准备
@@ -83,8 +85,11 @@ wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dyg
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams
```
-
-## 1.3 启动训练
+
+# 2. 开始训练
+
+
+## 2.1 启动训练
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
@@ -106,8 +111,8 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
```
-
-## 1.4 断点训练
+
+## 2.2 断点训练
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
```shell
@@ -116,8 +121,8 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./you
**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
-
-## 1.5 更换Backbone 训练
+
+## 2.3 更换Backbone 训练
PaddleOCR将网络划分为四部分,分别在[ppocr/modeling](../../ppocr/modeling)下。 进入网络的数据将按照顺序(transforms->backbones->
necks->heads)依次通过这四个部分。
@@ -164,8 +169,11 @@ args1: args1
**注意**:如果要更换网络的其他模块,可以参考[文档](./add_new_algorithm.md)。
-
-## 1.6 指标评估
+
+# 3. 模型评估与预测
+
+
+## 3.1 指标评估
PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean(F-Score)。
@@ -177,8 +185,8 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat
* 注:`box_thresh`、`unclip_ratio`是DB后处理所需要的参数,在评估EAST模型时不需要设置
-
-## 1.7 测试检测效果
+
+## 3.2 测试检测效果
测试单张图像的检测效果
```shell
@@ -195,8 +203,8 @@ python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy"
```
-
-## 1.8 转inference模型测试
+
+# 4. 模型导出与预测
inference 模型(`paddle.jit.save`保存的模型)
一般是模型训练,把模型结构和模型参数保存在文件中的固化模型,多用于预测部署场景。
@@ -218,8 +226,8 @@ python3 tools/infer/predict_det.py --det_algorithm="DB" --det_model_dir="./outpu
python3 tools/infer/predict_det.py --det_algorithm="EAST" --det_model_dir="./output/det_db_inference/" --image_dir="./doc/imgs/" --use_gpu=True
```
-
-# 2. FAQ
+
+# 5. FAQ
Q1: 训练模型转inference 模型之后预测效果不一致?
**A**:此类问题出现较多,问题多是trained model预测时候的预处理、后处理参数和inference model预测的时候的预处理、后处理参数不一致导致的。以det_mv3_db.yml配置文件训练的模型为例,训练模型、inference模型预测结果不一致问题解决方式如下:
diff --git a/doc/doc_ch/knowledge_distillation.md b/doc/doc_ch/knowledge_distillation.md
index 5827f48c81d51a674011e2df40c798e0548fb0a1..b2772454d90ba40e5d65e035d083f8fcd79f69af 100644
--- a/doc/doc_ch/knowledge_distillation.md
+++ b/doc/doc_ch/knowledge_distillation.md
@@ -39,7 +39,7 @@ PaddleOCR中集成了知识蒸馏的算法,具体地,有以下几个主要
### 2.1 识别配置文件解析
-配置文件在[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)。
+配置文件在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)。
#### 2.1.1 模型结构
@@ -246,6 +246,39 @@ Metric:
关于`DistillationMetric`更加具体的实现可以参考: [distillation_metric.py](../../ppocr/metrics/distillation_metric.py#L24)。
+#### 2.1.5 蒸馏模型微调
+
+对蒸馏得到的识别蒸馏进行微调有2种方式。
+
+(1)基于知识蒸馏的微调:这种情况比较简单,下载预训练模型,在[ch_PP-OCRv2_rec_distillation.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml)中配置好预训练模型路径以及自己的数据路径,即可进行模型微调训练。
+
+(2)微调时不使用知识蒸馏:这种情况,需要首先将预训练模型中的学生模型参数提取出来,具体步骤如下。
+
+* 首先下载预训练模型并解压。
+```shell
+# 下面预训练模型并解压
+wget https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar
+tar -xf ch_PP-OCRv2_rec_train.tar
+```
+
+* 然后使用python,对其中的学生模型参数进行提取
+
+```python
+import paddle
+# 加载预训练模型
+all_params = paddle.load("ch_PP-OCRv2_rec_train/best_accuracy.pdparams")
+# 查看权重参数的keys
+print(all_params.keys())
+# 学生模型的权重提取
+s_params = {key[len("Student."):]: all_params[key] for key in all_params if "Student." in key}
+# 查看学生模型权重参数的keys
+print(s_params.keys())
+# 保存
+paddle.save(s_params, "ch_PP-OCRv2_rec_train/student.pdparams")
+```
+
+转化完成之后,使用[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml),修改预训练模型的路径(为导出的`student.pdparams`模型路径)以及自己的数据路径,即可进行模型微调。
+
### 2.2 检测配置文件解析
* coming soon!
diff --git a/doc/doc_en/angle_class_en.md b/doc/doc_en/angle_class_en.md
index 0044d85ac0a43529c67746d25118bd80ee52be9a..dd7cc1e4b916b9cdb7f99600710bcb844e790f90 100644
--- a/doc/doc_en/angle_class_en.md
+++ b/doc/doc_en/angle_class_en.md
@@ -1,13 +1,22 @@
-## TEXT ANGLE CLASSIFICATION
+# TEXT ANGLE CLASSIFICATION
-### Method introduction
+- [Method Introduction](#method-introduction)
+- [Data Preparation](#data-preparation)
+- [Training](#training)
+- [Evaluation](#evaluation)
+- [Prediction](#prediction)
+
+
+## Method Introduction
The angle classification is used in the scene where the image is not 0 degrees. In this scene, it is necessary to perform a correction operation on the text line detected in the picture. In the PaddleOCR system,
The text line image obtained after text detection is sent to the recognition model after affine transformation. At this time, only a 0 and 180 degree angle classification of the text is required, so the built-in PaddleOCR text angle classifier **only supports 0 and 180 degree classification**. If you want to support more angles, you can modify the algorithm yourself to support.
Example of 0 and 180 degree data samples:
![](../imgs_results/angle_class_example.jpg)
-### DATA PREPARATION
+
+
+## Data Preparation
Please organize the dataset as follows:
@@ -62,8 +71,8 @@ containing all images (test) and a cls_gt_test.txt. The structure of the test se
|- word_003.jpg
| ...
```
-
-### TRAINING
+
+## Training
Write the prepared txt file and image folder path into the configuration file under the `Train/Eval.dataset.label_file_list` and `Train/Eval.dataset.data_dir` fields, the absolute path of the image consists of the `Train/Eval.dataset.data_dir` field and the image name recorded in the txt file.
PaddleOCR provides training scripts, evaluation scripts, and prediction scripts.
@@ -107,7 +116,8 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
**Note that the configuration file for prediction/evaluation must be consistent with the training.**
-### EVALUATION
+
+## Evaluation
The evaluation dataset can be set by modifying the `Eval.dataset.label_file_list` field in the `configs/cls/cls_mv3.yml` file.
@@ -116,8 +126,8 @@ export CUDA_VISIBLE_DEVICES=0
# GPU evaluation, Global.checkpoints is the weight to be tested
python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy
```
-
-### PREDICTION
+
+## Prediction
* Training engine prediction
diff --git a/doc/doc_en/detection_en.md b/doc/doc_en/detection_en.md
index d3f6f3da102d06c53e4e179a0bd89670536e1af7..016cf929283c0b8dac5cd1f0b3c808c398186917 100644
--- a/doc/doc_en/detection_en.md
+++ b/doc/doc_en/detection_en.md
@@ -1,21 +1,21 @@
-# CONTENT
+# TEXT DETECTION
-- [Paste Your Document In Here](#paste-your-document-in-here)
-- [1. TEXT DETECTION](#1-text-detection)
+This section uses the icdar2015 dataset as an example to introduce the training, evaluation, and testing of the detection model in PaddleOCR.
+
+- [1. DATA AND WEIGHTS PREPARATIO](#1-data-and-weights-preparatio)
* [1.1 DATA PREPARATION](#11-data-preparation)
* [1.2 DOWNLOAD PRETRAINED MODEL](#12-download-pretrained-model)
- * [1.3 START TRAINING](#13-start-training)
- * [1.4 LOAD TRAINED MODEL AND CONTINUE TRAINING](#14-load-trained-model-and-continue-training)
- * [1.5 TRAINING WITH NEW BACKBONE](#15-training-with-new-backbone)
- * [1.6 EVALUATION](#16-evaluation)
- * [1.7 TEST](#17-test)
- * [1.8 INFERENCE MODEL PREDICTION](#18-inference-model-prediction)
+- [2. TRAINING](#2-training)
+ * [2.1 START TRAINING](#21-start-training)
+ * [2.2 LOAD TRAINED MODEL AND CONTINUE TRAINING](#22-load-trained-model-and-continue-training)
+ * [2.3 TRAINING WITH NEW BACKBONE](#23-training-with-new-backbone)
+- [3. EVALUATION AND TEST](#3-evaluation-and-test)
+ * [3.1 EVALUATION](#31-evaluation)
+ * [3.2 TEST](#32-test)
+- [4. INFERENCE](#4-inference)
- [2. FAQ](#2-faq)
-
-# 1. TEXT DETECTION
-
-This section uses the icdar2015 dataset as an example to introduce the training, evaluation, and testing of the detection model in PaddleOCR.
+# 1 DATA AND WEIGHTS PREPARATIO
## 1.1 DATA PREPARATION
@@ -75,7 +75,10 @@ wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dyg
```
-## 1.3 START TRAINING
+# 2. TRAINING
+
+## 2.1 START TRAINING
+
*If CPU version installed, please set the parameter `use_gpu` to `false` in the configuration.*
```shell
python3 tools/train.py -c configs/det/det_mv3_db.yml \
@@ -98,7 +101,7 @@ python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs
```
-## 1.4 LOAD TRAINED MODEL AND CONTINUE TRAINING
+## 2.2 LOAD TRAINED MODEL AND CONTINUE TRAINING
If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded.
For example:
@@ -109,7 +112,7 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./you
**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrain_weights`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrain_weights` will be loaded.
-## 1.5 TRAINING WITH NEW BACKBONE
+## 2.3 TRAINING WITH NEW BACKBONE
The network part completes the construction of the network, and PaddleOCR divides the network into four parts, which are under [ppocr/modeling](../../ppocr/modeling). The data entering the network will pass through these four parts in sequence(transforms->backbones->
necks->heads).
@@ -159,7 +162,9 @@ After adding the four-part modules of the network, you only need to configure th
**NOTE**: More details about replace Backbone and other mudule can be found in [doc](add_new_algorithm_en.md).
-## 1.6 EVALUATION
+# 3. EVALUATION AND TEST
+
+## 3.1 EVALUATION
PaddleOCR calculates three indicators for evaluating performance of OCR detection task: Precision, Recall, and Hmean(F-Score).
@@ -174,7 +179,7 @@ python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{pat
* Note: `box_thresh` and `unclip_ratio` are parameters required for DB post-processing, and not need to be set when evaluating the EAST and SAST model.
-## 1.7 TEST
+## 3.2 TEST
Test the detection result on a single image:
```shell
@@ -192,7 +197,7 @@ Test the detection result on all images in the folder:
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy"
```
-## 1.8 INFERENCE MODEL PREDICTION
+# 4. INFERENCE
The inference model (the model saved by `paddle.jit.save`) is generally a solidified model saved after the model training is completed, and is mostly used to give prediction in deployment.
diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py
index 8306523ac1a933f0c664fc0b4cf077659cccdee3..d2ef5e5ac9692eec5bc30774c4451eab7706705d 100644
--- a/ppocr/losses/basic_loss.py
+++ b/ppocr/losses/basic_loss.py
@@ -56,31 +56,34 @@ class CELoss(nn.Layer):
class KLJSLoss(object):
def __init__(self, mode='kl'):
- assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
+ assert mode in ['kl', 'js', 'KL', 'JS'
+ ], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
- loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
+ loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js":
- loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
+ loss += paddle.multiply(
+ p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
if reduction == "mean":
- loss = paddle.mean(loss, axis=[1,2])
- elif reduction=="none" or reduction is None:
- return loss
+ loss = paddle.mean(loss, axis=[1, 2])
+ elif reduction == "none" or reduction is None:
+ return loss
else:
- loss = paddle.sum(loss, axis=[1,2])
+ loss = paddle.sum(loss, axis=[1, 2])
+
+ return loss
- return loss
class DMLLoss(nn.Layer):
"""
DMLLoss
"""
- def __init__(self, act=None):
+ def __init__(self, act=None, use_log=False):
super().__init__()
if act is not None:
assert act in ["softmax", "sigmoid"]
@@ -90,20 +93,24 @@ class DMLLoss(nn.Layer):
self.act = nn.Sigmoid()
else:
self.act = None
-
+
+ self.use_log = use_log
+
self.jskl_loss = KLJSLoss(mode="js")
def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)
- if len(out1.shape) < 2:
+ if self.use_log:
+ # for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0
else:
+ # for detection distillation log is not needed
loss = self.jskl_loss(out1, out2)
return loss
diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py
index 0d6fe968d0d7733200a4cfd21d779196cccaba03..f3bb36cf5ac751e6c27e4aa29a46fc5f913f7d05 100644
--- a/ppocr/losses/combined_loss.py
+++ b/ppocr/losses/combined_loss.py
@@ -49,11 +49,15 @@ class CombinedLoss(nn.Layer):
loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
+
weight = self.loss_weight[idx]
- for key in loss.keys():
- if key == "loss":
- loss_all += loss[key] * weight
- else:
- loss_dict["{}_{}".format(key, idx)] = loss[key]
+
+ loss = {key: loss[key] * weight for key in loss}
+
+ if "loss" in loss:
+ loss_all += loss["loss"]
+ else:
+ loss_all += paddle.add_n(list(loss.values()))
+ loss_dict.update(loss)
loss_dict["loss"] = loss_all
return loss_dict
diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py
index 75f0a773152e52c98ada5c1907f1c8cc2f72d8f3..73d3ae2ad2499607f897a102f6ea25e4cb7f297f 100644
--- a/ppocr/losses/distillation_loss.py
+++ b/ppocr/losses/distillation_loss.py
@@ -44,20 +44,22 @@ class DistillationDMLLoss(DMLLoss):
def __init__(self,
model_name_pairs=[],
act=None,
+ use_log=False,
key=None,
maps_name=None,
name="dml"):
- super().__init__(act=act)
+ super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
self.maps_name = self._check_maps_name(maps_name)
-
+
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
- elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
+ elif isinstance(model_name_pairs[0], list) and isinstance(
+ model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
@@ -112,9 +114,9 @@ class DistillationDMLLoss(DMLLoss):
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
0], pair[1], map_name, idx)] = loss[key]
else:
- loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
- idx)] = loss
-
+ loss_dict["{}_{}_{}".format(self.name, self.maps_name[
+ _c], idx)] = loss
+
loss_dict = _sum_loss(loss_dict)
return loss_dict
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index 3bb022ed98b140995b79ceea93d7f494d3f5930d..a7d24dd71a6e35ca619c2a3f90df3a202b8ad94b 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -108,14 +108,15 @@ def load_dygraph_params(config, model, logger, optimizer):
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
- else:
- logger.info(
- f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
- )
+ else:
+ logger.info(
+ f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
+ )
model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}")
return {}
+
def load_pretrained_params(model, path):
if path is None:
return False
@@ -138,6 +139,7 @@ def load_pretrained_params(model, path):
print(f"load pretrain successful from {path}")
return model
+
def save_model(model,
optimizer,
model_path,