@@ -47,7 +47,7 @@ A_CTC Loss定义如下:
实验中,λ = 0.1. ACE loss实现代码见: [ace_loss.py](../../ppocr/losses/ace_loss.py)
## 3. C-CTC Loss
-C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 < A Discriminative Feature Learning Approach for Deep Face Recognition>. 最早用于人脸识别任务,用于增大累间距离,减小类内距离, 是Metric Learning领域一种较早的、也比较常用的一种算法。
+C-CTC Loss是CTC Loss + Center Loss的简称。 其中Center Loss出自论文 < A Discriminative Feature Learning Approach for Deep Face Recognition>. 最早用于人脸识别任务,用于增大类间距离,减小类内距离, 是Metric Learning领域一种较早的、也比较常用的一种算法。
在中文OCR识别任务中,通过对badcase分析, 我们发现中文识别的一大难点是相似字符多,容易误识。 由此我们想到是否可以借鉴Metric Learing的想法, 增大相似字符的类间距,从而提高识别准确率。然而,MetricLearning主要用于图像识别领域,训练数据的标签为一个固定的值;而对于OCR识别来说,其本质上是一个序列识别任务,特征和label之间并不具有显式的对齐关系,因此两者如何结合依然是一个值得探索的方向。
通过尝试Arcmargin, Cosmargin等方法, 我们最终发现Centerloss 有助于进一步提升识别的准确率。C_CTC Loss定义如下:
diff --git a/doc/doc_ch/environment.md b/doc/doc_ch/environment.md
index a55008100312f50e0786dca49eaa186afc49b3b2..3a266c4bb8fe5516f844bea9f0aa21359d51660e 100644
--- a/doc/doc_ch/environment.md
+++ b/doc/doc_ch/environment.md
@@ -1,6 +1,13 @@
# 运行环境准备
+
Windows和Mac用户推荐使用Anaconda搭建Python环境,Linux用户建议使用docker搭建PyThon环境。
+推荐环境:
+- PaddlePaddle >= 2.0.0 (2.1.2)
+- python3.7
+- CUDA10.1 / CUDA10.2
+- CUDNN 7.6
+
如果对于Python环境熟悉的用户可以直接跳到第2步安装PaddlePaddle。
* [1. Python环境搭建](#1)
@@ -123,13 +130,13 @@ Windows和Mac用户推荐使用Anaconda搭建Python环境,Linux用户建议使
# !! Contents within this block are managed by 'conda init' !!
__conda_setup="$('/Users/xxx/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
- eval "$__conda_setup"
+ eval "$__conda_setup"
else
- if [ -f "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh" ]; then
- . "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh"
- else
- export PATH="/Users/xxx/opt/anaconda3/bin:$PATH"
- fi
+ if [ -f "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh" ]; then
+ . "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh"
+ else
+ export PATH="/Users/xxx/opt/anaconda3/bin:$PATH"
+ fi
fi
unset __conda_setup
# <<< conda initialize <<<
@@ -294,11 +301,12 @@ cd /home/Projects
# 首次运行需创建一个docker容器,再次运行时不需要运行当前命令
# 创建一个名字为ppocr的docker容器,并将当前目录映射到容器的/paddle目录下
-如果您希望在CPU环境下使用docker,使用docker而不是nvidia-docker创建docker
-sudo docker run --name ppocr -v $PWD:/paddle --network=host -it paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82 /bin/bash
+#如果您希望在CPU环境下使用docker,使用docker而不是nvidia-docker创建docker
+sudo docker run --name ppocr -v $PWD:/paddle --network=host -it registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda10.2-cudnn7 /bin/bash
-如果使用CUDA10,请运行以下命令创建容器,设置docker容器共享内存shm-size为64G,建议设置32G以上
-sudo nvidia-docker run --name ppocr -v $PWD:/paddle --shm-size=64G --network=host -it paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82 /bin/bash
+#如果使用CUDA10,请运行以下命令创建容器,设置docker容器共享内存shm-size为64G,建议设置32G以上
+# 如果是CUDA11+CUDNN8,推荐使用镜像registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda11.2-cudnn8
+sudo nvidia-docker run --name ppocr -v $PWD:/paddle --shm-size=64G --network=host -it registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda10.2-cudnn7 /bin/bash
# ctrl+P+Q可退出docker 容器,重新进入docker 容器使用如下命令
sudo docker container exec -it ppocr /bin/bash
@@ -321,8 +329,3 @@ python3 -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple
```
更多的版本需求,请参照[飞桨官网安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
-
-
-
-
-
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index b9be1e4cb2d1b256a05b82ef5d6db49dfcb2f31f..4e0f1d131e2547f0d4a8bdf35c0f4a6f8bf2e7a3 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -273,7 +273,7 @@ python3 tools/export_model.py -c configs/rec/rec_r34_vd_none_bilstm_ctc.yml -o G
CRNN 文本识别模型推理,可以执行如下命令:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rec_crnn/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rec_crnn/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```
![](../imgs_words_en/word_336.png)
@@ -288,7 +288,7 @@ Predicts of ./doc/imgs_words_en/word_336.png:('super', 0.9999073)
- 训练时采用的图像分辨率不同,训练上述模型采用的图像分辨率是[3,32,100],而中文模型训练时,为了保证长文本的识别效果,训练时采用的图像分辨率是[3, 32, 320]。预测推理程序默认的的形状参数是训练中文采用的图像分辨率,即[3, 32, 320]。因此,这里推理上述英文模型时,需要通过参数rec_image_shape设置识别图像的形状。
-- 字符列表,DTRB论文中实验只是针对26个小写英文本母和10个数字进行实验,总共36个字符。所有大小字符都转成了小写字符,不在上面列表的字符都忽略,认为是空格。因此这里没有输入字符字典,而是通过如下命令生成字典.因此在推理时需要设置参数rec_char_type,指定为英文"en"。
+- 字符列表,DTRB论文中实验只是针对26个小写英文本母和10个数字进行实验,总共36个字符。所有大小字符都转成了小写字符,不在上面列表的字符都忽略,认为是空格。因此这里没有输入字符字典,而是通过如下命令生成字典.因此在推理时需要设置参数rec_char_dict_path,指定为英文字典"./ppocr/utils/ic15_dict.txt"。
```
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
@@ -303,15 +303,15 @@ dict_character = list(self.character_str)
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
--rec_model_dir="./inference/srn/" \
--rec_image_shape="1, 64, 256" \
- --rec_char_type="en" \
+ --rec_char_dict_path="./ppocr/utils/ic15_dict.txt" \
--rec_algorithm="SRN"
```
### 4. 自定义文本识别字典的推理
-如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`
+如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_dict_path="your text dict path"
```
@@ -320,7 +320,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
```
![](../imgs_words/korean/1.jpg)
@@ -388,7 +388,7 @@ python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --de
下面给出基于EAST文本检测和STAR-Net文本识别执行命令:
```
-python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
+python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```
执行命令后,识别结果图像如下:
diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md
index 52f978a734cbc750f4e2f36bb3ff28b2e67ab612..bb7d01712a85c92a02109e41814059e6c98c7cdc 100644
--- a/doc/doc_ch/recognition.md
+++ b/doc/doc_ch/recognition.md
@@ -159,7 +159,6 @@ PaddleOCR内置了一部分字典,可以按需使用。
- 自定义字典
如需自定义dic文件,请在 `configs/rec/rec_icdar15_train.yml` 中添加 `character_dict_path` 字段, 指向您的字典路径。
-并将 `character_type` 设置为 `ch`。
### 1.4 添加空格类别
@@ -246,8 +245,6 @@ Global:
...
# 添加自定义字典,如修改字典请将路径指向新字典
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
- # 修改字符类型
- character_type: ch
...
# 识别空格
use_space_char: True
@@ -311,18 +308,18 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
按语系划分,目前PaddleOCR支持的语种有:
-| 配置文件 | 算法名称 | backbone | trans | seq | pred | language | character_type |
-| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
-| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 中文繁体 | chinese_cht|
-| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语(区分大小写) | EN |
-| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 | french |
-| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 | german |
-| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 | japan |
-| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 | korean |
-| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 拉丁字母 | latin |
-| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯字母 | ar |
-| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 斯拉夫字母 | cyrillic |
-| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 梵文字母 | devanagari |
+| 配置文件 | 算法名称 | backbone | trans | seq | pred | language |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 中文繁体 |
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语(区分大小写) |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 |
+| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 拉丁字母 |
+| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯字母 |
+| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 斯拉夫字母 |
+| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 梵文字母 |
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
diff --git a/doc/doc_ch/training.md b/doc/doc_ch/training.md
index fb7f94a9e86cf392421ab6ed6f99cf2d49390096..c6c7b87d9925197b36a246c651ab7179ff9d2e81 100644
--- a/doc/doc_ch/training.md
+++ b/doc/doc_ch/training.md
@@ -129,3 +129,9 @@ PaddleOCR主要聚焦通用OCR,如果有垂类需求,您可以用PaddleOCR+
A:识别模型训练初期acc为0是正常的,多训一段时间指标就上来了。
+
+***
+具体的训练教程可点击下方链接跳转:
+- [文本检测模型训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/detection.md)
+- [文本识别模型训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/recognition.md)
+- [文本方向分类器训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/angle_class.md)
\ No newline at end of file
diff --git a/doc/doc_en/config_en.md b/doc/doc_en/config_en.md
index aa78263e4b73a3ac35250e5483a394ab77450c90..ce76da9b2f39532b387e3e45ca2ff497b0408635 100644
--- a/doc/doc_en/config_en.md
+++ b/doc/doc_en/config_en.md
@@ -1,4 +1,4 @@
-# Configuration
+# Configuration
- [1. Optional Parameter List](#1-optional-parameter-list)
- [2. Intorduction to Global Parameters of Configuration File](#2-intorduction-to-global-parameters-of-configuration-file)
@@ -37,9 +37,8 @@ Take rec_chinese_lite_train_v2.0.yml as an example
| checkpoints | set model parameter path | None | Used to load parameters after interruption to continue training|
| use_visualdl | Set whether to enable visualdl for visual log display | False | [Tutorial](https://www.paddlepaddle.org.cn/paddle/visualdl) |
| infer_img | Set inference image path or folder path | ./infer_img | \|
-| character_dict_path | Set dictionary path | ./ppocr/utils/ppocr_keys_v1.txt | \ |
+| character_dict_path | Set dictionary path | ./ppocr/utils/ppocr_keys_v1.txt | If the character_dict_path is None, model can only recognize number and lower letters |
| max_text_length | Set the maximum length of text | 25 | \ |
-| character_type | Set character type | ch | en/ch, the default dict will be used for en, and the custom dict will be used for ch |
| use_space_char | Set whether to recognize spaces | True | Only support in character_type=ch mode |
| label_list | Set the angle supported by the direction classifier | ['0','180'] | Only valid in angle classifier model |
| save_res_path | Set the save address of the test model results | ./output/det_db/predicts_db.txt | Only valid in the text detection model |
@@ -196,40 +195,39 @@ Italian is made up of Latin letters, so after executing the command, you will ge
use_gpu: True
epoch_num: 500
...
- character_type: it # language
character_dict_path: {path/of/dict} # path of dict
-
+
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/ # root directory of training data
label_file_list: ["./train_data/train_list.txt"] # train label path
...
-
+
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/ # root directory of val data
label_file_list: ["./train_data/val_list.txt"] # val label path
...
-
+
```
Currently, the multi-language algorithms supported by PaddleOCR are:
-| Configuration file | Algorithm name | backbone | trans | seq | pred | language | character_type |
-| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
-| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional | chinese_cht|
-| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) | EN |
-| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French | french |
-| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German | german |
-| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese | japan |
-| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean | korean |
-| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin | latin |
-| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic | ar |
-| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic | cyrillic |
-| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari | devanagari |
+| Configuration file | Algorithm name | backbone | trans | seq | pred | language |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional |
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean |
+| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin |
+| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic |
+| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic |
+| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari |
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
diff --git a/doc/doc_en/environment_en.md b/doc/doc_en/environment_en.md
index 96a46cce3010934689e8d95985ca434f49d18886..9aad92cafb809a0eec519808b1c1755403b39318 100644
--- a/doc/doc_en/environment_en.md
+++ b/doc/doc_en/environment_en.md
@@ -1,11 +1,18 @@
# Environment Preparation
+Recommended working environment:
+- PaddlePaddle >= 2.0.0 (2.1.2)
+- python3.7
+- CUDA10.1 / CUDA10.2
+- CUDNN 7.6
+
* [1. Python Environment Setup](#1)
+ [1.1 Windows](#1.1)
+ [1.2 Mac](#1.2)
+ [1.3 Linux](#1.3)
* [2. Install PaddlePaddle 2.0](#2)
+
## 1. Python Environment Setup
@@ -38,7 +45,7 @@
- Check conda to add environment variables and ignore the warning that
-
+
#### 1.1.2 Opening the terminal and creating the conda environment
@@ -69,7 +76,7 @@
# View the current location of python
where python
```
-
+
The above anaconda environment and python environment are installed
@@ -133,13 +140,13 @@ The above anaconda environment and python environment are installed
# !!! Contents within this block are managed by 'conda init' !!!
__conda_setup="$('/Users/xxx/opt/anaconda3/bin/conda' 'shell.bash' 'hook' 2> /dev/null)"
if [ $? -eq 0 ]; then
- eval "$__conda_setup"
+ eval "$__conda_setup"
else
- if [ -f "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh" ]; then
- . "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh"
- else
- export PATH="/Users/xxx/opt/anaconda3/bin:$PATH"
- fi
+ if [ -f "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh" ]; then
+ . "/Users/xxx/opt/anaconda3/etc/profile.d/conda.sh"
+ else
+ export PATH="/Users/xxx/opt/anaconda3/bin:$PATH"
+ fi
fi
unset __conda_setup
# <<< conda initialize <<<
@@ -197,11 +204,10 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit
- **Download Anaconda**.
- Download at: https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/?C=M&O=D
-
+
-
- Select the appropriate version for your operating system
- Type `uname -m` in the terminal to check the command set used by your system
@@ -216,12 +222,12 @@ Linux users can choose to run either Anaconda or Docker. If you are familiar wit
sudo yum install wget # CentOS
```
```bash
- # Then use wget to download from Tsinghua source
+ # Then use wget to download from Tsinghua source
# If you want to download Anaconda3-2021.05-Linux-x86_64.sh, the download command is as follows
wget https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/Anaconda3-2021.05-Linux-x86_64.sh
# If you want to download another version, you need to change the file name after the last 1 / to the version you want to download
```
-
+
- To install Anaconda.
- Type `sh Anaconda3-2021.05-Linux-x86_64.sh` at the command line
@@ -309,7 +315,18 @@ cd /home/Projects
# Create a docker container named ppocr and map the current directory to the /paddle directory of the container
# If using CPU, use docker instead of nvidia-docker to create docker
-sudo docker run --name ppocr -v $PWD:/paddle --network=host -it paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82 /bin/bash
+sudo docker run --name ppocr -v $PWD:/paddle --network=host -it registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda10.2-cudnn7 /bin/bash
+
+# If using GPU, use nvidia-docker to create docker
+# docker image registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda11.2-cudnn8 is recommended for CUDA11.2 + CUDNN8.
+sudo nvidia-docker run --name ppocr -v $PWD:/paddle --shm-size=64G --network=host -it registry.baidubce.com/paddlepaddle/paddle:2.1.3-gpu-cuda10.2-cudnn7 /bin/bash
+
+```
+You can also visit [DockerHub](https://hub.docker.com/r/paddlepaddle/paddle/tags/) to get the image that fits your machine.
+
+```
+# ctrl+P+Q to exit docker, to re-enter docker using the following command:
+sudo docker container exec -it ppocr /bin/bash
```
@@ -329,4 +346,3 @@ python3 -m pip install paddlepaddle -i https://mirror.baidu.com/pypi/simple
```
For more software version requirements, please refer to the instructions in [Installation Document](https://www.paddlepaddle.org.cn/install/quick) for operation.
-
diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md
index b445232feeefadc355e0f38b329050e26ccc0368..019ac4d0ac15aceed89286048d2c4d88a259e501 100755
--- a/doc/doc_en/inference_en.md
+++ b/doc/doc_en/inference_en.md
@@ -21,7 +21,7 @@ Next, we first introduce how to convert a trained model into an inference model,
- [2.2 DB Text Detection Model Inference](#DB_DETECTION)
- [2.3 East Text Detection Model Inference](#EAST_DETECTION)
- [2.4 Sast Text Detection Model Inference](#SAST_DETECTION)
-
+
- [3. Text Recognition Model Inference](#RECOGNITION_MODEL_INFERENCE)
- [3.1 Lightweight Chinese Text Recognition Model Reference](#LIGHTWEIGHT_RECOGNITION)
- [3.2 CTC-Based Text Recognition Model Inference](#CTC-BASED_RECOGNITION)
@@ -281,7 +281,7 @@ python3 tools/export_model.py -c configs/det/rec_r34_vd_none_bilstm_ctc.yml -o G
For CRNN text recognition model inference, execute the following commands:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```
![](../imgs_words_en/word_336.png)
@@ -314,7 +314,7 @@ with the training, such as: --rec_image_shape="1, 64, 256"
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
--rec_model_dir="./inference/srn/" \
--rec_image_shape="1, 64, 256" \
- --rec_char_type="en" \
+ --rec_char_dict_path="./ppocr/utils/ic15_dict.txt" \
--rec_algorithm="SRN"
```
@@ -323,7 +323,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_dict_path="your text dict path"
```
@@ -333,7 +333,7 @@ If you need to predict other language models, when using inference model predict
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
```
![](../imgs_words/korean/1.jpg)
@@ -399,7 +399,7 @@ If you want to try other detection algorithms or recognition algorithms, please
The following command uses the combination of the EAST text detection and STAR-Net text recognition:
```
-python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
+python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```
After executing the command, the recognition result image is as follows:
diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md
index 84f5541562f9ce267da10abfad209ea1eb909a3e..51857ba16b7773ef38452fad6aa070f2117a9086 100644
--- a/doc/doc_en/recognition_en.md
+++ b/doc/doc_en/recognition_en.md
@@ -161,7 +161,7 @@ The current multi-language model is still in the demo stage and will continue to
If you like, you can submit the dictionary file to [dict](../../ppocr/utils/dict) and we will thank you in the Repo.
-To customize the dict file, please modify the `character_dict_path` field in `configs/rec/rec_icdar15_train.yml` and set `character_type` to `ch`.
+To customize the dict file, please modify the `character_dict_path` field in `configs/rec/rec_icdar15_train.yml` .
- Custom dictionary
@@ -172,8 +172,6 @@ If you need to customize dic file, please add character_dict_path field in confi
If you want to support the recognition of the `space` category, please set the `use_space_char` field in the yml file to `True`.
-**Note: use_space_char only takes effect when character_type=ch**
-
## 2.Training
@@ -250,7 +248,6 @@ Global:
# Add a custom dictionary, such as modify the dictionary, please point the path to the new dictionary
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
# Modify character type
- character_type: ch
...
# Whether to recognize spaces
use_space_char: True
@@ -312,18 +309,18 @@ Eval:
Currently, the multi-language algorithms supported by PaddleOCR are:
-| Configuration file | Algorithm name | backbone | trans | seq | pred | language | character_type |
-| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
-| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional | chinese_cht|
-| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) | EN |
-| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French | french |
-| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German | german |
-| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese | japan |
-| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean | korean |
-| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin | latin |
-| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic | ar |
-| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic | cyrillic |
-| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari | devanagari |
+| Configuration file | Algorithm name | backbone | trans | seq | pred | language |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional |
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean |
+| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin |
+| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic |
+| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic |
+| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari |
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
@@ -471,6 +468,3 @@ inference/det_db/
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
```
-
-
-
diff --git a/doc/doc_en/training_en.md b/doc/doc_en/training_en.md
index 106e41c0183b0cb20a038e154e155a25fdc6faa6..aa5500ac88fef97829b4f19c5421e36f18ae1812 100644
--- a/doc/doc_en/training_en.md
+++ b/doc/doc_en/training_en.md
@@ -147,3 +147,9 @@ There are several experiences for reference when constructing the data set:
A: It is normal for the acc to be 0 at the beginning of the recognition model training, and the indicator will come up after a longer training period.
+
+***
+Click the following links for detailed training tutorial:
+- [text detection model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/detection.md)
+- [text recognition model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/recognition.md)
+- [text direction classification model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/angle_class.md)
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index ebf52ec4e1d8713fd4da407318b14e682952606d..0a4fad621a9038e71a9d43eb4e12f78e7e92d73d 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -21,6 +21,8 @@ import numpy as np
import string
import json
+from ppocr.utils.logging import get_logger
+
class ClsLabelEncode(object):
def __init__(self, label_list, **kwargs):
@@ -92,31 +94,23 @@ class BaseRecLabelEncode(object):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False):
- support_character_type = [
- 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
- 'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
- 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
- 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
- ]
- assert character_type in support_character_type, "Only {} are supported now but get {}".format(
- support_character_type, character_type)
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
- if character_type == "en":
+ self.lower = False
+
+ if character_dict_path is None:
+ logger = get_logger()
+ logger.warning(
+ "The character_dict_path is None, model can only recognize number and lower letters"
+ )
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
- elif character_type == "EN_symbol":
- # same with ASTER setting (use 94 char).
- self.character_str = string.printable[:-6]
- dict_character = list(self.character_str)
- elif character_type in support_character_type:
+ self.lower = True
+ else:
self.character_str = ""
- assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
- character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
@@ -125,7 +119,6 @@ class BaseRecLabelEncode(object):
if use_space_char:
self.character_str += " "
dict_character = list(self.character_str)
- self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
@@ -147,7 +140,7 @@ class BaseRecLabelEncode(object):
"""
if len(text) == 0 or len(text) > self.max_text_len:
return None
- if self.character_type == "en":
+ if self.lower:
text = text.lower()
text_list = []
for char in text:
@@ -167,13 +160,11 @@ class NRTRLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='EN_symbol',
use_space_char=False,
**kwargs):
- super(NRTRLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(NRTRLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
@@ -200,12 +191,10 @@ class CTCLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
**kwargs):
- super(CTCLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(CTCLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
@@ -231,12 +220,10 @@ class E2ELabelEncodeTest(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='EN',
use_space_char=False,
**kwargs):
- super(E2ELabelEncodeTest,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(E2ELabelEncodeTest, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
import json
@@ -305,12 +292,10 @@ class AttnLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
**kwargs):
- super(AttnLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(AttnLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -353,12 +338,10 @@ class SEEDLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
**kwargs):
- super(SEEDLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(SEEDLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.end_str = "eos"
@@ -385,12 +368,10 @@ class SRNLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length=25,
character_dict_path=None,
- character_type='en',
use_space_char=False,
**kwargs):
- super(SRNLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(SRNLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
@@ -598,12 +579,10 @@ class SARLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
**kwargs):
- super(SARLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(SARLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
beg_end_str = "
"
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 71ed8976db7de24a489d1f75612a9a9a67995ba2..b4de6de95b09ced803375d9a3bb857194ef3e64b 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -87,17 +87,17 @@ class RecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
- character_type='ch',
+ character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
- self.character_type = character_type
+ self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
- if self.infer_mode and self.character_type == "ch":
+ if self.infer_mode and self.character_dict_path is not None:
norm_img = resize_norm_img_chinese(img, self.image_shape)
else:
norm_img = resize_norm_img(img, self.image_shape, self.padding)
diff --git a/ppocr/losses/ace_loss.py b/ppocr/losses/ace_loss.py
index 9c868520e5bd7b398c7f248c416b70427baee0a6..bf15f8e3a7b355bd9e8b69435a5dae01fc75a892 100644
--- a/ppocr/losses/ace_loss.py
+++ b/ppocr/losses/ace_loss.py
@@ -32,6 +32,7 @@ class ACELoss(nn.Layer):
def __call__(self, predicts, batch):
if isinstance(predicts, (list, tuple)):
predicts = predicts[-1]
+
B, N = predicts.shape[:2]
div = paddle.to_tensor([N]).astype('float32')
@@ -42,9 +43,7 @@ class ACELoss(nn.Layer):
length = batch[2].astype("float32")
batch = batch[3].astype("float32")
batch[:, 0] = paddle.subtract(div, length)
-
batch = paddle.divide(batch, div)
loss = self.loss_func(aggregation_preds, batch)
-
return {"loss_ace": loss}
diff --git a/ppocr/losses/center_loss.py b/ppocr/losses/center_loss.py
index 72149df19f9e9a864ec31239177dc574648da3d5..cbef4df965e2659c6aa63c0c69cd8798143df485 100644
--- a/ppocr/losses/center_loss.py
+++ b/ppocr/losses/center_loss.py
@@ -27,7 +27,6 @@ class CenterLoss(nn.Layer):
"""
Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
"""
-
def __init__(self,
num_classes=6625,
feat_dim=96,
@@ -37,8 +36,7 @@ class CenterLoss(nn.Layer):
self.num_classes = num_classes
self.feat_dim = feat_dim
self.centers = paddle.randn(
- shape=[self.num_classes, self.feat_dim]).astype(
- "float64") #random center
+ shape=[self.num_classes, self.feat_dim]).astype("float64")
if init_center:
assert os.path.exists(
@@ -60,22 +58,23 @@ class CenterLoss(nn.Layer):
batch_size = feats_reshape.shape[0]
- #calc feat * feat
- dist1 = paddle.sum(paddle.square(feats_reshape), axis=1, keepdim=True)
- dist1 = paddle.expand(dist1, [batch_size, self.num_classes])
+ #calc l2 distance between feats and centers
+ square_feat = paddle.sum(paddle.square(feats_reshape),
+ axis=1,
+ keepdim=True)
+ square_feat = paddle.expand(square_feat, [batch_size, self.num_classes])
- #dist2 of centers
- dist2 = paddle.sum(paddle.square(self.centers), axis=1,
- keepdim=True) #num_classes
- dist2 = paddle.expand(dist2,
- [self.num_classes, batch_size]).astype("float64")
- dist2 = paddle.transpose(dist2, [1, 0])
+ square_center = paddle.sum(paddle.square(self.centers),
+ axis=1,
+ keepdim=True)
+ square_center = paddle.expand(
+ square_center, [self.num_classes, batch_size]).astype("float64")
+ square_center = paddle.transpose(square_center, [1, 0])
- #first x * x + y * y
- distmat = paddle.add(dist1, dist2)
- tmp = paddle.matmul(feats_reshape,
- paddle.transpose(self.centers, [1, 0]))
- distmat = distmat - 2.0 * tmp
+ distmat = paddle.add(square_feat, square_center)
+ feat_dot_center = paddle.matmul(feats_reshape,
+ paddle.transpose(self.centers, [1, 0]))
+ distmat = distmat - 2.0 * feat_dot_center
#generate the mask
classes = paddle.arange(self.num_classes).astype("int64")
@@ -83,7 +82,8 @@ class CenterLoss(nn.Layer):
paddle.unsqueeze(label, 1), (batch_size, self.num_classes))
mask = paddle.equal(
paddle.expand(classes, [batch_size, self.num_classes]),
- label).astype("float64") #get mask
+ label).astype("float64")
dist = paddle.multiply(distmat, mask)
+
loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size
return {'loss_center': loss}
diff --git a/ppocr/modeling/backbones/rec_mv1_enhance.py b/ppocr/modeling/backbones/rec_mv1_enhance.py
index fe874fac1af439bfb47ba9050a61f02db302e224..04a909b8ccafd8e62f9a7076c7dedf63ff745303 100644
--- a/ppocr/modeling/backbones/rec_mv1_enhance.py
+++ b/ppocr/modeling/backbones/rec_mv1_enhance.py
@@ -1,4 +1,4 @@
-# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,26 +16,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-import paddle
-from paddle import ParamAttr
-import paddle.nn as nn
-import paddle.nn.functional as F
-from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
-from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
-from paddle.nn.initializer import KaimingNormal
import math
import numpy as np
import paddle
-from paddle import ParamAttr, reshape, transpose, concat, split
+from paddle import ParamAttr, reshape, transpose
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import Conv2D, BatchNorm, Linear, Dropout
from paddle.nn import AdaptiveAvgPool2D, MaxPool2D, AvgPool2D
from paddle.nn.initializer import KaimingNormal
-import math
-from paddle.nn.functional import hardswish, hardsigmoid
from paddle.regularizer import L2Decay
+from paddle.nn.functional import hardswish, hardsigmoid
class ConvBNLayer(nn.Layer):
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index c06159ca55600e7afe01a68ab43acd1919cf742c..ef1a43fd0ee65f3e55a8f72dfd2f96c478da1a9a 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -21,33 +21,15 @@ import re
class BaseRecLabelDecode(object):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False):
- support_character_type = [
- 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
- 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
- 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
- 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
- ]
- assert character_type in support_character_type, "Only {} are supported now but get {}".format(
- support_character_type, character_type)
-
+ def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
- if character_type == "en":
+ self.character_str = []
+ if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
- elif character_type == "EN_symbol":
- # same with ASTER setting (use 94 char).
- self.character_str = string.printable[:-6]
- dict_character = list(self.character_str)
- elif character_type in support_character_type:
- self.character_str = []
- assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
- character_type)
+ else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
@@ -57,9 +39,6 @@ class BaseRecLabelDecode(object):
self.character_str.append(" ")
dict_character = list(self.character_str)
- else:
- raise NotImplementedError
- self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
@@ -102,13 +81,10 @@ class BaseRecLabelDecode(object):
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
@@ -136,13 +112,12 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
def __init__(self,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
model_name=["student"],
key=None,
**kwargs):
- super(DistillationCTCLabelDecode, self).__init__(
- character_dict_path, character_type, use_space_char)
+ super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
@@ -162,13 +137,9 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
class NRTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='EN_symbol',
- use_space_char=True,
- **kwargs):
+ def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
super(NRTRLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
@@ -230,13 +201,10 @@ class NRTRLabelDecode(BaseRecLabelDecode):
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(AttnLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -313,13 +281,10 @@ class AttnLabelDecode(BaseRecLabelDecode):
class SEEDLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SEEDLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -394,13 +359,10 @@ class SEEDLabelDecode(BaseRecLabelDecode):
class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='en',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
def __call__(self, preds, label=None, *args, **kwargs):
@@ -616,13 +578,10 @@ class TableLabelDecode(object):
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SARLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
self.rm_symbol = kwargs.get('rm_symbol', False)
diff --git a/ppocr/utils/EN_symbol_dict.txt b/ppocr/utils/EN_symbol_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1aef43d6b842731a54cbe682ccda5c2dbfa694d9
--- /dev/null
+++ b/ppocr/utils/EN_symbol_dict.txt
@@ -0,0 +1,94 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+:
+;
+<
+=
+>
+?
+@
+[
+\
+]
+^
+_
+`
+{
+|
+}
+~
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 311030f65f2dc2dad4a51821e64f2777e7621a0b..6758a59bad20f6ffa271766fc4d0df5ebf4c7a4b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
shapely
-scikit-image==0.17.2
+scikit-image==0.18.3
imgaug==0.4.0
pyclipper
lmdb
diff --git a/tests/readme.md b/tests/readme.md
deleted file mode 100644
index 127eef9fe34de37f5dfa608b3ed9903f2b826fa0..0000000000000000000000000000000000000000
--- a/tests/readme.md
+++ /dev/null
@@ -1,72 +0,0 @@
-
-# 从训练到推理部署工具链测试方法介绍
-
-test.sh和params.txt文件配合使用,完成OCR轻量检测和识别模型从训练到预测的流程测试。
-
-# 安装依赖
-- 安装PaddlePaddle >= 2.0
-- 安装PaddleOCR依赖
- ```
- pip3 install -r ../requirements.txt
- ```
-- 安装autolog
- ```
- git clone https://github.com/LDOUBLEV/AutoLog
- cd AutoLog
- pip3 install -r requirements.txt
- python3 setup.py bdist_wheel
- pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl
- cd ../
- ```
-
-# 目录介绍
-
-```bash
-tests/
-├── ocr_det_params.txt # 测试OCR检测模型的参数配置文件
-├── ocr_rec_params.txt # 测试OCR识别模型的参数配置文件
-├── ocr_ppocr_mobile_params.txt # 测试OCR检测+识别模型串联的参数配置文件
-└── prepare.sh # 完成test.sh运行所需要的数据和模型下载
-└── test.sh # 测试主程序
-```
-
-# 使用方法
-
-test.sh包含四种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是:
-
-- 模式1:lite_train_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
-```shell
-bash tests/prepare.sh ./tests/ocr_det_params.txt 'lite_train_infer'
-bash tests/test.sh ./tests/ocr_det_params.txt 'lite_train_infer'
-```
-
-- 模式2:whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理;
-```shell
-bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_infer'
-bash tests/test.sh ./tests/ocr_det_params.txt 'whole_infer'
-```
-
-- 模式3:infer 不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度;
-```shell
-bash tests/prepare.sh ./tests/ocr_det_params.txt 'infer'
-# 用法1:
-bash tests/test.sh ./tests/ocr_det_params.txt 'infer'
-# 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号
-bash tests/test.sh ./tests/ocr_det_params.txt 'infer' '1'
-```
-
-- 模式4:whole_train_infer , CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度;
-```shell
-bash tests/prepare.sh ./tests/ocr_det_params.txt 'whole_train_infer'
-bash tests/test.sh ./tests/ocr_det_params.txt 'whole_train_infer'
-```
-
-- 模式5:cpp_infer , CE: 验证inference model的c++预测是否走通;
-```shell
-bash tests/prepare.sh ./tests/ocr_det_params.txt 'cpp_infer'
-bash tests/test.sh ./tests/ocr_det_params.txt 'cpp_infer'
-```
-
-# 日志输出
-最终在```tests/output```目录下生成.log后缀的日志文件
-
diff --git a/tests/test_python.sh b/tests/test_python.sh
deleted file mode 100644
index 2f5fdf4e8dc63a0ebb33c7880bc08c24be759740..0000000000000000000000000000000000000000
--- a/tests/test_python.sh
+++ /dev/null
@@ -1,173 +0,0 @@
-#!/bin/bash
-source tests/common_func.sh
-
-FILENAME=$1
-dataline=$(awk 'NR==1, NR==51{print}' $FILENAME)
-
-# parser params
-IFS=$'\n'
-lines=(${dataline})
-
-# The training params
-model_name=$(func_parser_value "${lines[1]}")
-python=$(func_parser_value "${lines[2]}")
-gpu_list=$(func_parser_value "${lines[3]}")
-train_use_gpu_key=$(func_parser_key "${lines[4]}")
-train_use_gpu_value=$(func_parser_value "${lines[4]}")
-autocast_list=$(func_parser_value "${lines[5]}")
-autocast_key=$(func_parser_key "${lines[5]}")
-epoch_key=$(func_parser_key "${lines[6]}")
-epoch_num=$(func_parser_params "${lines[6]}")
-save_model_key=$(func_parser_key "${lines[7]}")
-train_batch_key=$(func_parser_key "${lines[8]}")
-train_batch_value=$(func_parser_params "${lines[8]}")
-pretrain_model_key=$(func_parser_key "${lines[9]}")
-pretrain_model_value=$(func_parser_value "${lines[9]}")
-train_model_name=$(func_parser_value "${lines[10]}")
-train_infer_img_dir=$(func_parser_value "${lines[11]}")
-train_param_key1=$(func_parser_key "${lines[12]}")
-train_param_value1=$(func_parser_value "${lines[12]}")
-
-trainer_list=$(func_parser_value "${lines[14]}")
-trainer_norm=$(func_parser_key "${lines[15]}")
-norm_trainer=$(func_parser_value "${lines[15]}")
-pact_key=$(func_parser_key "${lines[16]}")
-pact_trainer=$(func_parser_value "${lines[16]}")
-fpgm_key=$(func_parser_key "${lines[17]}")
-fpgm_trainer=$(func_parser_value "${lines[17]}")
-distill_key=$(func_parser_key "${lines[18]}")
-distill_trainer=$(func_parser_value "${lines[18]}")
-trainer_key1=$(func_parser_key "${lines[19]}")
-trainer_value1=$(func_parser_value "${lines[19]}")
-trainer_key2=$(func_parser_key "${lines[20]}")
-trainer_value2=$(func_parser_value "${lines[20]}")
-
-eval_py=$(func_parser_value "${lines[23]}")
-eval_key1=$(func_parser_key "${lines[24]}")
-eval_value1=$(func_parser_value "${lines[24]}")
-
-save_infer_key=$(func_parser_key "${lines[27]}")
-export_weight=$(func_parser_key "${lines[28]}")
-norm_export=$(func_parser_value "${lines[29]}")
-pact_export=$(func_parser_value "${lines[30]}")
-fpgm_export=$(func_parser_value "${lines[31]}")
-distill_export=$(func_parser_value "${lines[32]}")
-export_key1=$(func_parser_key "${lines[33]}")
-export_value1=$(func_parser_value "${lines[33]}")
-export_key2=$(func_parser_key "${lines[34]}")
-export_value2=$(func_parser_value "${lines[34]}")
-
-# parser inference model
-infer_model_dir_list=$(func_parser_value "${lines[36]}")
-infer_export_list=$(func_parser_value "${lines[37]}")
-infer_is_quant=$(func_parser_value "${lines[38]}")
-# parser inference
-inference_py=$(func_parser_value "${lines[39]}")
-use_gpu_key=$(func_parser_key "${lines[40]}")
-use_gpu_list=$(func_parser_value "${lines[40]}")
-use_mkldnn_key=$(func_parser_key "${lines[41]}")
-use_mkldnn_list=$(func_parser_value "${lines[41]}")
-cpu_threads_key=$(func_parser_key "${lines[42]}")
-cpu_threads_list=$(func_parser_value "${lines[42]}")
-batch_size_key=$(func_parser_key "${lines[43]}")
-batch_size_list=$(func_parser_value "${lines[43]}")
-use_trt_key=$(func_parser_key "${lines[44]}")
-use_trt_list=$(func_parser_value "${lines[44]}")
-precision_key=$(func_parser_key "${lines[45]}")
-precision_list=$(func_parser_value "${lines[45]}")
-infer_model_key=$(func_parser_key "${lines[46]}")
-image_dir_key=$(func_parser_key "${lines[47]}")
-infer_img_dir=$(func_parser_value "${lines[47]}")
-save_log_key=$(func_parser_key "${lines[48]}")
-benchmark_key=$(func_parser_key "${lines[49]}")
-benchmark_value=$(func_parser_value "${lines[49]}")
-infer_key1=$(func_parser_key "${lines[50]}")
-infer_value1=$(func_parser_value "${lines[50]}")
-
-
-LOG_PATH="./tests/output"
-mkdir -p ${LOG_PATH}
-status_log="${LOG_PATH}/results_python.log"
-
-
-function func_inference(){
- IFS='|'
- _python=$1
- _script=$2
- _model_dir=$3
- _log_path=$4
- _img_dir=$5
- _flag_quant=$6
- # inference
- for use_gpu in ${use_gpu_list[*]}; do
- if [ ${use_gpu} = "False" ] || [ ${use_gpu} = "cpu" ]; then
- for use_mkldnn in ${use_mkldnn_list[*]}; do
- if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then
- continue
- fi
- for threads in ${cpu_threads_list[*]}; do
- for batch_size in ${batch_size_list[*]}; do
- _save_log_path="${_log_path}/infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log"
- set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
- set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
- set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
- set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
- set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
- set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
- command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
- eval $command
- last_status=${PIPESTATUS[0]}
- eval "cat ${_save_log_path}"
- status_check $last_status "${command}" "${status_log}"
- done
- done
- done
- elif [ ${use_gpu} = "True" ] || [ ${use_gpu} = "gpu" ]; then
- for use_trt in ${use_trt_list[*]}; do
- for precision in ${precision_list[*]}; do
- if [[ ${_flag_quant} = "False" ]] && [[ ${precision} =~ "int8" ]]; then
- continue
- fi
- if [[ ${precision} =~ "fp16" || ${precision} =~ "int8" ]] && [ ${use_trt} = "False" ]; then
- continue
- fi
- if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [ ${_flag_quant} = "True" ]; then
- continue
- fi
- for batch_size in ${batch_size_list[*]}; do
- _save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_${batch_size}.log"
- set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
- set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
- set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
- set_tensorrt=$(func_set_params "${use_trt_key}" "${use_trt}")
- set_precision=$(func_set_params "${precision_key}" "${precision}")
- set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
- set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
- command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
- eval $command
- last_status=${PIPESTATUS[0]}
- eval "cat ${_save_log_path}"
- status_check $last_status "${command}" "${status_log}"
-
- done
- done
- done
- else
- echo "Does not support hardware other than CPU and GPU Currently!"
- fi
- done
-}
-
-
-# set cuda device
-GPUID=$2
-if [ ${#GPUID} -le 0 ];then
- env=" "
-else
- env="export CUDA_VISIBLE_DEVICES=${GPUID}"
-fi
-set CUDA_VISIBLE_DEVICES
-eval $env
-
-
-echo "################### run test ###################"
diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py
index 53e50bd6d1d1a2bd07b9f1204b9f56594c669d13..1c68494861e60b4aaef541a4e247071944cf420c 100755
--- a/tools/infer/predict_cls.py
+++ b/tools/infer/predict_cls.py
@@ -131,14 +131,9 @@ def main(args):
img_list.append(img)
try:
img_list, cls_res, predict_time = text_classifier(img_list)
- except:
+ except Exception as E:
logger.info(traceback.format_exc())
- logger.info(
- "ERROR!!!! \n"
- "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
- "If your model has tps module: "
- "TPS does not support variable shape.\n"
- "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
+ logger.info(E)
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index dad70281ef7604f110d29963103068bba1c8fd9d..936994a215d10d543537b29cb41bfa42b42590c7 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -38,40 +38,34 @@ logger = get_logger()
class TextRecognizer(object):
def __init__(self, args):
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
- self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
postprocess_params = {
'name': 'CTCLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
if self.rec_algorithm == "SRN":
postprocess_params = {
'name': 'SRNLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "RARE":
postprocess_params = {
'name': 'AttnLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == 'NRTR':
postprocess_params = {
'name': 'NRTRLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "SAR":
postprocess_params = {
'name': 'SARLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 538f55c42b223f9741c5c7006dd7d1478ce1920b..41a3c0f14b6378751a367a3709ad7943ee981a4e 100755
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -74,7 +74,6 @@ def init_args():
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
- parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
@@ -268,10 +267,11 @@ def create_predictor(args, mode, logger):
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
-
+ if args.precision == "fp16":
+ config.enable_mkldnn_bfloat16()
# enable memory optim
config.enable_memory_optim()
- #config.disable_glog_info()
+ config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
if mode == 'table':