diff --git a/tests/common_func.sh b/PTDN/common_func.sh
similarity index 100%
rename from tests/common_func.sh
rename to PTDN/common_func.sh
diff --git a/tests/compare_results.py b/PTDN/compare_results.py
similarity index 100%
rename from tests/compare_results.py
rename to PTDN/compare_results.py
diff --git a/tests/configs/det_mv3_db.yml b/PTDN/configs/det_mv3_db.yml
similarity index 100%
rename from tests/configs/det_mv3_db.yml
rename to PTDN/configs/det_mv3_db.yml
diff --git a/tests/configs/det_r50_vd_db.yml b/PTDN/configs/det_r50_vd_db.yml
similarity index 100%
rename from tests/configs/det_r50_vd_db.yml
rename to PTDN/configs/det_r50_vd_db.yml
diff --git a/tests/configs/ppocr_det_mobile_params.txt b/PTDN/configs/ppocr_det_mobile_params.txt
similarity index 95%
rename from tests/configs/ppocr_det_mobile_params.txt
rename to PTDN/configs/ppocr_det_mobile_params.txt
index 5edb14cdbf8eef87b5b5558cbd8d1a2ff54ae919..3d2117d7ca9b444f55b9c9f343647026af7e97c6 100644
--- a/tests/configs/ppocr_det_mobile_params.txt
+++ b/PTDN/configs/ppocr_det_mobile_params.txt
@@ -1,9 +1,9 @@
===========================train_params===========================
model_name:ocr_det
python:python3.7
-gpu_list:0|0,1
-Global.use_gpu:True|True
-Global.auto_cast:null
+gpu_list:0|0,1|10.21.226.181,10.21.226.133;0,1
+Global.use_gpu:True|True|True
+Global.auto_cast:fp32|amp
Global.epoch_num:lite_train_infer=1|whole_train_infer=300
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_infer=2|whole_train_infer=4
@@ -65,6 +65,8 @@ inference:./deploy/cpp_infer/build/ppocr det
null:null
--benchmark:True
===========================serving_params===========================
+model_name:ocr_det
+python:python3.7
trans_model:-m paddle_serving_client.convert
--dirname:./inference/ch_ppocr_mobile_v2.0_det_infer/
--model_filename:inference.pdmodel
@@ -82,17 +84,17 @@ pipline:pipeline_http_client.py --image_dir=../../doc/imgs
===========================kl_quant_params===========================
infer_model:./inference/ch_ppocr_mobile_v2.0_det_infer/
infer_export:tools/export_model.py -c configs/det/ch_ppocr_v2.0/ch_det_mv3_db_v2.0.yml -o
-infer_quant:False
+infer_quant:True
inference:tools/infer/predict_det.py
--use_gpu:True|False
--enable_mkldnn:True|False
--cpu_threads:1|6
--rec_batch_num:1
--use_tensorrt:False|True
---precision:fp32|fp16|int8
+--precision:int8
--det_model_dir:
--image_dir:./inference/ch_det_data_50/all-sum-510/
null:null
--benchmark:True
null:null
-null:null
\ No newline at end of file
+null:null
diff --git a/tests/configs/ppocr_det_server_params.txt b/PTDN/configs/ppocr_det_server_params.txt
similarity index 57%
rename from tests/configs/ppocr_det_server_params.txt
rename to PTDN/configs/ppocr_det_server_params.txt
index b3df1735e50d941b34eeb274c28eb4ce50d79292..bba4ef44f769ed16671ead55a0eba6ee986aaaaa 100644
--- a/tests/configs/ppocr_det_server_params.txt
+++ b/PTDN/configs/ppocr_det_server_params.txt
@@ -49,4 +49,35 @@ inference:tools/infer/predict_det.py
--save_log_path:null
--benchmark:True
null:null
-
+===========================cpp_infer_params===========================
+use_opencv:True
+infer_model:./inference/ch_ppocr_server_v2.0_det_infer/
+infer_quant:False
+inference:./deploy/cpp_infer/build/ppocr det
+--use_gpu:True|False
+--enable_mkldnn:True|False
+--cpu_threads:1|6
+--rec_batch_num:1
+--use_tensorrt:False|True
+--precision:fp32|fp16
+--det_model_dir:
+--image_dir:./inference/ch_det_data_50/all-sum-510/
+null:null
+--benchmark:True
+===========================serving_params===========================
+model_name:ocr_det_server
+python:python3.7
+trans_model:-m paddle_serving_client.convert
+--dirname:./inference/ch_ppocr_server_v2.0_det_infer/
+--model_filename:inference.pdmodel
+--params_filename:inference.pdiparams
+--serving_server:./deploy/pdserving/ppocr_det_mobile_2.0_serving/
+--serving_client:./deploy/pdserving/ppocr_det_mobile_2.0_client/
+serving_dir:./deploy/pdserving
+web_service:web_service_det.py --config=config.yml --opt op.det.concurrency=1
+op.det.local_service_conf.devices:null|0
+op.det.local_service_conf.use_mkldnn:True|False
+op.det.local_service_conf.thread_num:1|6
+op.det.local_service_conf.use_trt:False|True
+op.det.local_service_conf.precision:fp32|fp16|int8
+pipline:pipeline_http_client.py --image_dir=../../doc/imgs
diff --git a/tests/configs/ppocr_rec_mobile_params.txt b/PTDN/configs/ppocr_rec_mobile_params.txt
similarity index 98%
rename from tests/configs/ppocr_rec_mobile_params.txt
rename to PTDN/configs/ppocr_rec_mobile_params.txt
index f9c407897269d4729b9cab7313c45fe69712c62d..f3f3a54e14e042693d28559e487852a079f77bdd 100644
--- a/tests/configs/ppocr_rec_mobile_params.txt
+++ b/PTDN/configs/ppocr_rec_mobile_params.txt
@@ -65,6 +65,8 @@ inference:./deploy/cpp_infer/build/ppocr rec
null:null
--benchmark:True
===========================serving_params===========================
+model_name:ocr_rec
+python:python3.7
trans_model:-m paddle_serving_client.convert
--dirname:./inference/ch_ppocr_mobile_v2.0_rec_infer/
--model_filename:inference.pdmodel
@@ -78,4 +80,4 @@ op.rec.local_service_conf.use_mkldnn:True|False
op.rec.local_service_conf.thread_num:1|6
op.rec.local_service_conf.use_trt:False|True
op.rec.local_service_conf.precision:fp32|fp16|int8
-pipline:pipeline_http_client.py --image_dir=../../doc/imgs_words_en
\ No newline at end of file
+pipline:pipeline_http_client.py --image_dir=../../doc/imgs_words_en
diff --git a/tests/configs/ppocr_rec_server_params.txt b/PTDN/configs/ppocr_rec_server_params.txt
similarity index 93%
rename from tests/configs/ppocr_rec_server_params.txt
rename to PTDN/configs/ppocr_rec_server_params.txt
index 7d151fcf0b793bd0bf63ac925c9ef3cf0ff56557..77961e8e651e0d770dae64860cc129aa2d50dcf2 100644
--- a/tests/configs/ppocr_rec_server_params.txt
+++ b/PTDN/configs/ppocr_rec_server_params.txt
@@ -65,12 +65,14 @@ inference:./deploy/cpp_infer/build/ppocr rec
null:null
--benchmark:True
===========================serving_params===========================
+model_name:ocr_server_rec
+python:python3.7
trans_model:-m paddle_serving_client.convert
--dirname:./inference/ch_ppocr_server_v2.0_rec_infer/
--model_filename:inference.pdmodel
--params_filename:inference.pdiparams
---serving_server:./deploy/pdserving/ppocr_rec_server_2.0_serving/
---serving_client:./deploy/pdserving/ppocr_rec_server_2.0_client/
+--serving_server:./deploy/pdserving/ppocr_rec_mobile_2.0_serving/
+--serving_client:./deploy/pdserving/ppocr_rec_mobile_2.0_client/
serving_dir:./deploy/pdserving
web_service:web_service_rec.py --config=config.yml --opt op.rec.concurrency=1
op.rec.local_service_conf.devices:null|0
@@ -78,4 +80,4 @@ op.rec.local_service_conf.use_mkldnn:True|False
op.rec.local_service_conf.thread_num:1|6
op.rec.local_service_conf.use_trt:False|True
op.rec.local_service_conf.precision:fp32|fp16|int8
-pipline:pipeline_http_client.py --image_dir=../../doc/imgs_words_en
\ No newline at end of file
+pipline:pipeline_http_client.py --image_dir=../../doc/imgs_words_en
diff --git a/tests/configs/ppocr_sys_mobile_params.txt b/PTDN/configs/ppocr_sys_mobile_params.txt
similarity index 100%
rename from tests/configs/ppocr_sys_mobile_params.txt
rename to PTDN/configs/ppocr_sys_mobile_params.txt
diff --git a/tests/configs/ppocr_sys_server_params.txt b/PTDN/configs/ppocr_sys_server_params.txt
similarity index 100%
rename from tests/configs/ppocr_sys_server_params.txt
rename to PTDN/configs/ppocr_sys_server_params.txt
diff --git a/tests/configs/rec_icdar15_r34_train.yml b/PTDN/configs/rec_icdar15_r34_train.yml
similarity index 100%
rename from tests/configs/rec_icdar15_r34_train.yml
rename to PTDN/configs/rec_icdar15_r34_train.yml
diff --git a/PTDN/docs/compare_cpp_right.png b/PTDN/docs/compare_cpp_right.png
new file mode 100644
index 0000000000000000000000000000000000000000..f9d0ba8ef8007ebc95ebffe2d593ff9e90066343
Binary files /dev/null and b/PTDN/docs/compare_cpp_right.png differ
diff --git a/PTDN/docs/compare_cpp_wrong.png b/PTDN/docs/compare_cpp_wrong.png
new file mode 100644
index 0000000000000000000000000000000000000000..621d446bbbe9ba10c3069ef5e59c463b714d42ad
Binary files /dev/null and b/PTDN/docs/compare_cpp_wrong.png differ
diff --git a/tests/docs/compare_right.png b/PTDN/docs/compare_right.png
similarity index 100%
rename from tests/docs/compare_right.png
rename to PTDN/docs/compare_right.png
diff --git a/tests/docs/compare_wrong.png b/PTDN/docs/compare_wrong.png
similarity index 100%
rename from tests/docs/compare_wrong.png
rename to PTDN/docs/compare_wrong.png
diff --git a/tests/docs/guide.png b/PTDN/docs/guide.png
similarity index 100%
rename from tests/docs/guide.png
rename to PTDN/docs/guide.png
diff --git a/PTDN/docs/install.md b/PTDN/docs/install.md
new file mode 100644
index 0000000000000000000000000000000000000000..28b92426fa04da79ce63381fffa9f52a0f42813f
--- /dev/null
+++ b/PTDN/docs/install.md
@@ -0,0 +1,48 @@
+
+## 环境配置
+
+本教程适用于PTDN目录下基础功能测试的运行环境搭建。
+
+推荐环境:
+- CUDA 10.1
+- CUDNN 7.6
+- TensorRT 6.1.0.5 / 7.1
+
+
+推荐docker镜像安装,按照如下命令创建镜像,当前目录映射到镜像中的`/paddle`目录下
+```
+nvidia-docker run --name paddle -it -v $PWD:/paddle paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82 /bin/bash
+cd /paddle
+
+# 安装带TRT的paddle
+pip3.7 install https://paddle-wheel.bj.bcebos.com/with-trt/2.1.3/linux-gpu-cuda10.1-cudnn7-mkl-gcc8.2-trt6-avx/paddlepaddle_gpu-2.1.3.post101-cp37-cp37m-linux_x86_64.whl
+
+# 安装AutoLog
+git clone https://github.com/LDOUBLEV/AutoLog
+cd AutoLog
+pip3.7 install -r requirements.txt
+python3.7 setup.py bdist_wheel
+pip3.7 install ./dist/auto_log-1.0.0-py3-none-any.whl
+
+
+# 下载OCR代码
+cd ../
+git clone https://github.com/PaddlePaddle/PaddleOCR
+
+```
+
+安装PaddleOCR依赖:
+```
+cd PaddleOCR
+pip3.7 install -r requirements.txt
+```
+
+## FAQ :
+Q. You are using Paddle compiled with TensorRT, but TensorRT dynamic library is not found. Ignore this if TensorRT is not needed.
+
+A. 问题一般是当前安装paddle版本带TRT,但是本地环境找不到TensorRT的预测库,需要下载TensorRT库,解压后设置环境变量LD_LIBRARY_PATH;
+如:
+```
+export LD_LIBRARY_PATH=/usr/local/python3.7.0/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/paddle/package/TensorRT-6.0.1.5/lib
+```
+或者问题是下载的TensorRT版本和当前paddle中编译的TRT版本不匹配,需要下载版本相符的TRT。
diff --git a/PTDN/docs/test.png b/PTDN/docs/test.png
new file mode 100644
index 0000000000000000000000000000000000000000..f99f23d7050eb61879cf317c0d7728ef14531b08
Binary files /dev/null and b/PTDN/docs/test.png differ
diff --git a/PTDN/docs/test_inference_cpp.md b/PTDN/docs/test_inference_cpp.md
new file mode 100644
index 0000000000000000000000000000000000000000..25db1b5b6b1aa101a8f8969cfae3efc02e542971
--- /dev/null
+++ b/PTDN/docs/test_inference_cpp.md
@@ -0,0 +1,60 @@
+# C++预测功能测试
+
+C++预测功能测试的主程序为`test_inference_cpp.sh`,可以测试基于C++预测库的模型推理功能。
+
+## 1. 测试结论汇总
+
+基于训练是否使用量化,进行本测试的模型可以分为`正常模型`和`量化模型`,这两类模型对应的C++预测功能汇总如下:
+
+| 模型类型 |device | batchsize | tensorrt | mkldnn | cpu多线程 |
+| ---- | ---- | ---- | :----: | :----: | :----: |
+| 正常模型 | GPU | 1/6 | fp32/fp16 | - | - |
+| 正常模型 | CPU | 1/6 | - | fp32 | 支持 |
+| 量化模型 | GPU | 1/6 | int8 | - | - |
+| 量化模型 | CPU | 1/6 | - | int8 | 支持 |
+
+## 2. 测试流程
+### 2.1 功能测试
+先运行`prepare.sh`准备数据和模型,然后运行`test_inference_cpp.sh`进行测试,最终在```PTDN/output```目录下生成`cpp_infer_*.log`后缀的日志文件。
+
+```shell
+bash PTDN/prepare.sh ./PTDN/configs/ppocr_det_mobile_params.txt "cpp_infer"
+
+# 用法1:
+bash PTDN/test_inference_cpp.sh ./PTDN/configs/ppocr_det_mobile_params.txt
+# 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号
+bash PTDN/test_inference_cpp.sh ./PTDN/configs/ppocr_det_mobile_params.txt '1'
+```
+
+
+### 2.2 精度测试
+
+使用compare_results.py脚本比较模型预测的结果是否符合预期,主要步骤包括:
+- 提取日志中的预测坐标;
+- 从本地文件中提取保存好的坐标结果;
+- 比较上述两个结果是否符合精度预期,误差大于设置阈值时会报错。
+
+#### 使用方式
+运行命令:
+```shell
+python3.7 PTDN/compare_results.py --gt_file=./PTDN/results/cpp_*.txt --log_file=./PTDN/output/cpp_*.log --atol=1e-3 --rtol=1e-3
+```
+
+参数介绍:
+- gt_file: 指向事先保存好的预测结果路径,支持*.txt 结尾,会自动索引*.txt格式的文件,文件默认保存在PTDN/result/ 文件夹下
+- log_file: 指向运行PTDN/test_inference_cpp.sh 脚本的infer模式保存的预测日志,预测日志中打印的有预测结果,比如:文本框,预测文本,类别等等,同样支持cpp_infer_*.log格式传入
+- atol: 设置的绝对误差
+- rtol: 设置的相对误差
+
+#### 运行结果
+
+正常运行效果如下图:
+
+
+出现不一致结果时的运行输出:
+
+
+
+## 3. 更多教程
+
+本文档为功能测试用,更详细的c++预测使用教程请参考:[服务器端C++预测](https://github.com/PaddlePaddle/PaddleOCR/tree/dygraph/deploy/cpp_infer)
diff --git a/PTDN/docs/test_serving.md b/PTDN/docs/test_serving.md
new file mode 100644
index 0000000000000000000000000000000000000000..c6b35630392249ea969585c69a9e4c3d35f1cf52
--- /dev/null
+++ b/PTDN/docs/test_serving.md
@@ -0,0 +1,78 @@
+# PaddleServing预测功能测试
+
+PaddleServing预测功能测试的主程序为`test_serving.sh`,可以测试基于PaddleServing的部署功能。
+
+## 1. 测试结论汇总
+
+基于训练是否使用量化,进行本测试的模型可以分为`正常模型`和`量化模型`,这两类模型对应的C++预测功能汇总如下:
+
+| 模型类型 |device | batchsize | tensorrt | mkldnn | cpu多线程 |
+| ---- | ---- | ---- | :----: | :----: | :----: |
+| 正常模型 | GPU | 1/6 | fp32/fp16 | - | - |
+| 正常模型 | CPU | 1/6 | - | fp32 | 支持 |
+| 量化模型 | GPU | 1/6 | int8 | - | - |
+| 量化模型 | CPU | 1/6 | - | int8 | 支持 |
+
+## 2. 测试流程
+### 2.1 功能测试
+先运行`prepare.sh`准备数据和模型,然后运行`test_serving.sh`进行测试,最终在```PTDN/output```目录下生成`serving_infer_*.log`后缀的日志文件。
+
+```shell
+bash PTDN/prepare.sh ./PTDN/configs/ppocr_det_mobile_params.txt "serving_infer"
+
+# 用法:
+bash PTND/test_serving.sh ./PTDN/configs/ppocr_det_mobile_params.txt
+```
+
+#### 运行结果
+
+各测试的运行情况会打印在 `PTDN/output/results_serving.log` 中:
+运行成功时会输出:
+
+```
+Run successfully with command - python3.7 pipeline_http_client.py --image_dir=../../doc/imgs > ../../tests/output/server_infer_cpu_usemkldnn_True_threads_1_batchsize_1.log 2>&1 !
+Run successfully with command - xxxxx
+...
+```
+
+运行失败时会输出:
+
+```
+Run failed with command - python3.7 pipeline_http_client.py --image_dir=../../doc/imgs > ../../tests/output/server_infer_cpu_usemkldnn_True_threads_1_batchsize_1.log 2>&1 !
+Run failed with command - python3.7 pipeline_http_client.py --image_dir=../../doc/imgs > ../../tests/output/server_infer_cpu_usemkldnn_True_threads_6_batchsize_1.log 2>&1 !
+Run failed with command - xxxxx
+...
+```
+
+详细的预测结果会存在 PTDN/output/ 文件夹下,例如`server_infer_gpu_usetrt_True_precision_fp16_batchsize_1.log`中会返回检测框的坐标:
+
+```
+{'err_no': 0, 'err_msg': '', 'key': ['dt_boxes'], 'value': ['[[[ 78. 642.]\n [409. 640.]\n [409. 657.]\n
+[ 78. 659.]]\n\n [[ 75. 614.]\n [211. 614.]\n [211. 635.]\n [ 75. 635.]]\n\n
+[[103. 554.]\n [135. 554.]\n [135. 575.]\n [103. 575.]]\n\n [[ 75. 531.]\n
+[347. 531.]\n [347. 549.]\n [ 75. 549.] ]\n\n [[ 76. 503.]\n [309. 498.]\n
+[309. 521.]\n [ 76. 526.]]\n\n [[163. 462.]\n [317. 462.]\n [317. 493.]\n
+[163. 493.]]\n\n [[324. 431.]\n [414. 431.]\n [414. 452.]\n [324. 452.]]\n\n
+[[ 76. 412.]\n [208. 408.]\n [209. 424.]\n [ 76. 428.]]\n\n [[307. 409.]\n
+[428. 409.]\n [428. 426.]\n [307 . 426.]]\n\n [[ 74. 385.]\n [217. 382.]\n
+[217. 400.]\n [ 74. 403.]]\n\n [[308. 381.]\n [427. 380.]\n [427. 400.]\n
+[308. 401.]]\n\n [[ 74. 363.]\n [195. 362.]\n [195. 378.]\n [ 74. 379.]]\n\n
+[[303. 359.]\n [423. 357.]\n [423. 375.]\n [303. 377.]]\n\n [[ 70. 336.]\n
+[239. 334.]\n [239. 354.]\ n [ 70. 356.]]\n\n [[ 70. 312.]\n [204. 310.]\n
+[204. 327.]\n [ 70. 330.]]\n\n [[303. 308.]\n [419. 306.]\n [419. 326.]\n
+[303. 328.]]\n\n [[113. 2 72.]\n [246. 270.]\n [247. 299.]\n [113. 301.]]\n\n
+ [[361. 269.]\n [384. 269.]\n [384. 296.]\n [361. 296.]]\n\n [[ 70. 250.]\n
+ [243. 246.]\n [243. 265.]\n [ 70. 269.]]\n\n [[ 65. 221.]\n [187. 220.]\n
+[187. 240.]\n [ 65. 241.]]\n\n [[337. 216.]\n [382. 216.]\n [382. 240.]\n
+[337. 240.]]\n\n [ [ 65. 196.]\n [247. 193.]\n [247. 213.]\n [ 65. 216.]]\n\n
+[[296. 197.]\n [423. 191.]\n [424. 209.]\n [296. 215.]]\n\n [[ 65. 167.]\n [244. 167.]\n
+[244. 186.]\n [ 65. 186.]]\n\n [[ 67. 139.]\n [290. 139.]\n [290. 159.]\n [ 67. 159.]]\n\n
+[[ 68. 113.]\n [410. 113.]\n [410. 128.]\n [ 68. 129.] ]\n\n [[277. 87.]\n [416. 87.]\n
+[416. 108.]\n [277. 108.]]\n\n [[ 79. 28.]\n [132. 28.]\n [132. 62.]\n [ 79. 62.]]\n\n
+[[163. 17.]\n [410. 14.]\n [410. 50.]\n [163. 53.]]]']}
+```
+
+
+## 3. 更多教程
+
+本文档为功能测试用,更详细的Serving预测使用教程请参考:[PPOCR 服务化部署](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/deploy/pdserving/README_CN.md)
diff --git a/PTDN/docs/test_train_inference_python.md b/PTDN/docs/test_train_inference_python.md
new file mode 100644
index 0000000000000000000000000000000000000000..89885ddfa3c1f36a120d713e39689767f8fc6342
--- /dev/null
+++ b/PTDN/docs/test_train_inference_python.md
@@ -0,0 +1,119 @@
+# 基础训练预测功能测试
+
+基础训练预测功能测试的主程序为`test_train_inference_python.sh`,可以测试基于Python的模型训练、评估、推理等基本功能,包括裁剪、量化、蒸馏。
+
+## 1. 测试结论汇总
+
+- 训练相关:
+
+| 算法名称 | 模型名称 | 单机单卡 | 单机多卡 | 多机多卡 | 模型压缩(单机多卡) |
+| :---- | :---- | :---- | :---- | :---- | :---- |
+| DB | ch_ppocr_mobile_v2.0_det| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化
离线量化(无需训练) |
+| DB | ch_ppocr_server_v2.0_det| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化
离线量化(无需训练) |
+| CRNN | ch_ppocr_mobile_v2.0_rec| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:PACT量化
离线量化(无需训练) |
+| CRNN | ch_ppocr_server_v2.0_rec| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:PACT量化
离线量化(无需训练) |
+|PP-OCR| ch_ppocr_mobile_v2.0| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | - |
+|PP-OCR| ch_ppocr_server_v2.0| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | - |
+|PP-OCRv2| ch_PP-OCRv2 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | - |
+
+
+- 预测相关:基于训练是否使用量化,可以将训练产出的模型可以分为`正常模型`和`量化模型`,这两类模型对应的预测功能汇总如下,
+
+| 模型类型 |device | batchsize | tensorrt | mkldnn | cpu多线程 |
+| ---- | ---- | ---- | :----: | :----: | :----: |
+| 正常模型 | GPU | 1/6 | fp32/fp16 | - | - |
+| 正常模型 | CPU | 1/6 | - | fp32 | 支持 |
+| 量化模型 | GPU | 1/6 | int8 | - | - |
+| 量化模型 | CPU | 1/6 | - | int8 | 支持 |
+
+
+## 2. 测试流程
+### 2.1 安装依赖
+- 安装PaddlePaddle >= 2.0
+- 安装PaddleOCR依赖
+ ```
+ pip3 install -r ../requirements.txt
+ ```
+- 安装autolog(规范化日志输出工具)
+ ```
+ git clone https://github.com/LDOUBLEV/AutoLog
+ cd AutoLog
+ pip3 install -r requirements.txt
+ python3 setup.py bdist_wheel
+ pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl
+ cd ../
+ ```
+
+
+### 2.2 功能测试
+先运行`prepare.sh`准备数据和模型,然后运行`test_train_inference_python.sh`进行测试,最终在```PTDN/output```目录下生成`python_infer_*.log`格式的日志文件。
+
+
+`test_train_inference_python.sh`包含5种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是:
+
+- 模式1:lite_train_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
+```shell
+bash PTDN/prepare.sh ./PTDN/configs/ppocr_det_mobile_params.txt 'lite_train_infer'
+bash PTDN/test_train_inference_python.sh ./PTDN/configs/ppocr_det_mobile_params.txt 'lite_train_infer'
+```
+
+- 模式2:whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理;
+```shell
+bash PTDN/prepare.sh ./PTDN/configs/ppocr_det_mobile_params.txt 'whole_infer'
+bash PTDN/test_train_inference_python.sh ./PTDN/configs/ppocr_det_mobile_params.txt 'whole_infer'
+```
+
+- 模式3:infer,不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度;
+```shell
+bash PTDN/prepare.sh ./PTDN/configs/ppocr_det_mobile_params.txt 'infer'
+# 用法1:
+bash PTDN/test_train_inference_python.sh ./PTDN/configs/ppocr_det_mobile_params.txt 'infer'
+# 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号
+bash PTDN/test_train_inference_python.sh ./PTDN/configs/ppocr_det_mobile_params.txt 'infer' '1'
+```
+
+- 模式4:whole_train_infer,CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度;
+```shell
+bash PTDN/prepare.sh ./PTDN/configs/ppocr_det_mobile_params.txt 'whole_train_infer'
+bash PTDN/test_train_inference_python.sh ./PTDN/configs/ppocr_det_mobile_params.txt 'whole_train_infer'
+```
+
+- 模式5:klquant_infer,测试离线量化;
+```shell
+bash PTDN/prepare.sh ./PTDN/configs/ppocr_det_mobile_params.txt 'klquant_infer'
+bash PTDN/test_train_inference_python.sh PTDN/configs/ppocr_det_mobile_params.txt 'klquant_infer'
+```
+
+
+### 2.3 精度测试
+
+使用compare_results.py脚本比较模型预测的结果是否符合预期,主要步骤包括:
+- 提取日志中的预测坐标;
+- 从本地文件中提取保存好的坐标结果;
+- 比较上述两个结果是否符合精度预期,误差大于设置阈值时会报错。
+
+#### 使用方式
+运行命令:
+```shell
+python3.7 PTDN/compare_results.py --gt_file=./PTDN/results/python_*.txt --log_file=./PTDN/output/python_*.log --atol=1e-3 --rtol=1e-3
+```
+
+参数介绍:
+- gt_file: 指向事先保存好的预测结果路径,支持*.txt 结尾,会自动索引*.txt格式的文件,文件默认保存在PTDN/result/ 文件夹下
+- log_file: 指向运行PTDN/test_train_inference_python.sh 脚本的infer模式保存的预测日志,预测日志中打印的有预测结果,比如:文本框,预测文本,类别等等,同样支持python_infer_*.log格式传入
+- atol: 设置的绝对误差
+- rtol: 设置的相对误差
+
+#### 运行结果
+
+正常运行效果如下图:
+
+
+出现不一致结果时的运行输出:
+
+
+
+## 3. 更多教程
+本文档为功能测试用,更丰富的训练预测使用教程请参考:
+[模型训练](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/training.md)
+[基于Python预测引擎推理](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/inference.md)
diff --git a/tests/prepare.sh b/PTDN/prepare.sh
similarity index 99%
rename from tests/prepare.sh
rename to PTDN/prepare.sh
index abb84c881e52ca8076f218e926d41679b6578d09..d842f4f573d0b1bd697bdad9b67a765ebcf6da6c 100644
--- a/tests/prepare.sh
+++ b/PTDN/prepare.sh
@@ -134,5 +134,5 @@ if [ ${MODE} = "serving_infer" ];then
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_det_infer.tar
wget -nc -P ./inference https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_server_v2.0_rec_infer.tar
- cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_det_infer.tar cd ../
+ cd ./inference && tar xf ch_ppocr_mobile_v2.0_det_infer.tar && tar xf ch_ppocr_mobile_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_rec_infer.tar && tar xf ch_ppocr_server_v2.0_det_infer.tar && cd ../
fi
diff --git a/PTDN/readme.md b/PTDN/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..71e888a2fe05a0a6d700b40250dd80d5f6d041e0
--- /dev/null
+++ b/PTDN/readme.md
@@ -0,0 +1,110 @@
+
+# 推理部署导航
+
+## 1. 简介
+
+飞桨除了基本的模型训练和预测,还提供了支持多端多平台的高性能推理部署工具。本文档提供了PaddleOCR中所有模型的推理部署导航PTDN(Paddle Train Deploy Navigation),方便用户查阅每种模型的推理部署打通情况,并可以进行一键测试。
+
+
+
+
+
+## 2. 汇总信息
+
+打通情况汇总如下,已填写的部分表示可以使用本工具进行一键测试,未填写的表示正在支持中。
+
+**字段说明:**
+- 基础训练预测:包括模型训练、Paddle Inference Python预测。
+- 更多训练方式:包括多机多卡、混合精度。
+- 模型压缩:包括裁剪、离线/在线量化、蒸馏。
+- 其他预测部署:包括Paddle Inference C++预测、Paddle Serving部署、Paddle-Lite部署等。
+
+更详细的mkldnn、Tensorrt等预测加速相关功能的支持情况可以查看各测试工具的[更多教程](#more)。
+
+| 算法论文 | 模型名称 | 模型类型 | 基础
训练预测 | 更多
训练方式 | 模型压缩 | 其他预测部署 |
+| :--- | :--- | :----: | :--------: | :---- | :---- | :---- |
+| DB |ch_ppocr_mobile_v2.0_det | 检测 | 支持 | 多机多卡
混合精度 | FPGM裁剪
离线量化| Paddle Inference: C++
Paddle Serving: Python, C++
Paddle-Lite:
(1) ARM CPU(C++) |
+| DB |ch_ppocr_server_v2.0_det | 检测 | 支持 | 多机多卡
混合精度 | FPGM裁剪
离线量化| Paddle Inference: C++
Paddle Serving: Python, C++
Paddle-Lite:
(1) ARM CPU(C++) |
+| DB |ch_PP-OCRv2_det | 检测 |
+| CRNN |ch_ppocr_mobile_v2.0_rec | 识别 | 支持 | 多机多卡
混合精度 | PACT量化
离线量化| Paddle Inference: C++
Paddle Serving: Python, C++
Paddle-Lite:
(1) ARM CPU(C++) |
+| CRNN |ch_ppocr_server_v2.0_rec | 识别 | 支持 | 多机多卡
混合精度 | PACT量化
离线量化| Paddle Inference: C++
Paddle Serving: Python, C++
Paddle-Lite:
(1) ARM CPU(C++) |
+| CRNN |ch_PP-OCRv2_rec | 识别 |
+| PP-OCR |ch_ppocr_mobile_v2.0 | 检测+识别 | 支持 | 多机多卡
混合精度 | - | Paddle Inference: C++
Paddle Serving: Python, C++
Paddle-Lite:
(1) ARM CPU(C++) |
+| PP-OCR |ch_ppocr_server_v2.0 | 检测+识别 | 支持 | 多机多卡
混合精度 | - | Paddle Inference: C++
Paddle Serving: Python, C++
Paddle-Lite:
(1) ARM CPU(C++) |
+|PP-OCRv2|ch_PP-OCRv2 | 检测+识别 |
+| DB |det_mv3_db_v2.0 | 检测 |
+| DB |det_r50_vd_db_v2.0 | 检测 |
+| EAST |det_mv3_east_v2.0 | 检测 |
+| EAST |det_r50_vd_east_v2.0 | 检测 |
+| PSENet |det_mv3_pse_v2.0 | 检测 |
+| PSENet |det_r50_vd_pse_v2.0 | 检测 |
+| SAST |det_r50_vd_sast_totaltext_v2.0 | 检测 |
+| Rosetta|rec_mv3_none_none_ctc_v2.0 | 识别 |
+| Rosetta|rec_r34_vd_none_none_ctc_v2.0 | 识别 |
+| CRNN |rec_mv3_none_bilstm_ctc_v2.0 | 识别 |
+| CRNN |rec_r34_vd_none_bilstm_ctc_v2.0| 识别 |
+| StarNet|rec_mv3_tps_bilstm_ctc_v2.0 | 识别 |
+| StarNet|rec_r34_vd_tps_bilstm_ctc_v2.0 | 识别 |
+| RARE |rec_mv3_tps_bilstm_att_v2.0 | 识别 |
+| RARE |rec_r34_vd_tps_bilstm_att_v2.0 | 识别 |
+| SRN |rec_r50fpn_vd_none_srn | 识别 |
+| NRTR |rec_mtb_nrtr | 识别 |
+| SAR |rec_r31_sar | 识别 |
+| PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端|
+
+
+
+## 3. 一键测试工具使用
+### 目录介绍
+
+```shell
+PTDN/
+├── configs/ # 配置文件目录
+ ├── det_mv3_db.yml # 测试mobile版ppocr检测模型训练的yml文件
+ ├── det_r50_vd_db.yml # 测试server版ppocr检测模型训练的yml文件
+ ├── rec_icdar15_r34_train.yml # 测试server版ppocr识别模型训练的yml文件
+ ├── ppocr_sys_mobile_params.txt # 测试mobile版ppocr检测+识别模型串联的参数配置文件
+ ├── ppocr_det_mobile_params.txt # 测试mobile版ppocr检测模型的参数配置文件
+ ├── ppocr_rec_mobile_params.txt # 测试mobile版ppocr识别模型的参数配置文件
+ ├── ppocr_sys_server_params.txt # 测试server版ppocr检测+识别模型串联的参数配置文件
+ ├── ppocr_det_server_params.txt # 测试server版ppocr检测模型的参数配置文件
+ ├── ppocr_rec_server_params.txt # 测试server版ppocr识别模型的参数配置文件
+ ├── ...
+├── results/ # 预先保存的预测结果,用于和实际预测结果进行精读比对
+ ├── python_ppocr_det_mobile_results_fp32.txt # 预存的mobile版ppocr检测模型python预测fp32精度的结果
+ ├── python_ppocr_det_mobile_results_fp16.txt # 预存的mobile版ppocr检测模型python预测fp16精度的结果
+ ├── cpp_ppocr_det_mobile_results_fp32.txt # 预存的mobile版ppocr检测模型c++预测的fp32精度的结果
+ ├── cpp_ppocr_det_mobile_results_fp16.txt # 预存的mobile版ppocr检测模型c++预测的fp16精度的结果
+ ├── ...
+├── prepare.sh # 完成test_*.sh运行所需要的数据和模型下载
+├── test_train_inference_python.sh # 测试python训练预测的主程序
+├── test_inference_cpp.sh # 测试c++预测的主程序
+├── test_serving.sh # 测试serving部署预测的主程序
+├── test_lite.sh # 测试lite部署预测的主程序
+├── compare_results.py # 用于对比log中的预测结果与results中的预存结果精度误差是否在限定范围内
+└── readme.md # 使用文档
+```
+
+### 测试流程
+使用本工具,可以测试不同功能的支持情况,以及预测结果是否对齐,测试流程如下:
+
+
+
+
+1. 运行prepare.sh准备测试所需数据和模型;
+2. 运行要测试的功能对应的测试脚本`test_*.sh`,产出log,由log可以看到不同配置是否运行成功;
+3. 用`compare_results.py`对比log中的预测结果和预存在results目录下的结果,判断预测精度是否符合预期(在误差范围内)。
+
+其中,有4个测试主程序,功能如下:
+- `test_train_inference_python.sh`:测试基于Python的模型训练、评估、推理等基本功能,包括裁剪、量化、蒸馏。
+- `test_inference_cpp.sh`:测试基于C++的模型推理。
+- `test_serving.sh`:测试基于Paddle Serving的服务化部署功能。
+- `test_lite.sh`:测试基于Paddle-Lite的端侧预测部署功能。
+
+
+#### 更多教程
+各功能测试中涉及混合精度、裁剪、量化等训练相关,及mkldnn、Tensorrt等多种预测相关参数配置,请点击下方相应链接了解更多细节和使用教程:
+[test_train_inference_python 使用](docs/test_train_inference_python.md)
+[test_inference_cpp 使用](docs/test_inference_cpp.md)
+[test_serving 使用](docs/test_serving.md)
+[test_lite 使用](docs/test_lite.md)
diff --git a/tests/results/ppocr_det_mobile_results_fp16_cpp.txt b/PTDN/results/cpp_ppocr_det_mobile_results_fp16.txt
similarity index 100%
rename from tests/results/ppocr_det_mobile_results_fp16_cpp.txt
rename to PTDN/results/cpp_ppocr_det_mobile_results_fp16.txt
diff --git a/tests/results/ppocr_det_mobile_results_fp32_cpp.txt b/PTDN/results/cpp_ppocr_det_mobile_results_fp32.txt
similarity index 100%
rename from tests/results/ppocr_det_mobile_results_fp32_cpp.txt
rename to PTDN/results/cpp_ppocr_det_mobile_results_fp32.txt
diff --git a/tests/results/ppocr_det_mobile_results_fp16.txt b/PTDN/results/python_ppocr_det_mobile_results_fp16.txt
similarity index 100%
rename from tests/results/ppocr_det_mobile_results_fp16.txt
rename to PTDN/results/python_ppocr_det_mobile_results_fp16.txt
diff --git a/tests/results/ppocr_det_mobile_results_fp32.txt b/PTDN/results/python_ppocr_det_mobile_results_fp32.txt
similarity index 100%
rename from tests/results/ppocr_det_mobile_results_fp32.txt
rename to PTDN/results/python_ppocr_det_mobile_results_fp32.txt
diff --git a/tests/test_cpp.sh b/PTDN/test_inference_cpp.sh
similarity index 96%
rename from tests/test_cpp.sh
rename to PTDN/test_inference_cpp.sh
index f755858cd0f781d7ef0de8089aefbee86ef83f82..124bdacb7dad04bdea07a62ba9c86b248be5a06d 100644
--- a/tests/test_cpp.sh
+++ b/PTDN/test_inference_cpp.sh
@@ -56,7 +56,11 @@ function func_cpp_inference(){
fi
for threads in ${cpp_cpu_threads_list[*]}; do
for batch_size in ${cpp_batch_size_list[*]}; do
- _save_log_path="${_log_path}/cpp_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log"
+ precision="fp32"
+ if [ ${use_mkldnn} = "False" ] && [ ${_flag_quant} = "True" ]; then
+ precison="int8"
+ fi
+ _save_log_path="${_log_path}/cpp_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
set_infer_data=$(func_set_params "${cpp_image_dir_key}" "${_img_dir}")
set_benchmark=$(func_set_params "${cpp_benchmark_key}" "${cpp_benchmark_value}")
set_batchsize=$(func_set_params "${cpp_batch_size_key}" "${batch_size}")
diff --git a/tests/test_serving.sh b/PTDN/test_serving.sh
similarity index 65%
rename from tests/test_serving.sh
rename to PTDN/test_serving.sh
index 8998ee7ee4c7f14c1a7df86611b709647c0d1a05..af66d70d7b0a255c33d1114a3951adb92407b8d1 100644
--- a/tests/test_serving.sh
+++ b/PTDN/test_serving.sh
@@ -1,45 +1,45 @@
#!/bin/bash
-source tests/common_func.sh
+source PTDN/common_func.sh
FILENAME=$1
-dataline=$(awk 'NR==67, NR==81{print}' $FILENAME)
+dataline=$(awk 'NR==67, NR==83{print}' $FILENAME)
# parser params
IFS=$'\n'
lines=(${dataline})
# parser serving
-trans_model_py=$(func_parser_value "${lines[1]}")
-infer_model_dir_key=$(func_parser_key "${lines[2]}")
-infer_model_dir_value=$(func_parser_value "${lines[2]}")
-model_filename_key=$(func_parser_key "${lines[3]}")
-model_filename_value=$(func_parser_value "${lines[3]}")
-params_filename_key=$(func_parser_key "${lines[4]}")
-params_filename_value=$(func_parser_value "${lines[4]}")
-serving_server_key=$(func_parser_key "${lines[5]}")
-serving_server_value=$(func_parser_value "${lines[5]}")
-serving_client_key=$(func_parser_key "${lines[6]}")
-serving_client_value=$(func_parser_value "${lines[6]}")
-serving_dir_value=$(func_parser_value "${lines[7]}")
-web_service_py=$(func_parser_value "${lines[8]}")
-web_use_gpu_key=$(func_parser_key "${lines[9]}")
-web_use_gpu_list=$(func_parser_value "${lines[9]}")
-web_use_mkldnn_key=$(func_parser_key "${lines[10]}")
-web_use_mkldnn_list=$(func_parser_value "${lines[10]}")
-web_cpu_threads_key=$(func_parser_key "${lines[11]}")
-web_cpu_threads_list=$(func_parser_value "${lines[11]}")
-web_use_trt_key=$(func_parser_key "${lines[12]}")
-web_use_trt_list=$(func_parser_value "${lines[12]}")
-web_precision_key=$(func_parser_key "${lines[13]}")
-web_precision_list=$(func_parser_value "${lines[13]}")
-pipeline_py=$(func_parser_value "${lines[14]}")
+model_name=$(func_parser_value "${lines[1]}")
+python=$(func_parser_value "${lines[2]}")
+trans_model_py=$(func_parser_value "${lines[3]}")
+infer_model_dir_key=$(func_parser_key "${lines[4]}")
+infer_model_dir_value=$(func_parser_value "${lines[4]}")
+model_filename_key=$(func_parser_key "${lines[5]}")
+model_filename_value=$(func_parser_value "${lines[5]}")
+params_filename_key=$(func_parser_key "${lines[6]}")
+params_filename_value=$(func_parser_value "${lines[6]}")
+serving_server_key=$(func_parser_key "${lines[7]}")
+serving_server_value=$(func_parser_value "${lines[7]}")
+serving_client_key=$(func_parser_key "${lines[8]}")
+serving_client_value=$(func_parser_value "${lines[8]}")
+serving_dir_value=$(func_parser_value "${lines[9]}")
+web_service_py=$(func_parser_value "${lines[10]}")
+web_use_gpu_key=$(func_parser_key "${lines[11]}")
+web_use_gpu_list=$(func_parser_value "${lines[11]}")
+web_use_mkldnn_key=$(func_parser_key "${lines[12]}")
+web_use_mkldnn_list=$(func_parser_value "${lines[12]}")
+web_cpu_threads_key=$(func_parser_key "${lines[13]}")
+web_cpu_threads_list=$(func_parser_value "${lines[13]}")
+web_use_trt_key=$(func_parser_key "${lines[14]}")
+web_use_trt_list=$(func_parser_value "${lines[14]}")
+web_precision_key=$(func_parser_key "${lines[15]}")
+web_precision_list=$(func_parser_value "${lines[15]}")
+pipeline_py=$(func_parser_value "${lines[16]}")
-
-LOG_PATH="./tests/output"
-mkdir -p ${LOG_PATH}
+LOG_PATH="../../PTDN/output"
+mkdir -p ./PTDN/output
status_log="${LOG_PATH}/results_serving.log"
-
function func_serving(){
IFS='|'
_python=$1
@@ -65,12 +65,12 @@ function func_serving(){
continue
fi
for threads in ${web_cpu_threads_list[*]}; do
- _save_log_path="${_log_path}/server_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_1.log"
+ _save_log_path="${LOG_PATH}/server_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_1.log"
set_cpu_threads=$(func_set_params "${web_cpu_threads_key}" "${threads}")
- web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}=${use_gpu} ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} &>${_save_log_path} &"
+ web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}=${use_gpu} ${web_use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} &"
eval $web_service_cmd
sleep 2s
- pipeline_cmd="${python} ${pipeline_py}"
+ pipeline_cmd="${python} ${pipeline_py} > ${_save_log_path} 2>&1 "
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
@@ -93,13 +93,13 @@ function func_serving(){
if [[ ${use_trt} = "False" || ${precision} =~ "int8" ]] && [[ ${_flag_quant} = "True" ]]; then
continue
fi
- _save_log_path="${_log_path}/infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_1.log"
+ _save_log_path="${LOG_PATH}/server_infer_gpu_usetrt_${use_trt}_precision_${precision}_batchsize_1.log"
set_tensorrt=$(func_set_params "${web_use_trt_key}" "${use_trt}")
set_precision=$(func_set_params "${web_precision_key}" "${precision}")
- web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} &>${_save_log_path} & "
+ web_service_cmd="${python} ${web_service_py} ${web_use_gpu_key}=${use_gpu} ${set_tensorrt} ${set_precision} & "
eval $web_service_cmd
sleep 2s
- pipeline_cmd="${python} ${pipeline_py}"
+ pipeline_cmd="${python} ${pipeline_py} > ${_save_log_path} 2>&1"
eval $pipeline_cmd
last_status=${PIPESTATUS[0]}
eval "cat ${_save_log_path}"
@@ -129,3 +129,7 @@ eval $env
echo "################### run test ###################"
+
+export Count=0
+IFS="|"
+func_serving "${web_service_cmd}"
diff --git a/tests/test_python.sh b/PTDN/test_train_inference_python.sh
similarity index 83%
rename from tests/test_python.sh
rename to PTDN/test_train_inference_python.sh
index 39b043b809016b245954a835d37789dcc28d7265..756e1f89d74c1df8de50cf8e23fd3d9c95bd20c5 100644
--- a/tests/test_python.sh
+++ b/PTDN/test_train_inference_python.sh
@@ -5,11 +5,7 @@ FILENAME=$1
# MODE be one of ['lite_train_infer' 'whole_infer' 'whole_train_infer', 'infer', 'klquant_infer']
MODE=$2
-if [ ${MODE} = "klquant_infer" ]; then
- dataline=$(awk 'NR==82, NR==98{print}' $FILENAME)
-else
- dataline=$(awk 'NR==1, NR==51{print}' $FILENAME)
-fi
+dataline=$(awk 'NR==1, NR==51{print}' $FILENAME)
# parser params
IFS=$'\n'
@@ -93,6 +89,8 @@ infer_value1=$(func_parser_value "${lines[50]}")
# parser klquant_infer
if [ ${MODE} = "klquant_infer" ]; then
+ dataline=$(awk 'NR==82, NR==98{print}' $FILENAME)
+ lines=(${dataline})
# parser inference model
infer_model_dir_list=$(func_parser_value "${lines[1]}")
infer_export_list=$(func_parser_value "${lines[2]}")
@@ -143,18 +141,28 @@ function func_inference(){
fi
for threads in ${cpu_threads_list[*]}; do
for batch_size in ${batch_size_list[*]}; do
- _save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_batchsize_${batch_size}.log"
- set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
- set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
- set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
- set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
- set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
- set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
- command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_infer_params1} > ${_save_log_path} 2>&1 "
- eval $command
- last_status=${PIPESTATUS[0]}
- eval "cat ${_save_log_path}"
- status_check $last_status "${command}" "${status_log}"
+ for precision in ${precision_list[*]}; do
+ if [ ${use_mkldnn} = "False" ] && [ ${precision} = "fp16" ]; then
+ continue
+ fi # skip when enable fp16 but disable mkldnn
+ if [ ${_flag_quant} = "True" ] && [ ${precision} != "int8" ]; then
+ continue
+ fi # skip when quant model inference but precision is not int8
+ set_precision=$(func_set_params "${precision_key}" "${precision}")
+
+ _save_log_path="${_log_path}/python_infer_cpu_usemkldnn_${use_mkldnn}_threads_${threads}_precision_${precision}_batchsize_${batch_size}.log"
+ set_infer_data=$(func_set_params "${image_dir_key}" "${_img_dir}")
+ set_benchmark=$(func_set_params "${benchmark_key}" "${benchmark_value}")
+ set_batchsize=$(func_set_params "${batch_size_key}" "${batch_size}")
+ set_cpu_threads=$(func_set_params "${cpu_threads_key}" "${threads}")
+ set_model_dir=$(func_set_params "${infer_model_key}" "${_model_dir}")
+ set_infer_params1=$(func_set_params "${infer_key1}" "${infer_value1}")
+ command="${_python} ${_script} ${use_gpu_key}=${use_gpu} ${use_mkldnn_key}=${use_mkldnn} ${set_cpu_threads} ${set_model_dir} ${set_batchsize} ${set_infer_data} ${set_benchmark} ${set_precision} ${set_infer_params1} > ${_save_log_path} 2>&1 "
+ eval $command
+ last_status=${PIPESTATUS[0]}
+ eval "cat ${_save_log_path}"
+ status_check $last_status "${command}" "${status_log}"
+ done
done
done
done
@@ -224,6 +232,9 @@ if [ ${MODE} = "infer" ] || [ ${MODE} = "klquant_infer" ]; then
fi
#run inference
is_quant=${infer_quant_flag[Count]}
+ if [ ${MODE} = "klquant_infer" ]; then
+ is_quant="True"
+ fi
func_inference "${python}" "${inference_py}" "${save_infer_dir}" "${LOG_PATH}" "${infer_img_dir}" ${is_quant}
Count=$(($Count + 1))
done
@@ -234,6 +245,7 @@ else
for gpu in ${gpu_list[*]}; do
use_gpu=${USE_GPU_KEY[Count]}
Count=$(($Count + 1))
+ ips=""
if [ ${gpu} = "-1" ];then
env=""
elif [ ${#gpu} -le 1 ];then
@@ -253,6 +265,11 @@ else
env=" "
fi
for autocast in ${autocast_list[*]}; do
+ if [ ${autocast} = "amp" ]; then
+ set_amp_config="Global.use_amp=True Global.scale_loss=1024.0 Global.use_dynamic_loss_scaling=True"
+ else
+ set_amp_config=" "
+ fi
for trainer in ${trainer_list[*]}; do
flag_quant=False
if [ ${trainer} = ${pact_key} ]; then
@@ -279,7 +296,6 @@ else
if [ ${run_train} = "null" ]; then
continue
fi
-
set_autocast=$(func_set_params "${autocast_key}" "${autocast}")
set_epoch=$(func_set_params "${epoch_key}" "${epoch_num}")
set_pretrain=$(func_set_params "${pretrain_model_key}" "${pretrain_model_value}")
@@ -295,11 +311,11 @@ else
set_save_model=$(func_set_params "${save_model_key}" "${save_log}")
if [ ${#gpu} -le 2 ];then # train with cpu or single gpu
- cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} "
- elif [ ${#gpu} -le 15 ];then # train with multi-gpu
- cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1}"
+ cmd="${python} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config} "
+ elif [ ${#ips} -le 26 ];then # train with multi-gpu
+ cmd="${python} -m paddle.distributed.launch --gpus=${gpu} ${run_train} ${set_use_gpu} ${set_save_model} ${set_epoch} ${set_pretrain} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
else # train with multi-machine
- cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${run_train} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1}"
+ cmd="${python} -m paddle.distributed.launch --ips=${ips} --gpus=${gpu} ${set_use_gpu} ${run_train} ${set_save_model} ${set_pretrain} ${set_epoch} ${set_autocast} ${set_batchsize} ${set_train_params1} ${set_amp_config}"
fi
# run train
eval "unset CUDA_VISIBLE_DEVICES"
diff --git a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
index 38f77f7372c4e422b5601deb5119c24fd1e3f787..e2aa50106ff60aa61858a22ba6fdd03b8cd04d85 100644
--- a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
+++ b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml
@@ -14,7 +14,6 @@ Global:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
- character_type: ch
max_text_length: 25
infer_mode: false
use_space_char: true
diff --git a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml
index d2308fd5747f3fadf3bb1c98c5602c67d5e63eca..ab48b99791d00785d143cd933ccc31b3f69d0f8f 100644
--- a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml
+++ b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_distillation.yml
@@ -14,7 +14,6 @@ Global:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
- character_type: ch
max_text_length: 25
infer_mode: false
use_space_char: true
diff --git a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml
index 8b568637a189ac47438b84e89fc55ddc643ab297..7161203035b2324c7afc56b2b0c743428558a098 100644
--- a/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml
+++ b/configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml
@@ -14,7 +14,6 @@ Global:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
- character_type: ch
max_text_length: 25
infer_mode: false
use_space_char: true
diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml
index 717c16814bac2f6fca78aa63566df12bd8cbf67b..c76063d5cedc31985404ddfff5147e1e0c100d20 100644
--- a/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml
+++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml
@@ -15,7 +15,6 @@ Global:
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
- character_type: ch
max_text_length: 25
infer_mode: False
use_space_char: True
diff --git a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml
index 660465f301047110db7001db7a32e687f2917b61..563ce110b865adabf320616227bdf8d2eb465c11 100644
--- a/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml
+++ b/configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml
@@ -15,7 +15,6 @@ Global:
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
- character_type: ch
max_text_length: 25
infer_mode: False
use_space_char: True
diff --git a/configs/rec/multi_language/rec_arabic_lite_train.yml b/configs/rec/multi_language/rec_arabic_lite_train.yml
index 6dcfd1b69988b09c7dfc05cdbacce9756ea1f7cb..a746260e0001e34b1f50fb066885091b3686cb4d 100644
--- a/configs/rec/multi_language/rec_arabic_lite_train.yml
+++ b/configs/rec/multi_language/rec_arabic_lite_train.yml
@@ -15,7 +15,6 @@ Global:
use_visualdl: false
infer_img: null
character_dict_path: ppocr/utils/dict/arabic_dict.txt
- character_type: arabic
max_text_length: 25
infer_mode: false
use_space_char: true
diff --git a/configs/rec/multi_language/rec_cyrillic_lite_train.yml b/configs/rec/multi_language/rec_cyrillic_lite_train.yml
index 52527c1dfb9a306429bbab9241c623581d546e45..98544f627111340b61abd210ea5b4d7979511a15 100644
--- a/configs/rec/multi_language/rec_cyrillic_lite_train.yml
+++ b/configs/rec/multi_language/rec_cyrillic_lite_train.yml
@@ -15,7 +15,6 @@ Global:
use_visualdl: false
infer_img: null
character_dict_path: ppocr/utils/dict/cyrillic_dict.txt
- character_type: cyrillic
max_text_length: 25
infer_mode: false
use_space_char: true
diff --git a/configs/rec/multi_language/rec_devanagari_lite_train.yml b/configs/rec/multi_language/rec_devanagari_lite_train.yml
index e1a7c829c3e6d3c3a57f1d501cdd80a560703ec7..518b9f19ccaccb6405f7e9cb4d783b441e8c7ae7 100644
--- a/configs/rec/multi_language/rec_devanagari_lite_train.yml
+++ b/configs/rec/multi_language/rec_devanagari_lite_train.yml
@@ -15,7 +15,6 @@ Global:
use_visualdl: false
infer_img: null
character_dict_path: ppocr/utils/dict/devanagari_dict.txt
- character_type: devanagari
max_text_length: 25
infer_mode: false
use_space_char: true
diff --git a/configs/rec/multi_language/rec_en_number_lite_train.yml b/configs/rec/multi_language/rec_en_number_lite_train.yml
index fff4dfcd905b406964bb07cf14017af22f40e91e..ff1fb8698163d00fae57e682059da47d2007505d 100644
--- a/configs/rec/multi_language/rec_en_number_lite_train.yml
+++ b/configs/rec/multi_language/rec_en_number_lite_train.yml
@@ -16,7 +16,6 @@ Global:
infer_img:
# for data or label process
character_dict_path: ppocr/utils/en_dict.txt
- character_type: EN
max_text_length: 25
infer_mode: False
use_space_char: True
diff --git a/configs/rec/multi_language/rec_french_lite_train.yml b/configs/rec/multi_language/rec_french_lite_train.yml
index 63378d38a0d31fc77c33173e0ed864f28c5c3a8b..217369d30bc3ac6e09c2a580facbd0395e0ce727 100644
--- a/configs/rec/multi_language/rec_french_lite_train.yml
+++ b/configs/rec/multi_language/rec_french_lite_train.yml
@@ -16,7 +16,6 @@ Global:
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict/french_dict.txt
- character_type: french
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/multi_language/rec_german_lite_train.yml b/configs/rec/multi_language/rec_german_lite_train.yml
index 1651510c5e4597e82298135d2f6c64aa747cf961..67520f5fb668327fdbd0cddb68cb6a3d6d3d112e 100644
--- a/configs/rec/multi_language/rec_german_lite_train.yml
+++ b/configs/rec/multi_language/rec_german_lite_train.yml
@@ -16,7 +16,6 @@ Global:
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict/german_dict.txt
- character_type: german
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/multi_language/rec_japan_lite_train.yml b/configs/rec/multi_language/rec_japan_lite_train.yml
index bb47584edbc70f68d8d2d89dced3ec9b12f0e1cb..448aff1ebd0b418191c622cee97346931a86929b 100644
--- a/configs/rec/multi_language/rec_japan_lite_train.yml
+++ b/configs/rec/multi_language/rec_japan_lite_train.yml
@@ -16,7 +16,6 @@ Global:
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict/japan_dict.txt
- character_type: japan
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/multi_language/rec_korean_lite_train.yml b/configs/rec/multi_language/rec_korean_lite_train.yml
index 77f15524f78cd7f1c3dcf4988960e718422f5d89..8118119da8f15102ad4c8485b7e26b9436d65cda 100644
--- a/configs/rec/multi_language/rec_korean_lite_train.yml
+++ b/configs/rec/multi_language/rec_korean_lite_train.yml
@@ -16,7 +16,6 @@ Global:
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict/korean_dict.txt
- character_type: korean
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/multi_language/rec_latin_lite_train.yml b/configs/rec/multi_language/rec_latin_lite_train.yml
index e71112b4b4f0afd3ceab9f10078bc5d518ee9e59..04fe6d1a49ea06341b2218123d2319a5962b934b 100644
--- a/configs/rec/multi_language/rec_latin_lite_train.yml
+++ b/configs/rec/multi_language/rec_latin_lite_train.yml
@@ -15,7 +15,6 @@ Global:
use_visualdl: false
infer_img: null
character_dict_path: ppocr/utils/dict/latin_dict.txt
- character_type: latin
max_text_length: 25
infer_mode: false
use_space_char: true
diff --git a/configs/rec/rec_icdar15_train.yml b/configs/rec/rec_icdar15_train.yml
index 17a4d76483635d648ebb8cb897f621a186dcd516..893f7382f8b82f3c2d5f10cdf10735645fd3a5ee 100644
--- a/configs/rec/rec_icdar15_train.yml
+++ b/configs/rec/rec_icdar15_train.yml
@@ -15,7 +15,6 @@ Global:
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path: ppocr/utils/en_dict.txt
- character_type: EN
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/rec_mtb_nrtr.yml b/configs/rec/rec_mtb_nrtr.yml
index 8639a28a931247ee34f2e3842407fd1d2e065950..04267500854310dc6d5df9318bb8c056c65cd5b5 100644
--- a/configs/rec/rec_mtb_nrtr.yml
+++ b/configs/rec/rec_mtb_nrtr.yml
@@ -14,11 +14,10 @@ Global:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
- character_dict_path:
- character_type: EN_symbol
+ character_dict_path: ppocr/utils/EN_symbol_dict.txt
max_text_length: 25
infer_mode: False
- use_space_char: True
+ use_space_char: False
save_res_path: ./output/rec/predicts_nrtr.txt
Optimizer:
diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml
index 9e0bd23edba053b44fc7241c0a587ced5cd1ac76..9a950923b0cd4292f3f4d70ae51abc60c59dc615 100644
--- a/configs/rec/rec_mv3_none_bilstm_ctc.yml
+++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml
@@ -14,8 +14,7 @@ Global:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
- character_dict_path:
- character_type: en
+ character_dict_path:
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml
index 904afe1134b565d6459cdcda4cbfa43ae4925b92..28f0252adb4b74f88f8c6203521adb66c851e6b0 100644
--- a/configs/rec/rec_mv3_none_none_ctc.yml
+++ b/configs/rec/rec_mv3_none_none_ctc.yml
@@ -15,7 +15,6 @@ Global:
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
- character_type: en
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/rec_mv3_tps_bilstm_att.yml b/configs/rec/rec_mv3_tps_bilstm_att.yml
index feaeb0545c687774938521e4c45c026207172f11..6c347e765fe04ca3e5330de6cabb9998855436c9 100644
--- a/configs/rec/rec_mv3_tps_bilstm_att.yml
+++ b/configs/rec/rec_mv3_tps_bilstm_att.yml
@@ -14,8 +14,7 @@ Global:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
- character_dict_path:
- character_type: en
+ character_dict_path:
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml
index 65ab23c42aff54ee548867e3482d7400603551ad..9d1ebbe4e2ce25d746ff9d6993bf820347a3558a 100644
--- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml
+++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml
@@ -15,7 +15,6 @@ Global:
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
- character_type: en
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/rec_r31_sar.yml b/configs/rec/rec_r31_sar.yml
index 41609fdf28e78f5340ab08878c2b8b23f46020d2..65e7877b28da80e0730f551b07d60b8a8c0ac48e 100644
--- a/configs/rec/rec_r31_sar.yml
+++ b/configs/rec/rec_r31_sar.yml
@@ -15,7 +15,6 @@ Global:
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict90.txt
- character_type: EN_symbol
max_text_length: 30
infer_mode: False
use_space_char: False
diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
index 331bb36ed84b83dc62a0f9b15524457238dedc13..9fdb5e99acec4ab5b2c3ff4b29158a41c766844b 100644
--- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
+++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
@@ -14,8 +14,7 @@ Global:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
- character_dict_path:
- character_type: en
+ character_dict_path:
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml
index 695a46958f669e4cb9508646080b45ac0767b8c9..0af2b2ff21938ce9b1750bd0fd8e27dabfd39998 100644
--- a/configs/rec/rec_r34_vd_none_none_ctc.yml
+++ b/configs/rec/rec_r34_vd_none_none_ctc.yml
@@ -15,7 +15,6 @@ Global:
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
character_dict_path:
- character_type: en
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/rec_r34_vd_tps_bilstm_att.yml b/configs/rec/rec_r34_vd_tps_bilstm_att.yml
index fdd3588c844ffd7ed61de73077ae2994f0ad498d..8919aae75720d1e2f786957dd44e2d5d6dcbb5af 100644
--- a/configs/rec/rec_r34_vd_tps_bilstm_att.yml
+++ b/configs/rec/rec_r34_vd_tps_bilstm_att.yml
@@ -14,8 +14,7 @@ Global:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
- character_dict_path:
- character_type: en
+ character_dict_path:
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
index 67108a6eaca2dd6f239261f5184341e5ade00dc0..c21fe61fbe62bab940bdb5ec1fef7833f402cb6c 100644
--- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
+++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
@@ -14,8 +14,7 @@ Global:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
- character_dict_path:
- character_type: en
+ character_dict_path:
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml
index fa7b1ae4e5fed41d3aa3670d6672cca01b63c359..b685362dedbcd6022fa247fe1499017647fa1546 100644
--- a/configs/rec/rec_r50_fpn_srn.yml
+++ b/configs/rec/rec_r50_fpn_srn.yml
@@ -14,8 +14,7 @@ Global:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
# for data or label process
- character_dict_path:
- character_type: en
+ character_dict_path:
max_text_length: 25
num_heads: 8
infer_mode: False
diff --git a/configs/rec/rec_resnet_stn_bilstm_att.yml b/configs/rec/rec_resnet_stn_bilstm_att.yml
index 1f6e534a6878a7ae84fc7fa7e1d975077f164d80..0f599258d46e2ce89a6b7deccf8287a2ec0f7e4e 100644
--- a/configs/rec/rec_resnet_stn_bilstm_att.yml
+++ b/configs/rec/rec_resnet_stn_bilstm_att.yml
@@ -14,8 +14,7 @@ Global:
use_visualdl: False
infer_img: doc/imgs_words_en/word_10.png
# for data or label process
- character_dict_path:
- character_type: EN_symbol
+ character_dict_path: ppocr/utils/EN_symbol_dict.txt
max_text_length: 100
infer_mode: False
use_space_char: False
diff --git a/doc/doc_ch/config.md b/doc/doc_ch/config.md
index 600d5bdb120444ec89222360af02adb3f96a8640..dcd0318ed908375c896d7a6730cd72db4cc4b848 100644
--- a/doc/doc_ch/config.md
+++ b/doc/doc_ch/config.md
@@ -37,10 +37,9 @@
| checkpoints | 加载模型参数路径 | None | 用于中断后加载参数继续训练 |
| use_visualdl | 设置是否启用visualdl进行可视化log展示 | False | [教程地址](https://www.paddlepaddle.org.cn/paddle/visualdl) |
| infer_img | 设置预测图像路径或文件夹路径 | ./infer_img | \|
-| character_dict_path | 设置字典路径 | ./ppocr/utils/ppocr_keys_v1.txt | \ |
+| character_dict_path | 设置字典路径 | ./ppocr/utils/ppocr_keys_v1.txt | 如果为空,则默认使用小写字母+数字作为字典 |
| max_text_length | 设置文本最大长度 | 25 | \ |
-| character_type | 设置字符类型 | ch | en/ch, en时将使用默认dict,ch时使用自定义dict|
-| use_space_char | 设置是否识别空格 | True | 仅在 character_type=ch 时支持空格 |
+| use_space_char | 设置是否识别空格 | True | |
| label_list | 设置方向分类器支持的角度 | ['0','180'] | 仅在方向分类器中生效 |
| save_res_path | 设置检测模型的结果保存地址 | ./output/det_db/predicts_db.txt | 仅在检测模型中生效 |
@@ -177,7 +176,7 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
--dict {path/of/dict} \ # 字典文件路径
-o Global.use_gpu=False # 是否使用gpu
...
-
+
```
意大利文由拉丁字母组成,因此执行完命令后会得到名为 rec_latin_lite_train.yml 的配置文件。
@@ -191,38 +190,37 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
use_gpu: True
epoch_num: 500
...
- character_type: it # 需要识别的语种
character_dict_path: {path/of/dict} # 字典文件所在路径
-
+
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/ # 数据存放根目录
label_file_list: ["./train_data/train_list.txt"] # 训练集label路径
...
-
+
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/ # 数据存放根目录
label_file_list: ["./train_data/val_list.txt"] # 验证集label路径
...
-
+
```
目前PaddleOCR支持的多语言算法有:
-| 配置文件 | 算法名称 | backbone | trans | seq | pred | language | character_type |
-| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
-| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 中文繁体 | chinese_cht|
-| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语(区分大小写) | EN |
-| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 | french |
-| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 | german |
-| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 | japan |
-| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 | korean |
-| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 拉丁字母 | latin |
-| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯字母 | ar |
-| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 斯拉夫字母 | cyrillic |
-| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 梵文字母 | devanagari |
+| 配置文件 | 算法名称 | backbone | trans | seq | pred | language |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 中文繁体 |
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语(区分大小写) |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 |
+| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 拉丁字母 |
+| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯字母 |
+| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 斯拉夫字母 |
+| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 梵文字母 |
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index b9be1e4cb2d1b256a05b82ef5d6db49dfcb2f31f..4e0f1d131e2547f0d4a8bdf35c0f4a6f8bf2e7a3 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -273,7 +273,7 @@ python3 tools/export_model.py -c configs/rec/rec_r34_vd_none_bilstm_ctc.yml -o G
CRNN 文本识别模型推理,可以执行如下命令:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rec_crnn/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/rec_crnn/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```
![](../imgs_words_en/word_336.png)
@@ -288,7 +288,7 @@ Predicts of ./doc/imgs_words_en/word_336.png:('super', 0.9999073)
- 训练时采用的图像分辨率不同,训练上述模型采用的图像分辨率是[3,32,100],而中文模型训练时,为了保证长文本的识别效果,训练时采用的图像分辨率是[3, 32, 320]。预测推理程序默认的的形状参数是训练中文采用的图像分辨率,即[3, 32, 320]。因此,这里推理上述英文模型时,需要通过参数rec_image_shape设置识别图像的形状。
-- 字符列表,DTRB论文中实验只是针对26个小写英文本母和10个数字进行实验,总共36个字符。所有大小字符都转成了小写字符,不在上面列表的字符都忽略,认为是空格。因此这里没有输入字符字典,而是通过如下命令生成字典.因此在推理时需要设置参数rec_char_type,指定为英文"en"。
+- 字符列表,DTRB论文中实验只是针对26个小写英文本母和10个数字进行实验,总共36个字符。所有大小字符都转成了小写字符,不在上面列表的字符都忽略,认为是空格。因此这里没有输入字符字典,而是通过如下命令生成字典.因此在推理时需要设置参数rec_char_dict_path,指定为英文字典"./ppocr/utils/ic15_dict.txt"。
```
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
@@ -303,15 +303,15 @@ dict_character = list(self.character_str)
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
--rec_model_dir="./inference/srn/" \
--rec_image_shape="1, 64, 256" \
- --rec_char_type="en" \
+ --rec_char_dict_path="./ppocr/utils/ic15_dict.txt" \
--rec_algorithm="SRN"
```
### 4. 自定义文本识别字典的推理
-如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch`
+如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_dict_path="your text dict path"
```
@@ -320,7 +320,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
```
![](../imgs_words/korean/1.jpg)
@@ -388,7 +388,7 @@ python3 tools/infer/predict_system.py --image_dir="./doc/imgs/00018069.jpg" --de
下面给出基于EAST文本检测和STAR-Net文本识别执行命令:
```
-python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
+python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```
执行命令后,识别结果图像如下:
diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md
index 52f978a734cbc750f4e2f36bb3ff28b2e67ab612..bb7d01712a85c92a02109e41814059e6c98c7cdc 100644
--- a/doc/doc_ch/recognition.md
+++ b/doc/doc_ch/recognition.md
@@ -159,7 +159,6 @@ PaddleOCR内置了一部分字典,可以按需使用。
- 自定义字典
如需自定义dic文件,请在 `configs/rec/rec_icdar15_train.yml` 中添加 `character_dict_path` 字段, 指向您的字典路径。
-并将 `character_type` 设置为 `ch`。
### 1.4 添加空格类别
@@ -246,8 +245,6 @@ Global:
...
# 添加自定义字典,如修改字典请将路径指向新字典
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
- # 修改字符类型
- character_type: ch
...
# 识别空格
use_space_char: True
@@ -311,18 +308,18 @@ PaddleOCR目前已支持80种(除中文外)语种识别,`configs/rec/multi
按语系划分,目前PaddleOCR支持的语种有:
-| 配置文件 | 算法名称 | backbone | trans | seq | pred | language | character_type |
-| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
-| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 中文繁体 | chinese_cht|
-| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语(区分大小写) | EN |
-| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 | french |
-| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 | german |
-| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 | japan |
-| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 | korean |
-| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 拉丁字母 | latin |
-| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯字母 | ar |
-| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 斯拉夫字母 | cyrillic |
-| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 梵文字母 | devanagari |
+| 配置文件 | 算法名称 | backbone | trans | seq | pred | language |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 中文繁体 |
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语(区分大小写) |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 |
+| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 拉丁字母 |
+| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯字母 |
+| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 斯拉夫字母 |
+| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 梵文字母 |
更多支持语种请参考: [多语言模型](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_ch/multi_languages.md#%E8%AF%AD%E7%A7%8D%E7%BC%A9%E5%86%99)
diff --git a/doc/doc_ch/training.md b/doc/doc_ch/training.md
index fb7f94a9e86cf392421ab6ed6f99cf2d49390096..c6c7b87d9925197b36a246c651ab7179ff9d2e81 100644
--- a/doc/doc_ch/training.md
+++ b/doc/doc_ch/training.md
@@ -129,3 +129,9 @@ PaddleOCR主要聚焦通用OCR,如果有垂类需求,您可以用PaddleOCR+
A:识别模型训练初期acc为0是正常的,多训一段时间指标就上来了。
+
+***
+具体的训练教程可点击下方链接跳转:
+- [文本检测模型训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/detection.md)
+- [文本识别模型训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/recognition.md)
+- [文本方向分类器训练](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/angle_class.md)
\ No newline at end of file
diff --git a/doc/doc_en/config_en.md b/doc/doc_en/config_en.md
index aa78263e4b73a3ac35250e5483a394ab77450c90..ce76da9b2f39532b387e3e45ca2ff497b0408635 100644
--- a/doc/doc_en/config_en.md
+++ b/doc/doc_en/config_en.md
@@ -1,4 +1,4 @@
-# Configuration
+# Configuration
- [1. Optional Parameter List](#1-optional-parameter-list)
- [2. Intorduction to Global Parameters of Configuration File](#2-intorduction-to-global-parameters-of-configuration-file)
@@ -37,9 +37,8 @@ Take rec_chinese_lite_train_v2.0.yml as an example
| checkpoints | set model parameter path | None | Used to load parameters after interruption to continue training|
| use_visualdl | Set whether to enable visualdl for visual log display | False | [Tutorial](https://www.paddlepaddle.org.cn/paddle/visualdl) |
| infer_img | Set inference image path or folder path | ./infer_img | \|
-| character_dict_path | Set dictionary path | ./ppocr/utils/ppocr_keys_v1.txt | \ |
+| character_dict_path | Set dictionary path | ./ppocr/utils/ppocr_keys_v1.txt | If the character_dict_path is None, model can only recognize number and lower letters |
| max_text_length | Set the maximum length of text | 25 | \ |
-| character_type | Set character type | ch | en/ch, the default dict will be used for en, and the custom dict will be used for ch |
| use_space_char | Set whether to recognize spaces | True | Only support in character_type=ch mode |
| label_list | Set the angle supported by the direction classifier | ['0','180'] | Only valid in angle classifier model |
| save_res_path | Set the save address of the test model results | ./output/det_db/predicts_db.txt | Only valid in the text detection model |
@@ -196,40 +195,39 @@ Italian is made up of Latin letters, so after executing the command, you will ge
use_gpu: True
epoch_num: 500
...
- character_type: it # language
character_dict_path: {path/of/dict} # path of dict
-
+
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/ # root directory of training data
label_file_list: ["./train_data/train_list.txt"] # train label path
...
-
+
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/ # root directory of val data
label_file_list: ["./train_data/val_list.txt"] # val label path
...
-
+
```
Currently, the multi-language algorithms supported by PaddleOCR are:
-| Configuration file | Algorithm name | backbone | trans | seq | pred | language | character_type |
-| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
-| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional | chinese_cht|
-| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) | EN |
-| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French | french |
-| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German | german |
-| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese | japan |
-| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean | korean |
-| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin | latin |
-| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic | ar |
-| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic | cyrillic |
-| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari | devanagari |
+| Configuration file | Algorithm name | backbone | trans | seq | pred | language |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional |
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean |
+| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin |
+| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic |
+| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic |
+| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari |
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md
index b445232feeefadc355e0f38b329050e26ccc0368..019ac4d0ac15aceed89286048d2c4d88a259e501 100755
--- a/doc/doc_en/inference_en.md
+++ b/doc/doc_en/inference_en.md
@@ -21,7 +21,7 @@ Next, we first introduce how to convert a trained model into an inference model,
- [2.2 DB Text Detection Model Inference](#DB_DETECTION)
- [2.3 East Text Detection Model Inference](#EAST_DETECTION)
- [2.4 Sast Text Detection Model Inference](#SAST_DETECTION)
-
+
- [3. Text Recognition Model Inference](#RECOGNITION_MODEL_INFERENCE)
- [3.1 Lightweight Chinese Text Recognition Model Reference](#LIGHTWEIGHT_RECOGNITION)
- [3.2 CTC-Based Text Recognition Model Inference](#CTC-BASED_RECOGNITION)
@@ -281,7 +281,7 @@ python3 tools/export_model.py -c configs/det/rec_r34_vd_none_bilstm_ctc.yml -o G
For CRNN text recognition model inference, execute the following commands:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```
![](../imgs_words_en/word_336.png)
@@ -314,7 +314,7 @@ with the training, such as: --rec_image_shape="1, 64, 256"
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \
--rec_model_dir="./inference/srn/" \
--rec_image_shape="1, 64, 256" \
- --rec_char_type="en" \
+ --rec_char_dict_path="./ppocr/utils/ic15_dict.txt" \
--rec_algorithm="SRN"
```
@@ -323,7 +323,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch`
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_dict_path="your text dict path"
```
@@ -333,7 +333,7 @@ If you need to predict other language models, when using inference model predict
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
```
![](../imgs_words/korean/1.jpg)
@@ -399,7 +399,7 @@ If you want to try other detection algorithms or recognition algorithms, please
The following command uses the combination of the EAST text detection and STAR-Net text recognition:
```
-python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
+python3 tools/infer/predict_system.py --image_dir="./doc/imgs_en/img_10.jpg" --det_model_dir="./inference/det_east/" --det_algorithm="EAST" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_dict_path="./ppocr/utils/ic15_dict.txt"
```
After executing the command, the recognition result image is as follows:
diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md
index 84f5541562f9ce267da10abfad209ea1eb909a3e..51857ba16b7773ef38452fad6aa070f2117a9086 100644
--- a/doc/doc_en/recognition_en.md
+++ b/doc/doc_en/recognition_en.md
@@ -161,7 +161,7 @@ The current multi-language model is still in the demo stage and will continue to
If you like, you can submit the dictionary file to [dict](../../ppocr/utils/dict) and we will thank you in the Repo.
-To customize the dict file, please modify the `character_dict_path` field in `configs/rec/rec_icdar15_train.yml` and set `character_type` to `ch`.
+To customize the dict file, please modify the `character_dict_path` field in `configs/rec/rec_icdar15_train.yml` .
- Custom dictionary
@@ -172,8 +172,6 @@ If you need to customize dic file, please add character_dict_path field in confi
If you want to support the recognition of the `space` category, please set the `use_space_char` field in the yml file to `True`.
-**Note: use_space_char only takes effect when character_type=ch**
-
## 2.Training
@@ -250,7 +248,6 @@ Global:
# Add a custom dictionary, such as modify the dictionary, please point the path to the new dictionary
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
# Modify character type
- character_type: ch
...
# Whether to recognize spaces
use_space_char: True
@@ -312,18 +309,18 @@ Eval:
Currently, the multi-language algorithms supported by PaddleOCR are:
-| Configuration file | Algorithm name | backbone | trans | seq | pred | language | character_type |
-| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
-| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional | chinese_cht|
-| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) | EN |
-| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French | french |
-| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German | german |
-| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese | japan |
-| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean | korean |
-| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin | latin |
-| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic | ar |
-| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic | cyrillic |
-| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari | devanagari |
+| Configuration file | Algorithm name | backbone | trans | seq | pred | language |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional |
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean |
+| rec_latin_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Latin |
+| rec_arabic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | arabic |
+| rec_cyrillic_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | cyrillic |
+| rec_devanagari_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | devanagari |
For more supported languages, please refer to : [Multi-language model](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.1/doc/doc_en/multi_languages_en.md#4-support-languages-and-abbreviations)
@@ -471,6 +468,3 @@ inference/det_db/
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="ch" --rec_char_dict_path="your text dict path"
```
-
-
-
diff --git a/doc/doc_en/training_en.md b/doc/doc_en/training_en.md
index 106e41c0183b0cb20a038e154e155a25fdc6faa6..aa5500ac88fef97829b4f19c5421e36f18ae1812 100644
--- a/doc/doc_en/training_en.md
+++ b/doc/doc_en/training_en.md
@@ -147,3 +147,9 @@ There are several experiences for reference when constructing the data set:
A: It is normal for the acc to be 0 at the beginning of the recognition model training, and the indicator will come up after a longer training period.
+
+***
+Click the following links for detailed training tutorial:
+- [text detection model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/detection.md)
+- [text recognition model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/recognition.md)
+- [text direction classification model training](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.3/doc/doc_ch/angle_class.md)
diff --git a/doc/joinus.PNG b/doc/joinus.PNG
index 974a4bd008d7b103de044cf8b4dbf37f09a0d06b..202ad0a5c6edf2190b71d5a7a544f1df94f866c4 100644
Binary files a/doc/joinus.PNG and b/doc/joinus.PNG differ
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index ebf52ec4e1d8713fd4da407318b14e682952606d..0a4fad621a9038e71a9d43eb4e12f78e7e92d73d 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -21,6 +21,8 @@ import numpy as np
import string
import json
+from ppocr.utils.logging import get_logger
+
class ClsLabelEncode(object):
def __init__(self, label_list, **kwargs):
@@ -92,31 +94,23 @@ class BaseRecLabelEncode(object):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False):
- support_character_type = [
- 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
- 'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
- 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
- 'mr', 'ne', 'latin', 'arabic', 'cyrillic', 'devanagari'
- ]
- assert character_type in support_character_type, "Only {} are supported now but get {}".format(
- support_character_type, character_type)
self.max_text_len = max_text_length
self.beg_str = "sos"
self.end_str = "eos"
- if character_type == "en":
+ self.lower = False
+
+ if character_dict_path is None:
+ logger = get_logger()
+ logger.warning(
+ "The character_dict_path is None, model can only recognize number and lower letters"
+ )
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
- elif character_type == "EN_symbol":
- # same with ASTER setting (use 94 char).
- self.character_str = string.printable[:-6]
- dict_character = list(self.character_str)
- elif character_type in support_character_type:
+ self.lower = True
+ else:
self.character_str = ""
- assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
- character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
@@ -125,7 +119,6 @@ class BaseRecLabelEncode(object):
if use_space_char:
self.character_str += " "
dict_character = list(self.character_str)
- self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
@@ -147,7 +140,7 @@ class BaseRecLabelEncode(object):
"""
if len(text) == 0 or len(text) > self.max_text_len:
return None
- if self.character_type == "en":
+ if self.lower:
text = text.lower()
text_list = []
for char in text:
@@ -167,13 +160,11 @@ class NRTRLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='EN_symbol',
use_space_char=False,
**kwargs):
- super(NRTRLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(NRTRLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
@@ -200,12 +191,10 @@ class CTCLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
**kwargs):
- super(CTCLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(CTCLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
text = data['label']
@@ -231,12 +220,10 @@ class E2ELabelEncodeTest(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='EN',
use_space_char=False,
**kwargs):
- super(E2ELabelEncodeTest,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(E2ELabelEncodeTest, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def __call__(self, data):
import json
@@ -305,12 +292,10 @@ class AttnLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
**kwargs):
- super(AttnLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(AttnLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -353,12 +338,10 @@ class SEEDLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
**kwargs):
- super(SEEDLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(SEEDLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
self.end_str = "eos"
@@ -385,12 +368,10 @@ class SRNLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length=25,
character_dict_path=None,
- character_type='en',
use_space_char=False,
**kwargs):
- super(SRNLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(SRNLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
dict_character = dict_character + [self.beg_str, self.end_str]
@@ -598,12 +579,10 @@ class SARLabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
**kwargs):
- super(SARLabelEncode,
- self).__init__(max_text_length, character_dict_path,
- character_type, use_space_char)
+ super(SARLabelEncode, self).__init__(
+ max_text_length, character_dict_path, use_space_char)
def add_special_char(self, dict_character):
beg_end_str = ""
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 71ed8976db7de24a489d1f75612a9a9a67995ba2..b4de6de95b09ced803375d9a3bb857194ef3e64b 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -87,17 +87,17 @@ class RecResizeImg(object):
def __init__(self,
image_shape,
infer_mode=False,
- character_type='ch',
+ character_dict_path='./ppocr/utils/ppocr_keys_v1.txt',
padding=True,
**kwargs):
self.image_shape = image_shape
self.infer_mode = infer_mode
- self.character_type = character_type
+ self.character_dict_path = character_dict_path
self.padding = padding
def __call__(self, data):
img = data['image']
- if self.infer_mode and self.character_type == "ch":
+ if self.infer_mode and self.character_dict_path is not None:
norm_img = resize_norm_img_chinese(img, self.image_shape)
else:
norm_img = resize_norm_img(img, self.image_shape, self.padding)
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index c06159ca55600e7afe01a68ab43acd1919cf742c..ef1a43fd0ee65f3e55a8f72dfd2f96c478da1a9a 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -21,33 +21,15 @@ import re
class BaseRecLabelDecode(object):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False):
- support_character_type = [
- 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
- 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
- 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
- 'ne', 'EN', 'latin', 'arabic', 'cyrillic', 'devanagari'
- ]
- assert character_type in support_character_type, "Only {} are supported now but get {}".format(
- support_character_type, character_type)
-
+ def __init__(self, character_dict_path=None, use_space_char=False):
self.beg_str = "sos"
self.end_str = "eos"
- if character_type == "en":
+ self.character_str = []
+ if character_dict_path is None:
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
- elif character_type == "EN_symbol":
- # same with ASTER setting (use 94 char).
- self.character_str = string.printable[:-6]
- dict_character = list(self.character_str)
- elif character_type in support_character_type:
- self.character_str = []
- assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
- character_type)
+ else:
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
@@ -57,9 +39,6 @@ class BaseRecLabelDecode(object):
self.character_str.append(" ")
dict_character = list(self.character_str)
- else:
- raise NotImplementedError
- self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
@@ -102,13 +81,10 @@ class BaseRecLabelDecode(object):
class CTCLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(CTCLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, tuple):
@@ -136,13 +112,12 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
def __init__(self,
character_dict_path=None,
- character_type='ch',
use_space_char=False,
model_name=["student"],
key=None,
**kwargs):
- super(DistillationCTCLabelDecode, self).__init__(
- character_dict_path, character_type, use_space_char)
+ super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
+ use_space_char)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
@@ -162,13 +137,9 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
class NRTRLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='EN_symbol',
- use_space_char=True,
- **kwargs):
+ def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
super(NRTRLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def __call__(self, preds, label=None, *args, **kwargs):
@@ -230,13 +201,10 @@ class NRTRLabelDecode(BaseRecLabelDecode):
class AttnLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(AttnLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -313,13 +281,10 @@ class AttnLabelDecode(BaseRecLabelDecode):
class SEEDLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SEEDLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
def add_special_char(self, dict_character):
self.beg_str = "sos"
@@ -394,13 +359,10 @@ class SEEDLabelDecode(BaseRecLabelDecode):
class SRNLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='en',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SRNLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
self.max_text_length = kwargs.get('max_text_length', 25)
def __call__(self, preds, label=None, *args, **kwargs):
@@ -616,13 +578,10 @@ class TableLabelDecode(object):
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
- def __init__(self,
- character_dict_path=None,
- character_type='ch',
- use_space_char=False,
+ def __init__(self, character_dict_path=None, use_space_char=False,
**kwargs):
super(SARLabelDecode, self).__init__(character_dict_path,
- character_type, use_space_char)
+ use_space_char)
self.rm_symbol = kwargs.get('rm_symbol', False)
diff --git a/ppocr/utils/EN_symbol_dict.txt b/ppocr/utils/EN_symbol_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1aef43d6b842731a54cbe682ccda5c2dbfa694d9
--- /dev/null
+++ b/ppocr/utils/EN_symbol_dict.txt
@@ -0,0 +1,94 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+a
+b
+c
+d
+e
+f
+g
+h
+i
+j
+k
+l
+m
+n
+o
+p
+q
+r
+s
+t
+u
+v
+w
+x
+y
+z
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
+!
+"
+#
+$
+%
+&
+'
+(
+)
+*
++
+,
+-
+.
+/
+:
+;
+<
+=
+>
+?
+@
+[
+\
+]
+^
+_
+`
+{
+|
+}
+~
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index 311030f65f2dc2dad4a51821e64f2777e7621a0b..6758a59bad20f6ffa271766fc4d0df5ebf4c7a4b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
shapely
-scikit-image==0.17.2
+scikit-image==0.18.3
imgaug==0.4.0
pyclipper
lmdb
diff --git a/tests/docs/test.png b/tests/docs/test.png
deleted file mode 100644
index a27c16ec75fbfb437240b9e05510b6fe74a5766b..0000000000000000000000000000000000000000
Binary files a/tests/docs/test.png and /dev/null differ
diff --git a/tests/docs/test_cpp.md b/tests/docs/test_cpp.md
deleted file mode 100644
index d8380671c01502b18b57523f012c5e096dc70fe0..0000000000000000000000000000000000000000
--- a/tests/docs/test_cpp.md
+++ /dev/null
@@ -1,56 +0,0 @@
-# C++预测功能测试
-
-C++预测功能测试的主程序为`test_cpp.sh`,可以测试基于C++预测库的模型推理功能。
-
-## 测试结论汇总
-
-| 算法名称 | 模型名称 |device | batchsize | mkldnn | cpu多线程 | tensorrt | 离线量化 |
-| ---- | ---- | ---- | ---- | ---- | ---- | ----| --- |
-| DB |ch_ppocr_mobile_v2.0_det| CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-| DB |ch_ppocr_server_v2.0_det| CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-| CRNN |ch_ppocr_mobile_v2.0_rec| CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-| CRNN |ch_ppocr_server_v2.0_rec| CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-|PP-OCR|ch_ppocr_server_v2.0 | CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-|PP-OCR|ch_ppocr_server_v2.0 | CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-
-
-
-## 1. 功能测试
-先运行`prepare.sh`准备数据和模型,然后运行`test_cpp.sh`进行测试,最终在```tests/output```目录下生成`cpp_infer_*.log`后缀的日志文件。
-
-```shell
-bash tests/prepare.sh ./tests/configs/ppocr_det_mobile_params.txt
-
-# 用法1:
-bash tests/test_cpp.sh ./tests/configs/ppocr_det_mobile_params.txt
-# 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号
-bash tests/test_cpp.sh ./tests/configs/ppocr_det_mobile_params.txt '1'
-```
-
-
-## 2. 精度测试
-
-使用compare_results.py脚本比较模型预测的结果是否符合预期,主要步骤包括:
-- 提取日志中的预测坐标;
-- 从本地文件中提取保存好的坐标结果;
-- 比较上述两个结果是否符合精度预期,误差大于设置阈值时会报错。
-
-### 使用方式
-运行命令:
-```shell
-python3.7 tests/compare_results.py --gt_file=./tests/results/*.txt --log_file=./tests/output/infer_*.log --atol=1e-3 --rtol=1e-3
-```
-
-参数介绍:
-- gt_file: 指向事先保存好的预测结果路径,支持*.txt 结尾,会自动索引*.txt格式的文件,文件默认保存在tests/result/ 文件夹下
-- log_file: 指向运行tests/test.sh 脚本的infer模式保存的预测日志,预测日志中打印的有预测结果,比如:文本框,预测文本,类别等等,同样支持infer_*.log格式传入
-- atol: 设置的绝对误差
-- rtol: 设置的相对误差
-
-### 运行结果
-
-正常运行效果如下图:
-
-
-出现不一致结果时的运行输出:
-
diff --git a/tests/docs/test_python.md b/tests/docs/test_python.md
deleted file mode 100644
index 87c58395c6038ff68bd172a469a788d5886adcab..0000000000000000000000000000000000000000
--- a/tests/docs/test_python.md
+++ /dev/null
@@ -1,107 +0,0 @@
-# Python功能测试
-
-Python功能测试的主程序为`test_python.sh`,可以测试基于Python的模型训练、评估、推理等基本功能,包括裁剪、量化、蒸馏。
-
-## 测试结论汇总
-
-- 训练相关:
-
-| 算法名称 | 模型名称 | 单机单卡 | 单机多卡 | 多机多卡 | 模型压缩(单机多卡) |
-| :---- | :---- | :---- | :---- | :---- | :---- |
-| DB | ch_ppocr_mobile_v2.0_det| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化 |
-| DB | ch_ppocr_server_v2.0_det| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化 |
-| CRNN | ch_ppocr_mobile_v2.0_rec| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化 |
-| CRNN | ch_ppocr_server_v2.0_rec| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化 |
-|PP-OCR| ch_ppocr_mobile_v2.0| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化 |
-|PP-OCR| ch_ppocr_server_v2.0| 正常训练
混合精度 | 正常训练
混合精度 | 正常训练
混合精度 | 正常训练:FPGM裁剪、PACT量化 |
-
-
-- 预测相关:
-
-| 算法名称 | 模型名称 |device | batchsize | mkldnn | cpu多线程 | tensorrt | 离线量化 |
-| ---- | ---- | ---- | ---- | ---- | ---- | ----| --- |
-| DB |ch_ppocr_mobile_v2.0_det| CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-| DB |ch_ppocr_server_v2.0_det| CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-| CRNN |ch_ppocr_mobile_v2.0_rec| CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-| CRNN |ch_ppocr_server_v2.0_rec| CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-|PP-OCR|ch_ppocr_server_v2.0 | CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-|PP-OCR|ch_ppocr_server_v2.0 | CPU/GPU | 1/6 | 支持 | 支持 | fp32/fp16/int8 | 支持 |
-
-
-
-## 1. 安装依赖
-- 安装PaddlePaddle >= 2.0
-- 安装PaddleOCR依赖
- ```
- pip3 install -r ../requirements.txt
- ```
-- 安装autolog(规范化日志输出工具)
- ```
- git clone https://github.com/LDOUBLEV/AutoLog
- cd AutoLog
- pip3 install -r requirements.txt
- python3 setup.py bdist_wheel
- pip3 install ./dist/auto_log-1.0.0-py3-none-any.whl
- cd ../
- ```
-
-
-## 2. 功能测试
-先运行`prepare.sh`准备数据和模型,然后运行`test_python.sh`进行测试,最终在```tests/output```目录下生成`infer_*.log`格式的日志文件。
-
-test_python.sh包含四种运行模式,每种模式的运行数据不同,分别用于测试速度和精度,分别是:
-
-- 模式1:lite_train_infer,使用少量数据训练,用于快速验证训练到预测的走通流程,不验证精度和速度;
-```shell
-bash tests/prepare.sh ./tests/configs/ppocr_det_mobile_params.txt 'lite_train_infer'
-bash tests/test_python.sh ./tests/configs/ppocr_det_mobile_params.txt 'lite_train_infer'
-```
-
-- 模式2:whole_infer,使用少量数据训练,一定量数据预测,用于验证训练后的模型执行预测,预测速度是否合理;
-```shell
-bash tests/prepare.sh ./tests/configs/ppocr_det_mobile_params.txt 'whole_infer'
-bash tests/test_python.sh ./tests/configs/ppocr_det_mobile_params.txt 'whole_infer'
-```
-
-- 模式3:infer 不训练,全量数据预测,走通开源模型评估、动转静,检查inference model预测时间和精度;
-```shell
-bash tests/prepare.sh ./tests/configs/ppocr_det_mobile_params.txt 'infer'
-# 用法1:
-bash tests/test_python.sh ./tests/configs/ppocr_det_mobile_params.txt 'infer'
-# 用法2: 指定GPU卡预测,第三个传入参数为GPU卡号
-bash tests/test_python.sh ./tests/configs/ppocr_det_mobile_params.txt 'infer' '1'
-```
-
-- 模式4:whole_train_infer , CE: 全量数据训练,全量数据预测,验证模型训练精度,预测精度,预测速度;
-```shell
-bash tests/prepare.sh ./tests/configs/ppocr_det_mobile_params.txt 'whole_train_infer'
-bash tests/test.sh ./tests/configs/ppocr_det_mobile_params.txt 'whole_train_infer'
-```
-
-
-## 3. 精度测试
-
-使用compare_results.py脚本比较模型预测的结果是否符合预期,主要步骤包括:
-- 提取日志中的预测坐标;
-- 从本地文件中提取保存好的坐标结果;
-- 比较上述两个结果是否符合精度预期,误差大于设置阈值时会报错。
-
-### 使用方式
-运行命令:
-```shell
-python3.7 tests/compare_results.py --gt_file=./tests/results/*.txt --log_file=./tests/output/infer_*.log --atol=1e-3 --rtol=1e-3
-```
-
-参数介绍:
-- gt_file: 指向事先保存好的预测结果路径,支持*.txt 结尾,会自动索引*.txt格式的文件,文件默认保存在tests/result/ 文件夹下
-- log_file: 指向运行tests/test.sh 脚本的infer模式保存的预测日志,预测日志中打印的有预测结果,比如:文本框,预测文本,类别等等,同样支持infer_*.log格式传入
-- atol: 设置的绝对误差
-- rtol: 设置的相对误差
-
-### 运行结果
-
-正常运行效果如下图:
-
-
-出现不一致结果时的运行输出:
-
diff --git a/tests/readme.md b/tests/readme.md
deleted file mode 100644
index b7138a6801c3a589f5ca2ed0e8a6bafb08db3fec..0000000000000000000000000000000000000000
--- a/tests/readme.md
+++ /dev/null
@@ -1,93 +0,0 @@
-
-# 推理部署导航
-
-飞桨除了基本的模型训练和预测,还提供了支持多端多平台的高性能推理部署工具。本文档提供了PaddleOCR中所有模型的推理部署导航,方便用户查阅每种模型的推理部署打通情况,并可以进行一键测试。
-
-
-
-
-
-打通情况汇总如下,已填写的部分表示可以使用本工具进行一键测试,未填写的表示正在支持中。
-
-| 算法论文 | 模型名称 | 模型类型 | python训练预测 | 其他 |
-| :--- | :--- | :---- | :-------- | :---- |
-| DB |ch_ppocr_mobile_v2.0_det | 检测 | 支持 | Paddle Inference: C++预测
Paddle Serving: Python, C++
Paddle-Lite: Python, C++ / ARM CPU |
-| DB |ch_ppocr_server_v2.0_det | 检测 | 支持 | Paddle Inference: C++预测
Paddle Serving: Python, C++
Paddle-Lite: Python, C++ / ARM CPU |
-| DB |ch_PP-OCRv2_det | 检测 |
-| CRNN |ch_ppocr_mobile_v2.0_rec | 识别 | 支持 | Paddle Inference: C++预测
Paddle Serving: Python, C++
Paddle-Lite: Python, C++ / ARM CPU |
-| CRNN |ch_ppocr_server_v2.0_rec | 识别 | 支持 | Paddle Inference: C++预测
Paddle Serving: Python, C++
Paddle-Lite: Python, C++ / ARM CPU |
-| CRNN |ch_PP-OCRv2_rec | 识别 |
-| DB |det_mv3_db_v2.0 | 检测 |
-| DB |det_r50_vd_db_v2.0 | 检测 |
-| EAST |det_mv3_east_v2.0 | 检测 |
-| EAST |det_r50_vd_east_v2.0 | 检测 |
-| PSENet |det_mv3_pse_v2.0 | 检测 |
-| PSENet |det_r50_vd_pse_v2.0 | 检测 |
-| SAST |det_r50_vd_sast_totaltext_v2.0 | 检测 |
-| Rosetta|rec_mv3_none_none_ctc_v2.0 | 识别 |
-| Rosetta|rec_r34_vd_none_none_ctc_v2.0 | 识别 |
-| CRNN |rec_mv3_none_bilstm_ctc_v2.0 | 识别 |
-| CRNN |rec_r34_vd_none_bilstm_ctc_v2.0| 识别 |
-| StarNet|rec_mv3_tps_bilstm_ctc_v2.0 | 识别 |
-| StarNet|rec_r34_vd_tps_bilstm_ctc_v2.0 | 识别 |
-| RARE |rec_mv3_tps_bilstm_att_v2.0 | 识别 |
-| RARE |rec_r34_vd_tps_bilstm_att_v2.0 | 识别 |
-| SRN |rec_r50fpn_vd_none_srn | 识别 |
-| NRTR |rec_mtb_nrtr | 识别 |
-| SAR |rec_r31_sar | 识别 |
-| PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端|
-
-
-
-## 一键测试工具使用
-### 目录介绍
-
-```shell
-tests/
-├── configs/ # 配置文件目录
- ├── det_mv3_db.yml # 测试mobile版ppocr检测模型训练的yml文件
- ├── det_r50_vd_db.yml # 测试server版ppocr检测模型训练的yml文件
- ├── rec_icdar15_r34_train.yml # 测试server版ppocr识别模型训练的yml文件
- ├── ppocr_sys_mobile_params.txt # 测试mobile版ppocr检测+识别模型串联的参数配置文件
- ├── ppocr_det_mobile_params.txt # 测试mobile版ppocr检测模型的参数配置文件
- ├── ppocr_rec_mobile_params.txt # 测试mobile版ppocr识别模型的参数配置文件
- ├── ppocr_sys_server_params.txt # 测试server版ppocr检测+识别模型串联的参数配置文件
- ├── ppocr_det_server_params.txt # 测试server版ppocr检测模型的参数配置文件
- ├── ppocr_rec_server_params.txt # 测试server版ppocr识别模型的参数配置文件
- ├── ...
-├── results/ # 预先保存的预测结果,用于和实际预测结果进行精读比对
- ├── ppocr_det_mobile_results_fp32.txt # 预存的mobile版ppocr检测模型fp32精度的结果
- ├── ppocr_det_mobile_results_fp16.txt # 预存的mobile版ppocr检测模型fp16精度的结果
- ├── ppocr_det_mobile_results_fp32_cpp.txt # 预存的mobile版ppocr检测模型c++预测的fp32精度的结果
- ├── ppocr_det_mobile_results_fp16_cpp.txt # 预存的mobile版ppocr检测模型c++预测的fp16精度的结果
- ├── ...
-├── prepare.sh # 完成test_*.sh运行所需要的数据和模型下载
-├── test_python.sh # 测试python训练预测的主程序
-├── test_cpp.sh # 测试c++预测的主程序
-├── test_serving.sh # 测试serving部署预测的主程序
-├── test_lite.sh # 测试lite部署预测的主程序
-├── compare_results.py # 用于对比log中的预测结果与results中的预存结果精度误差是否在限定范围内
-└── readme.md # 使用文档
-```
-
-### 测试流程
-使用本工具,可以测试不同功能的支持情况,以及预测结果是否对齐,测试流程如下:
-
-
-
-
-1. 运行prepare.sh准备测试所需数据和模型;
-2. 运行要测试的功能对应的测试脚本`test_*.sh`,产出log,由log可以看到不同配置是否运行成功;
-3. 用`compare_results.py`对比log中的预测结果和预存在results目录下的结果,判断预测精度是否符合预期(在误差范围内)。
-
-其中,有4个测试主程序,功能如下:
-- `test_python.sh`:测试基于Python的模型训练、评估、推理等基本功能,包括裁剪、量化、蒸馏。
-- `test_cpp.sh`:测试基于C++的模型推理。
-- `test_serving.sh`:测试基于Paddle Serving的服务化部署功能。
-- `test_lite.sh`:测试基于Paddle-Lite的端侧预测部署功能。
-
-各功能测试中涉及GPU/CPU、mkldnn、Tensorrt等多种参数配置,点击相应链接了解更多细节和使用教程:
-[test_python使用](docs/test_python.md)
-[test_cpp使用](docs/test_cpp.md)
-[test_serving使用](docs/test_serving.md)
-[test_lite使用](docs/test_lite.md)
diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py
index 53e50bd6d1d1a2bd07b9f1204b9f56594c669d13..1c68494861e60b4aaef541a4e247071944cf420c 100755
--- a/tools/infer/predict_cls.py
+++ b/tools/infer/predict_cls.py
@@ -131,14 +131,9 @@ def main(args):
img_list.append(img)
try:
img_list, cls_res, predict_time = text_classifier(img_list)
- except:
+ except Exception as E:
logger.info(traceback.format_exc())
- logger.info(
- "ERROR!!!! \n"
- "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
- "If your model has tps module: "
- "TPS does not support variable shape.\n"
- "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
+ logger.info(E)
exit()
for ino in range(len(img_list)):
logger.info("Predicts of {}:{}".format(valid_image_file_list[ino],
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index dad70281ef7604f110d29963103068bba1c8fd9d..936994a215d10d543537b29cb41bfa42b42590c7 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -38,40 +38,34 @@ logger = get_logger()
class TextRecognizer(object):
def __init__(self, args):
self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
- self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
postprocess_params = {
'name': 'CTCLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
if self.rec_algorithm == "SRN":
postprocess_params = {
'name': 'SRNLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "RARE":
postprocess_params = {
'name': 'AttnLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == 'NRTR':
postprocess_params = {
'name': 'NRTRLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
elif self.rec_algorithm == "SAR":
postprocess_params = {
'name': 'SARLabelDecode',
- "character_type": args.rec_char_type,
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 538f55c42b223f9741c5c7006dd7d1478ce1920b..41a3c0f14b6378751a367a3709ad7943ee981a4e 100755
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -74,7 +74,6 @@ def init_args():
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
- parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument(
@@ -268,10 +267,11 @@ def create_predictor(args, mode, logger):
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
-
+ if args.precision == "fp16":
+ config.enable_mkldnn_bfloat16()
# enable memory optim
config.enable_memory_optim()
- #config.disable_glog_info()
+ config.disable_glog_info()
config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
if mode == 'table':
diff --git a/tools/program.py b/tools/program.py
index 798e6dff297ad1149942488cca1d5540f1924867..6456aad5dcda764816e5af7b5becf30cc7192af4 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -159,7 +159,8 @@ def train(config,
eval_class,
pre_best_model_dict,
logger,
- vdl_writer=None):
+ vdl_writer=None,
+ scaler=None):
cal_metric_during_train = config['Global'].get('cal_metric_during_train',
False)
log_smooth_window = config['Global']['log_smooth_window']
@@ -226,14 +227,29 @@ def train(config,
images = batch[0]
if use_srn:
model_average = True
- if model_type == 'table' or extra_input:
- preds = model(images, data=batch[1:])
+
+ # use amp
+ if scaler:
+ with paddle.amp.auto_cast():
+ if model_type == 'table' or extra_input:
+ preds = model(images, data=batch[1:])
+ else:
+ preds = model(images)
else:
- preds = model(images)
+ if model_type == 'table' or extra_input:
+ preds = model(images, data=batch[1:])
+ else:
+ preds = model(images)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
- avg_loss.backward()
- optimizer.step()
+
+ if scaler:
+ scaled_avg_loss = scaler.scale(avg_loss)
+ scaled_avg_loss.backward()
+ scaler.minimize(optimizer, scaled_avg_loss)
+ else:
+ avg_loss.backward()
+ optimizer.step()
optimizer.clear_grad()
train_batch_cost += time.time() - batch_start
diff --git a/tools/train.py b/tools/train.py
index 05d295aa99718c25b94a123c23d08c2904fe8c6a..d182af2988cb29511be40a079d2b3e06605ebe28 100755
--- a/tools/train.py
+++ b/tools/train.py
@@ -102,10 +102,27 @@ def main(config, device, logger, vdl_writer):
if valid_dataloader is not None:
logger.info('valid dataloader has {} iters'.format(
len(valid_dataloader)))
+
+ use_amp = config["Global"].get("use_amp", False)
+ if use_amp:
+ AMP_RELATED_FLAGS_SETTING = {
+ 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
+ 'FLAGS_max_inplace_grad_add': 8,
+ }
+ paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
+ scale_loss = config["Global"].get("scale_loss", 1.0)
+ use_dynamic_loss_scaling = config["Global"].get(
+ "use_dynamic_loss_scaling", False)
+ scaler = paddle.amp.GradScaler(
+ init_loss_scaling=scale_loss,
+ use_dynamic_loss_scaling=use_dynamic_loss_scaling)
+ else:
+ scaler = None
+
# start train
program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class,
- eval_class, pre_best_model_dict, logger, vdl_writer)
+ eval_class, pre_best_model_dict, logger, vdl_writer, scaler)
def test_reader(config, device, logger):