```
-#### 1.1 自定义数据集
+### 1.1 自定义数据集
下面以通用数据集为例, 介绍如何准备数据集:
* 训练集
@@ -82,13 +80,15 @@ train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单
-1.2 数据下载
+### 1.2 数据下载
+
+- ICDAR2015
-若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,下载 benchmark 所需的lmdb格式数据集。
+若您本地没有数据集,可以在官网下载 [ICDAR2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,下载 benchmark 所需的lmdb格式数据集。
-如果你使用的是icdar2015的公开数据集,PaddleOCR 提供了一份用于训练 icdar2015 数据集的标签文件,通过以下方式下载:
+如果希望复现SAR的论文指标,需要下载[SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg), 提取码:627x。此外,真实数据集icdar2013, icdar2015, cocotext, IIIT5也作为训练数据的一部分。具体数据细节可以参考论文SAR。
-如果希望复现SRN的论文指标,需要下载离线[增广数据](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA),提取码: y3ry。增广数据是由MJSynth和SynthText做旋转和扰动得到的。数据下载完成后请解压到 {your_path}/PaddleOCR/train_data/data_lmdb_release/training/ 路径下。
+如果你使用的是icdar2015的公开数据集,PaddleOCR 提供了一份用于训练 ICDAR2015 数据集的标签文件,通过以下方式下载:
```
# 训练集标签
@@ -97,15 +97,25 @@ wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_t
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_test.txt
```
-PaddleOCR 也提供了数据格式转换脚本,可以将官网 label 转换支持的数据格式。 数据转换工具在 `ppocr/utils/gen_label.py`, 这里以训练集为例:
+PaddleOCR 也提供了数据格式转换脚本,可以将ICDAR官网 label 转换为PaddleOCR支持的数据格式。 数据转换工具在 `ppocr/utils/gen_label.py`, 这里以训练集为例:
```
# 将官网下载的标签文件转换为 rec_gt_label.txt
python gen_label.py --mode="rec" --input_path="{path/of/origin/label}" --output_label="rec_gt_label.txt"
```
+数据样式格式如下,(a)为原始图片,(b)为每张图片对应的 Ground Truth 文本文件:
+
+
+- 多语言数据集
+
+多语言模型的训练数据集均为100w的合成数据,使用了开源合成工具 [text_renderer](https://github.com/Sanster/text_renderer) ,少量的字体可以通过下面两种方式下载。
+* [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) 提取码:frgi
+* [google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
+
+
-1.3 字典
+### 1.3 字典
最后需要提供一个字典({word_dict_name}.txt),使模型在训练时,可以将所有出现的字符映射为字典的索引。
@@ -149,16 +159,29 @@ PaddleOCR内置了一部分字典,可以按需使用。
- 自定义字典
如需自定义dic文件,请在 `configs/rec/rec_icdar15_train.yml` 中添加 `character_dict_path` 字段, 指向您的字典路径。
-并将 `character_type` 设置为 `ch`。
-1.4 添加空格类别
+### 1.4 添加空格类别
如果希望支持识别"空格"类别, 请将yml文件中的 `use_space_char` 字段设置为 `True`。
-### 2. 启动训练
+## 2. 启动训练
+
+
+### 2.1 数据增强
+
+PaddleOCR提供了多种数据增强方式,默认配置文件中已经添加了数据增广。
+
+默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse)、TIA数据增广。
+
+训练过程中每种扰动方式以40%的概率被选择,具体代码实现请参考:[rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
+
+*由于OpenCV的兼容性问题,扰动操作暂时只支持Linux*
+
+
+### 2.2 通用模型训练
PaddleOCR提供了训练脚本、评估脚本和预测脚本,本节将以 CRNN 识别模型为例:
@@ -178,23 +201,16 @@ tar -xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf rec_mv3_none_bilstm_ctc
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```
-# GPU训练 支持单卡,多卡训练,通过--gpus参数指定卡号
+# GPU训练 支持单卡,多卡训练
# 训练icdar15英文数据 训练日志会自动保存为 "{save_model_dir}" 下的train.log
-python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_icdar15_train.yml
-```
-
-#### 2.1 数据增强
-PaddleOCR提供了多种数据增强方式,如果您希望在训练时加入扰动,请在配置文件中设置 `distort: true`。
+#单卡训练(训练周期长,不建议)
+python3 tools/train.py -c configs/rec/rec_icdar15_train.yml
-默认的扰动方式有:颜色空间转换(cvtColor)、模糊(blur)、抖动(jitter)、噪声(Gasuss noise)、随机切割(random crop)、透视(perspective)、颜色反转(reverse)。
-
-训练过程中每种扰动方式以50%的概率被选择,具体代码实现请参考:[img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py)
-
-*由于OpenCV的兼容性问题,扰动操作暂时只支持Linux*
+#多卡训练,通过--gpus参数指定卡号
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_icdar15_train.yml
+```
-
-#### 2.2 训练
PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_train.yml` 中修改 `eval_batch_step` 设置评估频率,默认每500个iter评估一次。评估过程中默认将最佳acc模型,保存为 `output/rec_CRNN/best_accuracy` 。
@@ -215,6 +231,11 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
+| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
+| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |
+| rec_resnet_stn_bilstm_att.yml | SEED | Aster_Resnet | STN | BiLSTM | att |
+
+*其中SEED模型需要额外加载FastText训练好的[语言模型](https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.en.300.bin.gz)
训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件:
@@ -224,8 +245,6 @@ Global:
...
# 添加自定义字典,如修改字典请将路径指向新字典
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
- # 修改字符类型
- character_type: ch
...
# 识别空格
use_space_char: True
@@ -282,105 +301,28 @@ Eval:
```
**注意,预测/评估时的配置文件请务必与训练一致。**
-
-#### 2.3 小语种
+
+### 2.3 多语言模型训练
PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi_languages` 路径下提供了一个多语言的配置文件模版: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)。
-您有两种方式创建所需的配置文件:
-
-1. 通过脚本自动生成
-
-[generate_multi_language_configs.py](../../configs/rec/multi_language/generate_multi_language_configs.py) 可以帮助您生成多语言模型的配置文件
-
-- 以意大利语为例,如果您的数据是按如下格式准备的:
- ```
- |-train_data
- |- it_train.txt # 训练集标签
- |- it_val.txt # 验证集标签
- |- data
- |- word_001.jpg
- |- word_002.jpg
- |- word_003.jpg
- | ...
- ```
-
- 可以使用默认参数,生成配置文件:
-
- ```bash
- # 该代码需要在指定目录运行
- cd PaddleOCR/configs/rec/multi_language/
- # 通过-l或者--language参数设置需要生成的语种的配置文件,该命令会将默认参数写入配置文件
- python3 generate_multi_language_configs.py -l it
- ```
-
-- 如果您的数据放置在其他位置,或希望使用自己的字典,可以通过指定相关参数来生成配置文件:
-
- ```bash
- # -l或者--language字段是必须的
- # --train修改训练集,--val修改验证集,--data_dir修改数据集目录,--dict修改字典路径, -o修改对应默认参数
- cd PaddleOCR/configs/rec/multi_language/
- python3 generate_multi_language_configs.py -l it \ # 语种
- --train {path/of/train_label.txt} \ # 训练标签文件的路径
- --val {path/of/val_label.txt} \ # 验证集标签文件的路径
- --data_dir {train_data/path} \ # 训练数据的根目录
- --dict {path/of/dict} \ # 字典文件路径
- -o Global.use_gpu=False # 是否使用gpu
- ...
-
- ```
-
-意大利文由拉丁字母组成,因此执行完命令后会得到名为 rec_latin_lite_train.yml 的配置文件。
-
-2. 手动修改配置文件
-
- 您也可以手动修改模版中的以下几个字段:
-
- ```
- Global:
- use_gpu: True
- epoch_num: 500
- ...
- character_type: it # 需要识别的语种
- character_dict_path: {path/of/dict} # 字典文件所在路径
-
- Train:
- dataset:
- name: SimpleDataSet
- data_dir: train_data/ # 数据存放根目录
- label_file_list: ["./train_data/train_list.txt"] # 训练集label路径
- ...
-
- Eval:
- dataset:
- name: SimpleDataSet
- data_dir: train_data/ # 数据存放根目录
- label_file_list: ["./train_data/val_list.txt"] # 验证集label路径
- ...
-
- ```
-
-目前PaddleOCR支持的多语言算法有:
-
-| 配置文件 | 算法名称 | backbone | trans | seq | pred | language | character_type |
-| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
-| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 中文繁体 | chinese_cht|
-| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语(区分大小写) | EN |
-| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 | french |
-| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 | german |
-| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 | japan |
-| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 | korean |
-| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 拉丁字母 | latin |
-| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯字母 | ar |
-| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 斯拉夫字母 | cyrillic |
-| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 梵文字母 | devanagari |
+按语系划分,目前PaddleOCR支持的语种有:
+
+| 配置文件 | 算法名称 | backbone | trans | seq | pred | language |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 中文繁体 |
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语(区分大小写) |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 |
+| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 拉丁字母 |
+| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯字母 |
+| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 斯拉夫字母 |
+| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 梵文字母 |
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
-多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以通过下面两种方式下载。
-* [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA)。提取码:frgi。
-* [google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
-
如您希望在现有模型效果的基础上调优,请参考下列说明修改配置文件:
以 `rec_french_lite_train` 为例:
@@ -416,7 +358,7 @@ Eval:
...
```
-### 3 评估
+## 3 评估
评估数据集可以通过 `configs/rec/rec_icdar15_train.yml` 修改Eval中的 `label_file_path` 设置。
@@ -426,14 +368,29 @@ python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec
```
-### 4 预测
-
-
-#### 4.1 训练引擎的预测
+## 4 预测
使用 PaddleOCR 训练好的模型,可以通过以下脚本进行快速预测。
-默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 指定权重:
+默认预测图片存储在 `infer_img` 里,通过 `-o Global.checkpoints` 加载训练好的参数文件:
+
+根据配置文件中设置的的 `save_model_dir` 和 `save_epoch_step` 字段,会有以下几种参数被保存下来:
+
+```
+output/rec/
+├── best_accuracy.pdopt
+├── best_accuracy.pdparams
+├── best_accuracy.states
+├── config.yml
+├── iter_epoch_3.pdopt
+├── iter_epoch_3.pdparams
+├── iter_epoch_3.states
+├── latest.pdopt
+├── latest.pdparams
+├── latest.states
+└── train.log
+```
+其中 best_accuracy.* 是评估集上的最优模型;iter_epoch_x.* 是以 `save_epoch_step` 为间隔保存下来的模型;latest.* 是最后一个epoch的模型。
```
# 预测英文结果
@@ -469,3 +426,37 @@ python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v
infer_img: doc/imgs_words/ch/word_1.jpg
result: ('韩国小馆', 0.997218)
```
+
+
+
+## 5. 转Inference模型测试
+
+识别模型转inference模型与检测的方式相同,如下:
+
+```
+# -c 后面设置训练算法的yml配置文件
+# -o 配置可选参数
+# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
+# Global.save_inference_dir参数设置转换的模型将保存的地址。
+
+python3 tools/export_model.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_rec_train/best_accuracy Global.save_inference_dir=./inference/rec_crnn/
+```
+
+**注意:**如果您是在自己的数据集上训练的模型,并且调整了中文字符的字典文件,请注意修改配置文件中的`character_dict_path`是否是所需要的字典文件。
+
+转换成功后,在目录下有三个文件:
+
+```
+/inference/rec_crnn/
+ ├── inference.pdiparams # 识别inference模型的参数文件
+ ├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
+ └── inference.pdmodel # 识别inference模型的program文件
+```
+
+- 自定义模型推理
+
+ 如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`
+
+ ```
+ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
+ ```
diff --git a/doc/doc_ch/training.md b/doc/doc_ch/training.md
new file mode 100644
index 0000000000000000000000000000000000000000..c6c7b87d9925197b36a246c651ab7179ff9d2e81
--- /dev/null
+++ b/doc/doc_ch/training.md
@@ -0,0 +1,137 @@
+# 模型训练
+
+本文将介绍模型训练时需掌握的基本概念,和训练时的调优方法。
+
+同时会简单介绍PaddleOCR模型训练数据的组成部分,以及如何在垂类场景中准备数据finetune模型。
+
+- [1. 基本概念](#基本概念)
+ * [1.1 学习率](#学习率)
+ * [1.2 正则化](#正则化)
+ * [1.3 评估指标](#评估指标)
+- [2. 数据与垂类场景](#数据与垂类场景)
+ * [2.1 训练数据](#训练数据)
+ * [2.2 垂类场景](#垂类场景)
+ * [2.3 自己构建数据集](#自己构建数据集)
+* [3. 常见问题](#常见问题)
+
+
+## 1. 基本概念
+
+OCR(Optical Character Recognition,光学字符识别)是指对图像进行分析识别处理,获取文字和版面信息的过程,是典型的计算机视觉任务,
+通常由文本检测和文本识别两个子任务构成。
+
+模型调优时需要关注以下参数:
+
+
+### 1.1 学习率
+
+学习率是训练神经网络的重要超参数之一,它代表在每一次迭代中梯度向损失函数最优解移动的步长。
+在PaddleOCR中提供了多种学习率更新策略,可以通过配置文件修改,例如:
+
+```
+Optimizer:
+ ...
+ lr:
+ name: Piecewise
+ decay_epochs : [700, 800]
+ values : [0.001, 0.0001]
+ warmup_epoch: 5
+```
+
+Piecewise 代表分段常数衰减,在不同的学习阶段指定不同的学习率,在每段内学习率相同。
+warmup_epoch 代表在前5个epoch中,学习率将逐渐从0增加到base_lr。全部策略可以参考代码[learning_rate.py](../../ppocr/optimizer/learning_rate.py) 。
+
+
+### 1.2 正则化
+
+正则化可以有效的避免算法过拟合,PaddleOCR中提供了L1、L2正则方法,L1 和 L2 正则化是最常用的正则化方法。L1 正则化向目标函数添加正则化项,以减少参数的绝对值总和;而 L2 正则化中,添加正则化项的目的在于减少参数平方的总和。配置方法如下:
+
+```
+Optimizer:
+ ...
+ regularizer:
+ name: L2
+ factor: 2.0e-05
+```
+
+
+### 1.3 评估指标
+
+(1)检测阶段:先按照检测框和标注框的IOU评估,IOU大于某个阈值判断为检测准确。这里检测框和标注框不同于一般的通用目标检测框,是采用多边形进行表示。检测准确率:正确的检测框个数在全部检测框的占比,主要是判断检测指标。检测召回率:正确的检测框个数在全部标注框的占比,主要是判断漏检的指标。
+
+(2)识别阶段: 字符识别准确率,即正确识别的文本行占标注的文本行数量的比例,只有整行文本识别对才算正确识别。
+
+(3)端到端统计: 端对端召回率:准确检测并正确识别文本行在全部标注文本行的占比; 端到端准确率:准确检测并正确识别文本行在 检测到的文本行数量 的占比; 准确检测的标准是检测框与标注框的IOU大于某个阈值,正确识别的的检测框中的文本与标注的文本相同。
+
+
+
+## 2. 数据与垂类场景
+
+
+### 2.1 训练数据
+目前开源的模型,数据集和量级如下:
+
+ - 检测:
+ - 英文数据集,ICDAR2015
+ - 中文数据集,LSVT街景数据集训练数据3w张图片
+
+ - 识别:
+ - 英文数据集,MJSynth和SynthText合成数据,数据量上千万。
+ - 中文数据集,LSVT街景数据集根据真值将图crop出来,并进行位置校准,总共30w张图像。此外基于LSVT的语料,合成数据500w。
+ - 小语种数据集,使用不同语料和字体,分别生成了100w合成数据集,并使用ICDAR-MLT作为验证集。
+
+其中,公开数据集都是开源的,用户可自行搜索下载,也可参考[中文数据集](./datasets.md),合成数据暂不开源,用户可使用开源合成工具自行合成,可参考的合成工具包括[text_renderer](https://github.com/Sanster/text_renderer) 、[SynthText](https://github.com/ankush-me/SynthText) 、[TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator) 等。
+
+
+### 2.2 垂类场景
+
+PaddleOCR主要聚焦通用OCR,如果有垂类需求,您可以用PaddleOCR+垂类数据自己训练;
+如果缺少带标注的数据,或者不想投入研发成本,建议直接调用开放的API,开放的API覆盖了目前比较常见的一些垂类。
+
+
+### 2.3 自己构建数据集
+
+在构建数据集时有几个经验可供参考:
+
+(1) 训练集的数据量:
+
+ a. 检测需要的数据相对较少,在PaddleOCR模型的基础上进行Fine-tune,一般需要500张可达到不错的效果。
+ b. 识别分英文和中文,一般英文场景需要几十万数据可达到不错的效果,中文则需要几百万甚至更多。
+
+
+(2)当训练数据量少时,可以尝试以下三种方式获取更多的数据:
+
+ a. 人工采集更多的训练数据,最直接也是最有效的方式。
+ b. 基于PIL和opencv基本图像处理或者变换。例如PIL中ImageFont, Image, ImageDraw三个模块将文字写到背景中,opencv的旋转仿射变换,高斯滤波等。
+ c. 利用数据生成算法合成数据,例如pix2pix或StyleText等算法。
+
+
+
+## 3. 常见问题
+
+**Q**:训练CRNN识别时,如何选择合适的网络输入shape?
+
+ A:一般高度采用32,最长宽度的选择,有两种方法:
+
+ (1)统计训练样本图像的宽高比分布。最大宽高比的选取考虑满足80%的训练样本。
+
+ (2)统计训练样本文字数目。最长字符数目的选取考虑满足80%的训练样本。然后中文字符长宽比近似认为是1,英文认为3:1,预估一个最长宽度。
+
+**Q**:识别训练时,训练集精度已经到达90了,但验证集精度一直在70,涨不上去怎么办?
+
+ A:训练集精度90,测试集70多的话,应该是过拟合了,有两个可尝试的方法:
+
+ (1)加入更多的增广方式或者调大增广prob的[概率](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/ppocr/data/imaug/rec_img_aug.py#L341),默认为0.4。
+
+ (2)调大系统的[l2 dcay值](https://github.com/PaddlePaddle/PaddleOCR/blob/a501603d54ff5513fc4fc760319472e59da25424/configs/rec/ch_ppocr_v1.1/rec_chinese_lite_train_v1.1.yml#L47)
+
+**Q**: 识别模型训练时,loss能正常下降,但acc一直为0
+
+ A:识别模型训练初期acc为0是正常的,多训一段时间指标就上来了。
+
+
+***
+具体的训练教程可点击下方链接跳转:
+- [文本检测模型训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/detection.md)
+- [文本识别模型训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/recognition.md)
+- [文本方向分类器训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/angle_class.md)
\ No newline at end of file
diff --git a/doc/doc_ch/update.md b/doc/doc_ch/update.md
index 3fe8a0c9ace4be31882b22fe75b88f18848e1ad9..0852e240886b4ca736a830c8c44651ca35ec1f25 100644
--- a/doc/doc_ch/update.md
+++ b/doc/doc_ch/update.md
@@ -1,4 +1,8 @@
# 更新
+- 2021.9.7 发布PaddleOCR v2.3,发布[PP-OCRv2](#PP-OCRv2),CPU推理速度相比于PP-OCR server提升220%;效果相比于PP-OCR mobile 提升7%。
+- 2021.8.3 发布PaddleOCR v2.2,新增文档结构分析[PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/ppstructure/README_ch.md)工具包,支持版面分析与表格识别(含Excel导出)。
+- 2021.6.29 [FAQ](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/doc/doc_ch/FAQ.md)新增5个高频问题,总数248个,每周一都会更新,欢迎大家持续关注。
+- 2021.4.8 release 2.1版本,新增AAAI 2021论文[端到端识别算法PGNet](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/doc/doc_ch/pgnet.md)开源,[多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/doc/doc_ch/multi_languages.md)支持种类增加到80+。
- 2020.12.15 更新数据合成工具[Style-Text](../../StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。
- 2020.12.07 [FAQ](../../doc/doc_ch/FAQ.md)新增5个高频问题,总数124个,并且计划以后每周一都会更新,欢迎大家持续关注。
- 2020.11.25 更新半自动标注工具[PPOCRLabel](../../PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。
diff --git a/doc/doc_ch/visualization.md b/doc/doc_ch/visualization.md
index f2ea2b09d9431ebd710f2d7ccac0bd73c50b558e..99d071ec22daccaa295b5087760c5fc0d45f9802 100644
--- a/doc/doc_ch/visualization.md
+++ b/doc/doc_ch/visualization.md
@@ -1,7 +1,13 @@
# 效果展示
+
+## 超轻量PP-OCRv2效果展示
+
+
+
+
-## 通用ppocr_server_2.0 效果展示
+## 通用PP-OCR server 效果展示
diff --git a/doc/doc_ch/whl.md b/doc/doc_ch/whl.md
index 167ed7b2b8a13706dfe1533265b6d96560265511..ba5bbae6255382d0c7fa5be319946d6242b1a544 100644
--- a/doc/doc_ch/whl.md
+++ b/doc/doc_ch/whl.md
@@ -210,7 +210,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --use_angle_cls true
```bash
[[[24.0, 36.0], [304.0, 34.0], [304.0, 72.0], [24.0, 74.0]], ['纯臻营养护发素', 0.964739]]
[[[24.0, 80.0], [172.0, 80.0], [172.0, 104.0], [24.0, 104.0]], ['产品信息/参数', 0.98069626]]
-[[[24.0, 109.0], [333.0, 109.0], [333.0, 136.0], [24.0, 136.0]], ['(45元/每公斤,100公斤起订)', 0.9676722]]
+[[[24.0, 109.0], [333.0, 109.0], [333.0, 136.0], [24.0, 136.0]], ['(45元/每公斤,100公斤起订)', 0.9676722]]µ
......
```
diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md
index d70f99bb5c5b0bdcb7d39209dfc9a77c56918260..df8a4ce3ef5fbcadb7ebdfd8ddf2bdf59637783e 100755
--- a/doc/doc_en/algorithm_overview_en.md
+++ b/doc/doc_en/algorithm_overview_en.md
@@ -11,9 +11,10 @@ This tutorial lists the text detection algorithms and text recognition algorithm
### 1. Text Detection Algorithm
PaddleOCR open source text detection algorithms list:
-- [x] EAST([paper](https://arxiv.org/abs/1704.03155))[2]
-- [x] DB([paper](https://arxiv.org/abs/1911.08947))[1]
-- [x] SAST([paper](https://arxiv.org/abs/1908.05498))[4]
+- [x] EAST([paper](https://arxiv.org/abs/1704.03155))
+- [x] DB([paper](https://arxiv.org/abs/1911.08947))
+- [x] SAST([paper](https://arxiv.org/abs/1908.05498))
+- [x] PSE([paper](https://arxiv.org/abs/1903.12473v2))
On the ICDAR2015 dataset, the text detection result is as follows:
@@ -24,6 +25,8 @@ On the ICDAR2015 dataset, the text detection result is as follows:
|DB|ResNet50_vd|86.41%|78.72%|82.38%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)|
|DB|MobileNetV3|77.29%|73.08%|75.12%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_mv3_db_v2.0_train.tar)|
|SAST|ResNet50_vd|91.39%|83.77%|87.42%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)|
+|PSE|ResNet50_vd|85.81%|79.53%|82.55%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_r50_vd_pse_v2.0_train.tar)|
+|PSE|MobileNetV3|82.20%|70.48%|75.89%|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/det_mv3_pse_v2.0_train.tar)|
On Total-Text dataset, the text detection result is as follows:
@@ -41,11 +44,13 @@ For the training guide and use of PaddleOCR text detection algorithms, please re
### 2. Text Recognition Algorithm
PaddleOCR open-source text recognition algorithms list:
-- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))[7]
-- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10]
-- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11]
-- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12]
-- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5]
+- [x] CRNN([paper](https://arxiv.org/abs/1507.05717))
+- [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))
+- [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))
+- [x] RARE([paper](https://arxiv.org/abs/1603.03915v1))
+- [x] SRN([paper](https://arxiv.org/abs/2003.12294))
+- [x] NRTR([paper](https://arxiv.org/abs/1806.00926v2))
+- [x] SAR([paper](https://arxiv.org/abs/1811.00751v2))
Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow:
@@ -60,5 +65,7 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
|RARE|MobileNetV3|82.5%|rec_mv3_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_att_v2.0_train.tar)|
|RARE|Resnet34_vd|83.6%|rec_r34_vd_tps_bilstm_att |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_att_v2.0_train.tar)|
|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)|
+|NRTR|NRTR_MTB| 84.3% | rec_mtb_nrtr | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mtb_nrtr_train.tar) |
+|SAR|Resnet31| 87.2% | rec_r31_sar | [Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.1/rec/rec_r31_sar_train.tar) |
Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md)
diff --git a/doc/doc_en/angle_class_en.md b/doc/doc_en/angle_class_en.md
index 0044d85ac0a43529c67746d25118bd80ee52be9a..46d91bee43de3af99659651b7f31cf1148e7b294 100644
--- a/doc/doc_en/angle_class_en.md
+++ b/doc/doc_en/angle_class_en.md
@@ -1,6 +1,14 @@
-## TEXT ANGLE CLASSIFICATION
+# Text Direction Classification
-### Method introduction
+- [1. Method Introduction](#method-introduction)
+- [2. Data Preparation](#data-preparation)
+- [3. Training](#training)
+- [4. Evaluation](#evaluation)
+- [5. Prediction](#prediction)
+
+
+
+## 1. Method Introduction
The angle classification is used in the scene where the image is not 0 degrees. In this scene, it is necessary to perform a correction operation on the text line detected in the picture. In the PaddleOCR system,
The text line image obtained after text detection is sent to the recognition model after affine transformation. At this time, only a 0 and 180 degree angle classification of the text is required, so the built-in PaddleOCR text angle classifier **only supports 0 and 180 degree classification**. If you want to support more angles, you can modify the algorithm yourself to support.
@@ -9,6 +17,9 @@ Example of 0 and 180 degree data samples:

### DATA PREPARATION
+
+## 2. Data Preparation
+
Please organize the dataset as follows:
The default storage path for training data is `PaddleOCR/train_data/cls`, if you already have a dataset on your disk, just create a soft link to the dataset directory:
@@ -62,8 +73,8 @@ containing all images (test) and a cls_gt_test.txt. The structure of the test se
|- word_003.jpg
| ...
```
-
-### TRAINING
+
+## 3. Training
Write the prepared txt file and image folder path into the configuration file under the `Train/Eval.dataset.label_file_list` and `Train/Eval.dataset.data_dir` fields, the absolute path of the image consists of the `Train/Eval.dataset.data_dir` field and the image name recorded in the txt file.
PaddleOCR provides training scripts, evaluation scripts, and prediction scripts.
@@ -107,7 +118,8 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
**Note that the configuration file for prediction/evaluation must be consistent with the training.**
-### EVALUATION
+
+## 4. Evaluation
The evaluation dataset can be set by modifying the `Eval.dataset.label_file_list` field in the `configs/cls/cls_mv3.yml` file.
@@ -116,6 +128,8 @@ export CUDA_VISIBLE_DEVICES=0
# GPU evaluation, Global.checkpoints is the weight to be tested
python3 tools/eval.py -c configs/cls/cls_mv3.yml -o Global.checkpoints={path/to/weights}/best_accuracy
```
+
+## 5. Prediction
### PREDICTION
diff --git a/doc/doc_en/benchmark_en.md b/doc/doc_en/benchmark_en.md
index 91b015941924add81f8b4f0d9d9ca13274348131..70b33aebd95cfa6e02122c6816cd3863d2b584ab 100755
--- a/doc/doc_en/benchmark_en.md
+++ b/doc/doc_en/benchmark_en.md
@@ -1,8 +1,8 @@
-# BENCHMARK
+# Benchmark
This document gives the performance of the series models for Chinese and English recognition.
-## TEST DATA
+## Test Data
We collected 300 images for different real application scenarios to evaluate the overall OCR system, including contract samples, license plates, nameplates, train tickets, test sheets, forms, certificates, street view images, business cards, digital meter, etc. The following figure shows some images of the test set.
@@ -10,10 +10,9 @@ We collected 300 images for different real application scenarios to evaluate the
-## MEASUREMENT
+## Measurement
Explanation:
-- v1.0 indicates DB+CRNN models without the strategies. v1.1 indicates the PP-OCR models with the strategies and the direction classify. slim_v1.1 indicates the PP-OCR models with prunner or quantization.
- The long size of the input for the text detector is 960.
@@ -27,30 +26,16 @@ Compares the model size and F-score:
| Model Name | Model Size
of the
Whole System\(M\) | Model Size
of the Text
Detector\(M\) | Model Size
of the Direction
Classifier\(M\) | Model Size
of the Text
Recognizer \(M\) | F\-score |
|:-:|:-:|:-:|:-:|:-:|:-:|
-| ch\_ppocr\_mobile\_v1\.1 | 8\.1 | 2\.6 | 0\.9 | 4\.6 | 0\.5193 |
-| ch\_ppocr\_server\_v1\.1 | 155\.1 | 47\.2 | 0\.9 | 107 | 0\.5414 |
-| ch\_ppocr\_mobile\_v1\.0 | 8\.6 | 4\.1 | \- | 4\.5 | 0\.393 |
-| ch\_ppocr\_server\_v1\.0 | 203\.8 | 98\.5 | \- | 105\.3 | 0\.4436 |
+| PP-OCRv2 | 11\.6 | 3\.0 | 0\.9 | 8\.6 | 0\.5224 |
+| PP-OCR mobile | 8\.1 | 2\.6 | 0\.9 | 4\.6 | 0\.503 |
+| PP-OCR server | 155\.1 | 47\.2 | 0\.9 | 107 | 0\.570 |
-Compares the time-consuming on T4 GPU (ms):
+Compares the time-consuming on CPU and T4 GPU (ms):
-| Model Name | Overall | Text Detector | Direction Classifier | Text Recognizer |
-|:-:|:-:|:-:|:-:|:-:|
-| ch\_ppocr\_mobile\_v1\.1 | 137 | 35 | 24 | 78 |
-| ch\_ppocr\_server\_v1\.1 | 204 | 39 | 25 | 140 |
-| ch\_ppocr\_mobile\_v1\.0 | 117 | 41 | \- | 76 |
-| ch\_ppocr\_server\_v1\.0 | 199 | 52 | \- | 147 |
+| Model Name | CPU | T4 GPU |
+|:-:|:-:|:-:|
+| PP-OCRv2 | 330 | 111 |
+| PP-OCR mobile | 356 | 116|
+| PP-OCR server | 1056 | 200 |
-Compares the time-consuming on CPU (ms):
-
-| Model Name | Overall | Text Detector | Direction Classifier | Text Recognizer |
-|:-:|:-:|:-:|:-:|:-:|
-| ch\_ppocr\_mobile\_v1\.1 | 421 | 164 | 51 | 206 |
-| ch\_ppocr\_mobile\_v1\.0 | 398 | 219 | \- | 179 |
-
-Compares the model size, F-score, the time-consuming on SD 855 of between the slim models and the original models:
-
-| Model Name | Model Size
of the
Whole System\(M\) | Model Size
of the Text
Detector\(M\) | Model Size
of the Direction
Classifier\(M\) | Model Size
of the Text
Recognizer \(M\) | F\-score | SD 855
\(ms\) |
-|:-:|:-:|:-:|:-:|:-:|:-:|:-:|
-| ch\_ppocr\_mobile\_v1\.1 | 8\.1 | 2\.6 | 0\.9 | 4\.6 | 0\.5193 | 306 |
-| ch\_ppocr\_mobile\_slim\_v1\.1 | 3\.5 | 1\.4 | 0\.5 | 1\.6 | 0\.521 | 268 |
+More indicators of PP-OCR series models can be referred to [PP-OCR Benchmark](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/doc/doc_en/benchmark_en.md)
diff --git a/doc/doc_en/config_en.md b/doc/doc_en/config_en.md
index 5e5847c4b298553b2d376b90196b61b7e0286efe..ce76da9b2f39532b387e3e45ca2ff497b0408635 100644
--- a/doc/doc_en/config_en.md
+++ b/doc/doc_en/config_en.md
@@ -1,4 +1,12 @@
-## Optional parameter list
+# Configuration
+
+- [1. Optional Parameter List](#1-optional-parameter-list)
+- [2. Intorduction to Global Parameters of Configuration File](#2-intorduction-to-global-parameters-of-configuration-file)
+- [3. Multilingual Config File Generation](#3-multilingual-config-file-generation)
+
+
+
+## 1. Optional Parameter List
The following list can be viewed through `--help`
@@ -7,7 +15,9 @@ The following list can be viewed through `--help`
| -c | ALL | Specify configuration file to use | None | **Please refer to the parameter introduction for configuration file usage** |
| -o | ALL | set configuration options | None | Configuration using -o has higher priority than the configuration file selected with -c. E.g: -o Global.use_gpu=false |
-## INTRODUCTION TO GLOBAL PARAMETERS OF CONFIGURATION FILE
+
+
+## 2. Intorduction to Global Parameters of Configuration File
Take rec_chinese_lite_train_v2.0.yml as an example
### Global
@@ -27,9 +37,8 @@ Take rec_chinese_lite_train_v2.0.yml as an example
| checkpoints | set model parameter path | None | Used to load parameters after interruption to continue training|
| use_visualdl | Set whether to enable visualdl for visual log display | False | [Tutorial](https://www.paddlepaddle.org.cn/paddle/visualdl) |
| infer_img | Set inference image path or folder path | ./infer_img | \|
-| character_dict_path | Set dictionary path | ./ppocr/utils/ppocr_keys_v1.txt | \ |
+| character_dict_path | Set dictionary path | ./ppocr/utils/ppocr_keys_v1.txt | If the character_dict_path is None, model can only recognize number and lower letters |
| max_text_length | Set the maximum length of text | 25 | \ |
-| character_type | Set character type | ch | en/ch, the default dict will be used for en, and the custom dict will be used for ch |
| use_space_char | Set whether to recognize spaces | True | Only support in character_type=ch mode |
| label_list | Set the angle supported by the direction classifier | ['0','180'] | Only valid in angle classifier model |
| save_res_path | Set the save address of the test model results | ./output/det_db/predicts_db.txt | Only valid in the text detection model |
@@ -51,7 +60,7 @@ Take rec_chinese_lite_train_v2.0.yml as an example
### Architecture ([ppocr/modeling](../../ppocr/modeling))
-In ppocr, the network is divided into four stages: Transform, Backbone, Neck and Head
+In PaddleOCR, the network is divided into four stages: Transform, Backbone, Neck and Head
| Parameter | Use | Defaults | Note |
| :---------------------: | :---------------------: | :--------------: | :--------------------: |
@@ -120,3 +129,108 @@ In ppocr, the network is divided into four stages: Transform, Backbone, Neck and
| batch_size_per_card | Single card batch size during training | 256 | \ |
| drop_last | Whether to discard the last incomplete mini-batch because the number of samples in the data set cannot be divisible by batch_size | True | \ |
| num_workers | The number of sub-processes used to load data, if it is 0, the sub-process is not started, and the data is loaded in the main process | 8 | \ |
+
+
+
+## 3. Multilingual Config File Generation
+
+PaddleOCR currently supports 80 (except Chinese) language recognition. A multi-language configuration file template is
+provided under the path `configs/rec/multi_languages`: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)。
+
+There are two ways to create the required configuration file::
+
+1. Automatically generated by script
+
+[generate_multi_language_configs.py](../../configs/rec/multi_language/generate_multi_language_configs.py) Can help you generate configuration files for multi-language models
+
+- Take Italian as an example, if your data is prepared in the following format:
+ ```
+ |-train_data
+ |- it_train.txt # train_set label
+ |- it_val.txt # val_set label
+ |- data
+ |- word_001.jpg
+ |- word_002.jpg
+ |- word_003.jpg
+ | ...
+ ```
+
+ You can use the default parameters to generate a configuration file:
+
+ ```bash
+ # The code needs to be run in the specified directory
+ cd PaddleOCR/configs/rec/multi_language/
+ # Set the configuration file of the language to be generated through the -l or --language parameter.
+ # This command will write the default parameters into the configuration file
+ python3 generate_multi_language_configs.py -l it
+ ```
+
+- If your data is placed in another location, or you want to use your own dictionary, you can generate the configuration file by specifying the relevant parameters:
+
+ ```bash
+ # -l or --language field is required
+ # --train to modify the training set
+ # --val to modify the validation set
+ # --data_dir to modify the data set directory
+ # --dict to modify the dict path
+ # -o to modify the corresponding default parameters
+ cd PaddleOCR/configs/rec/multi_language/
+ python3 generate_multi_language_configs.py -l it \ # language
+ --train {path/of/train_label.txt} \ # path of train_label
+ --val {path/of/val_label.txt} \ # path of val_label
+ --data_dir {train_data/path} \ # root directory of training data
+ --dict {path/of/dict} \ # path of dict
+ -o Global.use_gpu=False # whether to use gpu
+ ...
+
+ ```
+Italian is made up of Latin letters, so after executing the command, you will get the rec_latin_lite_train.yml.
+
+2. Manually modify the configuration file
+
+ You can also manually modify the following fields in the template:
+
+ ```
+ Global:
+ use_gpu: True
+ epoch_num: 500
+ ...
+ character_dict_path: {path/of/dict} # path of dict
+
+ Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: train_data/ # root directory of training data
+ label_file_list: ["./train_data/train_list.txt"] # train label path
+ ...
+
+ Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: train_data/ # root directory of val data
+ label_file_list: ["./train_data/val_list.txt"] # val label path
+ ...
+
+ ```
+
+
+Currently, the multi-language algorithms supported by PaddleOCR are:
+
+| Configuration file | Algorithm name | backbone | trans | seq | pred | language |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional |
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean |
+| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin |
+| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic |
+| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic |
+| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari |
+
+For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
+
+The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded using the following two methods.
+* [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
+* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
diff --git a/doc/doc_en/detection_en.md b/doc/doc_en/detection_en.md
index b736beb55d79db02bf4d4301a74c685537fce249..df96fd5336cd64049e8f5d9b898f60c55b82b7b4 100644
--- a/doc/doc_en/detection_en.md
+++ b/doc/doc_en/detection_en.md
@@ -1,9 +1,32 @@
-# TEXT DETECTION
+# Text Detection
This section uses the icdar2015 dataset as an example to introduce the training, evaluation, and testing of the detection model in PaddleOCR.
-## DATA PREPARATION
-The icdar2015 dataset can be obtained from [official website](https://rrc.cvc.uab.es/?ch=4&com=downloads). Registration is required for downloading.
+- [1. Data and Weights Preparation](#1-data-and-weights-preparatio)
+ * [1.1 Data Preparation](#11-data-preparation)
+ * [1.2 Download Pretrained Model](#12-download-pretrained-model)
+- [2. Training](#2-training)
+ * [2.1 Start Training](#21-start-training)
+ * [2.2 Load Trained Model and Continue Training](#22-load-trained-model-and-continue-training)
+ * [2.3 Training with New Backbone](#23-training-with-new-backbone)
+- [3. Evaluation and Test](#3-evaluation-and-test)
+ * [3.1 Evaluation](#31-evaluation)
+ * [3.2 Test](#32-test)
+- [4. Inference](#4-inference)
+- [5. FAQ](#2-faq)
+
+## 1. Data and Weights Preparation
+
+### 1.1 Data Preparation
+
+The icdar2015 dataset contains train set which has 1000 images obtained with wearable cameras and test set which has 500 images obtained with wearable cameras. The icdar2015 can be obtained from [official website](https://rrc.cvc.uab.es/?ch=4&com=downloads). Registration is required for downloading.
+
+
+After registering and logging in, download the part marked in the red box in the figure below. And, the content downloaded by `Training Set Images` should be saved as the folder `icdar_c4_train_imgs`, and the content downloaded by `Test Set Images` is saved as the folder `ch4_test_images`
+
+
+
+
Decompress the downloaded dataset to the working directory, assuming it is decompressed under PaddleOCR/train_data/. In addition, PaddleOCR organizes many scattered annotation files into two separate annotation files for train and test respectively, which can be downloaded by wget:
```shell
@@ -36,10 +59,11 @@ The `points` in the dictionary represent the coordinates (x, y) of the four poin
If you want to train PaddleOCR on other datasets, please build the annotation file according to the above format.
-## TRAINING
+### 1.2 Download Pretrained Model
+
+First download the pretrained model. The detection model of PaddleOCR currently supports 3 backbones, namely MobileNetV3, ResNet18_vd and ResNet50_vd. You can use the model in [PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/release/2.0/ppcls/modeling/architectures) to replace backbone according to your needs.
+And the responding download link of backbone pretrain weights can be found in (https://github.com/PaddlePaddle/PaddleClas/blob/release%2F2.0/README_cn.md#resnet%E5%8F%8A%E5%85%B6vd%E7%B3%BB%E5%88%97).
-First download the pretrained model. The detection model of PaddleOCR currently supports 3 backbones, namely MobileNetV3, ResNet18_vd and ResNet50_vd. You can use the model in [PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/develop/ppcls/modeling/architectures) to replace backbone according to your needs.
-And the responding download link of backbone pretrain weights can be found in [PaddleClas repo](https://github.com/PaddlePaddle/PaddleClas#mobile-series).
```shell
cd PaddleOCR/
# Download the pre-trained model of MobileNetV3
@@ -49,11 +73,16 @@ wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dyg
# or, download the pre-trained model of ResNet50_vd
wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/ResNet50_vd_ssld_pretrained.pdparams
+```
+
+## 2. Training
+
+### 2.1 Start Training
-#### START TRAINING
*If CPU version installed, please set the parameter `use_gpu` to `false` in the configuration.*
```shell
-python3 tools/train.py -c configs/det/det_mv3_db.yml
+python3 tools/train.py -c configs/det/det_mv3_db.yml \
+ -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
```
In the above instruction, use `-c` to select the training to use the `configs/det/det_db_mv3.yml` configuration file.
@@ -62,16 +91,17 @@ For a detailed explanation of the configuration file, please refer to [config](.
You can also use `-o` to change the training parameters without modifying the yml file. For example, adjust the training learning rate to 0.0001
```shell
# single GPU training
-python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
+python3 tools/train.py -c configs/det/det_mv3_db.yml -o \
+ Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained \
+ Optimizer.base_lr=0.0001
# multi-GPU training
# Set the GPU ID used by the '--gpus' parameter.
-python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
-
+python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
```
-#### load trained model and continue training
+### 2.2 Load Trained Model and Continue Training
If you expect to load trained model and continue the training again, you can specify the parameter `Global.checkpoints` as the model path to be loaded.
For example:
@@ -79,12 +109,64 @@ For example:
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./your/trained/model
```
-**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrain_weights`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrain_weights` will be loaded.
+**Note**: The priority of `Global.checkpoints` is higher than that of `Global.pretrained_model`, that is, when two parameters are specified at the same time, the model specified by `Global.checkpoints` will be loaded first. If the model path specified by `Global.checkpoints` is wrong, the one specified by `Global.pretrained_model` will be loaded.
+
+### 2.3 Training with New Backbone
+
+The network part completes the construction of the network, and PaddleOCR divides the network into four parts, which are under [ppocr/modeling](../../ppocr/modeling). The data entering the network will pass through these four parts in sequence(transforms->backbones->
+necks->heads).
+
+```bash
+├── architectures # Code for building network
+├── transforms # Image Transformation Module
+├── backbones # Feature extraction module
+├── necks # Feature enhancement module
+└── heads # Output module
+```
+
+If the Backbone to be replaced has a corresponding implementation in PaddleOCR, you can directly modify the parameters in the `Backbone` part of the configuration yml file.
+
+However, if you want to use a new Backbone, an example of replacing the backbones is as follows:
+
+1. Create a new file under the [ppocr/modeling/backbones](../../ppocr/modeling/backbones) folder, such as my_backbone.py.
+2. Add code in the my_backbone.py file, the sample code is as follows:
+
+```python
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class MyBackbone(nn.Layer):
+ def __init__(self, *args, **kwargs):
+ super(MyBackbone, self).__init__()
+ # your init code
+ self.conv = nn.xxxx
+
+ def forward(self, inputs):
+ # your network forward
+ y = self.conv(inputs)
+ return y
+```
-## EVALUATION
+3. Import the added module in the [ppocr/modeling/backbones/\__init\__.py](../../ppocr/modeling/backbones/__init__.py) file.
-PaddleOCR calculates three indicators for evaluating performance of OCR detection task: Precision, Recall, and Hmean.
+After adding the four-part modules of the network, you only need to configure them in the configuration file to use, such as:
+
+```yaml
+ Backbone:
+ name: MyBackbone
+ args1: args1
+```
+
+**NOTE**: More details about replace Backbone and other mudule can be found in [doc](add_new_algorithm_en.md).
+
+## 3. Evaluation and Test
+
+### 3.1 Evaluation
+
+PaddleOCR calculates three indicators for evaluating performance of OCR detection task: Precision, Recall, and Hmean(F-Score).
Run the following code to calculate the evaluation indicators. The result will be saved in the test result file specified by `save_res_path` in the configuration file `det_db_mv3.yml`
@@ -95,10 +177,9 @@ The model parameters during training are saved in the `Global.save_model_dir` di
python3 tools/eval.py -c configs/det/det_mv3_db.yml -o Global.checkpoints="{path/to/weights}/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
```
+* Note: `box_thresh` and `unclip_ratio` are parameters required for DB post-processing, and not need to be set when evaluating the EAST and SAST model.
-* Note: `box_thresh` and `unclip_ratio` are parameters required for DB post-processing, and not need to be set when evaluating the EAST model.
-
-## TEST
+### 3.2 Test
Test the detection result on a single image:
```shell
@@ -107,7 +188,7 @@ python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./
When testing the DB model, adjust the post-processing threshold:
```shell
-python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5
+python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy" PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=2.0
```
@@ -115,3 +196,33 @@ Test the detection result on all images in the folder:
```shell
python3 tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy"
```
+
+## 4. Inference
+
+The inference model (the model saved by `paddle.jit.save`) is generally a solidified model saved after the model training is completed, and is mostly used to give prediction in deployment.
+
+The model saved during the training process is the checkpoints model, which saves the parameters of the model and is mostly used to resume training.
+
+Compared with the checkpoints model, the inference model will additionally save the structural information of the model. Therefore, it is easier to deploy because the model structure and model parameters are already solidified in the inference model file, and is suitable for integration with actual systems.
+
+Firstly, we can convert DB trained model to inference model:
+```shell
+python3 tools/export_model.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model="./output/det_db/best_accuracy" Global.save_inference_dir="./output/det_db_inference/"
+```
+
+The detection inference model prediction:
+```shell
+python3 tools/infer/predict_det.py --det_algorithm="DB" --det_model_dir="./output/det_db_inference/" --image_dir="./doc/imgs/" --use_gpu=True
+```
+
+If it is other detection algorithms, such as the EAST, the det_algorithm parameter needs to be modified to EAST, and the default is the DB algorithm:
+```shell
+python3 tools/infer/predict_det.py --det_algorithm="EAST" --det_model_dir="./output/det_db_inference/" --image_dir="./doc/imgs/" --use_gpu=True
+```
+
+## 5. FAQ
+
+Q1: The prediction results of trained model and inference model are inconsistent?
+**A**: Most of the problems are caused by the inconsistency of the pre-processing and post-processing parameters during the prediction of the trained model and the pre-processing and post-processing parameters during the prediction of the inference model. Taking the model trained by the det_mv3_db.yml configuration file as an example, the solution to the problem of inconsistent prediction results between the training model and the inference model is as follows:
+- Check whether the [trained model preprocessing](https://github.com/PaddlePaddle/PaddleOCR/blob/c1ed243fb68d5d466258243092e56cbae32e2c14/configs/det/det_mv3_db.yml#L116) is consistent with the prediction [preprocessing function of the inference model](https://github.com/PaddlePaddle/PaddleOCR/blob/c1ed243fb68d5d466258243092e56cbae32e2c14/tools/infer/predict_det.py#L42). When the algorithm is evaluated, the input image size will affect the accuracy. In order to be consistent with the paper, the image is resized to [736, 1280] in the training icdar15 configuration file, but there is only a set of default parameters when the inference model predicts, which will be considered To predict the speed problem, the longest side of the image is limited to 960 for resize by default. The preprocessing function of the training model preprocessing and the inference model is located in [ppocr/data/imaug/operators.py](https://github.com/PaddlePaddle/PaddleOCR/blob/c1ed243fb68d5d466258243092e56cbae32e2c14/ppocr/data/imaug/operators.py#L147)
+- Check whether the [post-processing of the trained model](https://github.com/PaddlePaddle/PaddleOCR/blob/c1ed243fb68d5d466258243092e56cbae32e2c14/configs/det/det_mv3_db.yml#L51) is consistent with the [post-processing parameters of the inference](https://github.com/PaddlePaddle/PaddleOCR/blob/c1ed243fb68d5d466258243092e56cbae32e2c14/tools/infer/utility.py#L50).
diff --git a/doc/doc_en/environment_en.md b/doc/doc_en/environment_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..9aad92cafb809a0eec519808b1c1755403b39318
--- /dev/null
+++ b/doc/doc_en/environment_en.md
@@ -0,0 +1,348 @@
+# Environment Preparation
+
+Recommended working environment:
+- PaddlePaddle >= 2.0.0 (2.1.2)
+- python3.7
+- CUDA10.1 / CUDA10.2
+- CUDNN 7.6
+
+* [1. Python Environment Setup](#1)
+ + [1.1 Windows](#1.1)
+ + [1.2 Mac](#1.2)
+ + [1.3 Linux](#1.3)
+* [2. Install PaddlePaddle 2.0](#2)
+
+
+
+
+## 1. Python Environment Setup
+
+
+
+### 1.1 Windows
+
+#### 1.1.1 Install Anaconda
+
+- Note: To use paddlepaddle you need to install python environment first, here we choose python integrated environment Anaconda toolkit
+
+ - Anaconda is a common python package manager
+ - After installing Anaconda, you can install the python environment, as well as numpy and other required toolkit environment.
+
+- Anaconda download.
+
+ - Address: https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/?C=M&O=D
+
+ - Most Win10 computers are 64-bit operating systems, choose x86_64 version; if the computer is a 32-bit operating system, choose x86.exe
+
+
+
+ - After the download is complete, double-click the installer to enter the graphical interface
+
+ - The default installation location is C drive, it is recommended to change the installation location to D drive.
+
+
+
+ - Check conda to add environment variables and ignore the warning that
+
+
+
+
+#### 1.1.2 Opening the terminal and creating the conda environment
+
+- Open Anaconda Prompt terminal: bottom left Windows Start Menu -> Anaconda3 -> Anaconda Prompt start console
+
+
+
+
+- Create a new conda environment
+
+ ```shell
+ # Enter the following command at the command line to create an environment named paddle_env
+ # Here to speed up the download, use the Tsinghua source
+ conda create --name paddle_env python=3.8 --channel https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ # This is a one line command
+ ```
+
+ This command will create an executable environment named paddle_env with python version 3.8, which will take a while depending on the network status
+
+ The command line will then output a prompt, type y and enter to continue the installation
+
+
+
+- To activate the conda environment you just created, enter the following command at the command line.
+
+ ```shell
+ # Activate the paddle_env environment
+ conda activate paddle_env
+ # View the current location of python
+ where python
+ ```
+
+
+
+The above anaconda environment and python environment are installed
+
+
+
+
+
+### 1.2 Mac
+
+#### 1.2.1 Installing Anaconda
+
+- Note: To use paddlepaddle you need to install the python environment first, here we choose the python integrated environment Anaconda toolkit
+
+ - Anaconda is a common python package manager
+ - After installing Anaconda, you can install the python environment, as well as numpy and other required toolkit environment
+
+- Anaconda download:.
+
+ - Address: https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/?C=M&O=D
+
+
+
+ - Select `Anaconda3-2021.05-MacOSX-x86_64.pkg` at the bottom to download
+
+- After downloading, double click on the .pkg file to enter the graphical interface
+
+ - Just follow the default settings, it will take a while to install
+
+- It is recommended to install a code editor such as vscode or pycharm
+
+#### 1.2.2 Open a terminal and create a conda environment
+
+- Open the terminal
+
+ - Press command and spacebar at the same time, type "terminal" in the focus search, double click to enter terminal
+
+- **Add conda to the environment variables**
+
+ - Environment variables are added so that the system can recognize the conda command
+
+ - Open `~/.bash_profile` in the terminal by typing the following command.
+
+ ```shell
+ vim ~/.bash_profile
+ ```
+
+ - Add conda as an environment variable in `~/.bash_profile`.
+
+ ```shell
+ # Press i first to enter edit mode
+ # In the first line type.
+ export PATH="~/opt/anaconda3/bin:$PATH"
+ # If you customized the installation location during installation, change ~/opt/anaconda3/bin to the bin folder in the customized installation directory
+ ```
+
+ ```shell
+ # The modified ~/.bash_profile file should look like this (where xxx is the username)
+ export PATH="~/opt/anaconda3/bin:$PATH"
+ # >>> conda initialize >>>
+ # !!! Contents within this block are managed by 'conda init' !!!
+ __conda_setup="$('/Users/xxx/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
+ if [ $? -eq 0 ]; then
+ eval "$__conda_setup"
+ else
+ if [ -f "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh" ]; then
+ . "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh"
+ else
+ export PATH="/Users/xxx/opt/anaconda3/bin:$PATH"
+ fi
+ fi
+ unset __conda_setup
+ # <<< conda initialize <<<
+ ```
+
+ - When you are done, press `esc` to exit edit mode, then type `:wq!` and enter to save and exit
+
+ - Verify that the conda command is recognized.
+
+ - Enter `source ~/.bash_profile` in the terminal to update the environment variables
+ - Enter `conda info --envs` in the terminal again, if it shows that there is a base environment, then conda has been added to the environment variables
+
+- Create a new conda environment
+
+ ```shell
+ # Enter the following command at the command line to create an environment called paddle_env
+ # Here to speed up the download, use Tsinghua source
+ conda create --name paddle_env python=3.8 --channel https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
+ ```
+
+ - This command will create an executable environment named paddle_env with python version 3.8, which will take a while depending on the network status
+
+ - The command line will then output a prompt, type y and enter to continue the installation
+
+ -
+
+- To activate the conda environment you just created, enter the following command at the command line.
+
+ ```shell
+ # Activate the paddle_env environment
+ conda activate paddle_env
+ # View the current location of python
+ where python
+ ```
+
+
+
+The above anaconda environment and python environment are installed
+
+
+
+
+
+### 1.3 Linux
+
+Linux users can choose to run either Anaconda or Docker. If you are familiar with Docker and need to train the PaddleOCR model, it is recommended to use the Docker environment, where the development process of PaddleOCR is run. If you are not familiar with Docker, you can also use Anaconda to run the project.
+
+#### 1.3.1 Anaconda environment configuration
+
+- Note: To use paddlepaddle you need to install the python environment first, here we choose the python integrated environment Anaconda toolkit
+
+ - Anaconda is a common python package manager
+ - After installing Anaconda, you can install the python environment, as well as numpy and other required toolkit environment
+
+- **Download Anaconda**.
+
+ - Download at: https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/?C=M&O=D
+
+
+
+
+
+ - Select the appropriate version for your operating system
+ - Type `uname -m` in the terminal to check the command set used by your system
+
+ - Download method 1: Download locally, then transfer the installation package to the linux server
+
+ - Download method 2: Directly use linux command line to download
+
+ ```shell
+ # First install wget
+ sudo apt-get install wget # Ubuntu
+ sudo yum install wget # CentOS
+ ```
+ ```bash
+ # Then use wget to download from Tsinghua source
+ # If you want to download Anaconda3-2021.05-Linux-x86_64.sh, the download command is as follows
+ wget https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2021.05-Linux-x86_64.sh
+ # If you want to download another version, you need to change the file name after the last 1 / to the version you want to download
+ ```
+
+- To install Anaconda.
+
+ - Type `sh Anaconda3-2021.05-Linux-x86_64.sh` at the command line
+ - If you downloaded a different version, replace the file name of the command with the name of the file you downloaded
+ - Just follow the installation instructions
+ - You can exit by typing q when viewing the license
+
+- **Add conda to the environment variables**
+
+ - If you have already added conda to the environment variable path during the installation, you can skip this step
+
+ - Open `~/.bashrc` in a terminal.
+
+ ```shell
+ # Enter the following command in the terminal.
+ vim ~/.bashrc
+ ```
+
+ - Add conda as an environment variable in `~/.bashrc`.
+
+ ```shell
+ # Press i first to enter edit mode # In the first line enter.
+ export PATH="~/anaconda3/bin:$PATH"
+ # If you customized the installation location during installation, change ~/anaconda3/bin to the bin folder in the customized installation directory
+ ```
+
+ ```shell
+ # The modified ~/.bash_profile file should look like this (where xxx is the username)
+ export PATH="~/opt/anaconda3/bin:$PATH"
+ # >>> conda initialize >>>
+ # !!! Contents within this block are managed by 'conda init' !!!
+ __conda_setup="$('/Users/xxx/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
+ if [ $? -eq 0 ]; then
+ eval "$__conda_setup"
+ else
+ if [ -f "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh" ]; then
+ . "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh"
+ else
+ export PATH="/Users/xxx/opt/anaconda3/bin:$PATH"
+ fi
+ fi
+ unset __conda_setup
+ # <<< conda initialize <<<
+ ```
+
+ - When you are done, press `esc` to exit edit mode, then type `:wq!` and enter to save and exit
+
+ - Verify that the conda command is recognized.
+
+ - Enter `source ~/.bash_profile` in the terminal to update the environment variables
+ - Enter `conda info --envs` in the terminal again, if it shows that there is a base environment, then conda has been added to the environment variables
+
+- Create a new conda environment
+
+ ```shell
+ # Enter the following command at the command line to create an environment called paddle_env
+ # Here to speed up the download, use Tsinghua source
+ conda create --name paddle_env python=3.8 --channel https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
+ ```
+
+ - This command will create an executable environment named paddle_env with python version 3.8, which will take a while depending on the network status
+
+ - The command line will then output a prompt, type y and enter to continue the installation
+
+
+
+- To activate the conda environment you just created, enter the following command at the command line.
+
+ ```shell
+ # Activate the paddle_env environment
+ conda activate paddle_env
+ ```
+
+The above anaconda environment and python environment are installed
+
+
+#### 1.3.2 Docker environment preparation
+
+**The first time you use this docker image, it will be downloaded automatically. Please be patient.**
+
+```bash
+# Switch to the working directory
+cd /home/Projects
+# You need to create a docker container for the first run, and do not need to run the current command when you run it again
+# Create a docker container named ppocr and map the current directory to the /paddle directory of the container
+
+# If using CPU, use docker instead of nvidia-docker to create docker
+sudo docker run --name ppocr -v $PWD:/paddle --network=host -it registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda10.2-cudnn7 /bin/bash
+
+# If using GPU, use nvidia-docker to create docker
+# docker image registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda11.2-cudnn8 is recommended for CUDA11.2 + CUDNN8.
+sudo nvidia-docker run --name ppocr -v $PWD:/paddle --shm-size=64G --network=host -it registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda10.2-cudnn7 /bin/bash
+
+```
+You can also visit [DockerHub](https://hub.docker.com/r/paddlepaddle/paddle/tags/) to get the image that fits your machine.
+
+```
+# ctrl+P+Q to exit docker, to re-enter docker using the following command:
+sudo docker container exec -it ppocr /bin/bash
+```
+
+
+
+## 2. Install PaddlePaddle 2.0
+
+- If you have cuda9 or cuda10 installed on your machine, please run the following command to install
+
+```bash
+python3 -m pip install paddlepaddle-gpu -i https://mirror.baidu.com/pypi/simple
+```
+
+- If you only have cpu on your machine, please run the following command to install
+
+```bash
+python3 -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple
+```
+
+For more software version requirements, please refer to the instructions in [Installation Document](https://www.paddlepaddle.org.cn/install/quick) for operation.
diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md
index e30355fb8e29031bd4ce040a86ad0f57d18ce398..019ac4d0ac15aceed89286048d2c4d88a259e501 100755
--- a/doc/doc_en/inference_en.md
+++ b/doc/doc_en/inference_en.md
@@ -1,5 +1,5 @@
-# Reasoning based on Python prediction engine
+# Inference Based on Python Prediction Engine
The inference model (the model saved by `paddle.jit.save`) is generally a solidified model saved after the model training is completed, and is mostly used to give prediction in deployment.
@@ -10,37 +10,36 @@ For more details, please refer to the document [Classification Framework](https:
Next, we first introduce how to convert a trained model into an inference model, and then we will introduce text detection, text recognition, angle class, and the concatenation of them based on inference model.
-- [CONVERT TRAINING MODEL TO INFERENCE MODEL](#CONVERT)
- - [Convert detection model to inference model](#Convert_detection_model)
- - [Convert recognition model to inference model](#Convert_recognition_model)
- - [Convert angle classification model to inference model](#Convert_angle_class_model)
+- [1. Convert Training Model to Inference Model](#CONVERT)
+ - [1.1 Convert Detection Model to Inference Model](#Convert_detection_model)
+ - [1.2 Convert Recognition Model to Inference Model](#Convert_recognition_model)
+ - [1.3 Convert Angle Classification Model to Inference Model](#Convert_angle_class_model)
-- [TEXT DETECTION MODEL INFERENCE](#DETECTION_MODEL_INFERENCE)
- - [1. LIGHTWEIGHT CHINESE DETECTION MODEL INFERENCE](#LIGHTWEIGHT_DETECTION)
- - [2. DB TEXT DETECTION MODEL INFERENCE](#DB_DETECTION)
- - [3. EAST TEXT DETECTION MODEL INFERENCE](#EAST_DETECTION)
- - [4. SAST TEXT DETECTION MODEL INFERENCE](#SAST_DETECTION)
- - [5. Multilingual model inference](#Multilingual model inference)
+- [2. Text Detection Model Inference](#DETECTION_MODEL_INFERENCE)
+ - [2.1 Lightweight Chinese Detection Model Inference](#LIGHTWEIGHT_DETECTION)
+ - [2.2 DB Text Detection Model Inference](#DB_DETECTION)
+ - [2.3 East Text Detection Model Inference](#EAST_DETECTION)
+ - [2.4 Sast Text Detection Model Inference](#SAST_DETECTION)
-- [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE)
- - [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION)
- - [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION)
- - [3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE](#SRN-BASED_RECOGNITION)
- - [3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS)
- - [4. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE)
+- [3. Text Recognition Model Inference](#RECOGNITION_MODEL_INFERENCE)
+ - [3.1 Lightweight Chinese Text Recognition Model Reference](#LIGHTWEIGHT_RECOGNITION)
+ - [3.2 CTC-Based Text Recognition Model Inference](#CTC-BASED_RECOGNITION)
+ - [3.3 SRN-Based Text Recognition Model Inference](#SRN-BASED_RECOGNITION)
+ - [3.4 Text Recognition Model Inference Using Custom Characters Dictionary](#USING_CUSTOM_CHARACTERS)
+ - [3.5 Multilingual Model Inference](#MULTILINGUAL_MODEL_INFERENCE)
-- [ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
- - [1. ANGLE CLASSIFICATION MODEL INFERENCE](#ANGLE_CLASS_MODEL_INFERENCE)
+- [4. Angle Classification Model Inference](#ANGLE_CLASS_MODEL_INFERENCE)
-- [TEXT DETECTION ANGLE CLASSIFICATION AND RECOGNITION INFERENCE CONCATENATION](#CONCATENATION)
- - [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_CHINESE_MODEL)
- - [2. OTHER MODELS](#OTHER_MODELS)
+- [5. Text Detection Angle Classification And Recognition Inference Concatenation](#CONCATENATION)
+ - [5.1 Lightweight Chinese Model](#LIGHTWEIGHT_CHINESE_MODEL)
+ - [5.2 Other Models](#OTHER_MODELS)
-## CONVERT TRAINING MODEL TO INFERENCE MODEL
+## 1. Convert Training Model to Inference Model
-### Convert detection model to inference model
+
+### 1.1 Convert Detection Model to Inference Model
Download the lightweight Chinese detection model:
```
@@ -67,7 +66,7 @@ inference/det_db/
```
-### Convert recognition model to inference model
+### 1.2 Convert Recognition Model to Inference Model
Download the lightweight Chinese recognition model:
```
@@ -95,7 +94,7 @@ inference/det_db/
```
-### Convert angle classification model to inference model
+### 1.3 Convert Angle Classification Model to Inference Model
Download the angle classification model:
```
@@ -122,13 +121,13 @@ inference/det_db/
-## TEXT DETECTION MODEL INFERENCE
+## 2. Text Detection Model Inference
The following will introduce the lightweight Chinese detection model inference, DB text detection model inference and EAST text detection model inference. The default configuration is based on the inference setting of the DB text detection model.
Because EAST and DB algorithms are very different, when inference, it is necessary to **adapt the EAST text detection algorithm by passing in corresponding parameters**.
-### 1. LIGHTWEIGHT CHINESE DETECTION MODEL INFERENCE
+### 2.1 Lightweight Chinese Detection Model Inference
For lightweight Chinese detection model inference, you can execute the following commands:
@@ -163,7 +162,7 @@ python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_di
```
-### 2. DB TEXT DETECTION MODEL INFERENCE
+### 2.2 DB Text Detection Model Inference
First, convert the model saved in the DB text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_db_v2.0_train.tar)), you can use the following command to convert:
@@ -184,7 +183,7 @@ The visualized text detection results are saved to the `./inference_results` fol
**Note**: Since the ICDAR2015 dataset has only 1,000 training images, mainly for English scenes, the above model has very poor detection result on Chinese text images.
-### 3. EAST TEXT DETECTION MODEL INFERENCE
+### 2.3 EAST TEXT DETECTION MODEL INFERENCE
First, convert the model saved in the EAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar)), you can use the following command to convert:
@@ -205,7 +204,7 @@ The visualized text detection results are saved to the `./inference_results` fol
-### 4. SAST TEXT DETECTION MODEL INFERENCE
+### 2.4 Sast Text Detection Model Inference
#### (1). Quadrangle text detection model (ICDAR2015)
First, convert the model saved in the SAST text detection training process into an inference model. Taking the model based on the Resnet50_vd backbone network and trained on the ICDAR2015 English dataset as an example ([model download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)), you can use the following command to convert:
@@ -243,13 +242,13 @@ The visualized text detection results are saved to the `./inference_results` fol
**Note**: SAST post-processing locality aware NMS has two versions: Python and C++. The speed of C++ version is obviously faster than that of Python version. Due to the compilation version problem of NMS of C++ version, C++ version NMS will be called only in Python 3.5 environment, and python version NMS will be called in other cases.
-## TEXT RECOGNITION MODEL INFERENCE
+## 3. Text Recognition Model Inference
The following will introduce the lightweight Chinese recognition model inference, other CTC-based and Attention-based text recognition models inference. For Chinese text recognition, it is recommended to choose the recognition model based on CTC loss. In practice, it is also found that the result of the model based on Attention loss is not as good as the one based on CTC loss. In addition, if the characters dictionary is modified during training, make sure that you use the same characters set during inferencing. Please check below for details.
-### 1. LIGHTWEIGHT CHINESE TEXT RECOGNITION MODEL REFERENCE
+### 3.1 Lightweight Chinese Text Recognition Model Reference
For lightweight Chinese recognition model inference, you can execute the following commands:
@@ -269,7 +268,7 @@ Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.9897658)
```
-### 2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE
+### 3.2 CTC-Based Text Recognition Model Inference
Taking CRNN as an example, we introduce the recognition model inference based on CTC loss. Rosetta and Star-Net are used in a similar way, No need to set the recognition algorithm parameter rec_algorithm.
@@ -282,7 +281,7 @@ python3 tools/export_model.py -c configs/det/rec_r34_vd_none_bilstm_ctc.yml -o G
For CRNN text recognition model inference, execute the following commands:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```

@@ -292,6 +291,7 @@ After executing the command, the recognition result of the above image is as fol
```bash
Predicts of ./doc/imgs_words_en/word_336.png:('super', 0.9999073)
```
+
**Note**:Since the above model refers to [DTRB](https://arxiv.org/abs/1904.01906) text recognition training and evaluation process, it is different from the training of lightweight Chinese recognition model in two aspects:
- The image resolution used in training is different: the image resolution used in training the above model is [3,32,100], while during our Chinese model training, in order to ensure the recognition effect of long text, the image resolution used in training is [3, 32, 320]. The default shape parameter of the inference stage is the image resolution used in training phase, that is [3, 32, 320]. Therefore, when running inference of the above English model here, you need to set the shape of the recognition image through the parameter `rec_image_shape`.
@@ -304,7 +304,7 @@ dict_character = list(self.character_str)
```
-### 3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE
+### 3.3 SRN-Based Text Recognition Model Inference
The recognition model based on SRN requires additional setting of the recognition algorithm parameter
--rec_algorithm="SRN". At the same time, it is necessary to ensure that the predicted shape is consistent
@@ -314,25 +314,26 @@ with the training, such as: --rec_image_shape="1, 64, 256"
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
--rec_model_dir="./inference/srn/" \
--rec_image_shape="1, 64, 256" \
- --rec_char_type="en" \
+ --rec_char_dict_path="./ppocr/utils/ic15_dict.txt" \
--rec_algorithm="SRN"
```
-### 4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY
+### 3.4 Text Recognition Model Inference Using Custom Characters Dictionary
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_dict_path="your text dict path"
```
-### 5. MULTILINGAUL MODEL INFERENCE
+
+### 3.5 Multilingual Model Inference
If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
```

@@ -343,13 +344,7 @@ Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
```
-## ANGLE CLASSIFICATION MODEL INFERENCE
-
-The following will introduce the angle classification model inference.
-
-
-
-### 1.ANGLE CLASSIFICATION MODEL INFERENCE
+## 4. Angle Classification Model Inference
For angle classification model inference, you can execute the following commands:
@@ -371,10 +366,10 @@ After executing the command, the prediction results (classification angle and sc
```
-## TEXT DETECTION ANGLE CLASSIFICATION AND RECOGNITION INFERENCE CONCATENATION
+## 5. Text Detection Angle Classification and Recognition Inference Concatenation
-### 1. LIGHTWEIGHT CHINESE MODEL
+### 5.1 Lightweight Chinese Model
When performing prediction, you need to specify the path of a single image or a folder of images through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, the parameter `cls_model_dir` specifies the path to angle classification inference model and the parameter `rec_model_dir` specifies the path to identify the inference model. The parameter `use_angle_cls` is used to control whether to enable the angle classification model. The parameter `use_mp` specifies whether to use multi-process to infer `total_process_num` specifies process number when using multi-process. The parameter . The visualized recognition results are saved to the `./inference_results` folder by default.
@@ -388,14 +383,14 @@ python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --de
# use multi-process
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls=false --use_mp=True --total_process_num=6
```
-```
+
After executing the command, the recognition result image is as follows:

-### 2. OTHER MODELS
+### 5.2 Other Models
If you want to try other detection algorithms or recognition algorithms, please refer to the above text detection model inference and text recognition model inference, update the corresponding configuration and model.
@@ -404,7 +399,7 @@ If you want to try other detection algorithms or recognition algorithms, please
The following command uses the combination of the EAST text detection and STAR-Net text recognition:
```
-python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
+python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```
After executing the command, the recognition result image is as follows:
diff --git a/doc/doc_en/inference_ppocr_en.md b/doc/doc_en/inference_ppocr_en.md
new file mode 100755
index 0000000000000000000000000000000000000000..fa3b1c88713f01e8e411cf95d107b4b58dd7f4e1
--- /dev/null
+++ b/doc/doc_en/inference_ppocr_en.md
@@ -0,0 +1,135 @@
+
+# Python Inference for PP-OCR Model Library
+
+This article introduces the use of the Python inference engine for the PP-OCR model library. The content is in order of text detection, text recognition, direction classifier and the prediction method of the three in series on the CPU and GPU.
+
+
+- [Text Detection Model Inference](#DETECTION_MODEL_INFERENCE)
+
+- [Text Recognition Model Inference](#RECOGNITION_MODEL_INFERENCE)
+ - [1. Lightweight Chinese Recognition Model Inference](#LIGHTWEIGHT_RECOGNITION)
+ - [2. Multilingaul Model Inference](#MULTILINGUAL_MODEL_INFERENCE)
+
+- [Angle Classification Model Inference](#ANGLE_CLASS_MODEL_INFERENCE)
+
+- [Text Detection Angle Classification and Recognition Inference Concatenation](#CONCATENATION)
+
+
+
+## Text Detection Model Inference
+
+The default configuration is based on the inference setting of the DB text detection model. For lightweight Chinese detection model inference, you can execute the following commands:
+
+```
+# download DB text detection inference model
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
+tar xf ch_ppocr_mobile_v2.0_det_infer.tar
+# predict
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/"
+```
+
+The visual text detection results are saved to the ./inference_results folder by default, and the name of the result file is prefixed with'det_res'. Examples of results are as follows:
+
+
+
+You can use the parameters `limit_type` and `det_limit_side_len` to limit the size of the input image,
+The optional parameters of `limit_type` are [`max`, `min`], and
+`det_limit_size_len` is a positive integer, generally set to a multiple of 32, such as 960.
+
+The default setting of the parameters is `limit_type='max', det_limit_side_len=960`. Indicates that the longest side of the network input image cannot exceed 960,
+If this value is exceeded, the image will be resized with the same width ratio to ensure that the longest side is `det_limit_side_len`.
+Set as `limit_type='min', det_limit_side_len=960`, it means that the shortest side of the image is limited to 960.
+
+If the resolution of the input picture is relatively large and you want to use a larger resolution prediction, you can set det_limit_side_len to the desired value, such as 1216:
+```
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --det_limit_type=max --det_limit_side_len=1216
+```
+
+If you want to use the CPU for prediction, execute the command as follows
+```
+python3 tools/infer/predict_det.py --image_dir="./doc/imgs/1.jpg" --det_model_dir="./inference/det_db/" --use_gpu=False
+```
+
+
+
+## Text Recognition Model Inference
+
+
+
+### 1. Lightweight Chinese Recognition Model Inference
+
+For lightweight Chinese recognition model inference, you can execute the following commands:
+
+```
+# download CRNN text recognition inference model
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar
+tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_10.png" --rec_model_dir="ch_ppocr_mobile_v2.0_rec_infer"
+```
+
+
+
+After executing the command, the prediction results (recognized text and score) of the above image will be printed on the screen.
+
+```bash
+Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.9897658)
+```
+
+
+
+### 2. Multilingaul Model Inference
+If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
+You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
+
+```
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
+```
+
+
+After executing the command, the prediction result of the above figure is:
+
+``` text
+Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
+```
+
+
+
+## Angle Classification Model Inference
+
+For angle classification model inference, you can execute the following commands:
+
+
+```
+# download text angle class inference model:
+wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar
+tar xf ch_ppocr_mobile_v2.0_cls_infer.tar
+python3 tools/infer/predict_cls.py --image_dir="./doc/imgs_words_en/word_10.png" --cls_model_dir="ch_ppocr_mobile_v2.0_cls_infer"
+```
+
+
+After executing the command, the prediction results (classification angle and score) of the above image will be printed on the screen.
+
+```
+ Predicts of ./doc/imgs_words_en/word_10.png:['0', 0.9999995]
+```
+
+
+## Text Detection Angle Classification and Recognition Inference Concatenation
+
+When performing prediction, you need to specify the path of a single image or a folder of images through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, the parameter `cls_model_dir` specifies the path to angle classification inference model and the parameter `rec_model_dir` specifies the path to identify the inference model. The parameter `use_angle_cls` is used to control whether to enable the angle classification model. The parameter `use_mp` specifies whether to use multi-process to infer `total_process_num` specifies process number when using multi-process. The parameter . The visualized recognition results are saved to the `./inference_results` folder by default.
+
+```shell
+# use direction classifier
+python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --cls_model_dir="./inference/cls/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls=true
+
+# not use use direction classifier
+python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/"
+
+# use multi-process
+python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --det_model_dir="./inference/det_db/" --rec_model_dir="./inference/rec_crnn/" --use_angle_cls=false --use_mp=True --total_process_num=6
+```
+
+
+After executing the command, the recognition result image is as follows:
+
+
diff --git a/doc/doc_en/models_and_config_en.md b/doc/doc_en/models_and_config_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..414d844d63d51a2b53feea035c1f735594d73fe0
--- /dev/null
+++ b/doc/doc_en/models_and_config_en.md
@@ -0,0 +1,48 @@
+# PP-OCR Model and Configuration
+The chapter on PP-OCR model and configuration file mainly adds some basic concepts of OCR model and the content and role of configuration file to have a better experience in the subsequent parameter adjustment and training of the model.
+
+This chapter contains three parts. Firstly, [PP-OCR Model Download](. /models_list_en.md) explains the concept of PP-OCR model types and provides links to download all models. Then in [Yml Configuration](. /config_en.md) details the parameters needed to fine-tune the PP-OCR models. The final [Python Inference for PP-OCR Model Library](. /inference_ppocr_en.md) is an introduction to the use of the PP-OCR model library in the first section, which can quickly utilize the rich model library models to obtain test results through the Python inference engine.
+
+------
+
+Let's first understand some basic concepts.
+
+- [INTRODUCTION ABOUT OCR](#introduction-about-ocr)
+ * [BASIC CONCEPTS OF OCR DETECTION MODEL](#basic-concepts-of-ocr-detection-model)
+ * [Basic concepts of OCR recognition model](#basic-concepts-of-ocr-recognition-model)
+ * [PP-OCR model](#pp-ocr-model)
+ * [And a table of contents](#and-a-table-of-contents)
+ * [On the right](#on-the-right)
+
+
+## 1. INTRODUCTION ABOUT OCR
+
+This section briefly introduces the basic concepts of OCR detection model and recognition model, and introduces PaddleOCR's PP-OCR model.
+
+OCR (Optical Character Recognition, Optical Character Recognition) is currently the general term for text recognition. It is not limited to document or book text recognition, but also includes recognizing text in natural scenes. It can also be called STR (Scene Text Recognition).
+
+OCR text recognition generally includes two parts, text detection and text recognition. The text detection module first uses detection algorithms to detect text lines in the image. And then the recognition algorithm to identify the specific text in the text line.
+
+
+### 1.1 BASIC CONCEPTS OF OCR DETECTION MODEL
+
+Text detection can locate the text area in the image, and then usually mark the word or text line in the form of a bounding box. Traditional text detection algorithms mostly extract features manually, which are characterized by fast speed and good effect in simple scenes, but the effect will be greatly reduced when faced with natural scenes. Currently, deep learning methods are mostly used.
+
+Text detection algorithms based on deep learning can be roughly divided into the following categories:
+1. Method based on target detection. Generally, after the text box is predicted, the final text box is filtered through NMS, which is mostly four-point text box, which is not ideal for curved text scenes. Typical algorithms are methods such as EAST and Text Box.
+2. Method based on text segmentation. The text line is regarded as the segmentation target, and then the external text box is constructed through the segmentation result, which can handle curved text, and the effect is not ideal for the text cross scene problem. Typical algorithms are DB, PSENet and other methods.
+3. Hybrid target detection and segmentation method.
+
+
+### 1.2 Basic concepts of OCR recognition model
+
+The input of the OCR recognition algorithm is generally text lines images which has less background information, and the text information occupies the main part. The recognition algorithm can be divided into two types of algorithms:
+1. CTC-based method. The text prediction module of the recognition algorithm is based on CTC, and the commonly used algorithm combination is CNN+RNN+CTC. There are also some algorithms that try to add transformer modules to the network and so on.
+2. Attention-based method. The text prediction module of the recognition algorithm is based on Attention, and the commonly used algorithm combination is CNN+RNN+Attention.
+
+
+### 1.3 PP-OCR model
+
+PaddleOCR integrates many OCR algorithms, text detection algorithms include DB, EAST, SAST, etc., text recognition algorithms include CRNN, RARE, StarNet, Rosetta, SRN and other algorithms.
+
+Among them, PaddleOCR has released the PP-OCR series model for the general OCR in Chinese and English natural scenes. The PP-OCR model is composed of the DB+CRNN algorithm. It uses massive Chinese data training and model tuning methods to have high text detection and recognition capabilities in Chinese scenes. And PaddleOCR has launched a high-precision and ultra-lightweight PP-OCRv2 model. The detection model is only 3M, and the recognition model is only 8.5M. Using [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim)'s model quantification method, the detection model can be compressed to 0.8M without reducing the accuracy. The recognition is compressed to 3M, which is more suitable for mobile deployment scenarios.
diff --git a/doc/doc_en/models_en.md b/doc/doc_en/models_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..37c4a174563abc68085a103e11e2ddb3bd954714
--- /dev/null
+++ b/doc/doc_en/models_en.md
@@ -0,0 +1,46 @@
+# PP-OCR Model Zoo
+The PP-OCR model zoo section explains some basic concepts of the OCR model and how to quickly use the models in the PP-OCR model library.
+
+This section contains two parts. Firstly, [PP-OCR Model Download](./models_list_en.md) explains the concept of PP-OCR model types and provides links to download all models. The next [Python Inference for PP-OCR Model Zoo](./inference_ppocr_en.md) is an introduction to the use of the PP-OCR model library, which can quickly utilize the rich model library models to obtain test results through the Python inference engine.
+
+------
+
+Let's first understand some basic concepts.
+
+- [Introduction about OCR](#introduction-about-ocr)
+ * [Basic Concepts of OCR Detection Model](#basic-concepts-of-ocr-detection-model)
+ * [Basic Concepts of OCR Recognition Model](#basic-concepts-of-ocr-recognition-model)
+ * [PP-OCR Model](#pp-ocr-model)
+
+
+## 1. Introduction about OCR
+
+This section briefly introduces the basic concepts of OCR detection model and recognition model, and introduces PaddleOCR's PP-OCR model.
+
+OCR (Optical Character Recognition, Optical Character Recognition) is currently the general term for text recognition. It is not limited to document or book text recognition, but also includes recognizing text in natural scenes. It can also be called STR (Scene Text Recognition).
+
+OCR text recognition generally includes two parts, text detection and text recognition. The text detection module first uses detection algorithms to detect text lines in the image. And then the recognition algorithm to identify the specific text in the text line.
+
+
+### 1.1 Basic Concepts of OCR Detection Model
+
+Text detection can locate the text area in the image, and then usually mark the word or text line in the form of a bounding box. Traditional text detection algorithms mostly extract features manually, which are characterized by fast speed and good effect in simple scenes, but the effect will be greatly reduced when faced with natural scenes. Currently, deep learning methods are mostly used.
+
+Text detection algorithms based on deep learning can be roughly divided into the following categories:
+1. Method based on target detection. Generally, after the text box is predicted, the final text box is filtered through NMS, which is mostly four-point text box, which is not ideal for curved text scenes. Typical algorithms are methods such as EAST and Text Box.
+2. Method based on text segmentation. The text line is regarded as the segmentation target, and then the external text box is constructed through the segmentation result, which can handle curved text, and the effect is not ideal for the text cross scene problem. Typical algorithms are DB, PSENet and other methods.
+3. Hybrid target detection and segmentation method.
+
+
+### 1.2 Basic Concepts of OCR Recognition Model
+
+The input of the OCR recognition algorithm is generally text lines images which has less background information, and the text information occupies the main part. The recognition algorithm can be divided into two types of algorithms:
+1. CTC-based method. The text prediction module of the recognition algorithm is based on CTC, and the commonly used algorithm combination is CNN+RNN+CTC. There are also some algorithms that try to add transformer modules to the network and so on.
+2. Attention-based method. The text prediction module of the recognition algorithm is based on Attention, and the commonly used algorithm combination is CNN+RNN+Attention.
+
+
+### 1.3 PP-OCR Model
+
+PaddleOCR integrates many OCR algorithms, text detection algorithms include DB, EAST, SAST, etc., text recognition algorithms include CRNN, RARE, StarNet, Rosetta, SRN and other algorithms.
+
+Among them, PaddleOCR has released the PP-OCR series model for the general OCR in Chinese and English natural scenes. The PP-OCR model is composed of the DB+CRNN algorithm. It uses massive Chinese data training and model tuning methods to have high text detection and recognition capabilities in Chinese scenes. And PaddleOCR has launched a high-precision and ultra-lightweight PP-OCRv2 model. The detection model is only 3M, and the recognition model is only 8.5M. Using [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim)'s model quantification method, the detection model can be compressed to 0.8M without reducing the accuracy. The recognition is compressed to 3M, which is more suitable for mobile deployment scenarios.
diff --git a/doc/doc_en/models_list_en.md b/doc/doc_en/models_list_en.md
index 9bee4aef5121b1964a9bdbdeeaad4e81dd9ff6d4..3b9b5518701f052079af1398a4fa3e3770eb12a1 100644
--- a/doc/doc_en/models_list_en.md
+++ b/doc/doc_en/models_list_en.md
@@ -1,7 +1,8 @@
-## OCR model list(V2.0, updated on 2021.1.20)
+## OCR model list(V2.1, updated on 2021.9.6)
> **Note**
-> 1. Compared with [models 1.1](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/models_list_en.md), which are trained with static graph programming paradigm, models 2.0 are the dynamic graph trained version and achieve close performance.
-> 2. All models in this tutorial are all ppocr-series models, for more introduction of algorithms and models based on public dataset, you can refer to [algorithm overview tutorial](./algorithm_overview_en.md).
+> 1. Compared with the model v2.0, the 2.1 version of the detection model has a improvement in accuracy, and the 2.1 version of the recognition model is optimized in accuracy and CPU speed.
+> 2. Compared with [models 1.1](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/doc/doc_en/models_list_en.md), which are trained with static graph programming paradigm, models 2.0 are the dynamic graph trained version and achieve close performance.
+> 3. All models in this tutorial are all ppocr-series models, for more introduction of algorithms and models based on public dataset, you can refer to [algorithm overview tutorial](./algorithm_overview_en.md).
- [1. Text Detection Model](#Detection)
- [2. Text Recognition Model](#Recognition)
@@ -28,6 +29,8 @@ Relationship of the above models is as follows.
|model name|description|config|model size|download|
| --- | --- | --- | --- | --- |
+|ch_PP-OCRv2_det_slim|slim quantization with distillation lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)| 3M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_quant_infer.tar)|
+|ch_PP-OCRv2_det|Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_det_cml.yml](../../configs/det/ch_PP-OCRv2/ch_PP-OCR_det_cml.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_distill_train.tar)|
|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|2.6M |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar)|
|ch_ppocr_mobile_v2.0_det|Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_det_mv3_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml)|3M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar)|
|ch_ppocr_server_v2.0_det|General model, which is larger than the lightweight model, but achieved better performance|[ch_det_res18_db_v2.0.yml](../../configs/det/ch_ppocr_v2.0/ch_det_res18_db_v2.0.yml)|47M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar)|
@@ -40,6 +43,8 @@ Relationship of the above models is as follows.
|model name|description|config|model size|download|
| --- | --- | --- | --- | --- |
+|ch_PP-OCRv2_rec_slim|Slim qunatization with distillation lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)| 9M |[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_quant_train.tar) |
+|ch_PP-OCRv2_rec|Original lightweight model, supporting Chinese, English, multilingual text detection|[ch_PP-OCRv2_rec.yml](../../configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml)|8.5M|[inference model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_train.tar) |
|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)| 6M | [inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
|ch_ppocr_mobile_v2.0_rec|Original lightweight model, supporting Chinese, English and number recognition|[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)|5.2M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_train.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_pre.tar) |
|ch_ppocr_server_v2.0_rec|General model, supporting Chinese, English and number recognition|[rec_chinese_common_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml)|94.8M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_train.tar) / [pre-trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_pre.tar) |
@@ -58,45 +63,6 @@ Relationship of the above models is as follows.
#### Multilingual Recognition Model(Updating...)
-**Note:** The configuration file of the new multi language model is generated by code. You can use the `--help` parameter to check which multi language are supported by current PaddleOCR.
-
-```bash
-# The code needs to run in the specified directory
-cd {your/path/}PaddleOCR/configs/rec/multi_language/
-python3 generate_multi_language_configs.py --help
-```
-
-Take the Italian configuration file as an example:
-##### 1.Generate Italian configuration file to test the model provided
-you can generate the default configuration file through the following command, and use the default language dictionary provided by paddleocr for prediction.
-```bash
-# The code needs to run in the specified directory
-cd {your/path/}PaddleOCR/configs/rec/multi_language/
-# Set the required language configuration file through -l or --language parameter
-# This command will write the default parameter to the configuration file.
-python3 generate_multi_language_configs.py -l it
-```
-##### 2. Generate Italian configuration file to train your own data
-If you want to train your own model, you can prepare the training set file, verification set file, dictionary file and training data path. Here we assume that the Italian training set, verification set, dictionary and training data path are:
-- Training set:{your/path/}PaddleOCR/train_data/train_list.txt
-- Validation set: {your/path/}PaddleOCR/train_data/val_list.txt
-- Use the default dictionary provided by paddleocr:{your/path/}PaddleOCR/ppocr/utils/dict/it_dict.txt
-- Training data path:{your/path/}PaddleOCR/train_data
-```bash
-# The code needs to run in the specified directory
-cd {your/path/}PaddleOCR/configs/rec/multi_language/
-# The -l or --language parameter is required
-# --train modify train_list path
-# --val modify eval_list path
-# --data_dir modify data dir
-# -o modify default parameters
-# --dict Change the dictionary path. The example uses the default dictionary path, so that this parameter can be empty.
-python3 generate_multi_language_configs.py -l it \
---train {path/to/train_list} \
---val {path/to/val_list} \
---data_dir {path/to/data_dir} \
--o Global.use_gpu=False
-```
|model name| dict file | description|config|model size|download|
| --- | --- | --- |--- | --- | --- |
| french_mobile_v2.0_rec | ppocr/utils/dict/french_dict.txt | Lightweight model for French recognition|[rec_french_lite_train.yml](../../configs/rec/multi_language/rec_french_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_train.tar) |
@@ -120,12 +86,14 @@ For more supported languages, please refer to : [Multi-language model](./multi_l
|model name|description|config|model size|download|
| --- | --- | --- | --- | --- |
-|ch_ppocr_mobile_slim_v2.0_cls|Slim quantized model|[cls_mv3.yml](../../configs/cls/cls_mv3.yml)| 2.1M | [inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_slim_train.tar) |
-|ch_ppocr_mobile_v2.0_cls|Original model|[cls_mv3.yml](../../configs/cls/cls_mv3.yml)|1.38M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) |
+|ch_ppocr_mobile_slim_v2.0_cls|Slim quantized model for text angle classification|[cls_mv3.yml](../../configs/cls/cls_mv3.yml)| 2.1M | [inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_slim_train.tar) |
+|ch_ppocr_mobile_v2.0_cls|Original model for text angle classification|[cls_mv3.yml](../../configs/cls/cls_mv3.yml)|1.38M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) |
### 4. Paddle-Lite Model
|Version|Introduction|Model size|Detection model|Text Direction model|Recognition model|Paddle-Lite branch|
|---|---|---|---|---|---|---|
-|V2.0|extra-lightweight chinese OCR optimized model|7.8M|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_rec_opt.nb)|v2.9|
-|V2.0(slim)|extra-lightweight chinese OCR optimized model|3.3M|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_slim_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_slim_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_rec_slim_opt.nb)|v2.9|
+|PP-OCRv2|extra-lightweight chinese OCR optimized model|11M|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer_opt.nb)|v2.9|
+|PP-OCRv2(slim)|extra-lightweight chinese OCR optimized model|4.9M|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_slim_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_slim_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_slim_opt.nb)|v2.9|
+|V2.0|ppocr_v2.0 extra-lightweight chinese OCR optimized model|7.8M|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_rec_opt.nb)|v2.9|
+|V2.0(slim)|ppovr_v2.0 extra-lightweight chinese OCR optimized model|3.3M|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_det_slim_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_cls_slim_opt.nb)|[download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/lite/ch_ppocr_mobile_v2.0_rec_slim_opt.nb)|v2.9|
diff --git a/doc/doc_en/multi_languages_en.md b/doc/doc_en/multi_languages_en.md
index 43650c6ddfdd8c27ab44d0495111a767aeac9ca8..545be5524f2c52c9799d3b013f1aac8baf1a379f 100644
--- a/doc/doc_en/multi_languages_en.md
+++ b/doc/doc_en/multi_languages_en.md
@@ -198,13 +198,13 @@ If necessary, you can read related documents:
| Language | Abbreviation | | Language | Abbreviation |
| --- | --- | --- | --- | --- |
-|chinese and english|ch| |Arabic|ar|
-|english|en| |Hindi|hi|
-|french|fr| |Uyghur|ug|
-|german|german| |Persian|fa|
-|japan|japan| |Urdu|ur|
-|korean|korean| | Serbian(latin) |rs_latin|
-|chinese traditional |ch_tra| |Occitan |oc|
+|Chinese & English|ch| |Arabic|ar|
+|English|en| |Hindi|hi|
+|French|fr| |Uyghur|ug|
+|German|german| |Persian|fa|
+|Japan|japan| |Urdu|ur|
+|Korean|korean| | Serbian(latin) |rs_latin|
+|Chinese Traditional |chinese_cht| |Occitan |oc|
| Italian |it| |Marathi|mr|
|Spanish |es| |Nepali|ne|
| Portuguese|pt| |Serbian(cyrillic)|rs_cyrillic|
diff --git a/doc/doc_en/paddleOCR_overview_en.md b/doc/doc_en/paddleOCR_overview_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..073c3ec889b2f21e9e40f5f7d1d6dc719e3dcac9
--- /dev/null
+++ b/doc/doc_en/paddleOCR_overview_en.md
@@ -0,0 +1,39 @@
+# PaddleOCR Overview and Project Clone
+
+## 1. PaddleOCR Overview
+
+PaddleOCR contains rich text detection, text recognition and end-to-end algorithms. Combining actual testing and industrial experience, PaddleOCR chooses DB and CRNN as the basic detection and recognition models, and proposes a series of models, named PP-OCR, for industrial applications after a series of optimization strategies. The PP-OCR model is aimed at general scenarios and forms a model library according to different languages. Based on the capabilities of PP-OCR, PaddleOCR releases the PP-Structure tool library for document scene tasks, including two major tasks: layout analysis and table recognition. In order to get through the entire process of industrial landing, PaddleOCR provides large-scale data production tools and a variety of prediction deployment tools to help developers quickly turn ideas into reality.
+
+
+

+
+
+
+
+## 2. Project Clone
+
+### **2.1 Clone PaddleOCR repo**
+
+```
+# Recommend
+git clone https://github.com/PaddlePaddle/PaddleOCR
+
+# If you cannot pull successfully due to network problems, you can also choose to use the code hosting on the cloud:
+
+git clone https://gitee.com/paddlepaddle/PaddleOCR
+
+# Note: The cloud-hosting code may not be able to synchronize the update with this GitHub project in real time. There might be a delay of 3-5 days. Please give priority to the recommended method.
+```
+
+### **2.2 Install third-party libraries**
+
+```
+cd PaddleOCR
+pip3 install -r requirements.txt
+```
+
+If you getting this error `OSError: [WinError 126] The specified module could not be found` when you install shapely on windows.
+
+Please try to download Shapely whl file using [http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely](http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely).
+
+Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found)
\ No newline at end of file
diff --git a/doc/doc_en/quickstart_en.md b/doc/doc_en/quickstart_en.md
index a5c0881de30bfd4b76d30c7840b6585b5d7e2af9..0055d8f7a89d0d218d001ea94fd4c620de5d037f 100644
--- a/doc/doc_en/quickstart_en.md
+++ b/doc/doc_en/quickstart_en.md
@@ -1,103 +1,252 @@
-# Quick start of Chinese OCR model
+# PaddleOCR Quick Start
-## 1. Prepare for the environment
+[PaddleOCR Quick Start](#paddleocr-quick-start)
-Please refer to [quick installation](./installation_en.md) to configure the PaddleOCR operating environment.
++ [1. Install PaddleOCR Whl Package](#1-install-paddleocr-whl-package)
+* [2. Easy-to-Use](#2-easy-to-use)
+ + [2.1 Use by Command Line](#21-use-by-command-line)
+ - [2.1.1 English and Chinese Model](#211-english-and-chinese-model)
+ - [2.1.2 Multi-language Model](#212-multi-language-model)
+ - [2.1.3 Layout Analysis](#213-layoutAnalysis)
+ + [2.2 Use by Code](#22-use-by-code)
+ - [2.2.1 Chinese & English Model and Multilingual Model](#221-chinese---english-model-and-multilingual-model)
+ - [2.2.2 Layout Analysis](#222-layoutAnalysis)
-* Note: Support the use of PaddleOCR through whl package installation,pelease refer [PaddleOCR Package](./whl_en.md).
-## 2.inference models
-The detection and recognition models on the mobile and server sides are as follows. For more models (including multiple languages), please refer to [PP-OCR v2.0 series model list](../doc_ch/models_list.md)
+
-| Model introduction | Model name | Recommended scene | Detection model | Direction Classifier | Recognition model |
-| ------------ | --------------- | ----------------|---- | ---------- | -------- |
-| Ultra-lightweight Chinese OCR model (8.1M) | ch_ppocr_mobile_v2.0_xx |Mobile-side/Server-side|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar) / [pretrained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar)|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [pretrained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar) / [pretrained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_pre.tar) |
-| Universal Chinese OCR model (143M) | ch_ppocr_server_v2.0_xx |Server-side |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar) / [pretrained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_train.tar) |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar) / [pretrained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar) |[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar) / [pretrained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_pre.tar) |
+## 1. Install PaddleOCR Whl Package
+```bash
+pip install "paddleocr>=2.0.1" # Recommend to use version 2.0.1+
+```
-* If `wget` is not installed in the windows environment, you can copy the link to the browser to download when downloading the model, then uncompress it and place it in the corresponding directory.
+- **For windows users:** If you getting this error `OSError: [WinError 126] The specified module could not be found` when you install shapely on windows. Please try to download Shapely whl file [here](http://www.lfd.uci.edu/~gohlke/pythonlibs/#shapely).
-Copy the download address of the `inference model` for detection and recognition in the table above, and uncompress them.
+ Reference: [Solve shapely installation on windows](https://stackoverflow.com/questions/44398265/install-shapely-oserror-winerror-126-the-specified-module-could-not-be-found)
-```
-mkdir inference && cd inference
-# Download the detection model and unzip
-wget {url/of/detection/inference_model} && tar xf {name/of/detection/inference_model/package}
-# Download the recognition model and unzip
-wget {url/of/recognition/inference_model} && tar xf {name/of/recognition/inference_model/package}
-# Download the direction classifier model and unzip
-wget {url/of/classification/inference_model} && tar xf {name/of/classification/inference_model/package}
-cd ..
-```
+- **For layout analysis users**, run the following command to install **Layout-Parser**
-Take the ultra-lightweight model as an example:
+ ```bash
+ pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
+ ```
+
+
+## 2. Easy-to-Use
+
+
+
+### 2.1 Use by Command Line
+
+PaddleOCR provides a series of test images, click [here](https://paddleocr.bj.bcebos.com/dygraph_v2.1/ppocr_img.zip) to download, and then switch to the corresponding directory in the terminal
+
+```bash
+cd /path/to/ppocr_img
```
-mkdir inference && cd inference
-# Download the detection model of the ultra-lightweight Chinese OCR model and uncompress it
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_det_infer.tar
-# Download the recognition model of the ultra-lightweight Chinese OCR model and uncompress it
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar
-# Download the angle classifier model of the ultra-lightweight Chinese OCR model and uncompress it
-wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar && tar xf ch_ppocr_mobile_v2.0_cls_infer.tar
-cd ..
-```
-After decompression, the file structure should be as follows:
+If you do not use the provided test image, you can replace the following `--image_dir` parameter with the corresponding test image path
+
+
+
+#### 2.1.1 Chinese and English Model
+
+* Detection, direction classification and recognition: set the direction classifier parameter`--use_angle_cls true` to recognize vertical text.
+
+ ```bash
+ paddleocr --image_dir ./imgs_en/img_12.jpg --use_angle_cls true --lang en
+ ```
+
+ Output will be a list, each item contains bounding box, text and recognition confidence
+
+ ```bash
+ [[[442.0, 173.0], [1169.0, 173.0], [1169.0, 225.0], [442.0, 225.0]], ['ACKNOWLEDGEMENTS', 0.99283075]]
+ [[[393.0, 340.0], [1207.0, 342.0], [1207.0, 389.0], [393.0, 387.0]], ['We would like to thank all the designers and', 0.9357758]]
+ [[[399.0, 398.0], [1204.0, 398.0], [1204.0, 433.0], [399.0, 433.0]], ['contributors whohave been involved in the', 0.9592447]]
+ ......
+ ```
+
+* Only detection: set `--rec` to `false`
+
+ ```bash
+ paddleocr --image_dir ./imgs_en/img_12.jpg --rec false
+ ```
+
+ Output will be a list, each item only contains bounding box
+ ```bash
+ [[756.0, 812.0], [805.0, 812.0], [805.0, 830.0], [756.0, 830.0]]
+ [[820.0, 803.0], [1085.0, 801.0], [1085.0, 836.0], [820.0, 838.0]]
+ [[393.0, 801.0], [715.0, 805.0], [715.0, 839.0], [393.0, 836.0]]
+ ......
+ ```
+
+* Only recognition: set `--det` to `false`
+
+ ```bash
+ paddleocr --image_dir ./imgs_words_en/word_10.png --det false --lang en
+ ```
+
+ Output will be a list, each item contains text and recognition confidence
+
+ ```bash
+ ['PAIN', 0.990372]
+ ```
+
+If you need to use the 2.0 model, please specify the parameter `--version PP-OCR`, paddleocr uses the 2.1 model by default(`--versioin PP-OCRv2`). More whl package usage can be found in [whl package](./whl_en.md)
+
+
+#### 2.1.2 Multi-language Model
+
+Paddleocr currently supports 80 languages, which can be switched by modifying the `--lang` parameter.
+
+``` bash
+paddleocr --image_dir ./doc/imgs_en/254.jpg --lang=en
```
-├── ch_ppocr_mobile_v2.0_cls_infer
-│ ├── inference.pdiparams
-│ ├── inference.pdiparams.info
-│ └── inference.pdmodel
-├── ch_ppocr_mobile_v2.0_det_infer
-│ ├── inference.pdiparams
-│ ├── inference.pdiparams.info
-│ └── inference.pdmodel
-├── ch_ppocr_mobile_v2.0_rec_infer
- ├── inference.pdiparams
- ├── inference.pdiparams.info
- └── inference.pdmodel
+
+
+

+

+
+The result is a list, each item contains a text box, text and recognition confidence
+
+```text
+[('PHO CAPITAL', 0.95723116), [[66.0, 50.0], [327.0, 44.0], [327.0, 76.0], [67.0, 82.0]]]
+[('107 State Street', 0.96311164), [[72.0, 90.0], [451.0, 84.0], [452.0, 116.0], [73.0, 121.0]]]
+[('Montpelier Vermont', 0.97389287), [[69.0, 132.0], [501.0, 126.0], [501.0, 158.0], [70.0, 164.0]]]
+[('8022256183', 0.99810505), [[71.0, 175.0], [363.0, 170.0], [364.0, 202.0], [72.0, 207.0]]]
+[('REG 07-24-201706:59 PM', 0.93537045), [[73.0, 299.0], [653.0, 281.0], [654.0, 318.0], [74.0, 336.0]]]
+[('045555', 0.99346405), [[509.0, 331.0], [651.0, 325.0], [652.0, 356.0], [511.0, 362.0]]]
+[('CT1', 0.9988654), [[535.0, 367.0], [654.0, 367.0], [654.0, 406.0], [535.0, 406.0]]]
+......
```
-## 3. Single image or image set prediction
+Commonly used multilingual abbreviations include
+
+| Language | Abbreviation | | Language | Abbreviation | | Language | Abbreviation |
+| ------------------- | ------------ | ---- | -------- | ------------ | ---- | -------- | ------------ |
+| Chinese & English | ch | | French | fr | | Japanese | japan |
+| English | en | | German | german | | Korean | korean |
+| Chinese Traditional | chinese_cht | | Italian | it | | Russian | ru |
-* The following code implements text detection、angle class and recognition process. When performing prediction, you need to specify the path of a single image or image set through the parameter `image_dir`, the parameter `det_model_dir` specifies the path to detect the inference model, the parameter `rec_model_dir` specifies the path to identify the inference model, the parameter `use_angle_cls` specifies whether to use the direction classifier, the parameter `cls_model_dir` specifies the path to identify the direction classifier model, the parameter `use_space_char` specifies whether to predict the space char. The visual results are saved to the `./inference_results` folder by default.
+A list of all languages and their corresponding abbreviations can be found in [Multi-Language Model Tutorial](./multi_languages_en.md)
+
+#### 2.1.3 Layout Analysis
+Layout analysis refers to the division of 5 types of areas of the document, including text, title, list, picture and table. For the first three types of regions, directly use the OCR model to complete the text detection and recognition of the corresponding regions, and save the results in txt. For the table area, after the table structuring process, the table picture is converted into an Excel file of the same table style. The picture area will be individually cropped into an image.
+
+To use the layout analysis function of PaddleOCR, you need to specify `--type=structure`
```bash
+paddleocr --image_dir=../doc/table/1.png --type=structure
+```
-# Predict a single image specified by image_dir
-python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/ch_ppocr_mobile_v2.0_det_infer/" --rec_model_dir="./inference/ch_ppocr_mobile_v2.0_rec_infer/" --cls_model_dir="./inference/ch_ppocr_mobile_v2.0_cls_infer/" --use_angle_cls=True --use_space_char=True
+- **Results Format**
+
+ The returned results of PP-Structure is a list composed of a dict, an example is as follows
+
+ ```shell
+ [
+ { 'type': 'Text',
+ 'bbox': [34, 432, 345, 462],
+ 'res': ([[36.0, 437.0, 341.0, 437.0, 341.0, 446.0, 36.0, 447.0], [41.0, 454.0, 125.0, 453.0, 125.0, 459.0, 41.0, 460.0]],
+ [('Tigure-6. The performance of CNN and IPT models using difforen', 0.90060663), ('Tent ', 0.465441)])
+ }
+ ]
+ ```
+
+ The description of each field in dict is as follows
+
+ | Parameter | Description |
+ | --------- | ------------------------------------------------------------ |
+ | type | Type of image area |
+ | bbox | The coordinates of the image area in the original image, respectively [left upper x, left upper y, right bottom x, right bottom y] |
+ | res | OCR or table recognition result of image area。
Table: HTML string of the table;
OCR: A tuple containing the detection coordinates and recognition results of each single line of text |
+
+- **Parameter Description:**
+
+ | Parameter | Description | Default value |
+ | --------------- | ------------------------------------------------------------ | -------------------------------------------- |
+ | output | The path where excel and recognition results are saved | ./output/table |
+ | table_max_len | The long side of the image is resized in table structure model | 488 |
+ | table_model_dir | inference model path of table structure model | None |
+ | table_char_type | dict path of table structure model | ../ppocr/utils/dict/table_structure_dict.txt |
+
+
+
+### 2.2 Use by Code
+
+
+#### 2.2.1 Chinese & English Model and Multilingual Model
+
+* detection, angle classification and recognition:
+
+```python
+from paddleocr import PaddleOCR,draw_ocr
+# Paddleocr supports Chinese, English, French, German, Korean and Japanese.
+# You can set the parameter `lang` as `ch`, `en`, `fr`, `german`, `korean`, `japan`
+# to switch the language model in order.
+ocr = PaddleOCR(use_angle_cls=True, lang='en') # need to run only once to download and load model into memory
+img_path = './imgs_en/img_12.jpg'
+result = ocr.ocr(img_path, cls=True)
+for line in result:
+ print(line)
+
+
+# draw result
+from PIL import Image
+image = Image.open(img_path).convert('RGB')
+boxes = [line[0] for line in result]
+txts = [line[1][0] for line in result]
+scores = [line[1][1] for line in result]
+im_show = draw_ocr(image, boxes, txts, scores, font_path='./fonts/simfang.ttf')
+im_show = Image.fromarray(im_show)
+im_show.save('result.jpg')
+```
-# Predict imageset specified by image_dir
-python3 tools/infer/predict_system.py --image_dir="./doc/imgs/" --det_model_dir="./inference/ch_ppocr_mobile_v2.0_det_infer/" --rec_model_dir="./inference/ch_ppocr_mobile_v2.0_rec_infer/" --cls_model_dir="./inference/ch_ppocr_mobile_v2.0_cls_infer/" --use_angle_cls=True --use_space_char=True
+Output will be a list, each item contains bounding box, text and recognition confidence
-# If you want to use the CPU for prediction, you need to set the use_gpu parameter to False
-python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/ch_ppocr_mobile_v2.0_det_infer/" --rec_model_dir="./inference/ch_ppocr_mobile_v2.0_rec_infer/" --cls_model_dir="./inference/ch_ppocr_mobile_v2.0_cls_infer/" --use_angle_cls=True --use_space_char=True --use_gpu=False
+```bash
+[[[442.0, 173.0], [1169.0, 173.0], [1169.0, 225.0], [442.0, 225.0]], ['ACKNOWLEDGEMENTS', 0.99283075]]
+[[[393.0, 340.0], [1207.0, 342.0], [1207.0, 389.0], [393.0, 387.0]], ['We would like to thank all the designers and', 0.9357758]]
+[[[399.0, 398.0], [1204.0, 398.0], [1204.0, 433.0], [399.0, 433.0]], ['contributors whohave been involved in the', 0.9592447]]
+......
```
-- Universal Chinese OCR model
+Visualization of results
-Please follow the above steps to download the corresponding models and update the relevant parameters, The example is as follows.
+
+

+
+
-```
-# Predict a single image specified by image_dir
-python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/ch_ppocr_server_v2.0_det_infer/" --rec_model_dir="./inference/ch_ppocr_server_v2.0_rec_infer/" --cls_model_dir="./inference/ch_ppocr_mobile_v2.0_cls_infer/" --use_angle_cls=True --use_space_char=True
-```
+#### 2.2.2 Layout Analysis
+
+```python
+import os
+import cv2
+from paddleocr import PPStructure,draw_structure_result,save_structure_res
-* Note
- - If you want to use the recognition model which does not support space char recognition, please update the source code to the latest version and add parameters `--use_space_char=False`.
- - If you do not want to use direction classifier, please update the source code to the latest version and add parameters `--use_angle_cls=False`.
+table_engine = PPStructure(show_log=True)
+save_folder = './output/table'
+img_path = './table/1.png'
+img = cv2.imread(img_path)
+result = table_engine(img)
+save_structure_res(result, save_folder,os.path.basename(img_path).split('.')[0])
-For more text detection and recognition tandem reasoning, please refer to the document tutorial
-: [Inference with Python inference engine](./inference_en.md)。
+for line in result:
+ line.pop('img')
+ print(line)
-In addition, the tutorial also provides other deployment methods for the Chinese OCR model:
-- [Server-side C++ inference](../../deploy/cpp_infer/readme_en.md)
-- [Service deployment](../../deploy/hubserving)
-- [End-to-end deployment](https://github.com/PaddlePaddle/PaddleOCR/tree/develop/deploy/lite)
+from PIL import Image
+
+font_path = './fonts/simfang.ttf'
+image = Image.open(img_path).convert('RGB')
+im_show = draw_structure_result(image, result,font_path=font_path)
+im_show = Image.fromarray(im_show)
+im_show.save('result.jpg')
+```
diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md
index 634ec783aa5e1dd6c9202385cf2978d140ca44a1..51857ba16b7773ef38452fad6aa070f2117a9086 100644
--- a/doc/doc_en/recognition_en.md
+++ b/doc/doc_en/recognition_en.md
@@ -1,28 +1,28 @@
-## TEXT RECOGNITION
+# Text Recognition
-- [1 DATA PREPARATION](#DATA_PREPARATION)
+- [1. Data Preparation](#DATA_PREPARATION)
- [1.1 Costom Dataset](#Costom_Dataset)
- [1.2 Dataset Download](#Dataset_download)
- [1.3 Dictionary](#Dictionary)
- [1.4 Add Space Category](#Add_space_category)
-- [2 TRAINING](#TRAINING)
+- [2. Training](#TRAINING)
- [2.1 Data Augmentation](#Data_Augmentation)
- - [2.2 Training](#Training)
- - [2.3 Multi-language](#Multi_language)
+ - [2.2 General Training](#Training)
+ - [2.3 Multi-language Training](#Multi_language)
-- [3 EVALUATION](#EVALUATION)
+- [3. Evaluation](#EVALUATION)
-- [4 PREDICTION](#PREDICTION)
- - [4.1 Training engine prediction](#Training_engine_prediction)
+- [4. Prediction](#PREDICTION)
+- [5. Convert to Inference Model](#Inference)
-### DATA PREPARATION
+## 1. Data Preparation
PaddleOCR supports two data formats:
-- `LMDB` is used to train data sets stored in lmdb format;
-- `general data` is used to train data sets stored in text files:
+- `LMDB` is used to train data sets stored in lmdb format(LMDBDataSet);
+- `general data` is used to train data sets stored in text files(SimpleDataSet):
Please organize the dataset as follows:
@@ -36,7 +36,7 @@ mklink /d /train_data/dataset
```
-#### 1.1 Costom dataset
+### 1.1 Costom Dataset
If you want to use your own data for training, please refer to the following to organize your data.
@@ -84,11 +84,14 @@ Similar to the training set, the test set also needs to be provided a folder con
```
-#### 1.2 Dataset download
+### 1.2 Dataset Download
-If you do not have a dataset locally, you can download it on the official website [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads). Also refer to [DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,download the lmdb format dataset required for benchmark
+- ICDAR2015
-If you want to reproduce the paper indicators of SRN, you need to download offline [augmented data](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA), extraction code: y3ry. The augmented data is obtained by rotation and perturbation of mjsynth and synthtext. Please unzip the data to {your_path}/PaddleOCR/train_data/data_lmdb_Release/training/path.
+If you do not have a dataset locally, you can download it on the official website [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads).
+Also refer to [DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here) ,download the lmdb format dataset required for benchmark
+
+If you want to reproduce the paper SAR, you need to download extra dataset [SynthAdd](https://pan.baidu.com/share/init?surl=uV0LtoNmcxbO-0YA7Ch4dg), extraction code: 627x. Besides, icdar2013, icdar2015, cocotext, IIIT5k datasets are also used to train. For specific details, please refer to the paper SAR.
PaddleOCR provides label files for training the icdar2015 dataset, which can be downloaded in the following ways:
@@ -99,8 +102,28 @@ wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_t
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_test.txt
```
+PaddleOCR also provides a data format conversion script, which can convert ICDAR official website label to a data format
+supported by PaddleOCR. The data conversion tool is in `ppocr/utils/gen_label.py`, here is the training set as an example:
+
+```
+# convert the official gt to rec_gt_label.txt
+python gen_label.py --mode="rec" --input_path="{path/of/origin/label}" --output_label="rec_gt_label.txt"
+```
+
+The data format is as follows, (a) is the original picture, (b) is the Ground Truth text file corresponding to each picture:
+
+
+
+
+- Multilingual dataset
+
+The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded using the following two methods.
+* [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) ,Extraction code:frgi.
+* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
+
+
-#### 1.3 Dictionary
+### 1.3 Dictionary
Finally, a dictionary ({word_dict_name}.txt) needs to be provided so that when the model is trained, all the characters that appear can be mapped to the dictionary index.
@@ -138,21 +161,31 @@ The current multi-language model is still in the demo stage and will continue to
If you like, you can submit the dictionary file to [dict](../../ppocr/utils/dict) and we will thank you in the Repo.
-To customize the dict file, please modify the `character_dict_path` field in `configs/rec/rec_icdar15_train.yml` and set `character_type` to `ch`.
+To customize the dict file, please modify the `character_dict_path` field in `configs/rec/rec_icdar15_train.yml` .
- Custom dictionary
If you need to customize dic file, please add character_dict_path field in configs/rec/rec_icdar15_train.yml to point to your dictionary path. And set character_type to ch.
-#### 1.4 Add space category
+### 1.4 Add Space Category
If you want to support the recognition of the `space` category, please set the `use_space_char` field in the yml file to `True`.
-**Note: use_space_char only takes effect when character_type=ch**
-
-### 2 TRAINING
+## 2.Training
+
+
+### 2.1 Data Augmentation
+
+PaddleOCR provides a variety of data augmentation methods. All the augmentation methods are enabled by default.
+
+The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse, TIA augmentation.
+
+Each disturbance method is selected with a 40% probability during the training process. For specific code implementation, please refer to: [rec_img_aug.py](../../ppocr/data/imaug/rec_img_aug.py)
+
+
+### 2.2 General Training
PaddleOCR provides training scripts, evaluation scripts, and prediction scripts. In this section, the CRNN recognition model will be used as an example:
@@ -170,21 +203,15 @@ tar -xf rec_mv3_none_bilstm_ctc_v2.0_train.tar && rm -rf rec_mv3_none_bilstm_ctc
Start training:
```
-# GPU training Support single card and multi-card training, specify the card number through --gpus
+# GPU training Support single card and multi-card training
# Training icdar15 English data and The training log will be automatically saved as train.log under "{save_model_dir}"
+
+#specify the single card training(Long training time, not recommended)
+python3 tools/train.py -c configs/rec/rec_icdar15_train.yml
+#specify the card number through --gpus
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/rec/rec_icdar15_train.yml
```
-
-#### 2.1 Data Augmentation
-PaddleOCR provides a variety of data augmentation methods. If you want to add disturbance during training, please set `distort: true` in the configuration file.
-
-The default perturbation methods are: cvtColor, blur, jitter, Gasuss noise, random crop, perspective, color reverse.
-
-Each disturbance method is selected with a 50% probability during the training process. For specific code implementation, please refer to: [img_tools.py](https://github.com/PaddlePaddle/PaddleOCR/blob/develop/ppocr/data/rec/img_tools.py)
-
-
-#### 2.2 Training
PaddleOCR supports alternating training and evaluation. You can modify `eval_batch_step` in `configs/rec/rec_icdar15_train.yml` to set the evaluation frequency. By default, it is evaluated every 500 iter and the best acc model is saved under `output/rec_CRNN/best_accuracy` during the evaluation process.
@@ -207,6 +234,8 @@ If the evaluation set is large, the test will be time-consuming. It is recommend
| rec_mv3_tps_bilstm_att.yml | CRNN | Mobilenet_v3 | TPS | BiLSTM | att |
| rec_r34_vd_tps_bilstm_att.yml | CRNN | Resnet34_vd | TPS | BiLSTM | att |
| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn |
+| rec_mtb_nrtr.yml | NRTR | nrtr_mtb | None | transformer encoder | transformer decoder |
+| rec_r31_sar.yml | SAR | ResNet31 | None | LSTM encoder | LSTM decoder |
For training Chinese data, it is recommended to use
@@ -219,7 +248,6 @@ Global:
# Add a custom dictionary, such as modify the dictionary, please point the path to the new dictionary
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
# Modify character type
- character_type: ch
...
# Whether to recognize spaces
use_space_char: True
@@ -277,108 +305,25 @@ Eval:
**Note that the configuration file for prediction/evaluation must be consistent with the training.**
-#### 2.3 Multi-language
-
-PaddleOCR currently supports 80 (except Chinese) language recognition. A multi-language configuration file template is
-provided under the path `configs/rec/multi_languages`: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)。
-
-There are two ways to create the required configuration file::
-
-1. Automatically generated by script
-
-[generate_multi_language_configs.py](../../configs/rec/multi_language/generate_multi_language_configs.py) Can help you generate configuration files for multi-language models
-
-- Take Italian as an example, if your data is prepared in the following format:
- ```
- |-train_data
- |- it_train.txt # train_set label
- |- it_val.txt # val_set label
- |- data
- |- word_001.jpg
- |- word_002.jpg
- |- word_003.jpg
- | ...
- ```
-
- You can use the default parameters to generate a configuration file:
-
- ```bash
- # The code needs to be run in the specified directory
- cd PaddleOCR/configs/rec/multi_language/
- # Set the configuration file of the language to be generated through the -l or --language parameter.
- # This command will write the default parameters into the configuration file
- python3 generate_multi_language_configs.py -l it
- ```
-
-- If your data is placed in another location, or you want to use your own dictionary, you can generate the configuration file by specifying the relevant parameters:
-
- ```bash
- # -l or --language field is required
- # --train to modify the training set
- # --val to modify the validation set
- # --data_dir to modify the data set directory
- # --dict to modify the dict path
- # -o to modify the corresponding default parameters
- cd PaddleOCR/configs/rec/multi_language/
- python3 generate_multi_language_configs.py -l it \ # language
- --train {path/of/train_label.txt} \ # path of train_label
- --val {path/of/val_label.txt} \ # path of val_label
- --data_dir {train_data/path} \ # root directory of training data
- --dict {path/of/dict} \ # path of dict
- -o Global.use_gpu=False # whether to use gpu
- ...
-
- ```
-Italian is made up of Latin letters, so after executing the command, you will get the rec_latin_lite_train.yml.
-
-2. Manually modify the configuration file
-
- You can also manually modify the following fields in the template:
-
- ```
- Global:
- use_gpu: True
- epoch_num: 500
- ...
- character_type: it # language
- character_dict_path: {path/of/dict} # path of dict
-
- Train:
- dataset:
- name: SimpleDataSet
- data_dir: train_data/ # root directory of training data
- label_file_list: ["./train_data/train_list.txt"] # train label path
- ...
-
- Eval:
- dataset:
- name: SimpleDataSet
- data_dir: train_data/ # root directory of val data
- label_file_list: ["./train_data/val_list.txt"] # val label path
- ...
-
- ```
+### 2.3 Multi-language Training
Currently, the multi-language algorithms supported by PaddleOCR are:
-| Configuration file | Algorithm name | backbone | trans | seq | pred | language | character_type |
-| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
-| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional | chinese_cht|
-| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) | EN |
-| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French | french |
-| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German | german |
-| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese | japan |
-| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean | korean |
-| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin | latin |
-| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic | ar |
-| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic | cyrillic |
-| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari | devanagari |
+| Configuration file | Algorithm name | backbone | trans | seq | pred | language |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional |
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean |
+| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin |
+| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic |
+| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic |
+| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari |
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
-The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded using the following two methods.
-* [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
-* [Google drive](https://drive.google.com/file/d/18cSWX7wXSy4G0tbKJ0d9PuIaiwRLHpjA/view)
If you want to finetune on the basis of the existing model effect, please refer to the following instructions to modify the configuration file:
@@ -417,7 +362,8 @@ Eval:
```
-### 3 EVALUATION
+
+## 3. Evalution
The evaluation dataset can be set by modifying the `Eval.dataset.label_file_list` field in the `configs/rec/rec_icdar15_train.yml` file.
@@ -427,20 +373,39 @@ python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/rec/rec
```
-### 4 PREDICTION
+## 4. Prediction
-
-#### 4.1 Training engine prediction
Using the model trained by paddleocr, you can quickly get prediction through the following script.
-The default prediction picture is stored in `infer_img`, and the weight is specified via `-o Global.checkpoints`:
+The default prediction picture is stored in `infer_img`, and the trained weight is specified via `-o Global.checkpoints`:
+
+
+According to the `save_model_dir` and `save_epoch_step` fields set in the configuration file, the following parameters will be saved:
+
+```
+output/rec/
+├── best_accuracy.pdopt
+├── best_accuracy.pdparams
+├── best_accuracy.states
+├── config.yml
+├── iter_epoch_3.pdopt
+├── iter_epoch_3.pdparams
+├── iter_epoch_3.states
+├── latest.pdopt
+├── latest.pdparams
+├── latest.states
+└── train.log
+```
+
+Among them, best_accuracy.* is the best model on the evaluation set; iter_epoch_x.* is the model saved at intervals of `save_epoch_step`; latest.* is the model of the last epoch.
```
# Predict English results
python3 tools/infer_rec.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.load_static_weights=false Global.infer_img=doc/imgs_words/en/word_1.jpg
```
+
Input image:

@@ -469,3 +434,37 @@ Get the prediction result of the input image:
infer_img: doc/imgs_words/ch/word_1.jpg
result: ('韩国小馆', 0.997218)
```
+
+
+
+## 5. Convert to Inference Model
+
+The recognition model is converted to the inference model in the same way as the detection, as follows:
+
+```
+# -c Set the training algorithm yml configuration file
+# -o Set optional parameters
+# Global.pretrained_model parameter Set the training model address to be converted without adding the file suffix .pdmodel, .pdopt or .pdparams.
+# Global.save_inference_dir Set the address where the converted model will be saved.
+
+python3 tools/export_model.py -c configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_rec_train/best_accuracy Global.save_inference_dir=./inference/rec_crnn/
+```
+
+If you have a model trained on your own dataset with a different dictionary file, please make sure that you modify the `character_dict_path` in the configuration file to your dictionary file path.
+
+After the conversion is successful, there are three files in the model save directory:
+
+```
+inference/det_db/
+ ├── inference.pdiparams # The parameter file of recognition inference model
+ ├── inference.pdiparams.info # The parameter information of recognition inference model, which can be ignored
+ └── inference.pdmodel # The program file of recognition model
+```
+
+- Text recognition model Inference using custom characters dictionary
+
+ If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
+
+ ```
+ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
+ ```
diff --git a/doc/doc_en/training_en.md b/doc/doc_en/training_en.md
new file mode 100644
index 0000000000000000000000000000000000000000..aa5500ac88fef97829b4f19c5421e36f18ae1812
--- /dev/null
+++ b/doc/doc_en/training_en.md
@@ -0,0 +1,155 @@
+# Model Training
+
+- [1.Yml Configuration ](#1-Yml-Configuration)
+- [2. Basic Concepts](#1-basic-concepts)
+ * [2.1 Learning Rate](#11-learning-rate)
+ * [2.2 Regularization](#12-regularization)
+ * [2.3 Evaluation Indicators](#13-evaluation-indicators-)
+- [3. Data and Vertical Scenes](#2-data-and-vertical-scenes)
+ * [3.1 Training Data](#21-training-data)
+ * [3.2 Vertical Scene](#22-vertical-scene)
+ * [3.3 Build Your Own Dataset](#23-build-your-own-data-set)
+* [4. FAQ](#3-faq)
+
+
+This article will introduce the basic concepts that need to be mastered during model training and the tuning methods during training.
+
+At the same time, it will briefly introduce the components of the PaddleOCR model training data and how to prepare the data finetune model in the vertical scene.
+
+
+
+## 1. Yml Configuration
+
+The PaddleOCR model uses configuration files to manage network training and evaluation parameters. In the configuration file, you can set the model, optimizer, loss function, and pre- and post-processing parameters of the model. PaddleOCR reads these parameters from the configuration file, and then builds a complete training process to complete the model training. When optimized, the configuration can be completed by modifying the parameters in the configuration file, which is simple to use and convenient to modify.
+
+For the complete configuration file description, please refer to [Configuration File](./config_en.md)
+
+
+# 1. Basic concepts
+
+## 2. Basic Concepts
+
+The following parameters need to be paid attention to when tuning the model:
+
+
+### 2.1 Learning Rate
+
+The learning rate is one of the important hyperparameters for training neural networks. It represents the step length of the gradient moving to the optimal solution of the loss function in each iteration.
+A variety of learning rate update strategies are provided in PaddleOCR, which can be modified through configuration files, for example:
+
+```
+Optimizer:
+ ...
+ lr:
+ name: Piecewise
+ decay_epochs : [700, 800]
+ values : [0.001, 0.0001]
+ warmup_epoch: 5
+```
+
+Piecewise stands for piecewise constant attenuation. Different learning rates are specified in different learning stages,
+and the learning rate is the same in each stage.
+
+warmup_epoch means that in the first 5 epochs, the learning rate will gradually increase from 0 to base_lr. For all strategies, please refer to the code [learning_rate.py](../../ppocr/optimizer/learning_rate.py).
+
+
+## 1.2 Regularization
+
+Regularization can effectively avoid algorithm overfitting. PaddleOCR provides L1 and L2 regularization methods.
+L1 and L2 regularization are the most commonly used regularization methods.
+L1 regularization adds a regularization term to the objective function to reduce the sum of absolute values of the parameters;
+while in L2 regularization, the purpose of adding a regularization term is to reduce the sum of squared parameters.
+The configuration method is as follows:
+
+```
+Optimizer:
+ ...
+ regularizer:
+ name: L2
+ factor: 2.0e-05
+```
+
+### 2.3 Evaluation Indicators
+
+(1) Detection stage: First, evaluate according to the IOU of the detection frame and the labeled frame. If the IOU is greater than a certain threshold, it is judged that the detection is accurate. Here, the detection frame and the label frame are different from the general general target detection frame, and they are represented by polygons. Detection accuracy: the percentage of the correct detection frame number in all detection frames is mainly used to judge the detection index. Detection recall rate: the percentage of correct detection frames in all marked frames, which is mainly an indicator of missed detection.
+
+(2) Recognition stage: Character recognition accuracy, that is, the ratio of correctly recognized text lines to the number of marked text lines. Only the entire line of text recognition pairs can be regarded as correct recognition.
+
+(3) End-to-end statistics: End-to-end recall rate: accurately detect and correctly identify the proportion of text lines in all labeled text lines; End-to-end accuracy rate: accurately detect and correctly identify the number of text lines in the detected text lines The standard for accurate detection is that the IOU of the detection box and the labeled box is greater than a certain threshold, and the text in the correctly identified detection box is the same as the labeled text.
+
+
+
+## 3. Data and Vertical Scenes
+
+
+
+### 3.1 Training Data
+
+The current open source models, data sets and magnitudes are as follows:
+
+- Detection:
+ - English data set, ICDAR2015
+ - Chinese data set, LSVT street view data set training data 3w pictures
+
+- Identification:
+ - English data set, MJSynth and SynthText synthetic data, the data volume is tens of millions.
+ - Chinese data set, LSVT street view data set crops the image according to the truth value, and performs position calibration, a total of 30w images. In addition, based on the LSVT corpus, 500w of synthesized data.
+ - Small language data set, using different corpora and fonts, respectively generated 100w synthetic data set, and using ICDAR-MLT as the verification set.
+
+Among them, the public data sets are all open source, users can search and download by themselves, or refer to [Chinese data set](./datasets.md), synthetic data is not open source, users can use open source synthesis tools to synthesize by themselves. Synthesis tools include [text_renderer](https://github.com/Sanster/text_renderer), [SynthText](https://github.com/ankush-me/SynthText), [TextRecognitionDataGenerator](https://github.com/Belval/TextRecognitionDataGenerator) etc.
+
+
+
+### 3.2 Vertical Scene
+
+PaddleOCR mainly focuses on general OCR. If you have vertical requirements, you can use PaddleOCR + vertical data to train yourself;
+If there is a lack of labeled data, or if you do not want to invest in research and development costs, it is recommended to directly call the open API, which covers some of the more common vertical categories.
+
+
+
+### 3.3 Build Your Own Dataset
+
+There are several experiences for reference when constructing the data set:
+
+(1) The amount of data in the training set:
+
+ a. The data required for detection is relatively small. For Fine-tune based on the PaddleOCR model, 500 sheets are generally required to achieve good results.
+ b. Recognition is divided into English and Chinese. Generally, English scenarios require hundreds of thousands of data to achieve good results, while Chinese requires several million or more.
+
+
+(2) When the amount of training data is small, you can try the following three ways to get more data:
+
+ a. Manually collect more training data, the most direct and effective way.
+ b. Basic image processing or transformation based on PIL and opencv. For example, the three modules of ImageFont, Image, ImageDraw in PIL write text into the background, opencv's rotating affine transformation, Gaussian filtering and so on.
+ c. Use data generation algorithms to synthesize data, such as algorithms such as pix2pix.
+
+
+
+# 3. FAQ
+
+**Q**: How to choose a suitable network input shape when training CRNN recognition?
+
+ A: The general height is 32, the longest width is selected, there are two methods:
+
+ (1) Calculate the aspect ratio distribution of training sample images. The selection of the maximum aspect ratio considers 80% of the training samples.
+
+ (2) Count the number of texts in training samples. The selection of the longest number of characters considers the training sample that satisfies 80%. Then the aspect ratio of Chinese characters is approximately considered to be 1, and that of English is 3:1, and the longest width is estimated.
+
+**Q**: During the recognition training, the accuracy of the training set has reached 90, but the accuracy of the verification set has been kept at 70, what should I do?
+
+ A: If the accuracy of the training set is 90 and the test set is more than 70, it should be over-fitting. There are two methods to try:
+
+ (1) Add more augmentation methods or increase the [probability] of augmented prob (https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/ppocr/data/imaug/rec_img_aug.py#L341), The default is 0.4.
+
+ (2) Increase the [l2 dcay value] of the system (https://github.com/PaddlePaddle/PaddleOCR/blob/a501603d54ff5513fc4fc760319472e59da25424/configs/rec/ch_ppocr_v1.1/rec_chinese_lite_train_v1.1.yml#L47)
+
+**Q**: When the recognition model is trained, loss can drop normally, but acc is always 0
+
+ A: It is normal for the acc to be 0 at the beginning of the recognition model training, and the indicator will come up after a longer training period.
+
+
+***
+Click the following links for detailed training tutorial:
+- [text detection model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/detection.md)
+- [text recognition model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/recognition.md)
+- [text direction classification model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/angle_class.md)
diff --git a/doc/doc_en/update_en.md b/doc/doc_en/update_en.md
index ca2ecb0535ce27bc7f98a476752a131f869761d5..660688c6d6991a4744dbc327d24e9c677afa0fc1 100644
--- a/doc/doc_en/update_en.md
+++ b/doc/doc_en/update_en.md
@@ -1,4 +1,9 @@
# RECENT UPDATES
+- 2021.9.7 release PaddleOCR v2.3, [PP-OCRv2](#PP-OCRv2) is proposed. The CPU inference speed of PP-OCRv2 is 220% higher than that of PP-OCR server. The F-score of PP-OCRv2 is 7% higher than that of PP-OCR mobile.
+- 2021.8.3 released PaddleOCR v2.2, add a new structured documents analysis toolkit, i.e., [PP-Structure](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.2/ppstructure/README.md), support layout analysis and table recognition (One-key to export chart images to Excel files).
+- 2021.4.8 release end-to-end text recognition algorithm [PGNet](https://www.aaai.org/AAAI21Papers/AAAI-2885.WangP.pdf) which is published in AAAI 2021. Find tutorial [here](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/pgnet_en.md);release multi language recognition [models](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md), support more than 80 languages recognition; especically, the performance of [English recognition model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/models_list_en.md#English) is Optimized.
+
+- 2021.1.21 update more than 25+ multilingual recognition models [models list](./doc/doc_en/models_list_en.md), including:English, Chinese, German, French, Japanese,Spanish,Portuguese Russia Arabic and so on. Models for more languages will continue to be updated [Develop Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048).
- 2020.12.15 update Data synthesis tool, i.e., [Style-Text](../../StyleText/README.md),easy to synthesize a large number of images which are similar to the target scene image.
- 2020.11.25 Update a new data annotation tool, i.e., [PPOCRLabel](../../PPOCRLabel/README.md), which is helpful to improve the labeling efficiency. Moreover, the labeling results can be used in training of the PP-OCR system directly.
- 2020.9.22 Update the PP-OCR technical article, https://arxiv.org/abs/2009.09941
diff --git a/doc/doc_en/visualization_en.md b/doc/doc_en/visualization_en.md
index f9c455e5b3510a9f262c6bf59b8adfbaef3fa01d..71cfb043462f34f2b3bef594364d33f15e98d81e 100644
--- a/doc/doc_en/visualization_en.md
+++ b/doc/doc_en/visualization_en.md
@@ -1,5 +1,10 @@
# Visualization
+
+## PP-OCRv2
+
+
+
## ch_ppocr_server_2.0
diff --git a/doc/doc_en/whl_en.md b/doc/doc_en/whl_en.md
index c8c8353accdf7f6ce179d3700547bfe9bd70c200..c2577e1e151e4675abab5139da099db9ad20fb4b 100644
--- a/doc/doc_en/whl_en.md
+++ b/doc/doc_en/whl_en.md
@@ -1,4 +1,4 @@
-# paddleocr package
+# Paddleocr Package
## 1 Get started quickly
### 1.1 install package
diff --git a/doc/imgs_results/PP-OCRv2/PP-OCRv2-pic001.jpg b/doc/imgs_results/PP-OCRv2/PP-OCRv2-pic001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..45ffdb53aa431c8d25cc7219b2c0523690182ab6
Binary files /dev/null and b/doc/imgs_results/PP-OCRv2/PP-OCRv2-pic001.jpg differ
diff --git a/doc/imgs_results/PP-OCRv2/PP-OCRv2-pic002.jpg b/doc/imgs_results/PP-OCRv2/PP-OCRv2-pic002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..7ac153aee0d703580971539b5cff95587c0e830e
Binary files /dev/null and b/doc/imgs_results/PP-OCRv2/PP-OCRv2-pic002.jpg differ
diff --git a/doc/imgs_results/PP-OCRv2/PP-OCRv2-pic003.jpg b/doc/imgs_results/PP-OCRv2/PP-OCRv2-pic003.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..781aade629651b5adf24fcc76b84a9674154b8b8
Binary files /dev/null and b/doc/imgs_results/PP-OCRv2/PP-OCRv2-pic003.jpg differ
diff --git a/doc/install/linux/anaconda_download.png b/doc/install/linux/anaconda_download.png
new file mode 100755
index 0000000000000000000000000000000000000000..6ab6db30899d8431874e52bbe97af242e638ed6c
Binary files /dev/null and b/doc/install/linux/anaconda_download.png differ
diff --git a/doc/install/linux/conda_create.png b/doc/install/linux/conda_create.png
new file mode 100755
index 0000000000000000000000000000000000000000..533f592b7c1db78699d9166278e91332d3d8f258
Binary files /dev/null and b/doc/install/linux/conda_create.png differ
diff --git a/doc/install/mac/anaconda_start.png b/doc/install/mac/anaconda_start.png
new file mode 100755
index 0000000000000000000000000000000000000000..a860f5e56a76558a764d3d92055743832f4d5acb
Binary files /dev/null and b/doc/install/mac/anaconda_start.png differ
diff --git a/doc/install/mac/conda_activate.png b/doc/install/mac/conda_activate.png
new file mode 100755
index 0000000000000000000000000000000000000000..a2e6074e912988218b62068476b9d5d22deb0d71
Binary files /dev/null and b/doc/install/mac/conda_activate.png differ
diff --git a/doc/install/mac/conda_create.png b/doc/install/mac/conda_create.png
new file mode 100755
index 0000000000000000000000000000000000000000..9ff10c241be39216ea8255826ea50844368f27e8
Binary files /dev/null and b/doc/install/mac/conda_create.png differ
diff --git a/doc/install/windows/Anaconda_download.png b/doc/install/windows/Anaconda_download.png
new file mode 100644
index 0000000000000000000000000000000000000000..83a03414934a12f7071389ef664b6fd5e7df956f
Binary files /dev/null and b/doc/install/windows/Anaconda_download.png differ
diff --git a/doc/install/windows/anaconda_install_env.png b/doc/install/windows/anaconda_install_env.png
new file mode 100644
index 0000000000000000000000000000000000000000..7a22542712a3fa5d471f13d940806d483225c38f
Binary files /dev/null and b/doc/install/windows/anaconda_install_env.png differ
diff --git a/doc/install/windows/anaconda_install_folder.png b/doc/install/windows/anaconda_install_folder.png
new file mode 100644
index 0000000000000000000000000000000000000000..e9fac29eaa92fc445d324a565e95c064a984f9bf
Binary files /dev/null and b/doc/install/windows/anaconda_install_folder.png differ
diff --git a/doc/install/windows/anaconda_prompt.png b/doc/install/windows/anaconda_prompt.png
new file mode 100755
index 0000000000000000000000000000000000000000..1087610ae01f5c6181434e3dcc11189b138d419c
Binary files /dev/null and b/doc/install/windows/anaconda_prompt.png differ
diff --git a/doc/install/windows/conda_list_env.png b/doc/install/windows/conda_list_env.png
new file mode 100644
index 0000000000000000000000000000000000000000..5ffa0037c5e62b75c7b452a4012b7015b03c3f5f
Binary files /dev/null and b/doc/install/windows/conda_list_env.png differ
diff --git a/doc/install/windows/conda_new_env.png b/doc/install/windows/conda_new_env.png
new file mode 100644
index 0000000000000000000000000000000000000000..eed667ec3d4a4419cdfdd842fe57a1efca734c94
Binary files /dev/null and b/doc/install/windows/conda_new_env.png differ
diff --git a/doc/joinus.PNG b/doc/joinus.PNG
index 7a10f7aac3748062184085b68583c637d3963117..974a4bd008d7b103de044cf8b4dbf37f09a0d06b 100644
Binary files a/doc/joinus.PNG and b/doc/joinus.PNG differ
diff --git a/doc/overview.png b/doc/overview.png
new file mode 100644
index 0000000000000000000000000000000000000000..c5c4e09d6730bb0b1ca2c0b5442079ceb41ecdfa
Binary files /dev/null and b/doc/overview.png differ
diff --git a/doc/overview_en.png b/doc/overview_en.png
new file mode 100644
index 0000000000000000000000000000000000000000..b44da4e9874d6a2162a8bb05ff1b479875bd65f3
Binary files /dev/null and b/doc/overview_en.png differ
diff --git a/doc/ppocrv2_framework.jpg b/doc/ppocrv2_framework.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e5f1a2ef47601c3a9eaef43a6046a15ea0319e2b
Binary files /dev/null and b/doc/ppocrv2_framework.jpg differ
diff --git a/doc/table/1.png b/doc/table/1.png
index 47df618ab1bef431a5dd94418c01be16b09d31aa..faff6e3178662407961fe074a9202015f755e2f8 100644
Binary files a/doc/table/1.png and b/doc/table/1.png differ
diff --git a/doc/table/table.jpg b/doc/table/table.jpg
index 3daa619e52dc2471df62ea7767be3bff350b623f..95fdf84d92908d4b21f49fb516601334867163b1 100644
Binary files a/doc/table/table.jpg and b/doc/table/table.jpg differ
diff --git a/paddleocr.py b/paddleocr.py
index c52737f55b61cd29c08367adb6d7e05c561e933e..a98efd34088701d5eb5602743cf75b7d5e80157f 100644
--- a/paddleocr.py
+++ b/paddleocr.py
@@ -33,104 +33,141 @@ from tools.infer.utility import draw_ocr, str2bool
from ppstructure.utility import init_args, draw_structure_result
from ppstructure.predict_system import OCRSystem, save_structure_res
-__all__ = ['PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result', 'save_structure_res','download_with_progressbar']
-
-model_urls = {
- 'det': {
- 'ch':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
- 'en':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
- 'structure': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
+__all__ = [
+ 'PaddleOCR', 'PPStructure', 'draw_ocr', 'draw_structure_result',
+ 'save_structure_res', 'download_with_progressbar'
+]
+
+SUPPORT_DET_MODEL = ['DB']
+VERSION = '2.2.1'
+SUPPORT_REC_MODEL = ['CRNN']
+BASE_DIR = os.path.expanduser("~/.paddleocr/")
+
+DEFAULT_MODEL_VERSION = '2.0'
+MODEL_URLS = {
+ '2.1': {
+ 'det': {
+ 'ch': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar',
+ },
+ },
+ 'rec': {
+ 'ch': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar',
+ 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
+ }
+ }
},
- 'rec': {
- 'ch': {
- 'url':
- 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
+ '2.0': {
+ 'det': {
+ 'ch': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
+ },
+ 'en': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
+ },
+ 'structure': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
+ }
},
- 'en': {
- 'url':
+ 'rec': {
+ 'ch': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
+ 'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
+ },
+ 'en': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/en_dict.txt'
- },
- 'french': {
- 'url':
+ 'dict_path': './ppocr/utils/en_dict.txt'
+ },
+ 'french': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/french_dict.txt'
- },
- 'german': {
- 'url':
+ 'dict_path': './ppocr/utils/dict/french_dict.txt'
+ },
+ 'german': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/german_dict.txt'
- },
- 'korean': {
- 'url':
+ 'dict_path': './ppocr/utils/dict/german_dict.txt'
+ },
+ 'korean': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/korean_dict.txt'
- },
- 'japan': {
- 'url':
+ 'dict_path': './ppocr/utils/dict/korean_dict.txt'
+ },
+ 'japan': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/japan_dict.txt'
- },
- 'chinese_cht': {
- 'url':
+ 'dict_path': './ppocr/utils/dict/japan_dict.txt'
+ },
+ 'chinese_cht': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
- },
- 'ta': {
- 'url':
+ 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
+ },
+ 'ta': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/ta_dict.txt'
- },
- 'te': {
- 'url':
+ 'dict_path': './ppocr/utils/dict/ta_dict.txt'
+ },
+ 'te': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/te_dict.txt'
- },
- 'ka': {
- 'url':
+ 'dict_path': './ppocr/utils/dict/te_dict.txt'
+ },
+ 'ka': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/ka_dict.txt'
- },
- 'latin': {
- 'url':
+ 'dict_path': './ppocr/utils/dict/ka_dict.txt'
+ },
+ 'latin': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/latin_dict.txt'
- },
- 'arabic': {
- 'url':
+ 'dict_path': './ppocr/utils/dict/latin_dict.txt'
+ },
+ 'arabic': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/arabic_dict.txt'
- },
- 'cyrillic': {
- 'url':
+ 'dict_path': './ppocr/utils/dict/arabic_dict.txt'
+ },
+ 'cyrillic': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
- },
- 'devanagari': {
- 'url':
+ 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
+ },
+ 'devanagari': {
+ 'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
- 'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
+ 'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
+ },
+ 'structure': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
+ 'dict_path': 'ppocr/utils/dict/table_dict.txt'
+ }
+ },
+ 'cls': {
+ 'ch': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
+ }
},
- 'structure': {
- 'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
- 'dict_path': 'ppocr/utils/dict/table_dict.txt'
+ 'table': {
+ 'en': {
+ 'url':
+ 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
+ 'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
+ }
}
- },
- 'cls': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
- 'table': {
- 'url': 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
- 'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
}
}
-SUPPORT_DET_MODEL = ['DB']
-VERSION = '2.2'
-SUPPORT_REC_MODEL = ['CRNN']
-BASE_DIR = os.path.expanduser("~/.paddleocr/")
-
def parse_args(mMain=True):
import argparse
@@ -140,6 +177,7 @@ def parse_args(mMain=True):
parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--type", type=str, default='ocr')
+ parser.add_argument("--version", type=str, default='2.1')
for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
@@ -155,19 +193,19 @@ def parse_args(mMain=True):
def parse_lang(lang):
latin_lang = [
- 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga',
- 'hr', 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms',
- 'mt', 'nl', 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk',
- 'sl', 'sq', 'sv', 'sw', 'tl', 'tr', 'uz', 'vi'
+ 'af', 'az', 'bs', 'cs', 'cy', 'da', 'de', 'es', 'et', 'fr', 'ga', 'hr',
+ 'hu', 'id', 'is', 'it', 'ku', 'la', 'lt', 'lv', 'mi', 'ms', 'mt', 'nl',
+ 'no', 'oc', 'pi', 'pl', 'pt', 'ro', 'rs_latin', 'sk', 'sl', 'sq', 'sv',
+ 'sw', 'tl', 'tr', 'uz', 'vi'
]
arabic_lang = ['ar', 'fa', 'ug', 'ur']
cyrillic_lang = [
- 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd',
- 'ava', 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
+ 'ru', 'rs_cyrillic', 'be', 'bg', 'uk', 'mn', 'abq', 'ady', 'kbd', 'ava',
+ 'dar', 'inh', 'che', 'lbe', 'lez', 'tab'
]
devanagari_lang = [
- 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new',
- 'gom', 'sa', 'bgc'
+ 'hi', 'mr', 'ne', 'bh', 'mai', 'ang', 'bho', 'mah', 'sck', 'new', 'gom',
+ 'sa', 'bgc'
]
if lang in latin_lang:
lang = "latin"
@@ -177,9 +215,9 @@ def parse_lang(lang):
lang = "cyrillic"
elif lang in devanagari_lang:
lang = "devanagari"
- assert lang in model_urls[
+ assert lang in MODEL_URLS[DEFAULT_MODEL_VERSION][
'rec'], 'param lang must in {}, but got {}'.format(
- model_urls['rec'].keys(), lang)
+ MODEL_URLS[DEFAULT_MODEL_VERSION]['rec'].keys(), lang)
if lang == "ch":
det_lang = "ch"
elif lang == 'structure':
@@ -189,6 +227,35 @@ def parse_lang(lang):
return lang, det_lang
+def get_model_config(version, model_type, lang):
+ if version not in MODEL_URLS:
+ logger.warning('version {} not in {}, use version {} instead'.format(
+ version, MODEL_URLS.keys(), DEFAULT_MODEL_VERSION))
+ version = DEFAULT_MODEL_VERSION
+ if model_type not in MODEL_URLS[version]:
+ if model_type in MODEL_URLS[DEFAULT_MODEL_VERSION]:
+ logger.warning(
+ 'version {} not support {} models, use version {} instead'.
+ format(version, model_type, DEFAULT_MODEL_VERSION))
+ version = DEFAULT_MODEL_VERSION
+ else:
+ logger.error('{} models is not support, we only support {}'.format(
+ model_type, MODEL_URLS[DEFAULT_MODEL_VERSION].keys()))
+ sys.exit(-1)
+ if lang not in MODEL_URLS[version][model_type]:
+ if lang in MODEL_URLS[DEFAULT_MODEL_VERSION][model_type]:
+ logger.warning('lang {} is not support in {}, use {} instead'.
+ format(lang, version, DEFAULT_MODEL_VERSION))
+ version = DEFAULT_MODEL_VERSION
+ else:
+ logger.error(
+ 'lang {} is not support, we only support {} for {} models'.
+ format(lang, MODEL_URLS[DEFAULT_MODEL_VERSION][model_type].keys(
+ ), model_type))
+ sys.exit(-1)
+ return MODEL_URLS[version][model_type][lang]
+
+
class PaddleOCR(predict_system.TextSystem):
def __init__(self, **kwargs):
"""
@@ -204,15 +271,21 @@ class PaddleOCR(predict_system.TextSystem):
lang, det_lang = parse_lang(params.lang)
# init model dir
- params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
- os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
- model_urls['det'][det_lang])
- params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
- os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
- model_urls['rec'][lang]['url'])
- params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
- os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'),
- model_urls['cls'])
+ det_model_config = get_model_config(params.version, 'det', det_lang)
+ params.det_model_dir, det_url = confirm_model_dir_url(
+ params.det_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
+ det_model_config['url'])
+ rec_model_config = get_model_config(params.version, 'rec', lang)
+ params.rec_model_dir, rec_url = confirm_model_dir_url(
+ params.rec_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
+ rec_model_config['url'])
+ cls_model_config = get_model_config(params.version, 'cls', 'ch')
+ params.cls_model_dir, cls_url = confirm_model_dir_url(
+ params.cls_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'),
+ cls_model_config['url'])
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
@@ -226,7 +299,8 @@ class PaddleOCR(predict_system.TextSystem):
sys.exit(0)
if params.rec_char_dict_path is None:
- params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
+ params.rec_char_dict_path = str(
+ Path(__file__).parent / rec_model_config['dict_path'])
print(params)
# init det_model and rec_model
@@ -293,24 +367,32 @@ class PPStructure(OCRSystem):
lang, det_lang = parse_lang(params.lang)
# init model dir
- params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
- os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
- model_urls['det'][det_lang])
- params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
- os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
- model_urls['rec'][lang]['url'])
- params.table_model_dir, table_url = confirm_model_dir_url(params.table_model_dir,
- os.path.join(BASE_DIR, VERSION, 'ocr', 'table'),
- model_urls['table']['url'])
+ det_model_config = get_model_config(params.version, 'det', det_lang)
+ params.det_model_dir, det_url = confirm_model_dir_url(
+ params.det_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
+ det_model_config['url'])
+ rec_model_config = get_model_config(params.version, 'rec', lang)
+ params.rec_model_dir, rec_url = confirm_model_dir_url(
+ params.rec_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
+ rec_model_config['url'])
+ table_model_config = get_model_config(params.version, 'table', 'en')
+ params.table_model_dir, table_url = confirm_model_dir_url(
+ params.table_model_dir,
+ os.path.join(BASE_DIR, VERSION, 'ocr', 'table'),
+ table_model_config['url'])
# download model
maybe_download(params.det_model_dir, det_url)
maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.table_model_dir, table_url)
if params.rec_char_dict_path is None:
- params.rec_char_dict_path = str(Path(__file__).parent / model_urls['rec'][lang]['dict_path'])
+ params.rec_char_dict_path = str(
+ Path(__file__).parent / rec_model_config['dict_path'])
if params.table_char_dict_path is None:
- params.table_char_dict_path = str(Path(__file__).parent / model_urls['table']['dict_path'])
+ params.table_char_dict_path = str(
+ Path(__file__).parent / table_model_config['dict_path'])
print(params)
super().__init__(params)
@@ -374,4 +456,3 @@ def main():
for item in result:
item.pop('img')
logger.info(item)
-
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index e860c5a6986f495e6384d9df93c24795c04a0d5f..0bb3d506483a331fba48feafeff9ca2d439f3782 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -49,14 +49,12 @@ def term_mp(sig_num, frame):
os.killpg(pgid, signal.SIGKILL)
-signal.signal(signal.SIGINT, term_mp)
-signal.signal(signal.SIGTERM, term_mp)
-
-
def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
- support_dict = ['SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet']
+ support_dict = [
+ 'SimpleDataSet', 'LMDBDataSet', 'PGDataSet', 'PubTabDataSet'
+ ]
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
@@ -96,4 +94,8 @@ def build_dataloader(config, mode, device, logger, seed=None):
return_list=True,
use_shared_memory=use_shared_memory)
+ # support exit using ctrl+c
+ signal.signal(signal.SIGINT, term_mp)
+ signal.signal(signal.SIGTERM, term_mp)
+
return data_loader
diff --git a/ppocr/data/imaug/ColorJitter.py b/ppocr/data/imaug/ColorJitter.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b542abc8f9dc5af76529f9feb4bcb8b47b5f7d0
--- /dev/null
+++ b/ppocr/data/imaug/ColorJitter.py
@@ -0,0 +1,26 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from paddle.vision.transforms import ColorJitter as pp_ColorJitter
+
+__all__ = ['ColorJitter']
+
+class ColorJitter(object):
+ def __init__(self, brightness=0, contrast=0, saturation=0, hue=0,**kwargs):
+ self.aug = pp_ColorJitter(brightness, contrast, saturation, hue)
+
+ def __call__(self, data):
+ image = data['image']
+ image = self.aug(image)
+ data['image'] = image
+ return data
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 52194eb964f7a7fd159cc1a42b73d280f8ee5fb4..5aaa1cd71eb791efa94e6bd812f3ab76632c96c6 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -19,11 +19,13 @@ from __future__ import unicode_literals
from .iaa_augment import IaaAugment
from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
-from .random_crop_data import EastRandomCropData, PSERandomCrop
+from .random_crop_data import EastRandomCropData, RandomCropImgMask
+from .make_pse_gt import MakePseGt
-from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
+from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg, NRTRRecResizeImg, SARRecResizeImg
from .randaugment import RandAugment
from .copy_paste import CopyPaste
+from .ColorJitter import ColorJitter
from .operators import *
from .label_ops import *
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index d222c4109c3723bc1adb71ee7c21a27a010f8f45..0a4fad621a9038e71a9d43eb4e12f78e7e92d73d 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -21,6 +21,8 @@ import numpy as np
import string
import json
+from ppocr.utils.logging import get_logger
+
class ClsLabelEncode(object):
def __init__(self, label_list, **kwargs):
@@ -92,31 +94,23 @@ class BaseRecLabelEncode(object):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False):
- support_character_type = [
- 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
- 'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
- 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
- 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
- ]
- assert character_type in support_character_type, "Only {} are supported now but get {}".format(
- support_character_type, character_type)
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
- if character_type == "en":
+ self.lower = False
+
+ if character_dict_path is None:
+ logger = get_logger()
+ logger.warning(
+ "The character_dict_path is None, model can only recognize number and lower letters"
+ )
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
- elif character_type == "EN_symbol":
- # same with ASTER setting (use 94 char).
- self.character_str = string.printable[:-6]
- dict_character = list(self.character_str)
- elif character_type in support_character_type:
+ self.lower = True
+ else:
self.character_str = ""
- assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
- character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
@@ -125,7 +119,6 @@ class BaseRecLabelEncode(object):
if use_space_char:
self.character_str += " "
dict_character = list(self.character_str)
- self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
@@ -147,7 +140,7 @@ class BaseRecLabelEncode(object):
"""
if len(text) == 0 or len(text) > self.max_text_len:
return None
- if self.character_type == "en":
+ if self.lower:
text = text.lower()
text_list = []
for char in text:
@@ -161,18 +154,47 @@ class BaseRecLabelEncode(object):
return text_list
+class NRTRLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+
+ super(NRTRLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len - 1:
+ return None
+ data['length'] = np.array(len(text))
+ text.insert(0, 2)
+ text.append(3)
+ text = text + [0] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank', '', '', ''] + dict_character
+ return dict_character
+
+
class CTCLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
**kwargs):
- super(CTCLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(CTCLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
@@ -182,6 +204,11 @@ class CTCLabelEncode(BaseRecLabelEncode):
data['length'] = np.array(len(text))
text = text + [0] * (self.max_text_len - len(text))
data['label'] = np.array(text)
+
+ label = [0] * len(self.character)
+ for x in text:
+ label[x] += 1
+ data['label_ace'] = np.array(label)
return data
def add_special_char(self, dict_character):
@@ -193,12 +220,10 @@ class E2ELabelEncodeTest(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='EN',
use_space_char=False,
**kwargs):
- super(E2ELabelEncodeTest,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(E2ELabelEncodeTest, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
import json
@@ -267,12 +292,10 @@ class AttnLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
**kwargs):
- super(AttnLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(AttnLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -309,18 +332,46 @@ class AttnLabelEncode(BaseRecLabelEncode):
return idx
+class SEEDLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(SEEDLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def add_special_char(self, dict_character):
+ self.end_str = "eos"
+ dict_character = dict_character + [self.end_str]
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len:
+ return None
+ data['length'] = np.array(len(text)) + 1 # conclude eos
+ text = text + [len(self.character) - 1] * (self.max_text_len - len(text)
+ )
+ data['label'] = np.array(text)
+ return data
+
+
class SRNLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length=25,
character_dict_path=None,
- character_type='en',
use_space_char=False,
**kwargs):
- super(SRNLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(SRNLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
@@ -388,7 +439,6 @@ class TableLabelEncode(object):
substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
-
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\r\n")
list_character.append(character)
@@ -521,3 +571,47 @@ class TableLabelEncode(object):
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
+
+
+class SARLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ use_space_char=False,
+ **kwargs):
+ super(SARLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
+
+ def add_special_char(self, dict_character):
+ beg_end_str = ""
+ unknown_str = ""
+ padding_str = ""
+ dict_character = dict_character + [unknown_str]
+ self.unknown_idx = len(dict_character) - 1
+ dict_character = dict_character + [beg_end_str]
+ self.start_idx = len(dict_character) - 1
+ self.end_idx = len(dict_character) - 1
+ dict_character = dict_character + [padding_str]
+ self.padding_idx = len(dict_character) - 1
+
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ if text is None:
+ return None
+ if len(text) >= self.max_text_len - 1:
+ return None
+ data['length'] = np.array(len(text))
+ target = [self.start_idx] + text + [self.end_idx]
+ padded_text = [self.padding_idx for _ in range(self.max_text_len)]
+
+ padded_text[:len(target)] = target
+ data['label'] = np.array(padded_text)
+ return data
+
+ def get_ignored_tokens(self):
+ return [self.padding_idx]
diff --git a/ppocr/data/imaug/make_pse_gt.py b/ppocr/data/imaug/make_pse_gt.py
new file mode 100644
index 0000000000000000000000000000000000000000..55abc8970784fd00843d2e91f259c58b65ae8579
--- /dev/null
+++ b/ppocr/data/imaug/make_pse_gt.py
@@ -0,0 +1,85 @@
+# -*- coding:utf-8 -*-
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import cv2
+import numpy as np
+import pyclipper
+from shapely.geometry import Polygon
+
+__all__ = ['MakePseGt']
+
+class MakePseGt(object):
+ r'''
+ Making binary mask from detection data with ICDAR format.
+ Typically following the process of class `MakeICDARData`.
+ '''
+
+ def __init__(self, kernel_num=7, size=640, min_shrink_ratio=0.4, **kwargs):
+ self.kernel_num = kernel_num
+ self.min_shrink_ratio = min_shrink_ratio
+ self.size = size
+
+ def __call__(self, data):
+
+ image = data['image']
+ text_polys = data['polys']
+ ignore_tags = data['ignore_tags']
+
+ h, w, _ = image.shape
+ short_edge = min(h, w)
+ if short_edge < self.size:
+ # keep short_size >= self.size
+ scale = self.size / short_edge
+ image = cv2.resize(image, dsize=None, fx=scale, fy=scale)
+ text_polys *= scale
+
+ gt_kernels = []
+ for i in range(1,self.kernel_num+1):
+ # s1->sn, from big to small
+ rate = 1.0 - (1.0 - self.min_shrink_ratio) / (self.kernel_num - 1) * i
+ text_kernel, ignore_tags = self.generate_kernel(image.shape[0:2], rate, text_polys, ignore_tags)
+ gt_kernels.append(text_kernel)
+
+ training_mask = np.ones(image.shape[0:2], dtype='uint8')
+ for i in range(text_polys.shape[0]):
+ if ignore_tags[i]:
+ cv2.fillPoly(training_mask, text_polys[i].astype(np.int32)[np.newaxis, :, :], 0)
+
+ gt_kernels = np.array(gt_kernels)
+ gt_kernels[gt_kernels > 0] = 1
+
+ data['image'] = image
+ data['polys'] = text_polys
+ data['gt_kernels'] = gt_kernels[0:]
+ data['gt_text'] = gt_kernels[0]
+ data['mask'] = training_mask.astype('float32')
+ return data
+
+ def generate_kernel(self, img_size, shrink_ratio, text_polys, ignore_tags=None):
+ h, w = img_size
+ text_kernel = np.zeros((h, w), dtype=np.float32)
+ for i, poly in enumerate(text_polys):
+ polygon = Polygon(poly)
+ distance = polygon.area * (1 - shrink_ratio * shrink_ratio) / (polygon.length + 1e-6)
+ subject = [tuple(l) for l in poly]
+ pco = pyclipper.PyclipperOffset()
+ pco.AddPath(subject, pyclipper.JT_ROUND,
+ pyclipper.ET_CLOSEDPOLYGON)
+ shrinked = np.array(pco.Execute(-distance))
+
+ if len(shrinked) == 0 or shrinked.size == 0:
+ if ignore_tags is not None:
+ ignore_tags[i] = True
+ continue
+ try:
+ shrinked = np.array(shrinked[0]).reshape(-1, 2)
+ except:
+ if ignore_tags is not None:
+ ignore_tags[i] = True
+ continue
+ cv2.fillPoly(text_kernel, [shrinked.astype(np.int32)], i + 1)
+ return text_kernel, ignore_tags
diff --git a/ppocr/data/imaug/operators.py b/ppocr/data/imaug/operators.py
index 2535b4420c503f2e9e9cc5a677ef70c4dd9c36be..87e3088d07a8c5a2eea5d4deff87c69a753e215b 100644
--- a/ppocr/data/imaug/operators.py
+++ b/ppocr/data/imaug/operators.py
@@ -23,6 +23,7 @@ import sys
import six
import cv2
import numpy as np
+import fasttext
class DecodeImage(object):
@@ -57,6 +58,39 @@ class DecodeImage(object):
return data
+class NRTRDecodeImage(object):
+ """ decode image """
+
+ def __init__(self, img_mode='RGB', channel_first=False, **kwargs):
+ self.img_mode = img_mode
+ self.channel_first = channel_first
+
+ def __call__(self, data):
+ img = data['image']
+ if six.PY2:
+ assert type(img) is str and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ else:
+ assert type(img) is bytes and len(
+ img) > 0, "invalid input 'img' in DecodeImage"
+ img = np.frombuffer(img, dtype='uint8')
+
+ img = cv2.imdecode(img, 1)
+
+ if img is None:
+ return None
+ if self.img_mode == 'GRAY':
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+ elif self.img_mode == 'RGB':
+ assert img.shape[2] == 3, 'invalid shape of image[%s]' % (img.shape)
+ img = img[:, :, ::-1]
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ if self.channel_first:
+ img = img.transpose((2, 0, 1))
+ data['image'] = img
+ return data
+
+
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
@@ -81,7 +115,7 @@ class NormalizeImage(object):
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
data['image'] = (
- img.astype('float32') * self.scale - self.mean) / self.std
+ img.astype('float32') * self.scale - self.mean) / self.std
return data
@@ -101,6 +135,17 @@ class ToCHWImage(object):
return data
+class Fasttext(object):
+ def __init__(self, path="None", **kwargs):
+ self.fast_model = fasttext.load_model(path)
+
+ def __call__(self, data):
+ label = data['label']
+ fast_label = self.fast_model[label]
+ data['fast_label'] = fast_label
+ return data
+
+
class KeepKeys(object):
def __init__(self, keep_keys, **kwargs):
self.keep_keys = keep_keys
@@ -112,6 +157,34 @@ class KeepKeys(object):
return data_list
+class Resize(object):
+ def __init__(self, size=(640, 640), **kwargs):
+ self.size = size
+
+ def resize_image(self, img):
+ resize_h, resize_w = self.size
+ ori_h, ori_w = img.shape[:2] # (h, w, c)
+ ratio_h = float(resize_h) / ori_h
+ ratio_w = float(resize_w) / ori_w
+ img = cv2.resize(img, (int(resize_w), int(resize_h)))
+ return img, [ratio_h, ratio_w]
+
+ def __call__(self, data):
+ img = data['image']
+ text_polys = data['polys']
+
+ img_resize, [ratio_h, ratio_w] = self.resize_image(img)
+ new_boxes = []
+ for box in text_polys:
+ new_box = []
+ for cord in box:
+ new_box.append([cord[0] * ratio_w, cord[1] * ratio_h])
+ new_boxes.append(new_box)
+ data['image'] = img_resize
+ data['polys'] = np.array(new_boxes, dtype=np.float32)
+ return data
+
+
class DetResizeForTest(object):
def __init__(self, **kwargs):
super(DetResizeForTest, self).__init__()
@@ -183,7 +256,7 @@ class DetResizeForTest(object):
else:
ratio = 1.
elif self.limit_type == 'resize_long':
- ratio = float(limit_side_len) / max(h,w)
+ ratio = float(limit_side_len) / max(h, w)
else:
raise Exception('not support limit type, image ')
resize_h = int(h * ratio)
diff --git a/ppocr/data/imaug/random_crop_data.py b/ppocr/data/imaug/random_crop_data.py
index 4d67cff61d6f340be6d80d8243c68909a94c4e88..7c1c25abb56a0cf7d4d59b8523962bd5d81c873a 100644
--- a/ppocr/data/imaug/random_crop_data.py
+++ b/ppocr/data/imaug/random_crop_data.py
@@ -164,47 +164,55 @@ class EastRandomCropData(object):
return data
-class PSERandomCrop(object):
- def __init__(self, size, **kwargs):
+class RandomCropImgMask(object):
+ def __init__(self, size, main_key, crop_keys, p=3 / 8, **kwargs):
self.size = size
+ self.main_key = main_key
+ self.crop_keys = crop_keys
+ self.p = p
def __call__(self, data):
- imgs = data['imgs']
+ image = data['image']
- h, w = imgs[0].shape[0:2]
+ h, w = image.shape[0:2]
th, tw = self.size
if w == tw and h == th:
- return imgs
+ return data
- # label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制
- if np.max(imgs[2]) > 0 and random.random() > 3 / 8:
- # 文本实例的左上角点
- tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size
+ mask = data[self.main_key]
+ if np.max(mask) > 0 and random.random() > self.p:
+ # make sure to crop the text region
+ tl = np.min(np.where(mask > 0), axis=1) - (th, tw)
tl[tl < 0] = 0
- # 文本实例的右下角点
- br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
+ br = np.max(np.where(mask > 0), axis=1) - (th, tw)
br[br < 0] = 0
- # 保证选到右下角点时,有足够的距离进行crop
+
br[0] = min(br[0], h - th)
br[1] = min(br[1], w - tw)
- for _ in range(50000):
- i = random.randint(tl[0], br[0])
- j = random.randint(tl[1], br[1])
- # 保证shrink_label_map有文本
- if imgs[1][i:i + th, j:j + tw].sum() <= 0:
- continue
- else:
- break
+ i = random.randint(tl[0], br[0]) if tl[0] < br[0] else 0
+ j = random.randint(tl[1], br[1]) if tl[1] < br[1] else 0
else:
- i = random.randint(0, h - th)
- j = random.randint(0, w - tw)
+ i = random.randint(0, h - th) if h - th > 0 else 0
+ j = random.randint(0, w - tw) if w - tw > 0 else 0
# return i, j, th, tw
- for idx in range(len(imgs)):
- if len(imgs[idx].shape) == 3:
- imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
- else:
- imgs[idx] = imgs[idx][i:i + th, j:j + tw]
- data['imgs'] = imgs
+ for k in data:
+ if k in self.crop_keys:
+ if len(data[k].shape) == 3:
+ if np.argmin(data[k].shape) == 0:
+ img = data[k][:, i:i + th, j:j + tw]
+ if img.shape[1] != img.shape[2]:
+ a = 1
+ elif np.argmin(data[k].shape) == 2:
+ img = data[k][i:i + th, j:j + tw, :]
+ if img.shape[1] != img.shape[0]:
+ a = 1
+ else:
+ img = data[k]
+ else:
+ img = data[k][i:i + th, j:j + tw]
+ if img.shape[0] != img.shape[1]:
+ a = 1
+ data[k] = img
return data
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 28e6bd0bce768c45dbc334c15ace601fd6403f5d..b4de6de95b09ced803375d9a3bb857194ef3e64b 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -16,7 +16,7 @@ import math
import cv2
import numpy as np
import random
-
+from PIL import Image
from .text_image_aug import tia_perspective, tia_stretch, tia_distort
@@ -43,22 +43,64 @@ class ClsResizeImg(object):
return data
+class NRTRRecResizeImg(object):
+ def __init__(self, image_shape, resize_type, padding=False, **kwargs):
+ self.image_shape = image_shape
+ self.resize_type = resize_type
+ self.padding = padding
+
+ def __call__(self, data):
+ img = data['image']
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ image_shape = self.image_shape
+ if self.padding:
+ imgC, imgH, imgW = image_shape
+ # todo: change to 0 and modified image shape
+ h = img.shape[0]
+ w = img.shape[1]
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = cv2.resize(img, (resized_w, imgH))
+ norm_img = np.expand_dims(resized_image, -1)
+ norm_img = norm_img.transpose((2, 0, 1))
+ resized_image = norm_img.astype(np.float32) / 128. - 1.
+ padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
+ padding_im[:, :, 0:resized_w] = resized_image
+ data['image'] = padding_im
+ return data
+ if self.resize_type == 'PIL':
+ image_pil = Image.fromarray(np.uint8(img))
+ img = image_pil.resize(self.image_shape, Image.ANTIALIAS)
+ img = np.array(img)
+ if self.resize_type == 'OpenCV':
+ img = cv2.resize(img, self.image_shape)
+ norm_img = np.expand_dims(img, -1)
+ norm_img = norm_img.transpose((2, 0, 1))
+ data['image'] = norm_img.astype(np.float32) / 128. - 1.
+ return data
+
+
class RecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
- character_type='ch',
+ character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
+ padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
- self.character_type = character_type
+ self.character_dict_path = character_dict_path
+ self.padding = padding
def __call__(self, data):
img = data['image']
- if self.infer_mode and self.character_type == "ch":
+ if self.infer_mode and self.character_dict_path is not None:
norm_img = resize_norm_img_chinese(img, self.image_shape)
else:
- norm_img = resize_norm_img(img, self.image_shape)
+ norm_img = resize_norm_img(img, self.image_shape, self.padding)
data['image'] = norm_img
return data
@@ -83,16 +125,72 @@ class SRNRecResizeImg(object):
return data
-def resize_norm_img(img, image_shape):
- imgC, imgH, imgW = image_shape
+class SARRecResizeImg(object):
+ def __init__(self, image_shape, width_downsample_ratio=0.25, **kwargs):
+ self.image_shape = image_shape
+ self.width_downsample_ratio = width_downsample_ratio
+
+ def __call__(self, data):
+ img = data['image']
+ norm_img, resize_shape, pad_shape, valid_ratio = resize_norm_img_sar(
+ img, self.image_shape, self.width_downsample_ratio)
+ data['image'] = norm_img
+ data['resized_shape'] = resize_shape
+ data['pad_shape'] = pad_shape
+ data['valid_ratio'] = valid_ratio
+ return data
+
+
+def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
+ imgC, imgH, imgW_min, imgW_max = image_shape
h = img.shape[0]
w = img.shape[1]
+ valid_ratio = 1.0
+ # make sure new_width is an integral multiple of width_divisor.
+ width_divisor = int(1 / width_downsample_ratio)
+ # resize
ratio = w / float(h)
- if math.ceil(imgH * ratio) > imgW:
+ resize_w = math.ceil(imgH * ratio)
+ if resize_w % width_divisor != 0:
+ resize_w = round(resize_w / width_divisor) * width_divisor
+ if imgW_min is not None:
+ resize_w = max(imgW_min, resize_w)
+ if imgW_max is not None:
+ valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
+ resize_w = min(imgW_max, resize_w)
+ resized_image = cv2.resize(img, (resize_w, imgH))
+ resized_image = resized_image.astype('float32')
+ # norm
+ if image_shape[0] == 1:
+ resized_image = resized_image / 255
+ resized_image = resized_image[np.newaxis, :]
+ else:
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ resize_shape = resized_image.shape
+ padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
+ padding_im[:, :, 0:resize_w] = resized_image
+ pad_shape = padding_im.shape
+
+ return padding_im, resize_shape, pad_shape, valid_ratio
+
+
+def resize_norm_img(img, image_shape, padding=True):
+ imgC, imgH, imgW = image_shape
+ h = img.shape[0]
+ w = img.shape[1]
+ if not padding:
+ resized_image = cv2.resize(
+ img, (imgW, imgH), interpolation=cv2.INTER_LINEAR)
resized_w = imgW
else:
- resized_w = int(math.ceil(imgH * ratio))
- resized_image = cv2.resize(img, (resized_w, imgH))
+ ratio = w / float(h)
+ if math.ceil(imgH * ratio) > imgW:
+ resized_w = imgW
+ else:
+ resized_w = int(math.ceil(imgH * ratio))
+ resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py
index e9c3394cbe930d5169ae005e7582a2902e697b7e..6a33e1342506f26ccaa4a146f3f02fadfbd741a2 100644
--- a/ppocr/data/simple_dataset.py
+++ b/ppocr/data/simple_dataset.py
@@ -15,7 +15,6 @@ import numpy as np
import os
import random
from paddle.io import Dataset
-
from .imaug import transform, create_operators
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 025ae7ca5cc604eea59423ca7f523c37c1492e35..f3f4cd49332b605ec3a0e65e688d965fd91a5cdf 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -20,11 +20,15 @@ import paddle.nn as nn
from .det_db_loss import DBLoss
from .det_east_loss import EASTLoss
from .det_sast_loss import SASTLoss
+from .det_pse_loss import PSELoss
# rec loss
from .rec_ctc_loss import CTCLoss
from .rec_att_loss import AttentionLoss
from .rec_srn_loss import SRNLoss
+from .rec_nrtr_loss import NRTRLoss
+from .rec_sar_loss import SARLoss
+from .rec_aster_loss import AsterLoss
# cls loss
from .cls_loss import ClsLoss
@@ -41,10 +45,12 @@ from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
+
def build_loss(config):
support_dict = [
- 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'AttentionLoss',
- 'SRNLoss', 'PGLoss', 'CombinedLoss', 'TableAttentionLoss'
+ 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss',
+ 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'NRTRLoss',
+ 'TableAttentionLoss', 'SARLoss', 'AsterLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/ace_loss.py b/ppocr/losses/ace_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf15f8e3a7b355bd9e8b69435a5dae01fc75a892
--- /dev/null
+++ b/ppocr/losses/ace_loss.py
@@ -0,0 +1,49 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+import paddle.nn as nn
+
+
+class ACELoss(nn.Layer):
+ def __init__(self, **kwargs):
+ super().__init__()
+ self.loss_func = nn.CrossEntropyLoss(
+ weight=None,
+ ignore_index=0,
+ reduction='none',
+ soft_label=True,
+ axis=-1)
+
+ def __call__(self, predicts, batch):
+ if isinstance(predicts, (list, tuple)):
+ predicts = predicts[-1]
+
+ B, N = predicts.shape[:2]
+ div = paddle.to_tensor([N]).astype('float32')
+
+ predicts = nn.functional.softmax(predicts, axis=-1)
+ aggregation_preds = paddle.sum(predicts, axis=1)
+ aggregation_preds = paddle.divide(aggregation_preds, div)
+
+ length = batch[2].astype("float32")
+ batch = batch[3].astype("float32")
+ batch[:, 0] = paddle.subtract(div, length)
+ batch = paddle.divide(batch, div)
+
+ loss = self.loss_func(aggregation_preds, batch)
+ return {"loss_ace": loss}
diff --git a/ppocr/losses/basic_loss.py b/ppocr/losses/basic_loss.py
index 8306523ac1a933f0c664fc0b4cf077659cccdee3..d2ef5e5ac9692eec5bc30774c4451eab7706705d 100644
--- a/ppocr/losses/basic_loss.py
+++ b/ppocr/losses/basic_loss.py
@@ -56,31 +56,34 @@ class CELoss(nn.Layer):
class KLJSLoss(object):
def __init__(self, mode='kl'):
- assert mode in ['kl', 'js', 'KL', 'JS'], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
+ assert mode in ['kl', 'js', 'KL', 'JS'
+ ], "mode can only be one of ['kl', 'js', 'KL', 'JS']"
self.mode = mode
def __call__(self, p1, p2, reduction="mean"):
- loss = paddle.multiply(p2, paddle.log( (p2+1e-5)/(p1+1e-5) + 1e-5))
+ loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js":
- loss += paddle.multiply(p1, paddle.log((p1+1e-5)/(p2+1e-5) + 1e-5))
+ loss += paddle.multiply(
+ p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
if reduction == "mean":
- loss = paddle.mean(loss, axis=[1,2])
- elif reduction=="none" or reduction is None:
- return loss
+ loss = paddle.mean(loss, axis=[1, 2])
+ elif reduction == "none" or reduction is None:
+ return loss
else:
- loss = paddle.sum(loss, axis=[1,2])
+ loss = paddle.sum(loss, axis=[1, 2])
+
+ return loss
- return loss
class DMLLoss(nn.Layer):
"""
DMLLoss
"""
- def __init__(self, act=None):
+ def __init__(self, act=None, use_log=False):
super().__init__()
if act is not None:
assert act in ["softmax", "sigmoid"]
@@ -90,20 +93,24 @@ class DMLLoss(nn.Layer):
self.act = nn.Sigmoid()
else:
self.act = None
-
+
+ self.use_log = use_log
+
self.jskl_loss = KLJSLoss(mode="js")
def forward(self, out1, out2):
if self.act is not None:
out1 = self.act(out1)
out2 = self.act(out2)
- if len(out1.shape) < 2:
+ if self.use_log:
+ # for recognition distillation, log is needed for feature map
log_out1 = paddle.log(out1)
log_out2 = paddle.log(out2)
loss = (F.kl_div(
log_out1, out2, reduction='batchmean') + F.kl_div(
log_out2, out1, reduction='batchmean')) / 2.0
else:
+ # for detection distillation log is not needed
loss = self.jskl_loss(out1, out2)
return loss
diff --git a/ppocr/losses/center_loss.py b/ppocr/losses/center_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbef4df965e2659c6aa63c0c69cd8798143df485
--- /dev/null
+++ b/ppocr/losses/center_loss.py
@@ -0,0 +1,89 @@
+#copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import os
+import pickle
+
+import paddle
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class CenterLoss(nn.Layer):
+ """
+ Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
+ """
+ def __init__(self,
+ num_classes=6625,
+ feat_dim=96,
+ init_center=False,
+ center_file_path=None):
+ super().__init__()
+ self.num_classes = num_classes
+ self.feat_dim = feat_dim
+ self.centers = paddle.randn(
+ shape=[self.num_classes, self.feat_dim]).astype("float64")
+
+ if init_center:
+ assert os.path.exists(
+ center_file_path
+ ), f"center path({center_file_path}) must exist when init_center is set as True."
+ with open(center_file_path, 'rb') as f:
+ char_dict = pickle.load(f)
+ for key in char_dict.keys():
+ self.centers[key] = paddle.to_tensor(char_dict[key])
+
+ def __call__(self, predicts, batch):
+ assert isinstance(predicts, (list, tuple))
+ features, predicts = predicts
+
+ feats_reshape = paddle.reshape(
+ features, [-1, features.shape[-1]]).astype("float64")
+ label = paddle.argmax(predicts, axis=2)
+ label = paddle.reshape(label, [label.shape[0] * label.shape[1]])
+
+ batch_size = feats_reshape.shape[0]
+
+ #calc l2 distance between feats and centers
+ square_feat = paddle.sum(paddle.square(feats_reshape),
+ axis=1,
+ keepdim=True)
+ square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
+
+ square_center = paddle.sum(paddle.square(self.centers),
+ axis=1,
+ keepdim=True)
+ square_center = paddle.expand(
+ square_center, [self.num_classes, batch_size]).astype("float64")
+ square_center = paddle.transpose(square_center, [1, 0])
+
+ distmat = paddle.add(square_feat, square_center)
+ feat_dot_center = paddle.matmul(feats_reshape,
+ paddle.transpose(self.centers, [1, 0]))
+ distmat = distmat - 2.0 * feat_dot_center
+
+ #generate the mask
+ classes = paddle.arange(self.num_classes).astype("int64")
+ label = paddle.expand(
+ paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
+ mask = paddle.equal(
+ paddle.expand(classes, [batch_size, self.num_classes]),
+ label).astype("float64")
+ dist = paddle.multiply(distmat, mask)
+
+ loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
+ return {'loss_center': loss}
diff --git a/ppocr/losses/combined_loss.py b/ppocr/losses/combined_loss.py
index 0d6fe968d0d7733200a4cfd21d779196cccaba03..72f706e37d6eb0c640cc30de80afe00bce82fd13 100644
--- a/ppocr/losses/combined_loss.py
+++ b/ppocr/losses/combined_loss.py
@@ -15,6 +15,10 @@
import paddle
import paddle.nn as nn
+from .rec_ctc_loss import CTCLoss
+from .center_loss import CenterLoss
+from .ace_loss import ACELoss
+
from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
@@ -49,11 +53,15 @@ class CombinedLoss(nn.Layer):
loss = loss_func(input, batch, **kargs)
if isinstance(loss, paddle.Tensor):
loss = {"loss_{}_{}".format(str(loss), idx): loss}
+
weight = self.loss_weight[idx]
- for key in loss.keys():
- if key == "loss":
- loss_all += loss[key] * weight
- else:
- loss_dict["{}_{}".format(key, idx)] = loss[key]
+
+ loss = {key: loss[key] * weight for key in loss}
+
+ if "loss" in loss:
+ loss_all += loss["loss"]
+ else:
+ loss_all += paddle.add_n(list(loss.values()))
+ loss_dict.update(loss)
loss_dict["loss"] = loss_all
return loss_dict
diff --git a/ppocr/losses/det_basic_loss.py b/ppocr/losses/det_basic_loss.py
index eba5526dd2bd1c0328130b50817172df437cc360..7017236c284e55710f242275a413d56d32158d34 100644
--- a/ppocr/losses/det_basic_loss.py
+++ b/ppocr/losses/det_basic_loss.py
@@ -75,12 +75,6 @@ class BalanceLoss(nn.Layer):
mask (variable): masked maps.
return: (variable) balanced loss
"""
- # if self.main_loss_type in ['DiceLoss']:
- # # For the loss that returns to scalar value, perform ohem on the mask
- # mask = ohem_batch(pred, gt, mask, self.negative_ratio)
- # loss = self.loss(pred, gt, mask)
- # return loss
-
positive = gt * mask
negative = (1 - gt) * mask
@@ -153,53 +147,4 @@ class BCELoss(nn.Layer):
def forward(self, input, label, mask=None, weight=None, name=None):
loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
- return loss
-
-
-def ohem_single(score, gt_text, training_mask, ohem_ratio):
- pos_num = (int)(np.sum(gt_text > 0.5)) - (
- int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))
-
- if pos_num == 0:
- # selected_mask = gt_text.copy() * 0 # may be not good
- selected_mask = training_mask
- selected_mask = selected_mask.reshape(
- 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
-
- neg_num = (int)(np.sum(gt_text <= 0.5))
- neg_num = (int)(min(pos_num * ohem_ratio, neg_num))
-
- if neg_num == 0:
- selected_mask = training_mask
- selected_mask = selected_mask.reshape(
- 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
-
- neg_score = score[gt_text <= 0.5]
- # 将负样本得分从高到低排序
- neg_score_sorted = np.sort(-neg_score)
- threshold = -neg_score_sorted[neg_num - 1]
- # 选出 得分高的 负样本 和正样本 的 mask
- selected_mask = ((score >= threshold) |
- (gt_text > 0.5)) & (training_mask > 0.5)
- selected_mask = selected_mask.reshape(
- 1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')
- return selected_mask
-
-
-def ohem_batch(scores, gt_texts, training_masks, ohem_ratio):
- scores = scores.numpy()
- gt_texts = gt_texts.numpy()
- training_masks = training_masks.numpy()
-
- selected_masks = []
- for i in range(scores.shape[0]):
- selected_masks.append(
- ohem_single(scores[i, :, :], gt_texts[i, :, :], training_masks[
- i, :, :], ohem_ratio))
-
- selected_masks = np.concatenate(selected_masks, 0)
- selected_masks = paddle.to_tensor(selected_masks)
-
- return selected_masks
+ return loss
\ No newline at end of file
diff --git a/ppocr/losses/det_pse_loss.py b/ppocr/losses/det_pse_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..78423091f841f29b1217f73f79beb26fe1575844
--- /dev/null
+++ b/ppocr/losses/det_pse_loss.py
@@ -0,0 +1,145 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+from paddle import nn
+from paddle.nn import functional as F
+import numpy as np
+from ppocr.utils.iou import iou
+
+
+class PSELoss(nn.Layer):
+ def __init__(self,
+ alpha,
+ ohem_ratio=3,
+ kernel_sample_mask='pred',
+ reduction='sum',
+ eps=1e-6,
+ **kwargs):
+ """Implement PSE Loss.
+ """
+ super(PSELoss, self).__init__()
+ assert reduction in ['sum', 'mean', 'none']
+ self.alpha = alpha
+ self.ohem_ratio = ohem_ratio
+ self.kernel_sample_mask = kernel_sample_mask
+ self.reduction = reduction
+ self.eps = eps
+
+ def forward(self, outputs, labels):
+ predicts = outputs['maps']
+ predicts = F.interpolate(predicts, scale_factor=4)
+
+ texts = predicts[:, 0, :, :]
+ kernels = predicts[:, 1:, :, :]
+ gt_texts, gt_kernels, training_masks = labels[1:]
+
+ # text loss
+ selected_masks = self.ohem_batch(texts, gt_texts, training_masks)
+
+ loss_text = self.dice_loss(texts, gt_texts, selected_masks)
+ iou_text = iou((texts > 0).astype('int64'),
+ gt_texts,
+ training_masks,
+ reduce=False)
+ losses = dict(loss_text=loss_text, iou_text=iou_text)
+
+ # kernel loss
+ loss_kernels = []
+ if self.kernel_sample_mask == 'gt':
+ selected_masks = gt_texts * training_masks
+ elif self.kernel_sample_mask == 'pred':
+ selected_masks = (
+ F.sigmoid(texts) > 0.5).astype('float32') * training_masks
+
+ for i in range(kernels.shape[1]):
+ kernel_i = kernels[:, i, :, :]
+ gt_kernel_i = gt_kernels[:, i, :, :]
+ loss_kernel_i = self.dice_loss(kernel_i, gt_kernel_i,
+ selected_masks)
+ loss_kernels.append(loss_kernel_i)
+ loss_kernels = paddle.mean(paddle.stack(loss_kernels, axis=1), axis=1)
+ iou_kernel = iou((kernels[:, -1, :, :] > 0).astype('int64'),
+ gt_kernels[:, -1, :, :],
+ training_masks * gt_texts,
+ reduce=False)
+ losses.update(dict(loss_kernels=loss_kernels, iou_kernel=iou_kernel))
+ loss = self.alpha * loss_text + (1 - self.alpha) * loss_kernels
+ losses['loss'] = loss
+ if self.reduction == 'sum':
+ losses = {x: paddle.sum(v) for x, v in losses.items()}
+ elif self.reduction == 'mean':
+ losses = {x: paddle.mean(v) for x, v in losses.items()}
+ return losses
+
+ def dice_loss(self, input, target, mask):
+ input = F.sigmoid(input)
+
+ input = input.reshape([input.shape[0], -1])
+ target = target.reshape([target.shape[0], -1])
+ mask = mask.reshape([mask.shape[0], -1])
+
+ input = input * mask
+ target = target * mask
+
+ a = paddle.sum(input * target, 1)
+ b = paddle.sum(input * input, 1) + self.eps
+ c = paddle.sum(target * target, 1) + self.eps
+ d = (2 * a) / (b + c)
+ return 1 - d
+
+ def ohem_single(self, score, gt_text, training_mask, ohem_ratio=3):
+ pos_num = int(paddle.sum((gt_text > 0.5).astype('float32'))) - int(
+ paddle.sum(
+ paddle.logical_and((gt_text > 0.5), (training_mask <= 0.5))
+ .astype('float32')))
+
+ if pos_num == 0:
+ selected_mask = training_mask
+ selected_mask = selected_mask.reshape(
+ [1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
+ 'float32')
+ return selected_mask
+
+ neg_num = int(paddle.sum((gt_text <= 0.5).astype('float32')))
+ neg_num = int(min(pos_num * ohem_ratio, neg_num))
+
+ if neg_num == 0:
+ selected_mask = training_mask
+ selected_mask = selected_mask.view(
+ 1, selected_mask.shape[0],
+ selected_mask.shape[1]).astype('float32')
+ return selected_mask
+
+ neg_score = paddle.masked_select(score, gt_text <= 0.5)
+ neg_score_sorted = paddle.sort(-neg_score)
+ threshold = -neg_score_sorted[neg_num - 1]
+
+ selected_mask = paddle.logical_and(
+ paddle.logical_or((score >= threshold), (gt_text > 0.5)),
+ (training_mask > 0.5))
+ selected_mask = selected_mask.reshape(
+ [1, selected_mask.shape[0], selected_mask.shape[1]]).astype(
+ 'float32')
+ return selected_mask
+
+ def ohem_batch(self, scores, gt_texts, training_masks, ohem_ratio=3):
+ selected_masks = []
+ for i in range(scores.shape[0]):
+ selected_masks.append(
+ self.ohem_single(scores[i, :, :], gt_texts[i, :, :],
+ training_masks[i, :, :], ohem_ratio))
+
+ selected_masks = paddle.concat(selected_masks, 0).astype('float32')
+ return selected_masks
diff --git a/ppocr/losses/distillation_loss.py b/ppocr/losses/distillation_loss.py
index 75f0a773152e52c98ada5c1907f1c8cc2f72d8f3..06aa7fa8458a5deece75f1393fe7300e8227d3ca 100644
--- a/ppocr/losses/distillation_loss.py
+++ b/ppocr/losses/distillation_loss.py
@@ -44,20 +44,22 @@ class DistillationDMLLoss(DMLLoss):
def __init__(self,
model_name_pairs=[],
act=None,
+ use_log=False,
key=None,
maps_name=None,
name="dml"):
- super().__init__(act=act)
+ super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = self._check_model_name_pairs(model_name_pairs)
self.name = name
self.maps_name = self._check_maps_name(maps_name)
-
+
def _check_model_name_pairs(self, model_name_pairs):
if not isinstance(model_name_pairs, list):
return []
- elif isinstance(model_name_pairs[0], list) and isinstance(model_name_pairs[0][0], str):
+ elif isinstance(model_name_pairs[0], list) and isinstance(
+ model_name_pairs[0][0], str):
return model_name_pairs
else:
return [model_name_pairs]
@@ -110,11 +112,11 @@ class DistillationDMLLoss(DMLLoss):
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}_{}_{}_{}".format(key, pair[
- 0], pair[1], map_name, idx)] = loss[key]
+ 0], pair[1], self.maps_name, idx)] = loss[key]
else:
- loss_dict["{}_{}_{}".format(self.name, self.maps_name[_c],
- idx)] = loss
-
+ loss_dict["{}_{}_{}".format(self.name, self.maps_name[
+ _c], idx)] = loss
+
loss_dict = _sum_loss(loss_dict)
return loss_dict
diff --git a/ppocr/losses/rec_aster_loss.py b/ppocr/losses/rec_aster_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbb99d29a638540b02649a8912051339c08b22dd
--- /dev/null
+++ b/ppocr/losses/rec_aster_loss.py
@@ -0,0 +1,99 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class CosineEmbeddingLoss(nn.Layer):
+ def __init__(self, margin=0.):
+ super(CosineEmbeddingLoss, self).__init__()
+ self.margin = margin
+ self.epsilon = 1e-12
+
+ def forward(self, x1, x2, target):
+ similarity = paddle.fluid.layers.reduce_sum(
+ x1 * x2, dim=-1) / (paddle.norm(
+ x1, axis=-1) * paddle.norm(
+ x2, axis=-1) + self.epsilon)
+ one_list = paddle.full_like(target, fill_value=1)
+ out = paddle.fluid.layers.reduce_mean(
+ paddle.where(
+ paddle.equal(target, one_list), 1. - similarity,
+ paddle.maximum(
+ paddle.zeros_like(similarity), similarity - self.margin)))
+
+ return out
+
+
+class AsterLoss(nn.Layer):
+ def __init__(self,
+ weight=None,
+ size_average=True,
+ ignore_index=-100,
+ sequence_normalize=False,
+ sample_normalize=True,
+ **kwargs):
+ super(AsterLoss, self).__init__()
+ self.weight = weight
+ self.size_average = size_average
+ self.ignore_index = ignore_index
+ self.sequence_normalize = sequence_normalize
+ self.sample_normalize = sample_normalize
+ self.loss_sem = CosineEmbeddingLoss()
+ self.is_cosin_loss = True
+ self.loss_func_rec = nn.CrossEntropyLoss(weight=None, reduction='none')
+
+ def forward(self, predicts, batch):
+ targets = batch[1].astype("int64")
+ label_lengths = batch[2].astype('int64')
+ sem_target = batch[3].astype('float32')
+ embedding_vectors = predicts['embedding_vectors']
+ rec_pred = predicts['rec_pred']
+
+ if not self.is_cosin_loss:
+ sem_loss = paddle.sum(self.loss_sem(embedding_vectors, sem_target))
+ else:
+ label_target = paddle.ones([embedding_vectors.shape[0]])
+ sem_loss = paddle.sum(
+ self.loss_sem(embedding_vectors, sem_target, label_target))
+
+ # rec loss
+ batch_size, def_max_length = targets.shape[0], targets.shape[1]
+
+ mask = paddle.zeros([batch_size, def_max_length])
+ for i in range(batch_size):
+ mask[i, :label_lengths[i]] = 1
+ mask = paddle.cast(mask, "float32")
+ max_length = max(label_lengths)
+ assert max_length == rec_pred.shape[1]
+ targets = targets[:, :max_length]
+ mask = mask[:, :max_length]
+ rec_pred = paddle.reshape(rec_pred, [-1, rec_pred.shape[2]])
+ input = nn.functional.log_softmax(rec_pred, axis=1)
+ targets = paddle.reshape(targets, [-1, 1])
+ mask = paddle.reshape(mask, [-1, 1])
+ output = -paddle.index_sample(input, index=targets) * mask
+ output = paddle.sum(output)
+ if self.sequence_normalize:
+ output = output / paddle.sum(mask)
+ if self.sample_normalize:
+ output = output / batch_size
+
+ loss = output + sem_loss * 0.1
+ return {'loss': loss}
diff --git a/ppocr/losses/rec_ctc_loss.py b/ppocr/losses/rec_ctc_loss.py
index 6c0b56ff84db4ff23786fb781d461bf9fbc86ef2..063d68e30861e092e10fa3068e4b7f4755b6197f 100755
--- a/ppocr/losses/rec_ctc_loss.py
+++ b/ppocr/losses/rec_ctc_loss.py
@@ -21,16 +21,24 @@ from paddle import nn
class CTCLoss(nn.Layer):
- def __init__(self, **kwargs):
+ def __init__(self, use_focal_loss=False, **kwargs):
super(CTCLoss, self).__init__()
self.loss_func = nn.CTCLoss(blank=0, reduction='none')
+ self.use_focal_loss = use_focal_loss
def forward(self, predicts, batch):
+ if isinstance(predicts, (list, tuple)):
+ predicts = predicts[-1]
predicts = predicts.transpose((1, 0, 2))
N, B, _ = predicts.shape
preds_lengths = paddle.to_tensor([N] * B, dtype='int64')
labels = batch[1].astype("int32")
label_lengths = batch[2].astype('int64')
loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
- loss = loss.mean() # sum
+ if self.use_focal_loss:
+ weight = paddle.exp(-loss)
+ weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
+ weight = paddle.square(weight)
+ loss = paddle.multiply(loss, weight)
+ loss = loss.mean()
return {'loss': loss}
diff --git a/ppocr/losses/rec_enhanced_ctc_loss.py b/ppocr/losses/rec_enhanced_ctc_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b57be6468e2ec75811442e7449525267e7d9e82e
--- /dev/null
+++ b/ppocr/losses/rec_enhanced_ctc_loss.py
@@ -0,0 +1,70 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+from .ace_loss import ACELoss
+from .center_loss import CenterLoss
+from .rec_ctc_loss import CTCLoss
+
+
+class EnhancedCTCLoss(nn.Layer):
+ def __init__(self,
+ use_focal_loss=False,
+ use_ace_loss=False,
+ ace_loss_weight=0.1,
+ use_center_loss=False,
+ center_loss_weight=0.05,
+ num_classes=6625,
+ feat_dim=96,
+ init_center=False,
+ center_file_path=None,
+ **kwargs):
+ super(EnhancedCTCLoss, self).__init__()
+ self.ctc_loss_func = CTCLoss(use_focal_loss=use_focal_loss)
+
+ self.use_ace_loss = False
+ if use_ace_loss:
+ self.use_ace_loss = use_ace_loss
+ self.ace_loss_func = ACELoss()
+ self.ace_loss_weight = ace_loss_weight
+
+ self.use_center_loss = False
+ if use_center_loss:
+ self.use_center_loss = use_center_loss
+ self.center_loss_func = CenterLoss(
+ num_classes=num_classes,
+ feat_dim=feat_dim,
+ init_center=init_center,
+ center_file_path=center_file_path)
+ self.center_loss_weight = center_loss_weight
+
+ def __call__(self, predicts, batch):
+ loss = self.ctc_loss_func(predicts, batch)["loss"]
+
+ if self.use_center_loss:
+ center_loss = self.center_loss_func(
+ predicts, batch)["loss_center"] * self.center_loss_weight
+ loss = loss + center_loss
+
+ if self.use_ace_loss:
+ ace_loss = self.ace_loss_func(
+ predicts, batch)["loss_ace"] * self.ace_loss_weight
+ loss = loss + ace_loss
+
+ return {'enhanced_ctc_loss': loss}
diff --git a/ppocr/losses/rec_nrtr_loss.py b/ppocr/losses/rec_nrtr_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..41714dd2a3ae15eeedc62521d97935f68271c598
--- /dev/null
+++ b/ppocr/losses/rec_nrtr_loss.py
@@ -0,0 +1,30 @@
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+
+
+class NRTRLoss(nn.Layer):
+ def __init__(self, smoothing=True, **kwargs):
+ super(NRTRLoss, self).__init__()
+ self.loss_func = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
+ self.smoothing = smoothing
+
+ def forward(self, pred, batch):
+ pred = pred.reshape([-1, pred.shape[2]])
+ max_len = batch[2].max()
+ tgt = batch[1][:, 1:2 + max_len]
+ tgt = tgt.reshape([-1])
+ if self.smoothing:
+ eps = 0.1
+ n_class = pred.shape[1]
+ one_hot = F.one_hot(tgt, pred.shape[1])
+ one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
+ log_prb = F.log_softmax(pred, axis=1)
+ non_pad_mask = paddle.not_equal(
+ tgt, paddle.zeros(
+ tgt.shape, dtype='int64'))
+ loss = -(one_hot * log_prb).sum(axis=1)
+ loss = loss.masked_select(non_pad_mask).mean()
+ else:
+ loss = self.loss_func(pred, tgt)
+ return {'loss': loss}
diff --git a/ppocr/losses/rec_sar_loss.py b/ppocr/losses/rec_sar_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8bd8bb0ca395fa4658e57b8dcac52a3e94aadce
--- /dev/null
+++ b/ppocr/losses/rec_sar_loss.py
@@ -0,0 +1,28 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class SARLoss(nn.Layer):
+ def __init__(self, **kwargs):
+ super(SARLoss, self).__init__()
+ self.loss_func = paddle.nn.loss.CrossEntropyLoss(
+ reduction="mean", ignore_index=92)
+
+ def forward(self, predicts, batch):
+ predict = predicts[:, :
+ -1, :] # ignore last index of outputs to be in same seq_len with targets
+ label = batch[1].astype(
+ "int64")[:, 1:] # ignore first index of target in loss calculation
+ batch_size, num_steps, num_classes = predict.shape[0], predict.shape[
+ 1], predict.shape[2]
+ assert len(label.shape) == len(list(predict.shape)) - 1, \
+ "The target's shape and inputs's shape is [N, d] and [N, num_steps]"
+
+ inputs = paddle.reshape(predict, [-1, num_classes])
+ targets = paddle.reshape(label, [-1])
+ loss = self.loss_func(inputs, targets)
+ return {'loss': loss}
diff --git a/ppocr/metrics/eval_det_iou.py b/ppocr/metrics/eval_det_iou.py
index 0e32b2d19281de9a18a1fe0343bd7e8237825b7b..bc05e7df7d1d21abfb9d9fbd224ecd7254d9f393 100644
--- a/ppocr/metrics/eval_det_iou.py
+++ b/ppocr/metrics/eval_det_iou.py
@@ -169,21 +169,10 @@ class DetectionIoUEvaluator(object):
numGlobalCareDet += numDetCare
perSampleMetrics = {
- 'precision': precision,
- 'recall': recall,
- 'hmean': hmean,
- 'pairs': pairs,
- 'iouMat': [] if len(detPols) > 100 else iouMat.tolist(),
- 'gtPolPoints': gtPolPoints,
- 'detPolPoints': detPolPoints,
'gtCare': numGtCare,
'detCare': numDetCare,
- 'gtDontCare': gtDontCarePolsNum,
- 'detDontCare': detDontCarePolsNum,
'detMatched': detMatched,
- 'evaluationLog': evaluationLog
}
-
return perSampleMetrics
def combine_results(self, results):
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index 66c084d771dece0e2974bc72a177b53f564a8f2e..db2f41c3a140ecebc42b71ee03f0ecb5cf50ca80 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -13,13 +13,20 @@
# limitations under the License.
import Levenshtein
+import string
class RecMetric(object):
- def __init__(self, main_indicator='acc', **kwargs):
+ def __init__(self, main_indicator='acc', is_filter=False, **kwargs):
self.main_indicator = main_indicator
+ self.is_filter = is_filter
self.reset()
+ def _normalize_text(self, text):
+ text = ''.join(
+ filter(lambda x: x in (string.digits + string.ascii_letters), text))
+ return text.lower()
+
def __call__(self, pred_label, *args, **kwargs):
preds, labels = pred_label
correct_num = 0
@@ -28,6 +35,9 @@ class RecMetric(object):
for (pred, pred_conf), (target, _) in zip(preds, labels):
pred = pred.replace(" ", "")
target = target.replace(" ", "")
+ if self.is_filter:
+ pred = self._normalize_text(pred)
+ target = self._normalize_text(target)
norm_edit_dis += Levenshtein.distance(pred, target) / max(
len(pred), len(target), 1)
if pred == target:
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index dbd18070b36f7e99c62de94048ab53d1bedcebe0..c498d9862abcfc85eaf29ed1d949230a1dc1629c 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -14,7 +14,6 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
from paddle import nn
from ppocr.modeling.transforms import build_transform
from ppocr.modeling.backbones import build_backbone
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index f4fe8c76be0835f55f402f35ad6a91a5ca116d88..169eb821f110d4a212068ebab4d46d636e241307 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -26,8 +26,12 @@ def build_backbone(config, model_type):
from .rec_resnet_vd import ResNet
from .rec_resnet_fpn import ResNetFPN
from .rec_mv1_enhance import MobileNetV1Enhance
+ from .rec_nrtr_mtb import MTB
+ from .rec_resnet_31 import ResNet31
+ from .rec_resnet_aster import ResNet_ASTER
support_dict = [
- "MobileNetV1Enhance", "MobileNetV3", "ResNet", "ResNetFPN"
+ 'MobileNetV1Enhance', 'MobileNetV3', 'ResNet', 'ResNetFPN', 'MTB',
+ "ResNet31", "ResNet_ASTER"
]
elif model_type == "e2e":
from .e2e_resnet_vd_pg import ResNet
diff --git a/ppocr/modeling/backbones/rec_mv1_enhance.py b/ppocr/modeling/backbones/rec_mv1_enhance.py
index fe874fac1af439bfb47ba9050a61f02db302e224..04a909b8ccafd8e62f9a7076c7dedf63ff745303 100644
--- a/ppocr/modeling/backbones/rec_mv1_enhance.py
+++ b/ppocr/modeling/backbones/rec_mv1_enhance.py
@@ -1,4 +1,4 @@
-# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,26 +16,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-import paddle
-from paddle import ParamAttr
-import paddle.nn as nn
-import paddle.nn.functional as F
-from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
-from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
-from paddle.nn.initializer import KaimingNormal
import math
import numpy as np
import paddle
-from paddle import ParamAttr, reshape, transpose, concat, split
+from paddle import ParamAttr, reshape, transpose
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import KaimingNormal
-import math
-from paddle.nn.functional import hardswish, hardsigmoid
from paddle.regularizer import L2Decay
+from paddle.nn.functional import hardswish, hardsigmoid
class ConvBNLayer(nn.Layer):
diff --git a/ppocr/modeling/backbones/rec_nrtr_mtb.py b/ppocr/modeling/backbones/rec_nrtr_mtb.py
new file mode 100644
index 0000000000000000000000000000000000000000..22e02a6371c3ff8b28fd88b5cfa1087309d551f8
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_nrtr_mtb.py
@@ -0,0 +1,48 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle import nn
+import paddle
+
+
+class MTB(nn.Layer):
+ def __init__(self, cnn_num, in_channels):
+ super(MTB, self).__init__()
+ self.block = nn.Sequential()
+ self.out_channels = in_channels
+ self.cnn_num = cnn_num
+ if self.cnn_num == 2:
+ for i in range(self.cnn_num):
+ self.block.add_sublayer(
+ 'conv_{}'.format(i),
+ nn.Conv2D(
+ in_channels=in_channels
+ if i == 0 else 32 * (2**(i - 1)),
+ out_channels=32 * (2**i),
+ kernel_size=3,
+ stride=2,
+ padding=1))
+ self.block.add_sublayer('relu_{}'.format(i), nn.ReLU())
+ self.block.add_sublayer('bn_{}'.format(i),
+ nn.BatchNorm2D(32 * (2**i)))
+
+ def forward(self, images):
+ x = self.block(images)
+ if self.cnn_num == 2:
+ # (b, w, h, c)
+ x = paddle.transpose(x, [0, 3, 2, 1])
+ x_shape = paddle.shape(x)
+ x = paddle.reshape(
+ x, [x_shape[0], x_shape[1], x_shape[2] * x_shape[3]])
+ return x
diff --git a/ppocr/modeling/backbones/rec_resnet_31.py b/ppocr/modeling/backbones/rec_resnet_31.py
new file mode 100644
index 0000000000000000000000000000000000000000..f60729cdcced2af7626e5615ca323e32c99432ec
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_31.py
@@ -0,0 +1,176 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+import numpy as np
+
+__all__ = ["ResNet31"]
+
+
+def conv3x3(in_channel, out_channel, stride=1):
+ return nn.Conv2D(
+ in_channel,
+ out_channel,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias_attr=False
+ )
+
+
+class BasicBlock(nn.Layer):
+ expansion = 1
+ def __init__(self, in_channels, channels, stride=1, downsample=False):
+ super().__init__()
+ self.conv1 = conv3x3(in_channels, channels, stride)
+ self.bn1 = nn.BatchNorm2D(channels)
+ self.relu = nn.ReLU()
+ self.conv2 = conv3x3(channels, channels)
+ self.bn2 = nn.BatchNorm2D(channels)
+ self.downsample = downsample
+ if downsample:
+ self.downsample = nn.Sequential(
+ nn.Conv2D(in_channels, channels * self.expansion, 1, stride, bias_attr=False),
+ nn.BatchNorm2D(channels * self.expansion),
+ )
+ else:
+ self.downsample = nn.Sequential()
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet31(nn.Layer):
+ '''
+ Args:
+ in_channels (int): Number of channels of input image tensor.
+ layers (list[int]): List of BasicBlock number for each stage.
+ channels (list[int]): List of out_channels of Conv2d layer.
+ out_indices (None | Sequence[int]): Indices of output stages.
+ last_stage_pool (bool): If True, add `MaxPool2d` layer to last stage.
+ '''
+ def __init__(self,
+ in_channels=3,
+ layers=[1, 2, 5, 3],
+ channels=[64, 128, 256, 256, 512, 512, 512],
+ out_indices=None,
+ last_stage_pool=False):
+ super(ResNet31, self).__init__()
+ assert isinstance(in_channels, int)
+ assert isinstance(last_stage_pool, bool)
+
+ self.out_indices = out_indices
+ self.last_stage_pool = last_stage_pool
+
+ # conv 1 (Conv Conv)
+ self.conv1_1 = nn.Conv2D(in_channels, channels[0], kernel_size=3, stride=1, padding=1)
+ self.bn1_1 = nn.BatchNorm2D(channels[0])
+ self.relu1_1 = nn.ReLU()
+
+ self.conv1_2 = nn.Conv2D(channels[0], channels[1], kernel_size=3, stride=1, padding=1)
+ self.bn1_2 = nn.BatchNorm2D(channels[1])
+ self.relu1_2 = nn.ReLU()
+
+ # conv 2 (Max-pooling, Residual block, Conv)
+ self.pool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self.block2 = self._make_layer(channels[1], channels[2], layers[0])
+ self.conv2 = nn.Conv2D(channels[2], channels[2], kernel_size=3, stride=1, padding=1)
+ self.bn2 = nn.BatchNorm2D(channels[2])
+ self.relu2 = nn.ReLU()
+
+ # conv 3 (Max-pooling, Residual block, Conv)
+ self.pool3 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self.block3 = self._make_layer(channels[2], channels[3], layers[1])
+ self.conv3 = nn.Conv2D(channels[3], channels[3], kernel_size=3, stride=1, padding=1)
+ self.bn3 = nn.BatchNorm2D(channels[3])
+ self.relu3 = nn.ReLU()
+
+ # conv 4 (Max-pooling, Residual block, Conv)
+ self.pool4 = nn.MaxPool2D(kernel_size=(2, 1), stride=(2, 1), padding=0, ceil_mode=True)
+ self.block4 = self._make_layer(channels[3], channels[4], layers[2])
+ self.conv4 = nn.Conv2D(channels[4], channels[4], kernel_size=3, stride=1, padding=1)
+ self.bn4 = nn.BatchNorm2D(channels[4])
+ self.relu4 = nn.ReLU()
+
+ # conv 5 ((Max-pooling), Residual block, Conv)
+ self.pool5 = None
+ if self.last_stage_pool:
+ self.pool5 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0, ceil_mode=True)
+ self.block5 = self._make_layer(channels[4], channels[5], layers[3])
+ self.conv5 = nn.Conv2D(channels[5], channels[5], kernel_size=3, stride=1, padding=1)
+ self.bn5 = nn.BatchNorm2D(channels[5])
+ self.relu5 = nn.ReLU()
+
+ self.out_channels = channels[-1]
+
+ def _make_layer(self, input_channels, output_channels, blocks):
+ layers = []
+ for _ in range(blocks):
+ downsample = None
+ if input_channels != output_channels:
+ downsample = nn.Sequential(
+ nn.Conv2D(
+ input_channels,
+ output_channels,
+ kernel_size=1,
+ stride=1,
+ bias_attr=False),
+ nn.BatchNorm2D(output_channels),
+ )
+
+ layers.append(BasicBlock(input_channels, output_channels, downsample=downsample))
+ input_channels = output_channels
+ return nn.Sequential(*layers)
+
+
+ def forward(self, x):
+ x = self.conv1_1(x)
+ x = self.bn1_1(x)
+ x = self.relu1_1(x)
+
+ x = self.conv1_2(x)
+ x = self.bn1_2(x)
+ x = self.relu1_2(x)
+
+ outs = []
+ for i in range(4):
+ layer_index = i + 2
+ pool_layer = getattr(self, f'pool{layer_index}')
+ block_layer = getattr(self, f'block{layer_index}')
+ conv_layer = getattr(self, f'conv{layer_index}')
+ bn_layer = getattr(self, f'bn{layer_index}')
+ relu_layer = getattr(self, f'relu{layer_index}')
+
+ if pool_layer is not None:
+ x = pool_layer(x)
+ x = block_layer(x)
+ x = conv_layer(x)
+ x = bn_layer(x)
+ x= relu_layer(x)
+
+ outs.append(x)
+
+ if self.out_indices is not None:
+ return tuple([outs[i] for i in self.out_indices])
+
+ return x
diff --git a/ppocr/modeling/backbones/rec_resnet_aster.py b/ppocr/modeling/backbones/rec_resnet_aster.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdecaf46af98f9b967d9a339f82d4e938abdc6d9
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_aster.py
@@ -0,0 +1,140 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+
+import sys
+import math
+
+
+def conv3x3(in_planes, out_planes, stride=1):
+ """3x3 convolution with padding"""
+ return nn.Conv2D(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias_attr=False)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+ """1x1 convolution"""
+ return nn.Conv2D(
+ in_planes, out_planes, kernel_size=1, stride=stride, bias_attr=False)
+
+
+def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000):
+ # [n_position]
+ positions = paddle.arange(0, n_position)
+ # [feat_dim]
+ dim_range = paddle.arange(0, feat_dim)
+ dim_range = paddle.pow(wave_length, 2 * (dim_range // 2) / feat_dim)
+ # [n_position, feat_dim]
+ angles = paddle.unsqueeze(
+ positions, axis=1) / paddle.unsqueeze(
+ dim_range, axis=0)
+ angles = paddle.cast(angles, "float32")
+ angles[:, 0::2] = paddle.sin(angles[:, 0::2])
+ angles[:, 1::2] = paddle.cos(angles[:, 1::2])
+ return angles
+
+
+class AsterBlock(nn.Layer):
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
+ super(AsterBlock, self).__init__()
+ self.conv1 = conv1x1(inplanes, planes, stride)
+ self.bn1 = nn.BatchNorm2D(planes)
+ self.relu = nn.ReLU()
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2D(planes)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+ return out
+
+
+class ResNet_ASTER(nn.Layer):
+ """For aster or crnn"""
+
+ def __init__(self, with_lstm=True, n_group=1, in_channels=3):
+ super(ResNet_ASTER, self).__init__()
+ self.with_lstm = with_lstm
+ self.n_group = n_group
+
+ self.layer0 = nn.Sequential(
+ nn.Conv2D(
+ in_channels,
+ 32,
+ kernel_size=(3, 3),
+ stride=1,
+ padding=1,
+ bias_attr=False),
+ nn.BatchNorm2D(32),
+ nn.ReLU())
+
+ self.inplanes = 32
+ self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50]
+ self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25]
+ self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25]
+ self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25]
+ self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25]
+
+ if with_lstm:
+ self.rnn = nn.LSTM(512, 256, direction="bidirect", num_layers=2)
+ self.out_channels = 2 * 256
+ else:
+ self.out_channels = 512
+
+ def _make_layer(self, planes, blocks, stride):
+ downsample = None
+ if stride != [1, 1] or self.inplanes != planes:
+ downsample = nn.Sequential(
+ conv1x1(self.inplanes, planes, stride), nn.BatchNorm2D(planes))
+
+ layers = []
+ layers.append(AsterBlock(self.inplanes, planes, stride, downsample))
+ self.inplanes = planes
+ for _ in range(1, blocks):
+ layers.append(AsterBlock(self.inplanes, planes))
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x0 = self.layer0(x)
+ x1 = self.layer1(x0)
+ x2 = self.layer2(x1)
+ x3 = self.layer3(x2)
+ x4 = self.layer4(x3)
+ x5 = self.layer5(x4)
+
+ cnn_feat = x5.squeeze(2) # [N, c, w]
+ cnn_feat = paddle.transpose(cnn_feat, perm=[0, 2, 1])
+ if self.with_lstm:
+ rnn_feat, _ = self.rnn(cnn_feat)
+ return rnn_feat
+ else:
+ return cnn_feat
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 5096479415f504aa9f074d55bd9b2e4a31c730b4..fdadfed5e3fe30b6bd311a07d6ba36869f175488 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -20,18 +20,24 @@ def build_head(config):
from .det_db_head import DBHead
from .det_east_head import EASTHead
from .det_sast_head import SASTHead
+ from .det_pse_head import PSEHead
from .e2e_pg_head import PGHead
# rec head
from .rec_ctc_head import CTCHead
from .rec_att_head import AttentionHead
from .rec_srn_head import SRNHead
+ from .rec_nrtr_head import Transformer
+ from .rec_sar_head import SARHead
+ from .rec_aster_head import AsterHead
# cls head
from .cls_head import ClsHead
support_dict = [
- 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'AttentionHead',
- 'SRNHead', 'PGHead', 'TableAttentionHead']
+ 'DBHead', 'PSEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead',
+ 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
+ 'TableAttentionHead', 'SARHead', 'AsterHead'
+ ]
#table head
from .table_att_head import TableAttentionHead
diff --git a/ppocr/modeling/heads/det_pse_head.py b/ppocr/modeling/heads/det_pse_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..db800f57a216ab437b724988ce692a9ac0c545d9
--- /dev/null
+++ b/ppocr/modeling/heads/det_pse_head.py
@@ -0,0 +1,35 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from paddle import nn
+
+
+class PSEHead(nn.Layer):
+ def __init__(self,
+ in_channels,
+ hidden_dim=256,
+ out_channels=7,
+ **kwargs):
+ super(PSEHead, self).__init__()
+ self.conv1 = nn.Conv2D(in_channels, hidden_dim, kernel_size=3, stride=1, padding=1)
+ self.bn1 = nn.BatchNorm2D(hidden_dim)
+ self.relu1 = nn.ReLU()
+
+ self.conv2 = nn.Conv2D(hidden_dim, out_channels, kernel_size=1, stride=1, padding=0)
+
+
+ def forward(self, x, **kwargs):
+ out = self.conv1(x)
+ out = self.relu1(self.bn1(out))
+ out = self.conv2(out)
+ return {'maps': out}
diff --git a/ppocr/modeling/heads/multiheadAttention.py b/ppocr/modeling/heads/multiheadAttention.py
new file mode 100755
index 0000000000000000000000000000000000000000..900865ba1a8d80a108b3247ce1aff91c242860f2
--- /dev/null
+++ b/ppocr/modeling/heads/multiheadAttention.py
@@ -0,0 +1,163 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+from paddle import nn
+import paddle.nn.functional as F
+from paddle.nn import Linear
+from paddle.nn.initializer import XavierUniform as xavier_uniform_
+from paddle.nn.initializer import Constant as constant_
+from paddle.nn.initializer import XavierNormal as xavier_normal_
+
+zeros_ = constant_(value=0.)
+ones_ = constant_(value=1.)
+
+
+class MultiheadAttention(nn.Layer):
+ """Allows the model to jointly attend to information
+ from different representation subspaces.
+ See reference: Attention Is All You Need
+
+ .. math::
+ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
+ \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
+
+ Args:
+ embed_dim: total dimension of the model
+ num_heads: parallel attention layers, or heads
+
+ """
+
+ def __init__(self,
+ embed_dim,
+ num_heads,
+ dropout=0.,
+ bias=True,
+ add_bias_kv=False,
+ add_zero_attn=False):
+ super(MultiheadAttention, self).__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+ assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+ self.scaling = self.head_dim**-0.5
+ self.out_proj = Linear(embed_dim, embed_dim, bias_attr=bias)
+ self._reset_parameters()
+ self.conv1 = paddle.nn.Conv2D(
+ in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
+ self.conv2 = paddle.nn.Conv2D(
+ in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
+ self.conv3 = paddle.nn.Conv2D(
+ in_channels=embed_dim, out_channels=embed_dim, kernel_size=(1, 1))
+
+ def _reset_parameters(self):
+ xavier_uniform_(self.out_proj.weight)
+
+ def forward(self,
+ query,
+ key,
+ value,
+ key_padding_mask=None,
+ incremental_state=None,
+ attn_mask=None):
+ """
+ Inputs of forward function
+ query: [target length, batch size, embed dim]
+ key: [sequence length, batch size, embed dim]
+ value: [sequence length, batch size, embed dim]
+ key_padding_mask: if True, mask padding based on batch size
+ incremental_state: if provided, previous time steps are cashed
+ need_weights: output attn_output_weights
+ static_kv: key and value are static
+
+ Outputs of forward function
+ attn_output: [target length, batch size, embed dim]
+ attn_output_weights: [batch size, target length, sequence length]
+ """
+ q_shape = paddle.shape(query)
+ src_shape = paddle.shape(key)
+ q = self._in_proj_q(query)
+ k = self._in_proj_k(key)
+ v = self._in_proj_v(value)
+ q *= self.scaling
+ q = paddle.transpose(
+ paddle.reshape(
+ q, [q_shape[0], q_shape[1], self.num_heads, self.head_dim]),
+ [1, 2, 0, 3])
+ k = paddle.transpose(
+ paddle.reshape(
+ k, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
+ [1, 2, 0, 3])
+ v = paddle.transpose(
+ paddle.reshape(
+ v, [src_shape[0], q_shape[1], self.num_heads, self.head_dim]),
+ [1, 2, 0, 3])
+ if key_padding_mask is not None:
+ assert key_padding_mask.shape[0] == q_shape[1]
+ assert key_padding_mask.shape[1] == src_shape[0]
+ attn_output_weights = paddle.matmul(q,
+ paddle.transpose(k, [0, 1, 3, 2]))
+ if attn_mask is not None:
+ attn_mask = paddle.unsqueeze(paddle.unsqueeze(attn_mask, 0), 0)
+ attn_output_weights += attn_mask
+ if key_padding_mask is not None:
+ attn_output_weights = paddle.reshape(
+ attn_output_weights,
+ [q_shape[1], self.num_heads, q_shape[0], src_shape[0]])
+ key = paddle.unsqueeze(paddle.unsqueeze(key_padding_mask, 1), 2)
+ key = paddle.cast(key, 'float32')
+ y = paddle.full(
+ shape=paddle.shape(key), dtype='float32', fill_value='-inf')
+ y = paddle.where(key == 0., key, y)
+ attn_output_weights += y
+ attn_output_weights = F.softmax(
+ attn_output_weights.astype('float32'),
+ axis=-1,
+ dtype=paddle.float32 if attn_output_weights.dtype == paddle.float16
+ else attn_output_weights.dtype)
+ attn_output_weights = F.dropout(
+ attn_output_weights, p=self.dropout, training=self.training)
+
+ attn_output = paddle.matmul(attn_output_weights, v)
+ attn_output = paddle.reshape(
+ paddle.transpose(attn_output, [2, 0, 1, 3]),
+ [q_shape[0], q_shape[1], self.embed_dim])
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output
+
+ def _in_proj_q(self, query):
+ query = paddle.transpose(query, [1, 2, 0])
+ query = paddle.unsqueeze(query, axis=2)
+ res = self.conv1(query)
+ res = paddle.squeeze(res, axis=2)
+ res = paddle.transpose(res, [2, 0, 1])
+ return res
+
+ def _in_proj_k(self, key):
+ key = paddle.transpose(key, [1, 2, 0])
+ key = paddle.unsqueeze(key, axis=2)
+ res = self.conv2(key)
+ res = paddle.squeeze(res, axis=2)
+ res = paddle.transpose(res, [2, 0, 1])
+ return res
+
+ def _in_proj_v(self, value):
+ value = paddle.transpose(value, [1, 2, 0]) #(1, 2, 0)
+ value = paddle.unsqueeze(value, axis=2)
+ res = self.conv3(value)
+ res = paddle.squeeze(res, axis=2)
+ res = paddle.transpose(res, [2, 0, 1])
+ return res
diff --git a/ppocr/modeling/heads/rec_aster_head.py b/ppocr/modeling/heads/rec_aster_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4961897b409020fe6cff72eb96f3257156fa33ac
--- /dev/null
+++ b/ppocr/modeling/heads/rec_aster_head.py
@@ -0,0 +1,389 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import paddle
+from paddle import nn
+from paddle.nn import functional as F
+
+
+class AsterHead(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ sDim,
+ attDim,
+ max_len_labels,
+ time_step=25,
+ beam_width=5,
+ **kwargs):
+ super(AsterHead, self).__init__()
+ self.num_classes = out_channels
+ self.in_planes = in_channels
+ self.sDim = sDim
+ self.attDim = attDim
+ self.max_len_labels = max_len_labels
+ self.decoder = AttentionRecognitionHead(in_channels, out_channels, sDim,
+ attDim, max_len_labels)
+ self.time_step = time_step
+ self.embeder = Embedding(self.time_step, in_channels)
+ self.beam_width = beam_width
+ self.eos = self.num_classes - 1
+
+ def forward(self, x, targets=None, embed=None):
+ return_dict = {}
+ embedding_vectors = self.embeder(x)
+
+ if self.training:
+ rec_targets, rec_lengths, _ = targets
+ rec_pred = self.decoder([x, rec_targets, rec_lengths],
+ embedding_vectors)
+ return_dict['rec_pred'] = rec_pred
+ return_dict['embedding_vectors'] = embedding_vectors
+ else:
+ rec_pred, rec_pred_scores = self.decoder.beam_search(
+ x, self.beam_width, self.eos, embedding_vectors)
+ return_dict['rec_pred'] = rec_pred
+ return_dict['rec_pred_scores'] = rec_pred_scores
+ return_dict['embedding_vectors'] = embedding_vectors
+
+ return return_dict
+
+
+class Embedding(nn.Layer):
+ def __init__(self, in_timestep, in_planes, mid_dim=4096, embed_dim=300):
+ super(Embedding, self).__init__()
+ self.in_timestep = in_timestep
+ self.in_planes = in_planes
+ self.embed_dim = embed_dim
+ self.mid_dim = mid_dim
+ self.eEmbed = nn.Linear(
+ in_timestep * in_planes,
+ self.embed_dim) # Embed encoder output to a word-embedding like
+
+ def forward(self, x):
+ x = paddle.reshape(x, [paddle.shape(x)[0], -1])
+ x = self.eEmbed(x)
+ return x
+
+
+class AttentionRecognitionHead(nn.Layer):
+ """
+ input: [b x 16 x 64 x in_planes]
+ output: probability sequence: [b x T x num_classes]
+ """
+
+ def __init__(self, in_channels, out_channels, sDim, attDim, max_len_labels):
+ super(AttentionRecognitionHead, self).__init__()
+ self.num_classes = out_channels # this is the output classes. So it includes the .
+ self.in_planes = in_channels
+ self.sDim = sDim
+ self.attDim = attDim
+ self.max_len_labels = max_len_labels
+
+ self.decoder = DecoderUnit(
+ sDim=sDim, xDim=in_channels, yDim=self.num_classes, attDim=attDim)
+
+ def forward(self, x, embed):
+ x, targets, lengths = x
+ batch_size = paddle.shape(x)[0]
+ # Decoder
+ state = self.decoder.get_initial_state(embed)
+ outputs = []
+ for i in range(max(lengths)):
+ if i == 0:
+ y_prev = paddle.full(
+ shape=[batch_size], fill_value=self.num_classes)
+ else:
+ y_prev = targets[:, i - 1]
+ output, state = self.decoder(x, state, y_prev)
+ outputs.append(output)
+ outputs = paddle.concat([_.unsqueeze(1) for _ in outputs], 1)
+ return outputs
+
+ # inference stage.
+ def sample(self, x):
+ x, _, _ = x
+ batch_size = x.size(0)
+ # Decoder
+ state = paddle.zeros([1, batch_size, self.sDim])
+
+ predicted_ids, predicted_scores = [], []
+ for i in range(self.max_len_labels):
+ if i == 0:
+ y_prev = paddle.full(
+ shape=[batch_size], fill_value=self.num_classes)
+ else:
+ y_prev = predicted
+
+ output, state = self.decoder(x, state, y_prev)
+ output = F.softmax(output, axis=1)
+ score, predicted = output.max(1)
+ predicted_ids.append(predicted.unsqueeze(1))
+ predicted_scores.append(score.unsqueeze(1))
+ predicted_ids = paddle.concat([predicted_ids, 1])
+ predicted_scores = paddle.concat([predicted_scores, 1])
+ # return predicted_ids.squeeze(), predicted_scores.squeeze()
+ return predicted_ids, predicted_scores
+
+ def beam_search(self, x, beam_width, eos, embed):
+ def _inflate(tensor, times, dim):
+ repeat_dims = [1] * tensor.dim()
+ repeat_dims[dim] = times
+ output = paddle.tile(tensor, repeat_dims)
+ return output
+
+ # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py
+ batch_size, l, d = x.shape
+ x = paddle.tile(
+ paddle.transpose(
+ x.unsqueeze(1), perm=[1, 0, 2, 3]), [beam_width, 1, 1, 1])
+ inflated_encoder_feats = paddle.reshape(
+ paddle.transpose(
+ x, perm=[1, 0, 2, 3]), [-1, l, d])
+
+ # Initialize the decoder
+ state = self.decoder.get_initial_state(embed, tile_times=beam_width)
+
+ pos_index = paddle.reshape(
+ paddle.arange(batch_size) * beam_width, shape=[-1, 1])
+
+ # Initialize the scores
+ sequence_scores = paddle.full(
+ shape=[batch_size * beam_width, 1], fill_value=-float('Inf'))
+ index = [i * beam_width for i in range(0, batch_size)]
+ sequence_scores[index] = 0.0
+
+ # Initialize the input vector
+ y_prev = paddle.full(
+ shape=[batch_size * beam_width], fill_value=self.num_classes)
+
+ # Store decisions for backtracking
+ stored_scores = list()
+ stored_predecessors = list()
+ stored_emitted_symbols = list()
+
+ for i in range(self.max_len_labels):
+ output, state = self.decoder(inflated_encoder_feats, state, y_prev)
+ state = paddle.unsqueeze(state, axis=0)
+ log_softmax_output = paddle.nn.functional.log_softmax(
+ output, axis=1)
+
+ sequence_scores = _inflate(sequence_scores, self.num_classes, 1)
+ sequence_scores += log_softmax_output
+ scores, candidates = paddle.topk(
+ paddle.reshape(sequence_scores, [batch_size, -1]),
+ beam_width,
+ axis=1)
+
+ # Reshape input = (bk, 1) and sequence_scores = (bk, 1)
+ y_prev = paddle.reshape(
+ candidates % self.num_classes, shape=[batch_size * beam_width])
+ sequence_scores = paddle.reshape(
+ scores, shape=[batch_size * beam_width, 1])
+
+ # Update fields for next timestep
+ pos_index = paddle.expand_as(pos_index, candidates)
+ predecessors = paddle.cast(
+ candidates / self.num_classes + pos_index, dtype='int64')
+ predecessors = paddle.reshape(
+ predecessors, shape=[batch_size * beam_width, 1])
+ state = paddle.index_select(
+ state, index=predecessors.squeeze(), axis=1)
+
+ # Update sequence socres and erase scores for symbol so that they aren't expanded
+ stored_scores.append(sequence_scores.clone())
+ y_prev = paddle.reshape(y_prev, shape=[-1, 1])
+ eos_prev = paddle.full_like(y_prev, fill_value=eos)
+ mask = eos_prev == y_prev
+ mask = paddle.nonzero(mask)
+ if mask.dim() > 0:
+ sequence_scores = sequence_scores.numpy()
+ mask = mask.numpy()
+ sequence_scores[mask] = -float('inf')
+ sequence_scores = paddle.to_tensor(sequence_scores)
+
+ # Cache results for backtracking
+ stored_predecessors.append(predecessors)
+ y_prev = paddle.squeeze(y_prev)
+ stored_emitted_symbols.append(y_prev)
+
+ # Do backtracking to return the optimal values
+ #====== backtrak ======#
+ # Initialize return variables given different types
+ p = list()
+ l = [[self.max_len_labels] * beam_width for _ in range(batch_size)
+ ] # Placeholder for lengths of top-k sequences
+
+ # the last step output of the beams are not sorted
+ # thus they are sorted here
+ sorted_score, sorted_idx = paddle.topk(
+ paddle.reshape(
+ stored_scores[-1], shape=[batch_size, beam_width]),
+ beam_width)
+
+ # initialize the sequence scores with the sorted last step beam scores
+ s = sorted_score.clone()
+
+ batch_eos_found = [0] * batch_size # the number of EOS found
+ # in the backward loop below for each batch
+ t = self.max_len_labels - 1
+ # initialize the back pointer with the sorted order of the last step beams.
+ # add pos_index for indexing variable with b*k as the first dimension.
+ t_predecessors = paddle.reshape(
+ sorted_idx + pos_index.expand_as(sorted_idx),
+ shape=[batch_size * beam_width])
+ while t >= 0:
+ # Re-order the variables with the back pointer
+ current_symbol = paddle.index_select(
+ stored_emitted_symbols[t], index=t_predecessors, axis=0)
+ t_predecessors = paddle.index_select(
+ stored_predecessors[t].squeeze(), index=t_predecessors, axis=0)
+ eos_indices = stored_emitted_symbols[t] == eos
+ eos_indices = paddle.nonzero(eos_indices)
+
+ if eos_indices.dim() > 0:
+ for i in range(eos_indices.shape[0] - 1, -1, -1):
+ # Indices of the EOS symbol for both variables
+ # with b*k as the first dimension, and b, k for
+ # the first two dimensions
+ idx = eos_indices[i]
+ b_idx = int(idx[0] / beam_width)
+ # The indices of the replacing position
+ # according to the replacement strategy noted above
+ res_k_idx = beam_width - (batch_eos_found[b_idx] %
+ beam_width) - 1
+ batch_eos_found[b_idx] += 1
+ res_idx = b_idx * beam_width + res_k_idx
+
+ # Replace the old information in return variables
+ # with the new ended sequence information
+ t_predecessors[res_idx] = stored_predecessors[t][idx[0]]
+ current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]]
+ s[b_idx, res_k_idx] = stored_scores[t][idx[0], 0]
+ l[b_idx][res_k_idx] = t + 1
+
+ # record the back tracked results
+ p.append(current_symbol)
+ t -= 1
+
+ # Sort and re-order again as the added ended sequences may change
+ # the order (very unlikely)
+ s, re_sorted_idx = s.topk(beam_width)
+ for b_idx in range(batch_size):
+ l[b_idx] = [
+ l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx, :]
+ ]
+
+ re_sorted_idx = paddle.reshape(
+ re_sorted_idx + pos_index.expand_as(re_sorted_idx),
+ [batch_size * beam_width])
+
+ # Reverse the sequences and re-order at the same time
+ # It is reversed because the backtracking happens in reverse time order
+ p = [
+ paddle.reshape(
+ paddle.index_select(step, re_sorted_idx, 0),
+ shape=[batch_size, beam_width, -1]) for step in reversed(p)
+ ]
+ p = paddle.concat(p, -1)[:, 0, :]
+ return p, paddle.ones_like(p)
+
+
+class AttentionUnit(nn.Layer):
+ def __init__(self, sDim, xDim, attDim):
+ super(AttentionUnit, self).__init__()
+
+ self.sDim = sDim
+ self.xDim = xDim
+ self.attDim = attDim
+
+ self.sEmbed = nn.Linear(sDim, attDim)
+ self.xEmbed = nn.Linear(xDim, attDim)
+ self.wEmbed = nn.Linear(attDim, 1)
+
+ def forward(self, x, sPrev):
+ batch_size, T, _ = x.shape # [b x T x xDim]
+ x = paddle.reshape(x, [-1, self.xDim]) # [(b x T) x xDim]
+ xProj = self.xEmbed(x) # [(b x T) x attDim]
+ xProj = paddle.reshape(xProj, [batch_size, T, -1]) # [b x T x attDim]
+
+ sPrev = sPrev.squeeze(0)
+ sProj = self.sEmbed(sPrev) # [b x attDim]
+ sProj = paddle.unsqueeze(sProj, 1) # [b x 1 x attDim]
+ sProj = paddle.expand(sProj,
+ [batch_size, T, self.attDim]) # [b x T x attDim]
+
+ sumTanh = paddle.tanh(sProj + xProj)
+ sumTanh = paddle.reshape(sumTanh, [-1, self.attDim])
+
+ vProj = self.wEmbed(sumTanh) # [(b x T) x 1]
+ vProj = paddle.reshape(vProj, [batch_size, T])
+ alpha = F.softmax(
+ vProj, axis=1) # attention weights for each sample in the minibatch
+ return alpha
+
+
+class DecoderUnit(nn.Layer):
+ def __init__(self, sDim, xDim, yDim, attDim):
+ super(DecoderUnit, self).__init__()
+ self.sDim = sDim
+ self.xDim = xDim
+ self.yDim = yDim
+ self.attDim = attDim
+ self.emdDim = attDim
+
+ self.attention_unit = AttentionUnit(sDim, xDim, attDim)
+ self.tgt_embedding = nn.Embedding(
+ yDim + 1, self.emdDim, weight_attr=nn.initializer.Normal(
+ std=0.01)) # the last is used for
+ self.gru = nn.GRUCell(input_size=xDim + self.emdDim, hidden_size=sDim)
+ self.fc = nn.Linear(
+ sDim,
+ yDim,
+ weight_attr=nn.initializer.Normal(std=0.01),
+ bias_attr=nn.initializer.Constant(value=0))
+ self.embed_fc = nn.Linear(300, self.sDim)
+
+ def get_initial_state(self, embed, tile_times=1):
+ assert embed.shape[1] == 300
+ state = self.embed_fc(embed) # N * sDim
+ if tile_times != 1:
+ state = state.unsqueeze(1)
+ trans_state = paddle.transpose(state, perm=[1, 0, 2])
+ state = paddle.tile(trans_state, repeat_times=[tile_times, 1, 1])
+ trans_state = paddle.transpose(state, perm=[1, 0, 2])
+ state = paddle.reshape(trans_state, shape=[-1, self.sDim])
+ state = state.unsqueeze(0) # 1 * N * sDim
+ return state
+
+ def forward(self, x, sPrev, yPrev):
+ # x: feature sequence from the image decoder.
+ batch_size, T, _ = x.shape
+ alpha = self.attention_unit(x, sPrev)
+ context = paddle.squeeze(paddle.matmul(alpha.unsqueeze(1), x), axis=1)
+ yPrev = paddle.cast(yPrev, dtype="int64")
+ yProj = self.tgt_embedding(yPrev)
+
+ concat_context = paddle.concat([yProj, context], 1)
+ concat_context = paddle.squeeze(concat_context, 1)
+ sPrev = paddle.squeeze(sPrev, 0)
+ output, state = self.gru(concat_context, sPrev)
+ output = paddle.squeeze(output, axis=1)
+ output = self.fc(output)
+ return output, state
\ No newline at end of file
diff --git a/ppocr/modeling/heads/rec_ctc_head.py b/ppocr/modeling/heads/rec_ctc_head.py
index 9c38d31fa0abcf39a583e5edcebfc8f336f41c46..35d33d5f56b3b378286565cbfa9755f43343b278 100755
--- a/ppocr/modeling/heads/rec_ctc_head.py
+++ b/ppocr/modeling/heads/rec_ctc_head.py
@@ -38,6 +38,7 @@ class CTCHead(nn.Layer):
out_channels,
fc_decay=0.0004,
mid_channels=None,
+ return_feats=False,
**kwargs):
super(CTCHead, self).__init__()
if mid_channels is None:
@@ -66,14 +67,22 @@ class CTCHead(nn.Layer):
bias_attr=bias_attr2)
self.out_channels = out_channels
self.mid_channels = mid_channels
+ self.return_feats = return_feats
def forward(self, x, targets=None):
if self.mid_channels is None:
predicts = self.fc(x)
else:
- predicts = self.fc1(x)
- predicts = self.fc2(predicts)
-
+ x = self.fc1(x)
+ predicts = self.fc2(x)
+
+ if self.return_feats:
+ result = (x, predicts)
+ else:
+ result = predicts
+
if not self.training:
predicts = F.softmax(predicts, axis=2)
- return predicts
+ result = predicts
+
+ return result
diff --git a/ppocr/modeling/heads/rec_nrtr_head.py b/ppocr/modeling/heads/rec_nrtr_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..38ba0c917840ea7d1e2a3c2bf0da32c2c35f2b40
--- /dev/null
+++ b/ppocr/modeling/heads/rec_nrtr_head.py
@@ -0,0 +1,826 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import paddle
+import copy
+from paddle import nn
+import paddle.nn.functional as F
+from paddle.nn import LayerList
+from paddle.nn.initializer import XavierNormal as xavier_uniform_
+from paddle.nn import Dropout, Linear, LayerNorm, Conv2D
+import numpy as np
+from ppocr.modeling.heads.multiheadAttention import MultiheadAttention
+from paddle.nn.initializer import Constant as constant_
+from paddle.nn.initializer import XavierNormal as xavier_normal_
+
+zeros_ = constant_(value=0.)
+ones_ = constant_(value=1.)
+
+
+class Transformer(nn.Layer):
+ """A transformer model. User is able to modify the attributes as needed. The architechture
+ is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
+ Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
+ Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
+ Processing Systems, pages 6000-6010.
+
+ Args:
+ d_model: the number of expected features in the encoder/decoder inputs (default=512).
+ nhead: the number of heads in the multiheadattention models (default=8).
+ num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
+ num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+ custom_encoder: custom encoder (default=None).
+ custom_decoder: custom decoder (default=None).
+
+ """
+
+ def __init__(self,
+ d_model=512,
+ nhead=8,
+ num_encoder_layers=6,
+ beam_size=0,
+ num_decoder_layers=6,
+ dim_feedforward=1024,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1,
+ custom_encoder=None,
+ custom_decoder=None,
+ in_channels=0,
+ out_channels=0,
+ scale_embedding=True):
+ super(Transformer, self).__init__()
+ self.out_channels = out_channels + 1
+ self.embedding = Embeddings(
+ d_model=d_model,
+ vocab=self.out_channels,
+ padding_idx=0,
+ scale_embedding=scale_embedding)
+ self.positional_encoding = PositionalEncoding(
+ dropout=residual_dropout_rate,
+ dim=d_model, )
+ if custom_encoder is not None:
+ self.encoder = custom_encoder
+ else:
+ if num_encoder_layers > 0:
+ encoder_layer = TransformerEncoderLayer(
+ d_model, nhead, dim_feedforward, attention_dropout_rate,
+ residual_dropout_rate)
+ self.encoder = TransformerEncoder(encoder_layer,
+ num_encoder_layers)
+ else:
+ self.encoder = None
+
+ if custom_decoder is not None:
+ self.decoder = custom_decoder
+ else:
+ decoder_layer = TransformerDecoderLayer(
+ d_model, nhead, dim_feedforward, attention_dropout_rate,
+ residual_dropout_rate)
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers)
+
+ self._reset_parameters()
+ self.beam_size = beam_size
+ self.d_model = d_model
+ self.nhead = nhead
+ self.tgt_word_prj = nn.Linear(
+ d_model, self.out_channels, bias_attr=False)
+ w0 = np.random.normal(0.0, d_model**-0.5,
+ (d_model, self.out_channels)).astype(np.float32)
+ self.tgt_word_prj.weight.set_value(w0)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+
+ if isinstance(m, nn.Conv2D):
+ xavier_normal_(m.weight)
+ if m.bias is not None:
+ zeros_(m.bias)
+
+ def forward_train(self, src, tgt):
+ tgt = tgt[:, :-1]
+
+ tgt_key_padding_mask = self.generate_padding_mask(tgt)
+ tgt = self.embedding(tgt).transpose([1, 0, 2])
+ tgt = self.positional_encoding(tgt)
+ tgt_mask = self.generate_square_subsequent_mask(tgt.shape[0])
+
+ if self.encoder is not None:
+ src = self.positional_encoding(src.transpose([1, 0, 2]))
+ memory = self.encoder(src)
+ else:
+ memory = src.squeeze(2).transpose([2, 0, 1])
+ output = self.decoder(
+ tgt,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=None,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=None)
+ output = output.transpose([1, 0, 2])
+ logit = self.tgt_word_prj(output)
+ return logit
+
+ def forward(self, src, targets=None):
+ """Take in and process masked source/target sequences.
+ Args:
+ src: the sequence to the encoder (required).
+ tgt: the sequence to the decoder (required).
+ Shape:
+ - src: :math:`(S, N, E)`.
+ - tgt: :math:`(T, N, E)`.
+ Examples:
+ >>> output = transformer_model(src, tgt)
+ """
+
+ if self.training:
+ max_len = targets[1].max()
+ tgt = targets[0][:, :2 + max_len]
+ return self.forward_train(src, tgt)
+ else:
+ if self.beam_size > 0:
+ return self.forward_beam(src)
+ else:
+ return self.forward_test(src)
+
+ def forward_test(self, src):
+ bs = paddle.shape(src)[0]
+ if self.encoder is not None:
+ src = self.positional_encoding(paddle.transpose(src, [1, 0, 2]))
+ memory = self.encoder(src)
+ else:
+ memory = paddle.transpose(paddle.squeeze(src, 2), [2, 0, 1])
+ dec_seq = paddle.full((bs, 1), 2, dtype=paddle.int64)
+ dec_prob = paddle.full((bs, 1), 1., dtype=paddle.float32)
+ for len_dec_seq in range(1, 25):
+ dec_seq_embed = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
+ dec_seq_embed = self.positional_encoding(dec_seq_embed)
+ tgt_mask = self.generate_square_subsequent_mask(
+ paddle.shape(dec_seq_embed)[0])
+ output = self.decoder(
+ dec_seq_embed,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None)
+ dec_output = paddle.transpose(output, [1, 0, 2])
+ dec_output = dec_output[:, -1, :]
+ word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
+ preds_idx = paddle.argmax(word_prob, axis=1)
+ if paddle.equal_all(
+ preds_idx,
+ paddle.full(
+ paddle.shape(preds_idx), 3, dtype='int64')):
+ break
+ preds_prob = paddle.max(word_prob, axis=1)
+ dec_seq = paddle.concat(
+ [dec_seq, paddle.reshape(preds_idx, [-1, 1])], axis=1)
+ dec_prob = paddle.concat(
+ [dec_prob, paddle.reshape(preds_prob, [-1, 1])], axis=1)
+ return [dec_seq, dec_prob]
+
+ def forward_beam(self, images):
+ ''' Translation work in one batch '''
+
+ def get_inst_idx_to_tensor_position_map(inst_idx_list):
+ ''' Indicate the position of an instance in a tensor. '''
+ return {
+ inst_idx: tensor_position
+ for tensor_position, inst_idx in enumerate(inst_idx_list)
+ }
+
+ def collect_active_part(beamed_tensor, curr_active_inst_idx,
+ n_prev_active_inst, n_bm):
+ ''' Collect tensor parts associated to active instances. '''
+
+ beamed_tensor_shape = paddle.shape(beamed_tensor)
+ n_curr_active_inst = len(curr_active_inst_idx)
+ new_shape = (n_curr_active_inst * n_bm, beamed_tensor_shape[1],
+ beamed_tensor_shape[2])
+
+ beamed_tensor = beamed_tensor.reshape([n_prev_active_inst, -1])
+ beamed_tensor = beamed_tensor.index_select(
+ curr_active_inst_idx, axis=0)
+ beamed_tensor = beamed_tensor.reshape(new_shape)
+
+ return beamed_tensor
+
+ def collate_active_info(src_enc, inst_idx_to_position_map,
+ active_inst_idx_list):
+ # Sentences which are still active are collected,
+ # so the decoder will not run on completed sentences.
+
+ n_prev_active_inst = len(inst_idx_to_position_map)
+ active_inst_idx = [
+ inst_idx_to_position_map[k] for k in active_inst_idx_list
+ ]
+ active_inst_idx = paddle.to_tensor(active_inst_idx, dtype='int64')
+ active_src_enc = collect_active_part(
+ src_enc.transpose([1, 0, 2]), active_inst_idx,
+ n_prev_active_inst, n_bm).transpose([1, 0, 2])
+ active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
+ active_inst_idx_list)
+ return active_src_enc, active_inst_idx_to_position_map
+
+ def beam_decode_step(inst_dec_beams, len_dec_seq, enc_output,
+ inst_idx_to_position_map, n_bm,
+ memory_key_padding_mask):
+ ''' Decode and update beam status, and then return active beam idx '''
+
+ def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
+ dec_partial_seq = [
+ b.get_current_state() for b in inst_dec_beams if not b.done
+ ]
+ dec_partial_seq = paddle.stack(dec_partial_seq)
+ dec_partial_seq = dec_partial_seq.reshape([-1, len_dec_seq])
+ return dec_partial_seq
+
+ def predict_word(dec_seq, enc_output, n_active_inst, n_bm,
+ memory_key_padding_mask):
+ dec_seq = paddle.transpose(self.embedding(dec_seq), [1, 0, 2])
+ dec_seq = self.positional_encoding(dec_seq)
+ tgt_mask = self.generate_square_subsequent_mask(
+ paddle.shape(dec_seq)[0])
+ dec_output = self.decoder(
+ dec_seq,
+ enc_output,
+ tgt_mask=tgt_mask,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=memory_key_padding_mask, )
+ dec_output = paddle.transpose(dec_output, [1, 0, 2])
+ dec_output = dec_output[:,
+ -1, :] # Pick the last step: (bh * bm) * d_h
+ word_prob = F.softmax(self.tgt_word_prj(dec_output), axis=1)
+ word_prob = paddle.reshape(word_prob, [n_active_inst, n_bm, -1])
+ return word_prob
+
+ def collect_active_inst_idx_list(inst_beams, word_prob,
+ inst_idx_to_position_map):
+ active_inst_idx_list = []
+ for inst_idx, inst_position in inst_idx_to_position_map.items():
+ is_inst_complete = inst_beams[inst_idx].advance(word_prob[
+ inst_position])
+ if not is_inst_complete:
+ active_inst_idx_list += [inst_idx]
+
+ return active_inst_idx_list
+
+ n_active_inst = len(inst_idx_to_position_map)
+ dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
+ word_prob = predict_word(dec_seq, enc_output, n_active_inst, n_bm,
+ None)
+ # Update the beam with predicted word prob information and collect incomplete instances
+ active_inst_idx_list = collect_active_inst_idx_list(
+ inst_dec_beams, word_prob, inst_idx_to_position_map)
+ return active_inst_idx_list
+
+ def collect_hypothesis_and_scores(inst_dec_beams, n_best):
+ all_hyp, all_scores = [], []
+ for inst_idx in range(len(inst_dec_beams)):
+ scores, tail_idxs = inst_dec_beams[inst_idx].sort_scores()
+ all_scores += [scores[:n_best]]
+ hyps = [
+ inst_dec_beams[inst_idx].get_hypothesis(i)
+ for i in tail_idxs[:n_best]
+ ]
+ all_hyp += [hyps]
+ return all_hyp, all_scores
+
+ with paddle.no_grad():
+ #-- Encode
+ if self.encoder is not None:
+ src = self.positional_encoding(images.transpose([1, 0, 2]))
+ src_enc = self.encoder(src)
+ else:
+ src_enc = images.squeeze(2).transpose([0, 2, 1])
+
+ n_bm = self.beam_size
+ src_shape = paddle.shape(src_enc)
+ inst_dec_beams = [Beam(n_bm) for _ in range(1)]
+ active_inst_idx_list = list(range(1))
+ # Repeat data for beam search
+ src_enc = paddle.tile(src_enc, [1, n_bm, 1])
+ inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
+ active_inst_idx_list)
+ # Decode
+ for len_dec_seq in range(1, 25):
+ src_enc_copy = src_enc.clone()
+ active_inst_idx_list = beam_decode_step(
+ inst_dec_beams, len_dec_seq, src_enc_copy,
+ inst_idx_to_position_map, n_bm, None)
+ if not active_inst_idx_list:
+ break # all instances have finished their path to
+ src_enc, inst_idx_to_position_map = collate_active_info(
+ src_enc_copy, inst_idx_to_position_map,
+ active_inst_idx_list)
+ batch_hyp, batch_scores = collect_hypothesis_and_scores(inst_dec_beams,
+ 1)
+ result_hyp = []
+ hyp_scores = []
+ for bs_hyp, score in zip(batch_hyp, batch_scores):
+ l = len(bs_hyp[0])
+ bs_hyp_pad = bs_hyp[0] + [3] * (25 - l)
+ result_hyp.append(bs_hyp_pad)
+ score = float(score) / l
+ hyp_score = [score for _ in range(25)]
+ hyp_scores.append(hyp_score)
+ return [
+ paddle.to_tensor(
+ np.array(result_hyp), dtype=paddle.int64),
+ paddle.to_tensor(hyp_scores)
+ ]
+
+ def generate_square_subsequent_mask(self, sz):
+ """Generate a square mask for the sequence. The masked positions are filled with float('-inf').
+ Unmasked positions are filled with float(0.0).
+ """
+ mask = paddle.zeros([sz, sz], dtype='float32')
+ mask_inf = paddle.triu(
+ paddle.full(
+ shape=[sz, sz], dtype='float32', fill_value='-inf'),
+ diagonal=1)
+ mask = mask + mask_inf
+ return mask
+
+ def generate_padding_mask(self, x):
+ padding_mask = paddle.equal(x, paddle.to_tensor(0, dtype=x.dtype))
+ return padding_mask
+
+ def _reset_parameters(self):
+ """Initiate parameters in the transformer model."""
+
+ for p in self.parameters():
+ if p.dim() > 1:
+ xavier_uniform_(p)
+
+
+class TransformerEncoder(nn.Layer):
+ """TransformerEncoder is a stack of N encoder layers
+ Args:
+ encoder_layer: an instance of the TransformerEncoderLayer() class (required).
+ num_layers: the number of sub-encoder-layers in the encoder (required).
+ norm: the layer normalization component (optional).
+ """
+
+ def __init__(self, encoder_layer, num_layers):
+ super(TransformerEncoder, self).__init__()
+ self.layers = _get_clones(encoder_layer, num_layers)
+ self.num_layers = num_layers
+
+ def forward(self, src):
+ """Pass the input through the endocder layers in turn.
+ Args:
+ src: the sequnce to the encoder (required).
+ mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+ """
+ output = src
+
+ for i in range(self.num_layers):
+ output = self.layers[i](output,
+ src_mask=None,
+ src_key_padding_mask=None)
+
+ return output
+
+
+class TransformerDecoder(nn.Layer):
+ """TransformerDecoder is a stack of N decoder layers
+
+ Args:
+ decoder_layer: an instance of the TransformerDecoderLayer() class (required).
+ num_layers: the number of sub-decoder-layers in the decoder (required).
+ norm: the layer normalization component (optional).
+
+ """
+
+ def __init__(self, decoder_layer, num_layers):
+ super(TransformerDecoder, self).__init__()
+ self.layers = _get_clones(decoder_layer, num_layers)
+ self.num_layers = num_layers
+
+ def forward(self,
+ tgt,
+ memory,
+ tgt_mask=None,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None):
+ """Pass the inputs (and mask) through the decoder layer in turn.
+
+ Args:
+ tgt: the sequence to the decoder (required).
+ memory: the sequnce from the last layer of the encoder (required).
+ tgt_mask: the mask for the tgt sequence (optional).
+ memory_mask: the mask for the memory sequence (optional).
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
+ """
+ output = tgt
+ for i in range(self.num_layers):
+ output = self.layers[i](
+ output,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_mask=memory_mask,
+ tgt_key_padding_mask=tgt_key_padding_mask,
+ memory_key_padding_mask=memory_key_padding_mask)
+
+ return output
+
+
+class TransformerEncoderLayer(nn.Layer):
+ """TransformerEncoderLayer is made up of self-attn and feedforward network.
+ This standard encoder layer is based on the paper "Attention Is All You Need".
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
+ in a different way during application.
+
+ Args:
+ d_model: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+
+ """
+
+ def __init__(self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1):
+ super(TransformerEncoderLayer, self).__init__()
+ self.self_attn = MultiheadAttention(
+ d_model, nhead, dropout=attention_dropout_rate)
+
+ self.conv1 = Conv2D(
+ in_channels=d_model,
+ out_channels=dim_feedforward,
+ kernel_size=(1, 1))
+ self.conv2 = Conv2D(
+ in_channels=dim_feedforward,
+ out_channels=d_model,
+ kernel_size=(1, 1))
+
+ self.norm1 = LayerNorm(d_model)
+ self.norm2 = LayerNorm(d_model)
+ self.dropout1 = Dropout(residual_dropout_rate)
+ self.dropout2 = Dropout(residual_dropout_rate)
+
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
+ """Pass the input through the endocder layer.
+ Args:
+ src: the sequnce to the encoder layer (required).
+ src_mask: the mask for the src sequence (optional).
+ src_key_padding_mask: the mask for the src keys per batch (optional).
+ """
+ src2 = self.self_attn(
+ src,
+ src,
+ src,
+ attn_mask=src_mask,
+ key_padding_mask=src_key_padding_mask)
+ src = src + self.dropout1(src2)
+ src = self.norm1(src)
+
+ src = paddle.transpose(src, [1, 2, 0])
+ src = paddle.unsqueeze(src, 2)
+ src2 = self.conv2(F.relu(self.conv1(src)))
+ src2 = paddle.squeeze(src2, 2)
+ src2 = paddle.transpose(src2, [2, 0, 1])
+ src = paddle.squeeze(src, 2)
+ src = paddle.transpose(src, [2, 0, 1])
+
+ src = src + self.dropout2(src2)
+ src = self.norm2(src)
+ return src
+
+
+class TransformerDecoderLayer(nn.Layer):
+ """TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
+ This standard decoder layer is based on the paper "Attention Is All You Need".
+ Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
+ Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
+ Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
+ in a different way during application.
+
+ Args:
+ d_model: the number of expected features in the input (required).
+ nhead: the number of heads in the multiheadattention models (required).
+ dim_feedforward: the dimension of the feedforward network model (default=2048).
+ dropout: the dropout value (default=0.1).
+
+ """
+
+ def __init__(self,
+ d_model,
+ nhead,
+ dim_feedforward=2048,
+ attention_dropout_rate=0.0,
+ residual_dropout_rate=0.1):
+ super(TransformerDecoderLayer, self).__init__()
+ self.self_attn = MultiheadAttention(
+ d_model, nhead, dropout=attention_dropout_rate)
+ self.multihead_attn = MultiheadAttention(
+ d_model, nhead, dropout=attention_dropout_rate)
+
+ self.conv1 = Conv2D(
+ in_channels=d_model,
+ out_channels=dim_feedforward,
+ kernel_size=(1, 1))
+ self.conv2 = Conv2D(
+ in_channels=dim_feedforward,
+ out_channels=d_model,
+ kernel_size=(1, 1))
+
+ self.norm1 = LayerNorm(d_model)
+ self.norm2 = LayerNorm(d_model)
+ self.norm3 = LayerNorm(d_model)
+ self.dropout1 = Dropout(residual_dropout_rate)
+ self.dropout2 = Dropout(residual_dropout_rate)
+ self.dropout3 = Dropout(residual_dropout_rate)
+
+ def forward(self,
+ tgt,
+ memory,
+ tgt_mask=None,
+ memory_mask=None,
+ tgt_key_padding_mask=None,
+ memory_key_padding_mask=None):
+ """Pass the inputs (and mask) through the decoder layer.
+
+ Args:
+ tgt: the sequence to the decoder layer (required).
+ memory: the sequnce from the last layer of the encoder (required).
+ tgt_mask: the mask for the tgt sequence (optional).
+ memory_mask: the mask for the memory sequence (optional).
+ tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
+ memory_key_padding_mask: the mask for the memory keys per batch (optional).
+
+ """
+ tgt2 = self.self_attn(
+ tgt,
+ tgt,
+ tgt,
+ attn_mask=tgt_mask,
+ key_padding_mask=tgt_key_padding_mask)
+ tgt = tgt + self.dropout1(tgt2)
+ tgt = self.norm1(tgt)
+ tgt2 = self.multihead_attn(
+ tgt,
+ memory,
+ memory,
+ attn_mask=memory_mask,
+ key_padding_mask=memory_key_padding_mask)
+ tgt = tgt + self.dropout2(tgt2)
+ tgt = self.norm2(tgt)
+
+ # default
+ tgt = paddle.transpose(tgt, [1, 2, 0])
+ tgt = paddle.unsqueeze(tgt, 2)
+ tgt2 = self.conv2(F.relu(self.conv1(tgt)))
+ tgt2 = paddle.squeeze(tgt2, 2)
+ tgt2 = paddle.transpose(tgt2, [2, 0, 1])
+ tgt = paddle.squeeze(tgt, 2)
+ tgt = paddle.transpose(tgt, [2, 0, 1])
+
+ tgt = tgt + self.dropout3(tgt2)
+ tgt = self.norm3(tgt)
+ return tgt
+
+
+def _get_clones(module, N):
+ return LayerList([copy.deepcopy(module) for i in range(N)])
+
+
+class PositionalEncoding(nn.Layer):
+ """Inject some information about the relative or absolute position of the tokens
+ in the sequence. The positional encodings have the same dimension as
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
+ functions of different frequencies.
+ .. math::
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
+ \text{where pos is the word position and i is the embed idx)
+ Args:
+ d_model: the embed dim (required).
+ dropout: the dropout value (default=0.1).
+ max_len: the max. length of the incoming sequence (default=5000).
+ Examples:
+ >>> pos_encoder = PositionalEncoding(d_model)
+ """
+
+ def __init__(self, dropout, dim, max_len=5000):
+ super(PositionalEncoding, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = paddle.zeros([max_len, dim])
+ position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
+ div_term = paddle.exp(
+ paddle.arange(0, dim, 2).astype('float32') *
+ (-math.log(10000.0) / dim))
+ pe[:, 0::2] = paddle.sin(position * div_term)
+ pe[:, 1::2] = paddle.cos(position * div_term)
+ pe = paddle.unsqueeze(pe, 0)
+ pe = paddle.transpose(pe, [1, 0, 2])
+ self.register_buffer('pe', pe)
+
+ def forward(self, x):
+ """Inputs of forward function
+ Args:
+ x: the sequence fed to the positional encoder model (required).
+ Shape:
+ x: [sequence length, batch size, embed dim]
+ output: [sequence length, batch size, embed dim]
+ Examples:
+ >>> output = pos_encoder(x)
+ """
+ x = x + self.pe[:paddle.shape(x)[0], :]
+ return self.dropout(x)
+
+
+class PositionalEncoding_2d(nn.Layer):
+ """Inject some information about the relative or absolute position of the tokens
+ in the sequence. The positional encodings have the same dimension as
+ the embeddings, so that the two can be summed. Here, we use sine and cosine
+ functions of different frequencies.
+ .. math::
+ \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
+ \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
+ \text{where pos is the word position and i is the embed idx)
+ Args:
+ d_model: the embed dim (required).
+ dropout: the dropout value (default=0.1).
+ max_len: the max. length of the incoming sequence (default=5000).
+ Examples:
+ >>> pos_encoder = PositionalEncoding(d_model)
+ """
+
+ def __init__(self, dropout, dim, max_len=5000):
+ super(PositionalEncoding_2d, self).__init__()
+ self.dropout = nn.Dropout(p=dropout)
+
+ pe = paddle.zeros([max_len, dim])
+ position = paddle.arange(0, max_len, dtype=paddle.float32).unsqueeze(1)
+ div_term = paddle.exp(
+ paddle.arange(0, dim, 2).astype('float32') *
+ (-math.log(10000.0) / dim))
+ pe[:, 0::2] = paddle.sin(position * div_term)
+ pe[:, 1::2] = paddle.cos(position * div_term)
+ pe = paddle.transpose(paddle.unsqueeze(pe, 0), [1, 0, 2])
+ self.register_buffer('pe', pe)
+
+ self.avg_pool_1 = nn.AdaptiveAvgPool2D((1, 1))
+ self.linear1 = nn.Linear(dim, dim)
+ self.linear1.weight.data.fill_(1.)
+ self.avg_pool_2 = nn.AdaptiveAvgPool2D((1, 1))
+ self.linear2 = nn.Linear(dim, dim)
+ self.linear2.weight.data.fill_(1.)
+
+ def forward(self, x):
+ """Inputs of forward function
+ Args:
+ x: the sequence fed to the positional encoder model (required).
+ Shape:
+ x: [sequence length, batch size, embed dim]
+ output: [sequence length, batch size, embed dim]
+ Examples:
+ >>> output = pos_encoder(x)
+ """
+ w_pe = self.pe[:paddle.shape(x)[-1], :]
+ w1 = self.linear1(self.avg_pool_1(x).squeeze()).unsqueeze(0)
+ w_pe = w_pe * w1
+ w_pe = paddle.transpose(w_pe, [1, 2, 0])
+ w_pe = paddle.unsqueeze(w_pe, 2)
+
+ h_pe = self.pe[:paddle.shape(x).shape[-2], :]
+ w2 = self.linear2(self.avg_pool_2(x).squeeze()).unsqueeze(0)
+ h_pe = h_pe * w2
+ h_pe = paddle.transpose(h_pe, [1, 2, 0])
+ h_pe = paddle.unsqueeze(h_pe, 3)
+
+ x = x + w_pe + h_pe
+ x = paddle.transpose(
+ paddle.reshape(x,
+ [x.shape[0], x.shape[1], x.shape[2] * x.shape[3]]),
+ [2, 0, 1])
+
+ return self.dropout(x)
+
+
+class Embeddings(nn.Layer):
+ def __init__(self, d_model, vocab, padding_idx, scale_embedding):
+ super(Embeddings, self).__init__()
+ self.embedding = nn.Embedding(vocab, d_model, padding_idx=padding_idx)
+ w0 = np.random.normal(0.0, d_model**-0.5,
+ (vocab, d_model)).astype(np.float32)
+ self.embedding.weight.set_value(w0)
+ self.d_model = d_model
+ self.scale_embedding = scale_embedding
+
+ def forward(self, x):
+ if self.scale_embedding:
+ x = self.embedding(x)
+ return x * math.sqrt(self.d_model)
+ return self.embedding(x)
+
+
+class Beam():
+ ''' Beam search '''
+
+ def __init__(self, size, device=False):
+
+ self.size = size
+ self._done = False
+ # The score for each translation on the beam.
+ self.scores = paddle.zeros((size, ), dtype=paddle.float32)
+ self.all_scores = []
+ # The backpointers at each time-step.
+ self.prev_ks = []
+ # The outputs at each time-step.
+ self.next_ys = [paddle.full((size, ), 0, dtype=paddle.int64)]
+ self.next_ys[0][0] = 2
+
+ def get_current_state(self):
+ "Get the outputs for the current timestep."
+ return self.get_tentative_hypothesis()
+
+ def get_current_origin(self):
+ "Get the backpointers for the current timestep."
+ return self.prev_ks[-1]
+
+ @property
+ def done(self):
+ return self._done
+
+ def advance(self, word_prob):
+ "Update beam status and check if finished or not."
+ num_words = word_prob.shape[1]
+
+ # Sum the previous scores.
+ if len(self.prev_ks) > 0:
+ beam_lk = word_prob + self.scores.unsqueeze(1).expand_as(word_prob)
+ else:
+ beam_lk = word_prob[0]
+
+ flat_beam_lk = beam_lk.reshape([-1])
+ best_scores, best_scores_id = flat_beam_lk.topk(self.size, 0, True,
+ True) # 1st sort
+ self.all_scores.append(self.scores)
+ self.scores = best_scores
+ # bestScoresId is flattened as a (beam x word) array,
+ # so we need to calculate which word and beam each score came from
+ prev_k = best_scores_id // num_words
+ self.prev_ks.append(prev_k)
+ self.next_ys.append(best_scores_id - prev_k * num_words)
+ # End condition is when top-of-beam is EOS.
+ if self.next_ys[-1][0] == 3:
+ self._done = True
+ self.all_scores.append(self.scores)
+
+ return self._done
+
+ def sort_scores(self):
+ "Sort the scores."
+ return self.scores, paddle.to_tensor(
+ [i for i in range(int(self.scores.shape[0]))], dtype='int32')
+
+ def get_the_best_score_and_idx(self):
+ "Get the score of the best in the beam."
+ scores, ids = self.sort_scores()
+ return scores[1], ids[1]
+
+ def get_tentative_hypothesis(self):
+ "Get the decoded sequence for the current timestep."
+ if len(self.next_ys) == 1:
+ dec_seq = self.next_ys[0].unsqueeze(1)
+ else:
+ _, keys = self.sort_scores()
+ hyps = [self.get_hypothesis(k) for k in keys]
+ hyps = [[2] + h for h in hyps]
+ dec_seq = paddle.to_tensor(hyps, dtype='int64')
+ return dec_seq
+
+ def get_hypothesis(self, k):
+ """ Walk back to construct the full hypothesis. """
+ hyp = []
+ for j in range(len(self.prev_ks) - 1, -1, -1):
+ hyp.append(self.next_ys[j + 1][k])
+ k = self.prev_ks[j][k]
+ return list(map(lambda x: x.item(), hyp[::-1]))
diff --git a/ppocr/modeling/heads/rec_sar_head.py b/ppocr/modeling/heads/rec_sar_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7107788d9ef3b49ac6d4dcd4a8133a9603ada19b
--- /dev/null
+++ b/ppocr/modeling/heads/rec_sar_head.py
@@ -0,0 +1,384 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import ParamAttr
+import paddle.nn as nn
+import paddle.nn.functional as F
+
+
+class SAREncoder(nn.Layer):
+ """
+ Args:
+ enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
+ enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
+ enc_gru (bool): If True, use GRU, else LSTM in encoder.
+ d_model (int): Dim of channels from backbone.
+ d_enc (int): Dim of encoder RNN layer.
+ mask (bool): If True, mask padding in RNN sequence.
+ """
+
+ def __init__(self,
+ enc_bi_rnn=False,
+ enc_drop_rnn=0.1,
+ enc_gru=False,
+ d_model=512,
+ d_enc=512,
+ mask=True,
+ **kwargs):
+ super().__init__()
+ assert isinstance(enc_bi_rnn, bool)
+ assert isinstance(enc_drop_rnn, (int, float))
+ assert 0 <= enc_drop_rnn < 1.0
+ assert isinstance(enc_gru, bool)
+ assert isinstance(d_model, int)
+ assert isinstance(d_enc, int)
+ assert isinstance(mask, bool)
+
+ self.enc_bi_rnn = enc_bi_rnn
+ self.enc_drop_rnn = enc_drop_rnn
+ self.mask = mask
+
+ # LSTM Encoder
+ if enc_bi_rnn:
+ direction = 'bidirectional'
+ else:
+ direction = 'forward'
+ kwargs = dict(
+ input_size=d_model,
+ hidden_size=d_enc,
+ num_layers=2,
+ time_major=False,
+ dropout=enc_drop_rnn,
+ direction=direction)
+ if enc_gru:
+ self.rnn_encoder = nn.GRU(**kwargs)
+ else:
+ self.rnn_encoder = nn.LSTM(**kwargs)
+
+ # global feature transformation
+ encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
+ self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
+
+ def forward(self, feat, img_metas=None):
+ if img_metas is not None:
+ assert len(img_metas[0]) == feat.shape[0]
+
+ valid_ratios = None
+ if img_metas is not None and self.mask:
+ valid_ratios = img_metas[-1]
+
+ h_feat = feat.shape[2] # bsz c h w
+ feat_v = F.max_pool2d(
+ feat, kernel_size=(h_feat, 1), stride=1, padding=0)
+ feat_v = feat_v.squeeze(2) # bsz * C * W
+ feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
+ holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
+
+ if valid_ratios is not None:
+ valid_hf = []
+ T = holistic_feat.shape[1]
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_step = min(T, math.ceil(T * valid_ratio)) - 1
+ valid_hf.append(holistic_feat[i, valid_step, :])
+ valid_hf = paddle.stack(valid_hf, axis=0)
+ else:
+ valid_hf = holistic_feat[:, -1, :] # bsz * C
+ holistic_feat = self.linear(valid_hf) # bsz * C
+
+ return holistic_feat
+
+
+class BaseDecoder(nn.Layer):
+ def __init__(self, **kwargs):
+ super().__init__()
+
+ def forward_train(self, feat, out_enc, targets, img_metas):
+ raise NotImplementedError
+
+ def forward_test(self, feat, out_enc, img_metas):
+ raise NotImplementedError
+
+ def forward(self,
+ feat,
+ out_enc,
+ label=None,
+ img_metas=None,
+ train_mode=True):
+ self.train_mode = train_mode
+
+ if train_mode:
+ return self.forward_train(feat, out_enc, label, img_metas)
+ return self.forward_test(feat, out_enc, img_metas)
+
+
+class ParallelSARDecoder(BaseDecoder):
+ """
+ Args:
+ out_channels (int): Output class number.
+ enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
+ dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
+ dec_drop_rnn (float): Dropout of RNN layer in decoder.
+ dec_gru (bool): If True, use GRU, else LSTM in decoder.
+ d_model (int): Dim of channels from backbone.
+ d_enc (int): Dim of encoder RNN layer.
+ d_k (int): Dim of channels of attention module.
+ pred_dropout (float): Dropout probability of prediction layer.
+ max_seq_len (int): Maximum sequence length for decoding.
+ mask (bool): If True, mask padding in feature map.
+ start_idx (int): Index of start token.
+ padding_idx (int): Index of padding token.
+ pred_concat (bool): If True, concat glimpse feature from
+ attention with holistic feature and hidden state.
+ """
+
+ def __init__(
+ self,
+ out_channels, # 90 + unknown + start + padding
+ enc_bi_rnn=False,
+ dec_bi_rnn=False,
+ dec_drop_rnn=0.0,
+ dec_gru=False,
+ d_model=512,
+ d_enc=512,
+ d_k=64,
+ pred_dropout=0.1,
+ max_text_length=30,
+ mask=True,
+ pred_concat=True,
+ **kwargs):
+ super().__init__()
+
+ self.num_classes = out_channels
+ self.enc_bi_rnn = enc_bi_rnn
+ self.d_k = d_k
+ self.start_idx = out_channels - 2
+ self.padding_idx = out_channels - 1
+ self.max_seq_len = max_text_length
+ self.mask = mask
+ self.pred_concat = pred_concat
+
+ encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
+ decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
+
+ # 2D attention layer
+ self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
+ self.conv3x3_1 = nn.Conv2D(
+ d_model, d_k, kernel_size=3, stride=1, padding=1)
+ self.conv1x1_2 = nn.Linear(d_k, 1)
+
+ # Decoder RNN layer
+ if dec_bi_rnn:
+ direction = 'bidirectional'
+ else:
+ direction = 'forward'
+
+ kwargs = dict(
+ input_size=encoder_rnn_out_size,
+ hidden_size=encoder_rnn_out_size,
+ num_layers=2,
+ time_major=False,
+ dropout=dec_drop_rnn,
+ direction=direction)
+ if dec_gru:
+ self.rnn_decoder = nn.GRU(**kwargs)
+ else:
+ self.rnn_decoder = nn.LSTM(**kwargs)
+
+ # Decoder input embedding
+ self.embedding = nn.Embedding(
+ self.num_classes,
+ encoder_rnn_out_size,
+ padding_idx=self.padding_idx)
+
+ # Prediction layer
+ self.pred_dropout = nn.Dropout(pred_dropout)
+ pred_num_classes = self.num_classes - 1
+ if pred_concat:
+ fc_in_channel = decoder_rnn_out_size + d_model + d_enc
+ else:
+ fc_in_channel = d_model
+ self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
+
+ def _2d_attention(self,
+ decoder_input,
+ feat,
+ holistic_feat,
+ valid_ratios=None):
+
+ y = self.rnn_decoder(decoder_input)[0]
+ # y: bsz * (seq_len + 1) * hidden_size
+
+ attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
+ bsz, seq_len, attn_size = attn_query.shape
+ attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
+ # (bsz, seq_len + 1, attn_size, 1, 1)
+
+ attn_key = self.conv3x3_1(feat)
+ # bsz * attn_size * h * w
+ attn_key = attn_key.unsqueeze(1)
+ # bsz * 1 * attn_size * h * w
+
+ attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
+
+ # bsz * (seq_len + 1) * attn_size * h * w
+ attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
+ # bsz * (seq_len + 1) * h * w * attn_size
+ attn_weight = self.conv1x1_2(attn_weight)
+ # bsz * (seq_len + 1) * h * w * 1
+ bsz, T, h, w, c = attn_weight.shape
+ assert c == 1
+
+ if valid_ratios is not None:
+ # cal mask of attention weight
+ for i, valid_ratio in enumerate(valid_ratios):
+ valid_width = min(w, math.ceil(w * valid_ratio))
+ if valid_width < w:
+ attn_weight[i, :, :, valid_width:, :] = float('-inf')
+
+ attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
+ attn_weight = F.softmax(attn_weight, axis=-1)
+
+ attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
+ attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
+ # attn_weight: bsz * T * c * h * w
+ # feat: bsz * c * h * w
+ attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
+ (3, 4),
+ keepdim=False)
+ # bsz * (seq_len + 1) * C
+
+ # Linear transformation
+ if self.pred_concat:
+ hf_c = holistic_feat.shape[-1]
+ holistic_feat = paddle.expand(
+ holistic_feat, shape=[bsz, seq_len, hf_c])
+ y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
+ else:
+ y = self.prediction(attn_feat)
+ # bsz * (seq_len + 1) * num_classes
+ if self.train_mode:
+ y = self.pred_dropout(y)
+
+ return y
+
+ def forward_train(self, feat, out_enc, label, img_metas):
+ '''
+ img_metas: [label, valid_ratio]
+ '''
+ if img_metas is not None:
+ assert len(img_metas[0]) == feat.shape[0]
+
+ valid_ratios = None
+ if img_metas is not None and self.mask:
+ valid_ratios = img_metas[-1]
+
+ label = label.cuda()
+ lab_embedding = self.embedding(label)
+ # bsz * seq_len * emb_dim
+ out_enc = out_enc.unsqueeze(1)
+ # bsz * 1 * emb_dim
+ in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
+ # bsz * (seq_len + 1) * C
+ out_dec = self._2d_attention(
+ in_dec, feat, out_enc, valid_ratios=valid_ratios)
+ # bsz * (seq_len + 1) * num_classes
+
+ return out_dec[:, 1:, :] # bsz * seq_len * num_classes
+
+ def forward_test(self, feat, out_enc, img_metas):
+ if img_metas is not None:
+ assert len(img_metas[0]) == feat.shape[0]
+
+ valid_ratios = None
+ if img_metas is not None and self.mask:
+ valid_ratios = img_metas[-1]
+
+ seq_len = self.max_seq_len
+ bsz = feat.shape[0]
+ start_token = paddle.full(
+ (bsz, ), fill_value=self.start_idx, dtype='int64')
+ # bsz
+ start_token = self.embedding(start_token)
+ # bsz * emb_dim
+ emb_dim = start_token.shape[1]
+ start_token = start_token.unsqueeze(1)
+ start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
+ # bsz * seq_len * emb_dim
+ out_enc = out_enc.unsqueeze(1)
+ # bsz * 1 * emb_dim
+ decoder_input = paddle.concat((out_enc, start_token), axis=1)
+ # bsz * (seq_len + 1) * emb_dim
+
+ outputs = []
+ for i in range(1, seq_len + 1):
+ decoder_output = self._2d_attention(
+ decoder_input, feat, out_enc, valid_ratios=valid_ratios)
+ char_output = decoder_output[:, i, :] # bsz * num_classes
+ char_output = F.softmax(char_output, -1)
+ outputs.append(char_output)
+ max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
+ char_embedding = self.embedding(max_idx) # bsz * emb_dim
+ if i < seq_len:
+ decoder_input[:, i + 1, :] = char_embedding
+
+ outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
+
+ return outputs
+
+
+class SARHead(nn.Layer):
+ def __init__(self,
+ out_channels,
+ enc_bi_rnn=False,
+ enc_drop_rnn=0.1,
+ enc_gru=False,
+ dec_bi_rnn=False,
+ dec_drop_rnn=0.0,
+ dec_gru=False,
+ d_k=512,
+ pred_dropout=0.1,
+ max_text_length=30,
+ pred_concat=True,
+ **kwargs):
+ super(SARHead, self).__init__()
+
+ # encoder module
+ self.encoder = SAREncoder(
+ enc_bi_rnn=enc_bi_rnn, enc_drop_rnn=enc_drop_rnn, enc_gru=enc_gru)
+
+ # decoder module
+ self.decoder = ParallelSARDecoder(
+ out_channels=out_channels,
+ enc_bi_rnn=enc_bi_rnn,
+ dec_bi_rnn=dec_bi_rnn,
+ dec_drop_rnn=dec_drop_rnn,
+ dec_gru=dec_gru,
+ d_k=d_k,
+ pred_dropout=pred_dropout,
+ max_text_length=max_text_length,
+ pred_concat=pred_concat)
+
+ def forward(self, feat, targets=None):
+ '''
+ img_metas: [label, valid_ratio]
+ '''
+ holistic_feat = self.encoder(feat, targets) # bsz c
+
+ if self.training:
+ label = targets[0] # label
+ label = paddle.to_tensor(label, dtype='int64')
+ final_out = self.decoder(
+ feat, holistic_feat, label, img_metas=targets)
+ if not self.training:
+ final_out = self.decoder(
+ feat,
+ holistic_feat,
+ label=None,
+ img_metas=targets,
+ train_mode=False)
+ # (bsz, seq_len, num_classes)
+
+ return final_out
diff --git a/ppocr/modeling/necks/__init__.py b/ppocr/modeling/necks/__init__.py
index e97c4f64bdc9acd6729d67a9c6ff7a7563f6c95e..5606a4c35f68021e7f151a7eae4a0da4d5b6b95e 100644
--- a/ppocr/modeling/necks/__init__.py
+++ b/ppocr/modeling/necks/__init__.py
@@ -22,7 +22,8 @@ def build_neck(config):
from .rnn import SequenceEncoder
from .pg_fpn import PGFPN
from .table_fpn import TableFPN
- support_dict = ['DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
+ from .fpn import FPN
+ support_dict = ['FPN','DBFPN', 'EASTFPN', 'SASTFPN', 'SequenceEncoder', 'PGFPN', 'TableFPN']
module_name = config.pop('name')
assert module_name in support_dict, Exception('neck only support {}'.format(
diff --git a/ppocr/modeling/necks/fpn.py b/ppocr/modeling/necks/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8728a5c9ded5b9c174fd34f088d8012961f65ec0
--- /dev/null
+++ b/ppocr/modeling/necks/fpn.py
@@ -0,0 +1,100 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle.nn as nn
+import paddle
+import math
+import paddle.nn.functional as F
+
+class Conv_BN_ReLU(nn.Layer):
+ def __init__(self, in_planes, out_planes, kernel_size=1, stride=1, padding=0):
+ super(Conv_BN_ReLU, self).__init__()
+ self.conv = nn.Conv2D(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
+ bias_attr=False)
+ self.bn = nn.BatchNorm2D(out_planes, momentum=0.1)
+ self.relu = nn.ReLU()
+
+ for m in self.sublayers():
+ if isinstance(m, nn.Conv2D):
+ n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+ m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Normal(0, math.sqrt(2. / n)))
+ elif isinstance(m, nn.BatchNorm2D):
+ m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(1.0))
+ m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32', default_initializer=paddle.nn.initializer.Constant(0.0))
+
+ def forward(self, x):
+ return self.relu(self.bn(self.conv(x)))
+
+class FPN(nn.Layer):
+ def __init__(self, in_channels, out_channels):
+ super(FPN, self).__init__()
+
+ # Top layer
+ self.toplayer_ = Conv_BN_ReLU(in_channels[3], out_channels, kernel_size=1, stride=1, padding=0)
+ # Lateral layers
+ self.latlayer1_ = Conv_BN_ReLU(in_channels[2], out_channels, kernel_size=1, stride=1, padding=0)
+
+ self.latlayer2_ = Conv_BN_ReLU(in_channels[1], out_channels, kernel_size=1, stride=1, padding=0)
+
+ self.latlayer3_ = Conv_BN_ReLU(in_channels[0], out_channels, kernel_size=1, stride=1, padding=0)
+
+ # Smooth layers
+ self.smooth1_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.smooth2_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ self.smooth3_ = Conv_BN_ReLU(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+
+ self.out_channels = out_channels * 4
+ for m in self.sublayers():
+ if isinstance(m, nn.Conv2D):
+ n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
+ m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32',
+ default_initializer=paddle.nn.initializer.Normal(0,
+ math.sqrt(2. / n)))
+ elif isinstance(m, nn.BatchNorm2D):
+ m.weight = paddle.create_parameter(shape=m.weight.shape, dtype='float32',
+ default_initializer=paddle.nn.initializer.Constant(1.0))
+ m.bias = paddle.create_parameter(shape=m.bias.shape, dtype='float32',
+ default_initializer=paddle.nn.initializer.Constant(0.0))
+
+ def _upsample(self, x, scale=1):
+ return F.upsample(x, scale_factor=scale, mode='bilinear')
+
+ def _upsample_add(self, x, y, scale=1):
+ return F.upsample(x, scale_factor=scale, mode='bilinear') + y
+
+ def forward(self, x):
+ f2, f3, f4, f5 = x
+ p5 = self.toplayer_(f5)
+
+ f4 = self.latlayer1_(f4)
+ p4 = self._upsample_add(p5, f4,2)
+ p4 = self.smooth1_(p4)
+
+ f3 = self.latlayer2_(f3)
+ p3 = self._upsample_add(p4, f3,2)
+ p3 = self.smooth2_(p3)
+
+ f2 = self.latlayer3_(f2)
+ p2 = self._upsample_add(p3, f2,2)
+ p2 = self.smooth3_(p2)
+
+ p3 = self._upsample(p3, 2)
+ p4 = self._upsample(p4, 4)
+ p5 = self._upsample(p5, 8)
+
+ fuse = paddle.concat([p2, p3, p4, p5], axis=1)
+ return fuse
\ No newline at end of file
diff --git a/ppocr/modeling/necks/rnn.py b/ppocr/modeling/necks/rnn.py
index de87b3d9895168657f8c9722177c026b992c2966..86e649028f8fbb76cb5a1fd85381bd361277c6ee 100644
--- a/ppocr/modeling/necks/rnn.py
+++ b/ppocr/modeling/necks/rnn.py
@@ -51,7 +51,7 @@ class EncoderWithFC(nn.Layer):
super(EncoderWithFC, self).__init__()
self.out_channels = hidden_size
weight_attr, bias_attr = get_para_bias_attr(
- l2_decay=0.00001, k=in_channels, name='reduce_encoder_fea')
+ l2_decay=0.00001, k=in_channels)
self.fc = nn.Linear(
in_channels,
hidden_size,
diff --git a/ppocr/modeling/transforms/__init__.py b/ppocr/modeling/transforms/__init__.py
index 78eaecccc55f77d6624aa0c5bdb839acc3462129..405ab3cc6c0380654f61e42e523ddc85839139b3 100755
--- a/ppocr/modeling/transforms/__init__.py
+++ b/ppocr/modeling/transforms/__init__.py
@@ -17,8 +17,9 @@ __all__ = ['build_transform']
def build_transform(config):
from .tps import TPS
+ from .stn import STN_ON
- support_dict = ['TPS']
+ support_dict = ['TPS', 'STN_ON']
module_name = config.pop('name')
assert module_name in support_dict, Exception(
diff --git a/ppocr/modeling/transforms/stn.py b/ppocr/modeling/transforms/stn.py
new file mode 100644
index 0000000000000000000000000000000000000000..215895f4c4c719f407f4998f7429d965e0529ddc
--- /dev/null
+++ b/ppocr/modeling/transforms/stn.py
@@ -0,0 +1,132 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import numpy as np
+
+from .tps_spatial_transformer import TPSSpatialTransformer
+
+
+def conv3x3_block(in_channels, out_channels, stride=1):
+ n = 3 * 3 * out_channels
+ w = math.sqrt(2. / n)
+ conv_layer = nn.Conv2D(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ weight_attr=nn.initializer.Normal(
+ mean=0.0, std=w),
+ bias_attr=nn.initializer.Constant(0))
+ block = nn.Sequential(conv_layer, nn.BatchNorm2D(out_channels), nn.ReLU())
+ return block
+
+
+class STN(nn.Layer):
+ def __init__(self, in_channels, num_ctrlpoints, activation='none'):
+ super(STN, self).__init__()
+ self.in_channels = in_channels
+ self.num_ctrlpoints = num_ctrlpoints
+ self.activation = activation
+ self.stn_convnet = nn.Sequential(
+ conv3x3_block(in_channels, 32), #32x64
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(32, 64), #16x32
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(64, 128), # 8*16
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(128, 256), # 4*8
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(256, 256), # 2*4,
+ nn.MaxPool2D(
+ kernel_size=2, stride=2),
+ conv3x3_block(256, 256)) # 1*2
+ self.stn_fc1 = nn.Sequential(
+ nn.Linear(
+ 2 * 256,
+ 512,
+ weight_attr=nn.initializer.Normal(0, 0.001),
+ bias_attr=nn.initializer.Constant(0)),
+ nn.BatchNorm1D(512),
+ nn.ReLU())
+ fc2_bias = self.init_stn()
+ self.stn_fc2 = nn.Linear(
+ 512,
+ num_ctrlpoints * 2,
+ weight_attr=nn.initializer.Constant(0.0),
+ bias_attr=nn.initializer.Assign(fc2_bias))
+
+ def init_stn(self):
+ margin = 0.01
+ sampling_num_per_side = int(self.num_ctrlpoints / 2)
+ ctrl_pts_x = np.linspace(margin, 1. - margin, sampling_num_per_side)
+ ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin
+ ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1 - margin)
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ ctrl_points = np.concatenate(
+ [ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32)
+ if self.activation == 'none':
+ pass
+ elif self.activation == 'sigmoid':
+ ctrl_points = -np.log(1. / ctrl_points - 1.)
+ ctrl_points = paddle.to_tensor(ctrl_points)
+ fc2_bias = paddle.reshape(
+ ctrl_points, shape=[ctrl_points.shape[0] * ctrl_points.shape[1]])
+ return fc2_bias
+
+ def forward(self, x):
+ x = self.stn_convnet(x)
+ batch_size, _, h, w = x.shape
+ x = paddle.reshape(x, shape=(batch_size, -1))
+ img_feat = self.stn_fc1(x)
+ x = self.stn_fc2(0.1 * img_feat)
+ if self.activation == 'sigmoid':
+ x = F.sigmoid(x)
+ x = paddle.reshape(x, shape=[-1, self.num_ctrlpoints, 2])
+ return img_feat, x
+
+
+class STN_ON(nn.Layer):
+ def __init__(self, in_channels, tps_inputsize, tps_outputsize,
+ num_control_points, tps_margins, stn_activation):
+ super(STN_ON, self).__init__()
+ self.tps = TPSSpatialTransformer(
+ output_image_size=tuple(tps_outputsize),
+ num_control_points=num_control_points,
+ margins=tuple(tps_margins))
+ self.stn_head = STN(in_channels=in_channels,
+ num_ctrlpoints=num_control_points,
+ activation=stn_activation)
+ self.tps_inputsize = tps_inputsize
+ self.out_channels = in_channels
+
+ def forward(self, image):
+ stn_input = paddle.nn.functional.interpolate(
+ image, self.tps_inputsize, mode="bilinear", align_corners=True)
+ stn_img_feat, ctrl_points = self.stn_head(stn_input)
+ x, _ = self.tps(image, ctrl_points)
+ return x
diff --git a/ppocr/modeling/transforms/tps.py b/ppocr/modeling/transforms/tps.py
index dcce6246ac64b4b84229cbd69a4dc53c658b4c7b..6cd68555369dd1ddbd6ccf5236688a4b957b8525 100644
--- a/ppocr/modeling/transforms/tps.py
+++ b/ppocr/modeling/transforms/tps.py
@@ -231,7 +231,8 @@ class GridGenerator(nn.Layer):
""" Return inv_delta_C which is needed to calculate T """
F = self.F
hat_eye = paddle.eye(F, dtype='float64') # F x F
- hat_C = paddle.norm(C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
+ hat_C = paddle.norm(
+ C.reshape([1, F, 2]) - C.reshape([F, 1, 2]), axis=2) + hat_eye
hat_C = (hat_C**2) * paddle.log(hat_C)
delta_C = paddle.concat( # F+3 x F+3
[
diff --git a/ppocr/modeling/transforms/tps_spatial_transformer.py b/ppocr/modeling/transforms/tps_spatial_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b510acb0d4012c9a4d90c7ca07cac895f0bf242e
--- /dev/null
+++ b/ppocr/modeling/transforms/tps_spatial_transformer.py
@@ -0,0 +1,152 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import numpy as np
+import itertools
+
+
+def grid_sample(input, grid, canvas=None):
+ input.stop_gradient = False
+ output = F.grid_sample(input, grid)
+ if canvas is None:
+ return output
+ else:
+ input_mask = paddle.ones(shape=input.shape)
+ output_mask = F.grid_sample(input_mask, grid)
+ padded_output = output * output_mask + canvas * (1 - output_mask)
+ return padded_output
+
+
+# phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2
+def compute_partial_repr(input_points, control_points):
+ N = input_points.shape[0]
+ M = control_points.shape[0]
+ pairwise_diff = paddle.reshape(
+ input_points, shape=[N, 1, 2]) - paddle.reshape(
+ control_points, shape=[1, M, 2])
+ # original implementation, very slow
+ # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance
+ pairwise_diff_square = pairwise_diff * pairwise_diff
+ pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :,
+ 1]
+ repr_matrix = 0.5 * pairwise_dist * paddle.log(pairwise_dist)
+ # fix numerical error for 0 * log(0), substitute all nan with 0
+ mask = repr_matrix != repr_matrix
+ repr_matrix[mask] = 0
+ return repr_matrix
+
+
+# output_ctrl_pts are specified, according to our task.
+def build_output_control_points(num_control_points, margins):
+ margin_x, margin_y = margins
+ num_ctrl_pts_per_side = num_control_points // 2
+ ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side)
+ ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y
+ ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y)
+ ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
+ ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
+ output_ctrl_pts_arr = np.concatenate(
+ [ctrl_pts_top, ctrl_pts_bottom], axis=0)
+ output_ctrl_pts = paddle.to_tensor(output_ctrl_pts_arr)
+ return output_ctrl_pts
+
+
+class TPSSpatialTransformer(nn.Layer):
+ def __init__(self,
+ output_image_size=None,
+ num_control_points=None,
+ margins=None):
+ super(TPSSpatialTransformer, self).__init__()
+ self.output_image_size = output_image_size
+ self.num_control_points = num_control_points
+ self.margins = margins
+
+ self.target_height, self.target_width = output_image_size
+ target_control_points = build_output_control_points(num_control_points,
+ margins)
+ N = num_control_points
+
+ # create padded kernel matrix
+ forward_kernel = paddle.zeros(shape=[N + 3, N + 3])
+ target_control_partial_repr = compute_partial_repr(
+ target_control_points, target_control_points)
+ target_control_partial_repr = paddle.cast(target_control_partial_repr,
+ forward_kernel.dtype)
+ forward_kernel[:N, :N] = target_control_partial_repr
+ forward_kernel[:N, -3] = 1
+ forward_kernel[-3, :N] = 1
+ target_control_points = paddle.cast(target_control_points,
+ forward_kernel.dtype)
+ forward_kernel[:N, -2:] = target_control_points
+ forward_kernel[-2:, :N] = paddle.transpose(
+ target_control_points, perm=[1, 0])
+ # compute inverse matrix
+ inverse_kernel = paddle.inverse(forward_kernel)
+
+ # create target cordinate matrix
+ HW = self.target_height * self.target_width
+ target_coordinate = list(
+ itertools.product(
+ range(self.target_height), range(self.target_width)))
+ target_coordinate = paddle.to_tensor(target_coordinate) # HW x 2
+ Y, X = paddle.split(
+ target_coordinate, target_coordinate.shape[1], axis=1)
+ Y = Y / (self.target_height - 1)
+ X = X / (self.target_width - 1)
+ target_coordinate = paddle.concat(
+ [X, Y], axis=1) # convert from (y, x) to (x, y)
+ target_coordinate_partial_repr = compute_partial_repr(
+ target_coordinate, target_control_points)
+ target_coordinate_repr = paddle.concat(
+ [
+ target_coordinate_partial_repr, paddle.ones(shape=[HW, 1]),
+ target_coordinate
+ ],
+ axis=1)
+
+ # register precomputed matrices
+ self.inverse_kernel = inverse_kernel
+ self.padding_matrix = paddle.zeros(shape=[3, 2])
+ self.target_coordinate_repr = target_coordinate_repr
+ self.target_control_points = target_control_points
+
+ def forward(self, input, source_control_points):
+ assert source_control_points.ndimension() == 3
+ assert source_control_points.shape[1] == self.num_control_points
+ assert source_control_points.shape[2] == 2
+ batch_size = paddle.shape(source_control_points)[0]
+
+ self.padding_matrix = paddle.expand(
+ self.padding_matrix, shape=[batch_size, 3, 2])
+ Y = paddle.concat([source_control_points, self.padding_matrix], 1)
+ mapping_matrix = paddle.matmul(self.inverse_kernel, Y)
+ source_coordinate = paddle.matmul(self.target_coordinate_repr,
+ mapping_matrix)
+
+ grid = paddle.reshape(
+ source_coordinate,
+ shape=[-1, self.target_height, self.target_width, 2])
+ grid = paddle.clip(grid, 0,
+ 1) # the source_control_points may be out of [0, 1].
+ # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1]
+ grid = 2.0 * grid - 1.0
+ output_maps = grid_sample(input, grid, canvas=None)
+ return output_maps, source_coordinate
diff --git a/ppocr/optimizer/optimizer.py b/ppocr/optimizer/optimizer.py
index 8215b92d8c8d05c2b3c2e95ac989bf4ea011310b..34098c0fad553f7d39f6b5341e4da70a263eeaea 100644
--- a/ppocr/optimizer/optimizer.py
+++ b/ppocr/optimizer/optimizer.py
@@ -127,3 +127,34 @@ class RMSProp(object):
grad_clip=self.grad_clip,
parameters=parameters)
return opt
+
+
+class Adadelta(object):
+ def __init__(self,
+ learning_rate=0.001,
+ epsilon=1e-08,
+ rho=0.95,
+ parameter_list=None,
+ weight_decay=None,
+ grad_clip=None,
+ name=None,
+ **kwargs):
+ self.learning_rate = learning_rate
+ self.epsilon = epsilon
+ self.rho = rho
+ self.parameter_list = parameter_list
+ self.learning_rate = learning_rate
+ self.weight_decay = weight_decay
+ self.grad_clip = grad_clip
+ self.name = name
+
+ def __call__(self, parameters):
+ opt = optim.Adadelta(
+ learning_rate=self.learning_rate,
+ epsilon=self.epsilon,
+ rho=self.rho,
+ weight_decay=self.weight_decay,
+ grad_clip=self.grad_clip,
+ name=self.name,
+ parameters=parameters)
+ return opt
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index 654ddf39d23590fbaf7f7b9b57f38cc86a1b6669..3a4ebf52a3bd91ffd509b113103dab900588b0bd 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -18,6 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals
import copy
+import platform
__all__ = ['build_post_process']
@@ -25,17 +26,22 @@ from .db_postprocess import DBPostProcess, DistillationDBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, DistillationCTCLabelDecode, \
- TableLabelDecode
+ TableLabelDecode, NRTRLabelDecode, SARLabelDecode , SEEDLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
+if platform.system() != "Windows":
+ # pse is not support in Windows
+ from .pse_postprocess import PSEPostProcess
+
def build_post_process(config, global_config=None):
support_dict = [
- 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
- 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'PGPostProcess',
- 'DistillationCTCLabelDecode', 'TableLabelDecode',
- 'DistillationDBPostProcess'
+ 'DBPostProcess', 'PSEPostProcess', 'EASTPostProcess', 'SASTPostProcess',
+ 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
+ 'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
+ 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
+ 'SEEDLabelDecode'
]
config = copy.deepcopy(config)
diff --git a/ppocr/postprocess/pse_postprocess/__init__.py b/ppocr/postprocess/pse_postprocess/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..680473bf4b1863ac695dc8173778e59bd4fdacf9
--- /dev/null
+++ b/ppocr/postprocess/pse_postprocess/__init__.py
@@ -0,0 +1,15 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .pse_postprocess import PSEPostProcess
\ No newline at end of file
diff --git a/ppocr/postprocess/pse_postprocess/pse/README.md b/ppocr/postprocess/pse_postprocess/pse/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..9c2d9eaeaa5f93550358ebdd4d9161330b78a86f
--- /dev/null
+++ b/ppocr/postprocess/pse_postprocess/pse/README.md
@@ -0,0 +1,5 @@
+## 编译
+code from https://github.com/whai362/pan_pp.pytorch
+```python
+python3 setup.py build_ext --inplace
+```
diff --git a/ppocr/postprocess/pse_postprocess/pse/__init__.py b/ppocr/postprocess/pse_postprocess/pse/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..97b8d8aff0cf229a4e3ec1961638273bd201822a
--- /dev/null
+++ b/ppocr/postprocess/pse_postprocess/pse/__init__.py
@@ -0,0 +1,23 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import sys
+import os
+import subprocess
+
+python_path = sys.executable
+
+if subprocess.call('cd ppocr/postprocess/pse_postprocess/pse;{} setup.py build_ext --inplace;cd -'.format(python_path), shell=True) != 0:
+ raise RuntimeError('Cannot compile pse: {}'.format(os.path.dirname(os.path.realpath(__file__))))
+
+from .pse import pse
\ No newline at end of file
diff --git a/ppocr/postprocess/pse_postprocess/pse/pse.pyx b/ppocr/postprocess/pse_postprocess/pse/pse.pyx
new file mode 100644
index 0000000000000000000000000000000000000000..b2be49e9471865c11b840207f922258e67a554b6
--- /dev/null
+++ b/ppocr/postprocess/pse_postprocess/pse/pse.pyx
@@ -0,0 +1,70 @@
+
+import numpy as np
+import cv2
+cimport numpy as np
+cimport cython
+cimport libcpp
+cimport libcpp.pair
+cimport libcpp.queue
+from libcpp.pair cimport *
+from libcpp.queue cimport *
+
+@cython.boundscheck(False)
+@cython.wraparound(False)
+cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
+ np.ndarray[np.int32_t, ndim=2] label,
+ int kernel_num,
+ int label_num,
+ float min_area=0):
+ cdef np.ndarray[np.int32_t, ndim=2] pred
+ pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
+
+ for label_idx in range(1, label_num):
+ if np.sum(label == label_idx) < min_area:
+ label[label == label_idx] = 0
+
+ cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
+ queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
+ cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
+ queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
+ cdef np.int16_t* dx = [-1, 1, 0, 0]
+ cdef np.int16_t* dy = [0, 0, -1, 1]
+ cdef np.int16_t tmpx, tmpy
+
+ points = np.array(np.where(label > 0)).transpose((1, 0))
+ for point_idx in range(points.shape[0]):
+ tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
+ que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
+ pred[tmpx, tmpy] = label[tmpx, tmpy]
+
+ cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
+ cdef int cur_label
+ for kernel_idx in range(kernel_num - 1, -1, -1):
+ while not que.empty():
+ cur = que.front()
+ que.pop()
+ cur_label = pred[cur.first, cur.second]
+
+ is_edge = True
+ for j in range(4):
+ tmpx = cur.first + dx[j]
+ tmpy = cur.second + dy[j]
+ if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
+ continue
+ if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
+ continue
+
+ que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
+ pred[tmpx, tmpy] = cur_label
+ is_edge = False
+ if is_edge:
+ nxt_que.push(cur)
+
+ que, nxt_que = nxt_que, que
+
+ return pred
+
+def pse(kernels, min_area):
+ kernel_num = kernels.shape[0]
+ label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
+ return _pse(kernels[:-1], label, kernel_num, label_num, min_area)
\ No newline at end of file
diff --git a/ppocr/postprocess/pse_postprocess/pse/setup.py b/ppocr/postprocess/pse_postprocess/pse/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..03746782af791938bff31c24e4a760f566c73b49
--- /dev/null
+++ b/ppocr/postprocess/pse_postprocess/pse/setup.py
@@ -0,0 +1,14 @@
+from distutils.core import setup, Extension
+from Cython.Build import cythonize
+import numpy
+
+setup(ext_modules=cythonize(Extension(
+ 'pse',
+ sources=['pse.pyx'],
+ language='c++',
+ include_dirs=[numpy.get_include()],
+ library_dirs=[],
+ libraries=[],
+ extra_compile_args=['-O3'],
+ extra_link_args=[]
+)))
diff --git a/ppocr/postprocess/pse_postprocess/pse_postprocess.py b/ppocr/postprocess/pse_postprocess/pse_postprocess.py
new file mode 100755
index 0000000000000000000000000000000000000000..4b89d221d284602933ab3d4f21468fcae79ef310
--- /dev/null
+++ b/ppocr/postprocess/pse_postprocess/pse_postprocess.py
@@ -0,0 +1,112 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import cv2
+import paddle
+from paddle.nn import functional as F
+
+from ppocr.postprocess.pse_postprocess.pse import pse
+
+
+class PSEPostProcess(object):
+ """
+ The post process for PSE.
+ """
+
+ def __init__(self,
+ thresh=0.5,
+ box_thresh=0.85,
+ min_area=16,
+ box_type='box',
+ scale=4,
+ **kwargs):
+ assert box_type in ['box', 'poly'], 'Only box and poly is supported'
+ self.thresh = thresh
+ self.box_thresh = box_thresh
+ self.min_area = min_area
+ self.box_type = box_type
+ self.scale = scale
+
+ def __call__(self, outs_dict, shape_list):
+ pred = outs_dict['maps']
+ if not isinstance(pred, paddle.Tensor):
+ pred = paddle.to_tensor(pred)
+ pred = F.interpolate(pred, scale_factor=4 // self.scale, mode='bilinear')
+
+ score = F.sigmoid(pred[:, 0, :, :])
+
+ kernels = (pred > self.thresh).astype('float32')
+ text_mask = kernels[:, 0, :, :]
+ kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
+
+ score = score.numpy()
+ kernels = kernels.numpy().astype(np.uint8)
+
+ boxes_batch = []
+ for batch_index in range(pred.shape[0]):
+ boxes, scores = self.boxes_from_bitmap(score[batch_index], kernels[batch_index], shape_list[batch_index])
+
+ boxes_batch.append({'points': boxes, 'scores': scores})
+ return boxes_batch
+
+ def boxes_from_bitmap(self, score, kernels, shape):
+ label = pse(kernels, self.min_area)
+ return self.generate_box(score, label, shape)
+
+ def generate_box(self, score, label, shape):
+ src_h, src_w, ratio_h, ratio_w = shape
+ label_num = np.max(label) + 1
+
+ boxes = []
+ scores = []
+ for i in range(1, label_num):
+ ind = label == i
+ points = np.array(np.where(ind)).transpose((1, 0))[:, ::-1]
+
+ if points.shape[0] < self.min_area:
+ label[ind] = 0
+ continue
+
+ score_i = np.mean(score[ind])
+ if score_i < self.box_thresh:
+ label[ind] = 0
+ continue
+
+ if self.box_type == 'box':
+ rect = cv2.minAreaRect(points)
+ bbox = cv2.boxPoints(rect)
+ elif self.box_type == 'poly':
+ box_height = np.max(points[:, 1]) + 10
+ box_width = np.max(points[:, 0]) + 10
+
+ mask = np.zeros((box_height, box_width), np.uint8)
+ mask[points[:, 1], points[:, 0]] = 255
+
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+ bbox = np.squeeze(contours[0], 1)
+ else:
+ raise NotImplementedError
+
+ bbox[:, 0] = np.clip(
+ np.round(bbox[:, 0] / ratio_w), 0, src_w)
+ bbox[:, 1] = np.clip(
+ np.round(bbox[:, 1] / ratio_h), 0, src_h)
+ boxes.append(bbox)
+ scores.append(score_i)
+ return boxes, scores
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 8ebe5b2741b77537b46b8057d9aa9c36dc99aeec..ef1a43fd0ee65f3e55a8f72dfd2f96c478da1a9a 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -15,38 +15,21 @@ import numpy as np
import string
import paddle
from paddle.nn import functional as F
+import re
class BaseRecLabelDecode(object):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False):
- support_character_type = [
- 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
- 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
- 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
- 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
- ]
- assert character_type in support_character_type, "Only {} are supported now but get {}".format(
- support_character_type, character_type)
-
+ def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
- if character_type == "en":
+ self.character_str = []
+ if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
- elif character_type == "EN_symbol":
- # same with ASTER setting (use 94 char).
- self.character_str = string.printable[:-6]
- dict_character = list(self.character_str)
- elif character_type in support_character_type:
- self.character_str = []
- assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
- character_type)
+ else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
@@ -56,9 +39,6 @@ class BaseRecLabelDecode(object):
self.character_str.append(" ")
dict_character = list(self.character_str)
- else:
- raise NotImplementedError
- self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
@@ -101,15 +81,14 @@ class BaseRecLabelDecode(object):
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
+ if isinstance(preds, tuple):
+ preds = preds[-1]
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
preds_idx = preds.argmax(axis=2)
@@ -133,13 +112,12 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
def __init__(self,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
model_name=["student"],
key=None,
**kwargs):
- super(DistillationCTCLabelDecode, self).__init__(
- character_dict_path, character_type, use_space_char)
+ super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
@@ -156,16 +134,77 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
return output
+class NRTRLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
+ super(NRTRLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+
+ if len(preds) == 2:
+ preds_id = preds[0]
+ preds_prob = preds[1]
+ if isinstance(preds_id, paddle.Tensor):
+ preds_id = preds_id.numpy()
+ if isinstance(preds_prob, paddle.Tensor):
+ preds_prob = preds_prob.numpy()
+ if preds_id[0][0] == 2:
+ preds_idx = preds_id[:, 1:]
+ preds_prob = preds_prob[:, 1:]
+ else:
+ preds_idx = preds_id
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if label is None:
+ return text
+ label = self.decode(label[:, 1:])
+ else:
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if label is None:
+ return text
+ label = self.decode(label[:, 1:])
+ return text, label
+
+ def add_special_char(self, dict_character):
+ dict_character = ['blank', '', '', ''] + dict_character
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ if text_index[batch_idx][idx] == 3: # end
+ break
+ try:
+ char_list.append(self.character[int(text_index[batch_idx][
+ idx])])
+ except:
+ continue
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ result_list.append((text.lower(), np.mean(conf_list)))
+ return result_list
+
+
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(AttnLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -239,16 +278,91 @@ class AttnLabelDecode(BaseRecLabelDecode):
return idx
+class SEEDLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(SEEDLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ def add_special_char(self, dict_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ dict_character = dict_character + [self.end_str]
+ return dict_character
+
+ def get_ignored_tokens(self):
+ end_idx = self.get_beg_end_flag_idx("eos")
+ return [end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "sos":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "eos":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
+ return idx
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ [end_idx] = self.get_ignored_tokens()
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ if int(text_index[batch_idx][idx]) == int(end_idx):
+ break
+ if is_remove_duplicate:
+ # only for predict
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
+ batch_idx][idx]:
+ continue
+ char_list.append(self.character[int(text_index[batch_idx][
+ idx])])
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list)))
+ return result_list
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ """
+ text = self.decode(text)
+ if label is None:
+ return text
+ else:
+ label = self.decode(label, is_remove_duplicate=False)
+ return text, label
+ """
+ preds_idx = preds["rec_pred"]
+ if isinstance(preds_idx, paddle.Tensor):
+ preds_idx = preds_idx.numpy()
+ if "rec_pred_scores" in preds:
+ preds_idx = preds["rec_pred"]
+ preds_prob = preds["rec_pred_scores"]
+ else:
+ preds_idx = preds["rec_pred"].argmax(axis=2)
+ preds_prob = preds["rec_pred"].max(axis=2)
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ if label is None:
+ return text
+ label = self.decode(label, is_remove_duplicate=False)
+ return text, label
+
+
class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='en',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
def __call__(self, preds, label=None, *args, **kwargs):
@@ -324,10 +438,9 @@ class SRNLabelDecode(BaseRecLabelDecode):
class TableLabelDecode(object):
""" """
- def __init__(self,
- character_dict_path,
- **kwargs):
- list_character, list_elem = self.load_char_elem_dict(character_dict_path)
+ def __init__(self, character_dict_path, **kwargs):
+ list_character, list_elem = self.load_char_elem_dict(
+ character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
@@ -346,7 +459,8 @@ class TableLabelDecode(object):
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
- substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split("\t")
+ substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
+ "\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
@@ -366,14 +480,14 @@ class TableLabelDecode(object):
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
- if isinstance(structure_probs,paddle.Tensor):
+ if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
- if isinstance(loc_preds,paddle.Tensor):
+ if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
- structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(structure_idx,
- structure_probs, 'elem')
+ structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
+ structure_idx, structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
@@ -388,8 +502,13 @@ class TableLabelDecode(object):
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
- return {'res_html_code': res_html_code_list, 'res_loc': res_loc_list, 'res_score_list': result_score_list,
- 'res_elem_idx_list': result_elem_idx_list,'structure_str_list':structure_str}
+ return {
+ 'res_html_code': res_html_code_list,
+ 'res_loc': res_loc_list,
+ 'res_score_list': result_score_list,
+ 'res_elem_idx_list': result_elem_idx_list,
+ 'structure_str_list': structure_str
+ }
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
@@ -454,3 +573,79 @@ class TableLabelDecode(object):
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
+
+
+class SARLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self, character_dict_path=None, use_space_char=False,
+ **kwargs):
+ super(SARLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
+
+ self.rm_symbol = kwargs.get('rm_symbol', False)
+
+ def add_special_char(self, dict_character):
+ beg_end_str = ""
+ unknown_str = ""
+ padding_str = ""
+ dict_character = dict_character + [unknown_str]
+ self.unknown_idx = len(dict_character) - 1
+ dict_character = dict_character + [beg_end_str]
+ self.start_idx = len(dict_character) - 1
+ self.end_idx = len(dict_character) - 1
+ dict_character = dict_character + [padding_str]
+ self.padding_idx = len(dict_character) - 1
+ return dict_character
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+
+ batch_size = len(text_index)
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ if text_index[batch_idx][idx] in ignored_tokens:
+ continue
+ if int(text_index[batch_idx][idx]) == int(self.end_idx):
+ if text_prob is None and idx == 0:
+ continue
+ else:
+ break
+ if is_remove_duplicate:
+ # only for predict
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
+ batch_idx][idx]:
+ continue
+ char_list.append(self.character[int(text_index[batch_idx][
+ idx])])
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+ text = ''.join(char_list)
+ if self.rm_symbol:
+ comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
+ text = text.lower()
+ text = comp.sub('', text)
+ result_list.append((text, np.mean(conf_list)))
+ return result_list
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ if isinstance(preds, paddle.Tensor):
+ preds = preds.numpy()
+ preds_idx = preds.argmax(axis=2)
+ preds_prob = preds.max(axis=2)
+
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+
+ if label is None:
+ return text
+ label = self.decode(label, is_remove_duplicate=False)
+ return text, label
+
+ def get_ignored_tokens(self):
+ return [self.padding_idx]
diff --git a/ppocr/utils/EN_symbol_dict.txt b/ppocr/utils/EN_symbol_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1aef43d6b842731a54cbe682ccda5c2dbfa694d9
--- /dev/null
+++ b/ppocr/utils/EN_symbol_dict.txt
@@ -0,0 +1,94 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+:
+;
+<
+=
+>
+?
+@
+[
+\
+]
+^
+_
+`
+{
+|
+}
+~
\ No newline at end of file
diff --git a/ppocr/utils/dict90.txt b/ppocr/utils/dict90.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a945ae9c526e4faa68852eb3fb47d078a2f3f6ce
--- /dev/null
+++ b/ppocr/utils/dict90.txt
@@ -0,0 +1,90 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+:
+;
+<
+=
+>
+?
+@
+[
+\
+]
+_
+`
+~
\ No newline at end of file
diff --git a/ppocr/utils/iou.py b/ppocr/utils/iou.py
new file mode 100644
index 0000000000000000000000000000000000000000..20529dee2d14083f3de4ac034668d004136c56e2
--- /dev/null
+++ b/ppocr/utils/iou.py
@@ -0,0 +1,48 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+
+EPS = 1e-6
+
+def iou_single(a, b, mask, n_class):
+ valid = mask == 1
+ a = a.masked_select(valid)
+ b = b.masked_select(valid)
+ miou = []
+ for i in range(n_class):
+ if a.shape == [0] and a.shape==b.shape:
+ inter = paddle.to_tensor(0.0)
+ union = paddle.to_tensor(0.0)
+ else:
+ inter = ((a == i).logical_and(b == i)).astype('float32')
+ union = ((a == i).logical_or(b == i)).astype('float32')
+ miou.append(paddle.sum(inter) / (paddle.sum(union) + EPS))
+ miou = sum(miou) / len(miou)
+ return miou
+
+def iou(a, b, mask, n_class=2, reduce=True):
+ batch_size = a.shape[0]
+
+ a = a.reshape([batch_size, -1])
+ b = b.reshape([batch_size, -1])
+ mask = mask.reshape([batch_size, -1])
+
+ iou = paddle.zeros((batch_size,), dtype='float32')
+ for i in range(batch_size):
+ iou[i] = iou_single(a[i], b[i], mask[i], n_class)
+
+ if reduce:
+ iou = paddle.mean(iou)
+ return iou
\ No newline at end of file
diff --git a/ppocr/utils/profiler.py b/ppocr/utils/profiler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4e28bc6bea9ca912a0786d879a48ec0349e7698
--- /dev/null
+++ b/ppocr/utils/profiler.py
@@ -0,0 +1,110 @@
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import paddle
+
+# A global variable to record the number of calling times for profiler
+# functions. It is used to specify the tracing range of training steps.
+_profiler_step_id = 0
+
+# A global variable to avoid parsing from string every time.
+_profiler_options = None
+
+
+class ProfilerOptions(object):
+ '''
+ Use a string to initialize a ProfilerOptions.
+ The string should be in the format: "key1=value1;key2=value;key3=value3".
+ For example:
+ "profile_path=model.profile"
+ "batch_range=[50, 60]; profile_path=model.profile"
+ "batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
+ ProfilerOptions supports following key-value pair:
+ batch_range - a integer list, e.g. [100, 110].
+ state - a string, the optional values are 'CPU', 'GPU' or 'All'.
+ sorted_key - a string, the optional values are 'calls', 'total',
+ 'max', 'min' or 'ave.
+ tracer_option - a string, the optional values are 'Default', 'OpDetail',
+ 'AllOpDetail'.
+ profile_path - a string, the path to save the serialized profile data,
+ which can be used to generate a timeline.
+ exit_on_finished - a boolean.
+ '''
+
+ def __init__(self, options_str):
+ assert isinstance(options_str, str)
+
+ self._options = {
+ 'batch_range': [10, 20],
+ 'state': 'All',
+ 'sorted_key': 'total',
+ 'tracer_option': 'Default',
+ 'profile_path': '/tmp/profile',
+ 'exit_on_finished': True
+ }
+ self._parse_from_string(options_str)
+
+ def _parse_from_string(self, options_str):
+ for kv in options_str.replace(' ', '').split(';'):
+ key, value = kv.split('=')
+ if key == 'batch_range':
+ value_list = value.replace('[', '').replace(']', '').split(',')
+ value_list = list(map(int, value_list))
+ if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
+ 1] > value_list[0]:
+ self._options[key] = value_list
+ elif key == 'exit_on_finished':
+ self._options[key] = value.lower() in ("yes", "true", "t", "1")
+ elif key in [
+ 'state', 'sorted_key', 'tracer_option', 'profile_path'
+ ]:
+ self._options[key] = value
+
+ def __getitem__(self, name):
+ if self._options.get(name, None) is None:
+ raise ValueError(
+ "ProfilerOptions does not have an option named %s." % name)
+ return self._options[name]
+
+
+def add_profiler_step(options_str=None):
+ '''
+ Enable the operator-level timing using PaddlePaddle's profiler.
+ The profiler uses a independent variable to count the profiler steps.
+ One call of this function is treated as a profiler step.
+
+ Args:
+ profiler_options - a string to initialize the ProfilerOptions.
+ Default is None, and the profiler is disabled.
+ '''
+ if options_str is None:
+ return
+
+ global _profiler_step_id
+ global _profiler_options
+
+ if _profiler_options is None:
+ _profiler_options = ProfilerOptions(options_str)
+
+ if _profiler_step_id == _profiler_options['batch_range'][0]:
+ paddle.utils.profiler.start_profiler(
+ _profiler_options['state'], _profiler_options['tracer_option'])
+ elif _profiler_step_id == _profiler_options['batch_range'][1]:
+ paddle.utils.profiler.stop_profiler(_profiler_options['sorted_key'],
+ _profiler_options['profile_path'])
+ if _profiler_options['exit_on_finished']:
+ sys.exit(0)
+
+ _profiler_step_id += 1
diff --git a/ppocr/utils/save_load.py b/ppocr/utils/save_load.py
index 3bb022ed98b140995b79ceea93d7f494d3f5930d..a7d24dd71a6e35ca619c2a3f90df3a202b8ad94b 100644
--- a/ppocr/utils/save_load.py
+++ b/ppocr/utils/save_load.py
@@ -108,14 +108,15 @@ def load_dygraph_params(config, model, logger, optimizer):
for k1, k2 in zip(state_dict.keys(), params.keys()):
if list(state_dict[k1].shape) == list(params[k2].shape):
new_state_dict[k1] = params[k2]
- else:
- logger.info(
- f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
- )
+ else:
+ logger.info(
+ f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
+ )
model.set_state_dict(new_state_dict)
logger.info(f"loaded pretrained_model successful from {pm}")
return {}
+
def load_pretrained_params(model, path):
if path is None:
return False
@@ -138,6 +139,7 @@ def load_pretrained_params(model, path):
print(f"load pretrain successful from {path}")
return model
+
def save_model(model,
optimizer,
model_path,
diff --git a/ppstructure/README.md b/ppstructure/README.md
index 8e1642cc75cc52b179d0f8441a8da2fe86e78d7b..849c5c5667ff0532dfee35479715880192df0dc5 100644
--- a/ppstructure/README.md
+++ b/ppstructure/README.md
@@ -30,13 +30,13 @@ python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/
# CPU
python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
-# For more,refer[Installation](https://www.paddlepaddle.org.cn/install/quick)。
```
+For more,refer [Installation](https://www.paddlepaddle.org.cn/install/quick) .
- **(2) Install Layout-Parser**
```bash
-pip3 install -U premailer paddleocr https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
+pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
```
### 2.2 Install PaddleOCR(including PP-OCR and PP-Structure)
@@ -124,8 +124,6 @@ Most of the parameters are consistent with the paddleocr whl package, see [doc o
After running, each image will have a directory with the same name under the directory specified in the output field. Each table in the picture will be stored as an excel and figure area will be cropped and saved, the excel and image file name will be the coordinates of the table in the image.
## 4. PP-Structure Pipeline
-
-the process is as follows

In PP-Structure, the image will be analyzed by layoutparser first. In the layout analysis, the area in the image will be classified, including **text, title, image, list and table** 5 categories. For the first 4 types of areas, directly use the PP-OCR to complete the text detection and recognition. The table area will be converted to an excel file of the same table style via Table OCR.
@@ -180,10 +178,10 @@ OCR and table recognition model
|model name|description|model size|download|
| --- | --- | --- | --- |
-|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
-|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) |
-|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) |
-|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) |
-|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
+|ch_ppocr_mobile_slim_v2.0_det|Slim pruned lightweight model, supporting Chinese, English, multilingual text detection|2.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
+|ch_ppocr_mobile_slim_v2.0_rec|Slim pruned and quantized lightweight model, supporting Chinese, English and number recognition|6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
+|en_ppocr_mobile_v2.0_table_det|Text detection of English table scenes trained on PubLayNet dataset|4.7M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
+|en_ppocr_mobile_v2.0_table_rec|Text recognition of English table scene trained on PubLayNet dataset|6.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
+|en_ppocr_mobile_v2.0_table_structure|Table structure prediction of English table scene trained on PubLayNet dataset|18.6M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
If you need to use other models, you can download the model in [model_list](../doc/doc_en/models_list_en.md) or use your own trained model to configure it to the three fields of `det_model_dir`, `rec_model_dir`, `table_model_dir` .
diff --git a/ppstructure/README_ch.md b/ppstructure/README_ch.md
index c8acac590039647cf52f47b16a99092ff68f2b6e..821a6c3e36361abefa4d754537fdbd694e844efe 100644
--- a/ppstructure/README_ch.md
+++ b/ppstructure/README_ch.md
@@ -30,13 +30,13 @@ python3 -m pip install paddlepaddle-gpu==2.1.1 -i https://mirror.baidu.com/pypi/
# CPU安装
python3 -m pip install paddlepaddle==2.1.1 -i https://mirror.baidu.com/pypi/simple
-# 更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
```
+更多需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
- **(2) 安装 Layout-Parser**
```bash
-pip3 install -U premailer paddleocr https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
+pip3 install -U https://paddleocr.bj.bcebos.com/whl/layoutparser-0.0.0-py3-none-any.whl
```
### 2.2 安装PaddleOCR(包含PP-OCR和PP-Structure)
@@ -179,10 +179,10 @@ OCR和表格识别模型
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
-|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
-|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) |
-|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) |
-|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) |
-|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) |
+|ch_ppocr_mobile_slim_v2.0_det|slim裁剪版超轻量模型,支持中英文、多语种文本检测|2.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/slim/ch_ppocr_mobile_v2.0_det_prune_infer.tar) |
+|ch_ppocr_mobile_slim_v2.0_rec|slim裁剪量化版超轻量模型,支持中英文、数字识别|6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_slim_train.tar) |
+|en_ppocr_mobile_v2.0_table_det|PubLayNet数据集训练的英文表格场景的文字检测|4.7M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_det_train.tar) |
+|en_ppocr_mobile_v2.0_table_rec|PubLayNet数据集训练的英文表格场景的文字识别|6.9M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_rec_train.tar) |
+|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
如需要使用其他模型,可以在 [model_list](../doc/doc_ch/models_list.md) 下载模型或者使用自己训练好的模型配置到`det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
diff --git a/ppstructure/layout/train_layoutparser_model.md b/ppstructure/layout/train_layoutparser_model.md
index 08f5ebbf1aa276e4a3ecf27af46442161afcda1f..58975d71606e45b2f68a7f68565459042ef32775 100644
--- a/ppstructure/layout/train_layoutparser_model.md
+++ b/ppstructure/layout/train_layoutparser_model.md
@@ -4,9 +4,9 @@
[1.1 Requirements](#Requirements)
- [1.2 Install PaddleDetection](#Install PaddleDetection)
+ [1.2 Install PaddleDetection](#Install_PaddleDetection)
-[2. Data preparation](#Data preparation)
+[2. Data preparation](#Data_reparation)
[3. Configuration](#Configuration)
@@ -16,7 +16,7 @@
[6. Deployment](#Deployment)
- [6.1 Export model](#Export model)
+ [6.1 Export model](#Export_model)
[6.2 Inference](#Inference)
@@ -35,7 +35,7 @@
- CUDA >= 10.1
- cuDNN >= 7.6
-
+
### 1.2 Install PaddleDetection
@@ -51,7 +51,7 @@ pip install -r requirements.txt
For more installation tutorials, please refer to: [Install doc](https://github.com/PaddlePaddle/PaddleDetection/blob/release/2.1/docs/tutorials/INSTALL_cn.md)
-
+
## 2. Data preparation
@@ -165,7 +165,7 @@ python tools/infer.py -c configs/ppyolo/ppyolov2_r50vd_dcn_365e_coco.yml --infer
Use your trained model in Layout Parser
-
+
### 6.1 Export model
diff --git a/ppstructure/table/README.md b/ppstructure/table/README.md
index a8d10b79e507ab59ef2481982a33902e4a95e73e..67c4d8e26d5c615f4a930752005420ba1abcc834 100644
--- a/ppstructure/table/README.md
+++ b/ppstructure/table/README.md
@@ -41,7 +41,7 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# run
-python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
Note: The above model is trained on the PubLayNet dataset and only supports English scanning scenarios. If you need to identify other scenarios, you need to train the model yourself and replace the three fields `det_model_dir`, `rec_model_dir`, `table_model_dir`.
diff --git a/ppstructure/table/README_ch.md b/ppstructure/table/README_ch.md
index 2ded403c371984a447f94268d23ca1c6240cf432..2e90ad33423da347b5a51444f2be53ed2eb67a7a 100644
--- a/ppstructure/table/README_ch.md
+++ b/ppstructure/table/README_ch.md
@@ -1,6 +1,16 @@
# 表格识别
+* [1. 表格识别 pipeline](#1)
+* [2. 性能](#2)
+* [3. 使用](#3)
+ + [3.1 快速开始](#31)
+ + [3.2 训练](#32)
+ + [3.3 评估](#33)
+ + [3.4 预测](#34)
+
+
## 1. 表格识别 pipeline
+
表格识别主要包含三个模型
1. 单行文本检测-DB
2. 单行文本识别-CRNN
@@ -17,6 +27,8 @@
3. 由单行文字的坐标、识别结果和单元格的坐标一起组合出单元格的识别结果。
4. 单元格的识别结果和表格结构一起构造表格的html字符串。
+
+
## 2. 性能
我们在 PubTabNet[1] 评估数据集上对算法进行了评估,性能如下
@@ -26,8 +38,9 @@
| EDD[2] | 88.3 |
| Ours | 93.32 |
+
## 3. 使用
-
+
### 3.1 快速开始
```python
@@ -43,12 +56,12 @@ wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_tab
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar
cd ..
# 执行预测
-python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/ppocr_keys_v1.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=ch --det_limit_side_len=736 --det_limit_type=min --output ../output/table
+python3 table/predict_table.py --det_model_dir=inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=inference/en_ppocr_mobile_v2.0_table_rec_infer --table_model_dir=inference/en_ppocr_mobile_v2.0_table_structure_infer --image_dir=../doc/table/table.jpg --rec_char_dict_path=../ppocr/utils/dict/table_dict.txt --table_char_dict_path=../ppocr/utils/dict/table_structure_dict.txt --rec_char_type=EN --det_limit_side_len=736 --det_limit_type=min --output ../output/table
```
运行完成后,每张图片的excel表格会保存到output字段指定的目录下
note: 上述模型是在 PubLayNet 数据集上训练的表格识别模型,仅支持英文扫描场景,如需识别其他场景需要自己训练模型后替换 `det_model_dir`,`rec_model_dir`,`table_model_dir`三个字段即可。
-
+
### 3.2 训练
在这一章节中,我们仅介绍表格结构模型的训练,[文字检测](../../doc/doc_ch/detection.md)和[文字识别](../../doc/doc_ch/recognition.md)的模型训练请参考对应的文档。
@@ -75,7 +88,7 @@ python3 tools/train.py -c configs/table/table_mv3.yml -o Global.checkpoints=./yo
**注意**:`Global.checkpoints`的优先级高于`Global.pretrain_weights`的优先级,即同时指定两个参数时,优先加载`Global.checkpoints`指定的模型,如果`Global.checkpoints`指定的模型路径有误,会加载`Global.pretrain_weights`指定的模型。
-
+
### 3.3 评估
表格使用 [TEDS(Tree-Edit-Distance-based Similarity)](https://github.com/ibm-aur-nlp/PubTabNet/tree/master/src) 作为模型的评估指标。在进行模型评估之前,需要将pipeline中的三个模型分别导出为inference模型(我们已经提供好),还需要准备评估的gt, gt示例如下:
@@ -100,7 +113,7 @@ python3 table/eval_table.py --det_model_dir=path/to/det_model_dir --rec_model_di
```bash
teds: 93.32
```
-
+
### 3.4 预测
```python
diff --git a/requirements.txt b/requirements.txt
index 351d409092a1f387b720c3ff2d43889170f320a7..6758a59bad20f6ffa271766fc4d0df5ebf4c7a4b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
shapely
-scikit-image==0.17.2
+scikit-image==0.18.3
imgaug==0.4.0
pyclipper
lmdb
@@ -7,4 +7,9 @@ tqdm
numpy
visualdl
python-Levenshtein
-opencv-contrib-python==4.4.0.46
\ No newline at end of file
+opencv-contrib-python==4.4.0.46
+cython
+lxml
+premailer
+openpyxl
+fasttext==0.9.1
\ No newline at end of file
diff --git a/tests/ocr_det_params.txt b/tests/ocr_det_params.txt
deleted file mode 100644
index 6aff66c6aa8591c9f48c81cf857809f956a3cda2..0000000000000000000000000000000000000000
--- a/tests/ocr_det_params.txt
+++ /dev/null
@@ -1,52 +0,0 @@
-===========================train_params===========================
-model_name:ocr_det
-python:python3.7
-gpu_list:0|0,1
-Global.use_gpu:True|True
-Global.auto_cast:null
-Global.epoch_num:lite_train_infer=2|whole_train_infer=300
-Global.save_model_dir:./output/
-Train.loader.batch_size_per_card:lite_train_infer=2|whole_train_infer=4
-Global.pretrained_model:null
-train_model_name:latest
-train_infer_img_dir:./train_data/icdar2015/text_localization/ch4_test_images/
-null:null
-##
-trainer:norm_train|pact_train
-norm_train:tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
-pact_train:deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o
-fpgm_train:null
-distill_train:null
-null:null
-null:null
-##
-===========================eval_params===========================
-eval:tools/eval.py -c configs/det/det_mv3_db.yml -o
-null:null
-##
-===========================infer_params===========================
-Global.save_inference_dir:./output/
-Global.pretrained_model:
-norm_export:tools/export_model.py -c configs/det/det_mv3_db.yml -o
-quant_export:deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o
-fpgm_export:deploy/slim/prune/export_prune_model.py
-distill_export:null
-export1:null
-export2:null
-##
-infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
-infer_export:null
-infer_quant:False
-inference:tools/infer/predict_det.py
---use_gpu:True|False
---enable_mkldnn:True|False
---cpu_threads:1|6
---rec_batch_num:1
---use_tensorrt:False|True
---precision:fp32|fp16|int8
---det_model_dir:
---image_dir:./inference/ch_det_data_50/all-sum-510/
---save_log_path:null
---benchmark:True
-null:null
-
diff --git a/tests/prepare.sh b/tests/prepare.sh
deleted file mode 100644
index 418e5661ad0f315bc60b8fda37742c115b395b7c..0000000000000000000000000000000000000000
--- a/tests/prepare.sh
+++ /dev/null
@@ -1,76 +0,0 @@
-#!/bin/bash
-FILENAME=$1
-# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer']
-MODE=$2
-
-dataline=$(cat ${FILENAME})
-
-# parser params
-IFS=$'\n'
-lines=(${dataline})
-function func_parser_key(){
- strs=$1
- IFS=":"
- array=(${strs})
- tmp=${array[0]}
- echo ${tmp}
-}
-function func_parser_value(){
- strs=$1
- IFS=":"
- array=(${strs})
- tmp=${array[1]}
- echo ${tmp}
-}
-IFS=$'\n'
-# The training params
-model_name=$(func_parser_value "${lines[1]}")
-
-trainer_list=$(func_parser_value "${lines[14]}")
-
-# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer']
-MODE=$2
-
-if [ ${MODE} = "lite_train_infer" ];then
- # pretrain lite train data
- wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
- rm -rf ./train_data/icdar2015
- rm -rf ./train_data/ic15_data
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_lite.tar
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar # todo change to bcebos
-
- cd ./train_data/ && tar xf icdar2015_lite.tar && tar xf ic15_data.tar
- ln -s ./icdar2015_lite ./icdar2015
- cd ../
-elif [ ${MODE} = "whole_train_infer" ];then
- wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
- rm -rf ./train_data/icdar2015
- rm -rf ./train_data/ic15_data
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
- cd ./train_data/ && tar xf icdar2015.tar && tar xf ic15_data.tar && cd ../
-elif [ ${MODE} = "whole_infer" ];then
- wget -nc -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/MobileNetV3_large_x0_5_pretrained.pdparams
- rm -rf ./train_data/icdar2015
- rm -rf ./train_data/ic15_data
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015_infer.tar
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
- cd ./train_data/ && tar xf icdar2015_infer.tar && tar xf ic15_data.tar
- ln -s ./icdar2015_infer ./icdar2015
- cd ../
-else
- if [ ${model_name} = "ocr_det" ]; then
- eval_model_name="ch_ppocr_mobile_v2.0_det_infer"
- rm -rf ./train_data/icdar2015
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ch_det_data_50.tar
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar
- cd ./inference && tar xf ${eval_model_name}.tar && tar xf ch_det_data_50.tar && cd ../
- else
- rm -rf ./train_data/ic15_data
- eval_model_name="ch_ppocr_mobile_v2.0_rec_infer"
- wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/ic15_data.tar
- wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar
- cd ./inference && tar xf ${eval_model_name}.tar && tar xf ic15_data.tar && cd ../
- fi
-fi
-
diff --git a/tests/readme.md b/tests/readme.md
deleted file mode 100644
index 1c5e0faee90cad9709b6e4d517cbf7830aa2bb8e..0000000000000000000000000000000000000000
--- a/tests/readme.md
+++ /dev/null
@@ -1,58 +0,0 @@
-
-# 介绍
-
-test.sh和params.txt文件配合使用,完成OCR轻量检测和识别模型从训练到预测的流程测试。
-
-# 安装依赖
-- 安装PaddlePaddle >= 2.0
-- 安装PaddleOCR依赖
- ```
- pip3 install -r ../requirements.txt
- ```
-- 安装autolog
- ```
- git clone https://github.com/LDOUBLEV/AutoLog
- cd AutoLog
- pip3 install -r requirements.txt
- python3 setup.py bdist_wheel
- pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl
- cd ../
- ```
-
-# 目录介绍
-
-```bash
-tests/
-├── ocr_det_params.txt # 测试OCR检测模型的参数配置文件
-├── ocr_rec_params.txt # 测试OCR识别模型的参数配置文件
-└── prepare.sh # 完成test.sh运行所需要的数据和模型下载
-└── test.sh # 根据
-```
-
-# 使用方法
-test.sh包含四种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是:
-- 模式1 lite_train_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
-```
-bash test/prepare.sh ./tests/ocr_det_params.txt 'lite_train_infer'
-bash tests/test.sh ./tests/ocr_det_params.txt 'lite_train_infer'
-```
-- 模式2 whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理;
-```
-bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_infer'
-bash tests/test.sh ./tests/ocr_det_params.txt 'whole_infer'
-```
-
-- 模式3 infer 不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度;
-```
-bash tests/prepare.sh ./tests/ocr_det_params.txt 'infer'
-用法1:
-bash tests/test.sh ./tests/ocr_det_params.txt 'infer'
-用法2: 指定GPU卡预测,第三个传入参数为GPU卡号
-bash tests/test.sh ./tests/ocr_det_params.txt 'infer' '1'
-```
-
-模式4: whole_train_infer , CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度
-```
-bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_train_infer'
-bash tests/test.sh ./tests/ocr_det_params.txt 'whole_train_infer'
-```
diff --git a/tools/eval.py b/tools/eval.py
index 0120baab0f34d5fadbbf4df20d92d6b62dd176a2..28247bc57450aaf067fcb405674098eacb990166 100755
--- a/tools/eval.py
+++ b/tools/eval.py
@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric
-from ppocr.utils.save_load import init_model, load_pretrained_params
+from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.utility import print_dict
import tools.program as program
@@ -54,13 +54,13 @@ def main():
config['Architecture']["Head"]['out_channels'] = char_num
model = build_model(config['Architecture'])
- use_srn = config['Architecture']['algorithm'] == "SRN"
+ extra_input = config['Architecture']['algorithm'] in ["SRN", "SAR"]
if "model_type" in config['Architecture'].keys():
model_type = config['Architecture']['model_type']
else:
model_type = None
- best_model_dict = init_model(config, model)
+ best_model_dict = load_dygraph_params(config, model, logger, None)
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():
@@ -71,7 +71,7 @@ def main():
# start eval
metric = program.eval(model, valid_dataloader, post_process_class,
- eval_class, model_type, use_srn)
+ eval_class, model_type, extra_input)
logger.info('metric eval ***************')
for k, v in metric.items():
logger.info('{}:{}'.format(k, v))
diff --git a/tools/export_center.py b/tools/export_center.py
new file mode 100644
index 0000000000000000000000000000000000000000..c46e8b9d58997b9b66c6ce81b2558ecd4cad0e81
--- /dev/null
+++ b/tools/export_center.py
@@ -0,0 +1,77 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import pickle
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
+
+from ppocr.data import build_dataloader
+from ppocr.modeling.architectures import build_model
+from ppocr.postprocess import build_post_process
+from ppocr.utils.save_load import init_model, load_dygraph_params
+from ppocr.utils.utility import print_dict
+import tools.program as program
+
+
+def main():
+ global_config = config['Global']
+ # build dataloader
+ config['Eval']['dataset']['name'] = config['Train']['dataset']['name']
+ config['Eval']['dataset']['data_dir'] = config['Train']['dataset'][
+ 'data_dir']
+ config['Eval']['dataset']['label_file_list'] = config['Train']['dataset'][
+ 'label_file_list']
+ eval_dataloader = build_dataloader(config, 'Eval', device, logger)
+
+ # build post process
+ post_process_class = build_post_process(config['PostProcess'],
+ global_config)
+
+ # build model
+ # for rec algorithm
+ if hasattr(post_process_class, 'character'):
+ char_num = len(getattr(post_process_class, 'character'))
+ config['Architecture']["Head"]['out_channels'] = char_num
+
+ #set return_features = True
+ config['Architecture']["Head"]["return_feats"] = True
+
+ model = build_model(config['Architecture'])
+
+ best_model_dict = load_dygraph_params(config, model, logger, None)
+ if len(best_model_dict):
+ logger.info('metric in ckpt ***************')
+ for k, v in best_model_dict.items():
+ logger.info('{}:{}'.format(k, v))
+
+ # get features from train data
+ char_center = program.get_center(model, eval_dataloader, post_process_class)
+
+ #serialize to disk
+ with open("train_center.pkl", 'wb') as f:
+ pickle.dump(char_center, f)
+ return
+
+
+if __name__ == '__main__':
+ config, device, logger, vdl_writer = program.preprocess()
+ main()
diff --git a/tools/export_model.py b/tools/export_model.py
index 785aca10e46200bda49bdff2b89ba00cafbe7a20..64a0d4036303716a632eb93c53f2478f32b42848 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -49,6 +49,12 @@ def export_single_model(model, arch_config, save_path, logger):
]
]
model = to_static(model, input_spec=other_shape)
+ elif arch_config["algorithm"] == "SAR":
+ other_shape = [
+ paddle.static.InputSpec(
+ shape=[None, 3, 48, 160], dtype="float32"),
+ ]
+ model = to_static(model, input_spec=other_shape)
else:
infer_shape = [3, -1, -1]
if arch_config["model_type"] == "rec":
@@ -60,6 +66,8 @@ def export_single_model(model, arch_config, save_path, logger):
"When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
)
infer_shape[-1] = 100
+ if arch_config["algorithm"] == "NRTR":
+ infer_shape = [1, 32, 100]
elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488]
model = to_static(
@@ -93,6 +101,9 @@ def main():
for key in config["Architecture"]["Models"]:
config["Architecture"]["Models"][key]["Head"][
"out_channels"] = char_num
+ # just one final tensor needs to to exported for inference
+ config["Architecture"]["Models"][key][
+ "return_all_feats"] = False
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py
index 53e50bd6d1d1a2bd07b9f1204b9f56594c669d13..1c68494861e60b4aaef541a4e247071944cf420c 100755
--- a/tools/infer/predict_cls.py
+++ b/tools/infer/predict_cls.py
@@ -131,14 +131,9 @@ def main(args):
img_list.append(img)
try:
img_list, cls_res, predict_time = text_classifier(img_list)
- except:
+ except Exception as E:
logger.info(traceback.format_exc())
- logger.info(
- "ERROR!!!! \n"
- "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
- "If your model has tps module: "
- "TPS does not support variable shape.\n"
- "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
+ logger.info(E)
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py
index 5c75e0c480eac6796d6d4b7075d1b38d254380fd..b24ad2bbb504caf1f262b4e47625348ce32d6fce 100755
--- a/tools/infer/predict_det.py
+++ b/tools/infer/predict_det.py
@@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
-
+import json
logger = get_logger()
@@ -89,6 +89,14 @@ class TextDetector(object):
postprocess_params["sample_pts_num"] = 2
postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3
+ elif self.det_algorithm == "PSE":
+ postprocess_params['name'] = 'PSEPostProcess'
+ postprocess_params["thresh"] = args.det_pse_thresh
+ postprocess_params["box_thresh"] = args.det_pse_box_thresh
+ postprocess_params["min_area"] = args.det_pse_min_area
+ postprocess_params["box_type"] = args.det_pse_box_type
+ postprocess_params["scale"] = args.det_pse_scale
+ self.det_pse_box_type = args.det_pse_box_type
else:
logger.info("unknown det_algorithm:{}".format(self.det_algorithm))
sys.exit(0)
@@ -209,7 +217,7 @@ class TextDetector(object):
preds['f_score'] = outputs[1]
preds['f_tco'] = outputs[2]
preds['f_tvo'] = outputs[3]
- elif self.det_algorithm == 'DB':
+ elif self.det_algorithm in ['DB', 'PSE']:
preds['maps'] = outputs[0]
else:
raise NotImplementedError
@@ -217,7 +225,9 @@ class TextDetector(object):
#self.predictor.try_shrink_memory()
post_result = self.postprocess_op(preds, shape_list)
dt_boxes = post_result[0]['points']
- if self.det_algorithm == "SAST" and self.det_sast_polygon:
+ if (self.det_algorithm == "SAST" and
+ self.det_sast_polygon) or (self.det_algorithm == "PSE" and
+ self.det_pse_box_type == 'poly'):
dt_boxes = self.filter_tag_det_res_only_clip(dt_boxes, ori_im.shape)
else:
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
@@ -243,6 +253,7 @@ if __name__ == "__main__":
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
+ save_results = []
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
@@ -256,8 +267,11 @@ if __name__ == "__main__":
if count > 0:
total_time += elapse
count += 1
-
- logger.info("Predict time of {}: {}".format(image_file, elapse))
+ save_pred = os.path.basename(image_file) + "\t" + str(
+ json.dumps(np.array(dt_boxes).astype(np.int32).tolist())) + "\n"
+ save_results.append(save_pred)
+ logger.info(save_pred)
+ logger.info("The predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_text_det_res(dt_boxes, image_file)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
@@ -265,5 +279,8 @@ if __name__ == "__main__":
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
+ with open(os.path.join(draw_img_save, "det_results.txt"), 'w') as f:
+ f.writelines(save_results)
+ f.close()
if args.benchmark:
text_detector.autolog.report()
diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py
index cd6c2005a7cc77c356e3f004cd586a84676ea7fa..5029d6059346a00062418d8d1b6cb029b0110643 100755
--- a/tools/infer/predict_e2e.py
+++ b/tools/infer/predict_e2e.py
@@ -74,7 +74,7 @@ class TextE2E(object):
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
- self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
+ self.predictor, self.input_tensor, self.output_tensors, _ = utility.create_predictor(
args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 97dfa5214628123d0c9b7edd7d94060a2bfd2a1e..936994a215d10d543537b29cb41bfa42b42590c7 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -13,7 +13,7 @@
# limitations under the License.
import os
import sys
-
+from PIL import Image
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
@@ -38,26 +38,34 @@ logger = get_logger()
class TextRecognizer(object):
def __init__(self, args):
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
- self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
postprocess_params = {
'name': 'CTCLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
if self.rec_algorithm == "SRN":
postprocess_params = {
'name': 'SRNLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "RARE":
postprocess_params = {
'name': 'AttnLabelDecode',
- "character_type": args.rec_char_type,
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
+ elif self.rec_algorithm == 'NRTR':
+ postprocess_params = {
+ 'name': 'NRTRLabelDecode',
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
+ elif self.rec_algorithm == "SAR":
+ postprocess_params = {
+ 'name': 'SARLabelDecode',
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
@@ -87,9 +95,19 @@ class TextRecognizer(object):
def resize_norm_img(self, img, max_wh_ratio):
imgC, imgH, imgW = self.rec_image_shape
+ if self.rec_algorithm == 'NRTR':
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+ # return padding_im
+ image_pil = Image.fromarray(np.uint8(img))
+ img = image_pil.resize([100, 32], Image.ANTIALIAS)
+ img = np.array(img)
+ norm_img = np.expand_dims(img, -1)
+ norm_img = norm_img.transpose((2, 0, 1))
+ return norm_img.astype(np.float32) / 128. - 1.
+
assert imgC == img.shape[2]
- if self.character_type == "ch":
- imgW = int((32 * max_wh_ratio))
+ max_wh_ratio = max(max_wh_ratio, imgW / imgH)
+ imgW = int((32 * max_wh_ratio))
h, w = img.shape[:2]
ratio = w / float(h)
if math.ceil(imgH * ratio) > imgW:
@@ -177,6 +195,41 @@ class TextRecognizer(object):
return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
gsrm_slf_attn_bias2)
+ def resize_norm_img_sar(self, img, image_shape,
+ width_downsample_ratio=0.25):
+ imgC, imgH, imgW_min, imgW_max = image_shape
+ h = img.shape[0]
+ w = img.shape[1]
+ valid_ratio = 1.0
+ # make sure new_width is an integral multiple of width_divisor.
+ width_divisor = int(1 / width_downsample_ratio)
+ # resize
+ ratio = w / float(h)
+ resize_w = math.ceil(imgH * ratio)
+ if resize_w % width_divisor != 0:
+ resize_w = round(resize_w / width_divisor) * width_divisor
+ if imgW_min is not None:
+ resize_w = max(imgW_min, resize_w)
+ if imgW_max is not None:
+ valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
+ resize_w = min(imgW_max, resize_w)
+ resized_image = cv2.resize(img, (resize_w, imgH))
+ resized_image = resized_image.astype('float32')
+ # norm
+ if image_shape[0] == 1:
+ resized_image = resized_image / 255
+ resized_image = resized_image[np.newaxis, :]
+ else:
+ resized_image = resized_image.transpose((2, 0, 1)) / 255
+ resized_image -= 0.5
+ resized_image /= 0.5
+ resize_shape = resized_image.shape
+ padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
+ padding_im[:, :, 0:resize_w] = resized_image
+ pad_shape = padding_im.shape
+
+ return padding_im, resize_shape, pad_shape, valid_ratio
+
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
@@ -199,11 +252,19 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
- if self.rec_algorithm != "SRN":
+ if self.rec_algorithm != "SRN" and self.rec_algorithm != "SAR":
norm_img = self.resize_norm_img(img_list[indices[ino]],
max_wh_ratio)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
+ elif self.rec_algorithm == "SAR":
+ norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
+ img_list[indices[ino]], self.rec_image_shape)
+ norm_img = norm_img[np.newaxis, :]
+ valid_ratio = np.expand_dims(valid_ratio, axis=0)
+ valid_ratios = []
+ valid_ratios.append(valid_ratio)
+ norm_img_batch.append(norm_img)
else:
norm_img = self.process_image_srn(
img_list[indices[ino]], self.rec_image_shape, 8, 25)
@@ -249,17 +310,38 @@ class TextRecognizer(object):
if self.benchmark:
self.autolog.times.stamp()
preds = {"predict": outputs[2]}
+ elif self.rec_algorithm == "SAR":
+ valid_ratios = np.concatenate(valid_ratios)
+ inputs = [
+ norm_img_batch,
+ valid_ratios,
+ ]
+ input_names = self.predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = self.predictor.get_input_handle(input_names[
+ i])
+ input_tensor.copy_from_cpu(inputs[i])
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ if self.benchmark:
+ self.autolog.times.stamp()
+ preds = outputs[0]
else:
self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.run()
-
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
if self.benchmark:
self.autolog.times.stamp()
- preds = outputs[0]
+ if len(outputs) != 1:
+ preds = outputs
+ else:
+ preds = outputs[0]
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
@@ -278,7 +360,7 @@ def main(args):
if args.warmup:
img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8)
for i in range(2):
- res = text_recognizer([img])
+ res = text_recognizer([img] * int(args.rec_batch_num))
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py
index eae0e27cd284ccce9f41f0c20b05dee09f46fc84..b5edd01589685a29a37dc20064b0d58e9d776fec 100755
--- a/tools/infer/predict_system.py
+++ b/tools/infer/predict_system.py
@@ -173,6 +173,9 @@ def main(args):
logger.info("The predict total time is {}".format(time.time() - _st))
logger.info("\nThe predict total time is {}".format(total_time))
+ if args.benchmark:
+ text_sys.text_detector.autolog.report()
+ text_sys.text_recognizer.autolog.report()
if __name__ == "__main__":
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 707328f28468db86c5061795d04713dc3b21a5cb..41a3c0f14b6378751a367a3709ad7943ee981a4e 100755
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -35,7 +35,7 @@ def init_args():
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
- parser.add_argument("--min_subgraph_size", type=int, default=10)
+ parser.add_argument("--min_subgraph_size", type=int, default=15)
parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
@@ -63,11 +63,17 @@ def init_args():
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2)
parser.add_argument("--det_sast_polygon", type=str2bool, default=False)
+ # PSE parmas
+ parser.add_argument("--det_pse_thresh", type=float, default=0)
+ parser.add_argument("--det_pse_box_thresh", type=float, default=0.85)
+ parser.add_argument("--det_pse_min_area", type=float, default=16)
+ parser.add_argument("--det_pse_box_type", type=str, default='box')
+ parser.add_argument("--det_pse_scale", type=int, default=1)
+
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
- parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
@@ -236,11 +242,11 @@ def create_predictor(args, mode, logger):
max_input_shape.update(max_pact_shape)
opt_input_shape.update(opt_pact_shape)
elif mode == "rec":
- min_input_shape = {"x": [args.rec_batch_num, 3, 32, 10]}
+ min_input_shape = {"x": [1, 3, 32, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 32, 2000]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 32, 320]}
elif mode == "cls":
- min_input_shape = {"x": [args.rec_batch_num, 3, 48, 10]}
+ min_input_shape = {"x": [1, 3, 48, 10]}
max_input_shape = {"x": [args.rec_batch_num, 3, 48, 2000]}
opt_input_shape = {"x": [args.rec_batch_num, 3, 48, 320]}
else:
@@ -261,10 +267,11 @@ def create_predictor(args, mode, logger):
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
-
+ if args.precision == "fp16":
+ config.enable_mkldnn_bfloat16()
# enable memory optim
config.enable_memory_optim()
- #config.disable_glog_info()
+ config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
if mode == 'table':
diff --git a/tools/infer_det.py b/tools/infer_det.py
index a964cd28c934504ce79ea4873d3345295c1266e5..ce16da8dc5fffb3f5fdc633aeb00a386a2d60d4f 100755
--- a/tools/infer_det.py
+++ b/tools/infer_det.py
@@ -34,23 +34,21 @@ import paddle
from ppocr.data import create_operators, transform
from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
-from ppocr.utils.save_load import init_model
+from ppocr.utils.save_load import init_model, load_dygraph_params
from ppocr.utils.utility import get_image_file_list
import tools.program as program
-def draw_det_res(dt_boxes, config, img, img_name):
+def draw_det_res(dt_boxes, config, img, img_name, save_path):
if len(dt_boxes) > 0:
import cv2
src_im = img
for box in dt_boxes:
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
- save_det_path = os.path.dirname(config['Global'][
- 'save_res_path']) + "/det_results/"
- if not os.path.exists(save_det_path):
- os.makedirs(save_det_path)
- save_path = os.path.join(save_det_path, os.path.basename(img_name))
+ if not os.path.exists(save_path):
+ os.makedirs(save_path)
+ save_path = os.path.join(save_path, os.path.basename(img_name))
cv2.imwrite(save_path, src_im)
logger.info("The detected Image saved in {}".format(save_path))
@@ -61,8 +59,7 @@ def main():
# build model
model = build_model(config['Architecture'])
- init_model(config, model)
-
+ _ = load_dygraph_params(config, model, logger, None)
# build post process
post_process_class = build_post_process(config['PostProcess'])
@@ -96,17 +93,41 @@ def main():
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, shape_list)
- boxes = post_result[0]['points']
- # write result
+
+ src_img = cv2.imread(file)
+
dt_boxes_json = []
- for box in boxes:
- tmp_json = {"transcription": ""}
- tmp_json['points'] = box.tolist()
- dt_boxes_json.append(tmp_json)
+ # parser boxes if post_result is dict
+ if isinstance(post_result, dict):
+ det_box_json = {}
+ for k in post_result.keys():
+ boxes = post_result[k][0]['points']
+ dt_boxes_list = []
+ for box in boxes:
+ tmp_json = {"transcription": ""}
+ tmp_json['points'] = box.tolist()
+ dt_boxes_list.append(tmp_json)
+ det_box_json[k] = dt_boxes_list
+ save_det_path = os.path.dirname(config['Global'][
+ 'save_res_path']) + "/det_results_{}/".format(k)
+ draw_det_res(boxes, config, src_img, file, save_det_path)
+ else:
+ boxes = post_result[0]['points']
+ dt_boxes_json = []
+ # write result
+ for box in boxes:
+ tmp_json = {"transcription": ""}
+ tmp_json['points'] = box.tolist()
+ dt_boxes_json.append(tmp_json)
+ save_det_path = os.path.dirname(config['Global'][
+ 'save_res_path']) + "/det_results/"
+ draw_det_res(boxes, config, src_img, file, save_det_path)
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
- src_img = cv2.imread(file)
- draw_det_res(boxes, config, src_img, file)
+
+ save_det_path = os.path.dirname(config['Global'][
+ 'save_res_path']) + "/det_results/"
+ draw_det_res(boxes, config, src_img, file, save_det_path)
logger.info("success!")
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
index 09f5a0c767b15c312cdfbe8ed695ea06bdc8cdc4..29d4b530dfcfb8a3201e12b38c9b9f186f34b627 100755
--- a/tools/infer_rec.py
+++ b/tools/infer_rec.py
@@ -74,6 +74,10 @@ def main():
'image', 'encoder_word_pos', 'gsrm_word_pos',
'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
]
+ elif config['Architecture']['algorithm'] == "SAR":
+ op[op_name]['keep_keys'] = [
+ 'image', 'valid_ratio'
+ ]
else:
op[op_name]['keep_keys'] = ['image']
transforms.append(op)
@@ -106,11 +110,16 @@ def main():
paddle.to_tensor(gsrm_slf_attn_bias1_list),
paddle.to_tensor(gsrm_slf_attn_bias2_list)
]
+ if config['Architecture']['algorithm'] == "SAR":
+ valid_ratio = np.expand_dims(batch[-1], axis=0)
+ img_metas = [paddle.to_tensor(valid_ratio)]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
if config['Architecture']['algorithm'] == "SRN":
preds = model(images, others)
+ elif config['Architecture']['algorithm'] == "SAR":
+ preds = model(images, img_metas)
else:
preds = model(images)
post_result = post_process_class(preds)
@@ -121,7 +130,7 @@ def main():
if len(post_result[key][0]) >= 2:
rec_info[key] = {
"label": post_result[key][0][0],
- "score": post_result[key][0][1],
+ "score": float(post_result[key][0][1]),
}
info = json.dumps(rec_info)
else:
diff --git a/tools/program.py b/tools/program.py
index 595fe4cb96c0379b1a33504e0ebdd85e70086340..798e6dff297ad1149942488cca1d5540f1924867 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -31,6 +31,7 @@ from ppocr.utils.stats import TrainingStats
from ppocr.utils.save_load import save_model
from ppocr.utils.utility import print_dict
from ppocr.utils.logging import get_logger
+from ppocr.utils import profiler
from ppocr.data import build_dataloader
import numpy as np
@@ -42,6 +43,13 @@ class ArgsParser(ArgumentParser):
self.add_argument("-c", "--config", help="configuration file to use")
self.add_argument(
"-o", "--opt", nargs='+', help="set configuration options")
+ self.add_argument(
+ '-p',
+ '--profiler_options',
+ type=str,
+ default=None,
+ help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
+ )
def parse_args(self, argv=None):
args = super(ArgsParser, self).parse_args(argv)
@@ -158,6 +166,7 @@ def train(config,
epoch_num = config['Global']['epoch_num']
print_batch_step = config['Global']['print_batch_step']
eval_batch_step = config['Global']['eval_batch_step']
+ profiler_options = config['profiler_options']
global_step = 0
if 'global_step' in pre_best_model_dict:
@@ -186,10 +195,13 @@ def train(config,
model.train()
use_srn = config['Architecture']['algorithm'] == "SRN"
- try:
+ extra_input = config['Architecture'][
+ 'algorithm'] in ["SRN", "NRTR", "SAR", "SEED"]
+ try:
model_type = config['Architecture']['model_type']
- except:
+ except:
model_type = None
+ algorithm = config['Architecture']['algorithm']
if 'start_epoch' in best_model_dict:
start_epoch = best_model_dict['start_epoch']
@@ -206,6 +218,7 @@ def train(config,
max_iter = len(train_dataloader) - 1 if platform.system(
) == "Windows" else len(train_dataloader)
for idx, batch in enumerate(train_dataloader):
+ profiler.add_profiler_step(profiler_options)
train_reader_cost += time.time() - batch_start
if idx >= max_iter:
break
@@ -213,7 +226,7 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
- if use_srn or model_type == 'table':
+ if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
else:
preds = model(images)
@@ -277,7 +290,7 @@ def train(config,
post_process_class,
eval_class,
model_type,
- use_srn=use_srn)
+ extra_input=extra_input)
cur_metric_str = 'cur metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
logger.info(cur_metric_str)
@@ -348,8 +361,8 @@ def eval(model,
valid_dataloader,
post_process_class,
eval_class,
- model_type,
- use_srn=False):
+ model_type=None,
+ extra_input=False):
model.eval()
with paddle.no_grad():
total_frame = 0.0
@@ -362,7 +375,7 @@ def eval(model,
break
images = batch[0]
start = time.time()
- if use_srn or model_type == 'table':
+ if model_type == 'table' or extra_input:
preds = model(images, data=batch[1:])
else:
preds = model(images)
@@ -386,10 +399,76 @@ def eval(model,
return metric
+def update_center(char_center, post_result, preds):
+ result, label = post_result
+ feats, logits = preds
+ logits = paddle.argmax(logits, axis=-1)
+ feats = feats.numpy()
+ logits = logits.numpy()
+
+ for idx_sample in range(len(label)):
+ if result[idx_sample][0] == label[idx_sample][0]:
+ feat = feats[idx_sample]
+ logit = logits[idx_sample]
+ for idx_time in range(len(logit)):
+ index = logit[idx_time]
+ if index in char_center.keys():
+ char_center[index][0] = (
+ char_center[index][0] * char_center[index][1] +
+ feat[idx_time]) / (char_center[index][1] + 1)
+ char_center[index][1] += 1
+ else:
+ char_center[index] = [feat[idx_time], 1]
+ return char_center
+
+
+def get_center(model, eval_dataloader, post_process_class):
+ pbar = tqdm(total=len(eval_dataloader), desc='get center:')
+ max_iter = len(eval_dataloader) - 1 if platform.system(
+ ) == "Windows" else len(eval_dataloader)
+ char_center = dict()
+ for idx, batch in enumerate(eval_dataloader):
+ if idx >= max_iter:
+ break
+ images = batch[0]
+ start = time.time()
+ preds = model(images)
+
+ batch = [item.numpy() for item in batch]
+ # Obtain usable results from post-processing methods
+ total_time += time.time() - start
+ # Evaluate the results of the current batch
+ post_result = post_process_class(preds, batch[1])
+
+ #update char_center
+ char_center = update_center(char_center, post_result, preds)
+ pbar.update(1)
+
+ pbar.close()
+ for key in char_center.keys():
+ char_center[key] = char_center[key][0]
+ return char_center
+
+
def preprocess(is_train=False):
FLAGS = ArgsParser().parse_args()
+ profiler_options = FLAGS.profiler_options
config = load_config(FLAGS.config)
merge_config(FLAGS.opt)
+ profile_dic = {"profiler_options": FLAGS.profiler_options}
+ merge_config(profile_dic)
+
+ if is_train:
+ # save_config
+ save_model_dir = config['Global']['save_model_dir']
+ os.makedirs(save_model_dir, exist_ok=True)
+ with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
+ yaml.dump(
+ dict(config), f, default_flow_style=False, sort_keys=False)
+ log_file = '{}/train.log'.format(save_model_dir)
+ else:
+ log_file = None
+ logger = get_logger(name='root', log_file=log_file)
# check if set use_gpu=True in paddlepaddle cpu version
use_gpu = config['Global']['use_gpu']
@@ -398,24 +477,20 @@ def preprocess(is_train=False):
alg = config['Architecture']['algorithm']
assert alg in [
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
- 'CLS', 'PGNet', 'Distillation', 'TableAttn'
+ 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
+ 'SEED'
]
+ windows_not_support_list = ['PSE']
+ if platform.system() == "Windows" and alg in windows_not_support_list:
+ logger.warning('{} is not support in Windows now'.format(
+ windows_not_support_list))
+ sys.exit()
device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
device = paddle.set_device(device)
config['Global']['distributed'] = dist.get_world_size() != 1
- if is_train:
- # save_config
- save_model_dir = config['Global']['save_model_dir']
- os.makedirs(save_model_dir, exist_ok=True)
- with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
- yaml.dump(
- dict(config), f, default_flow_style=False, sort_keys=False)
- log_file = '{}/train.log'.format(save_model_dir)
- else:
- log_file = None
- logger = get_logger(name='root', log_file=log_file)
+
if config['Global']['use_visualdl']:
from visualdl import LogWriter
save_model_dir = config['Global']['save_model_dir']