提交 a28ef7f0 编写于 作者: 张欣-男's avatar 张欣-男

Merge remote-tracking branch 'upstream/develop' into zxdev

...@@ -246,7 +246,8 @@ PaddleOCR文本识别算法的训练和使用请参考文档教程中[文本识 ...@@ -246,7 +246,8 @@ PaddleOCR文本识别算法的训练和使用请参考文档教程中[文本识
## 许可证书 ## 许可证书
本项目的发布受<a href="https://github.com/PaddlePaddle/PaddleOCR/blob/master/LICENSE">Apache 2.0 license</a>许可认证。 本项目的发布受<a href="https://github.com/PaddlePaddle/PaddleOCR/blob/master/LICENSE">Apache 2.0 license</a>许可认证。
## 如何贡献代码 ## 贡献代码
我们非常欢迎你为PaddleOCR贡献代码,也十分感谢你的反馈。 我们非常欢迎你为PaddleOCR贡献代码,也十分感谢你的反馈。
- 非常感谢 [Khanh Tran](https://github.com/xxxpsyduck) 贡献了英文文档。 - 非常感谢 [Khanh Tran](https://github.com/xxxpsyduck) 贡献了英文文档。
- 非常感谢 [zhangxin](https://github.com/ZhangXinNan)([Blog](https://blog.csdn.net/sdlypyzq)) 贡献新的可视化方式、添加.gitgnore、处理手动设置PYTHONPATH环境变量的问题
...@@ -255,3 +255,4 @@ This project is released under <a href="https://github.com/PaddlePaddle/PaddleOC ...@@ -255,3 +255,4 @@ This project is released under <a href="https://github.com/PaddlePaddle/PaddleOC
We welcome all the contributions to PaddleOCR and appreciate for your feedback very much. We welcome all the contributions to PaddleOCR and appreciate for your feedback very much.
- Many thanks to [Khanh Tran](https://github.com/xxxpsyduck) for contributing the English documentation. - Many thanks to [Khanh Tran](https://github.com/xxxpsyduck) for contributing the English documentation.
- Many thanks to [zhangxin](https://github.com/ZhangXinNan) for contributing the new visualize function、add .gitgnore and discard set PYTHONPATH manually.
...@@ -23,7 +23,7 @@ wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_la ...@@ -23,7 +23,7 @@ wget -P ./train_data/ https://paddleocr.bj.bcebos.com/dataset/test_icdar2015_la
└─ test_icdar2015_label.txt icdar数据集的测试标注 └─ test_icdar2015_label.txt icdar数据集的测试标注
``` ```
提供的标注文件格式为: 提供的标注文件格式为,其中中间是"\t"分隔
``` ```
" 图像文件名 json.dumps编码的图像标注信息" " 图像文件名 json.dumps编码的图像标注信息"
ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}] ch4_test_images/img_61.jpg [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]], ...}]
...@@ -35,7 +35,7 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中 ...@@ -35,7 +35,7 @@ json.dumps编码前的图像标注信息是包含多个字典的list,字典中
## 快速启动训练 ## 快速启动训练
首先下载pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet50_vd, 首先下载模型backbone的pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet50_vd,
您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。 您可以根据需求使用[PaddleClas](https://github.com/PaddlePaddle/PaddleClas/tree/master/ppcls/modeling/architectures)中的模型更换backbone。
``` ```
cd PaddleOCR/ cd PaddleOCR/
...@@ -62,7 +62,7 @@ tar xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_models ...@@ -62,7 +62,7 @@ tar xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_models
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false* *如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
``` ```
python3 tools/train.py -c configs/det/det_mv3_db.yml python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
``` ```
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。 上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
...@@ -73,6 +73,15 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml ...@@ -73,6 +73,15 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001 python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
``` ```
**断点训练**
如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:
```
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./your/trained/model
```
**注意**:Global.checkpoints的优先级高于Global.pretrain_weights的优先级,即同时指定两个参数时,优先加载Global.checkpoints指定的模型,如果Global.checkpoints指定的模型路径有误,会加载Global.pretrain_weights指定的模型。
## 指标评估 ## 指标评估
PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean。 PaddleOCR计算三个OCR检测相关的指标,分别是:Precision、Recall、Hmean。
......
...@@ -192,6 +192,13 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" ...@@ -192,6 +192,13 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
``` ```
### 4.自定义文本识别字典的推理
如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path"
```
## 文本检测、识别串联推理 ## 文本检测、识别串联推理
### 1.超轻量中文OCR模型推理 ### 1.超轻量中文OCR模型推理
......
...@@ -69,6 +69,17 @@ You can also use the `-o` parameter to change the training parameters without mo ...@@ -69,6 +69,17 @@ You can also use the `-o` parameter to change the training parameters without mo
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001 python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
``` ```
**load trained model and conntinue training**
If you expect to load trained model and continue the training again, you can specify the `Global.checkpoints` parameter as the model path to be loaded.
For example:
```
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./your/trained/model
```
**Note**:The priority of Global.checkpoints is higher than the priority of Global.pretrain_weights, that is, when two parameters are specified at the same time, the model specified by Global.checkpoints will be loaded first. If the model path specified by Global.checkpoints is wrong, the one specified by Global.pretrain_weights will be loaded.
## Evaluation Indicator ## Evaluation Indicator
PaddleOCR calculates three indicators for evaluating performance of OCR detection task: Precision, Recall, and Hmean. PaddleOCR calculates three indicators for evaluating performance of OCR detection task: Precision, Recall, and Hmean.
......
...@@ -182,6 +182,13 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" ...@@ -182,6 +182,13 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str) dict_character = list(self.character_str)
``` ```
### 4.Recognition model inference using custom text dictionary file
If the text dictionary is replaced during training, you need to specify the text dictionary path by setting the parameter `rec_char_dict_path` when using your inference model to predict.
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./your inference model" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_char_dict_path="your text dict path"
```
## Text detection and recognition inference concatenation ## Text detection and recognition inference concatenation
### 1. Ultra-lightweight Chinese OCR model inference ### 1. Ultra-lightweight Chinese OCR model inference
......
...@@ -95,8 +95,10 @@ class EvalTestReader(object): ...@@ -95,8 +95,10 @@ class EvalTestReader(object):
for img_path in img_list: for img_path in img_list:
img = cv2.imread(img_path) img = cv2.imread(img_path)
if img is None: if img is None:
logger.info("load image error:" + img_path) logger.info("{} does not exist!".format(img_path))
continue continue
elif len(list(img.shape)) == 2 or img.shape[2] == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
outs = process_function(img) outs = process_function(img)
outs.append(img_path) outs.append(img_path)
batch_outs.append(outs) batch_outs.append(outs)
......
...@@ -104,6 +104,8 @@ class DBProcessTrain(object): ...@@ -104,6 +104,8 @@ class DBProcessTrain(object):
if imgvalue is None: if imgvalue is None:
logger.info("{} does not exist!".format(img_path)) logger.info("{} does not exist!".format(img_path))
return None return None
if len(list(imgvalue.shape)) == 2 or imgvalue.shape[2] == 1:
imgvalue = cv2.cvtColor(imgvalue, cv2.COLOR_GRAY2BGR)
data = self.make_data_dict(imgvalue, gt_label) data = self.make_data_dict(imgvalue, gt_label)
data = AugmentData(data) data = AugmentData(data)
data = RandomCropData(data, self.image_shape[1:]) data = RandomCropData(data, self.image_shape[1:])
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
import math import math
import cv2 import cv2
import numpy as np import numpy as np
from ppocr.utils.utility import initial_logger
logger = initial_logger()
def get_bounding_box_rect(pos): def get_bounding_box_rect(pos):
...@@ -101,9 +103,13 @@ def process_image(img, ...@@ -101,9 +103,13 @@ def process_image(img,
norm_img = resize_norm_img_chinese(img, image_shape) norm_img = resize_norm_img_chinese(img, image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
if label is not None: if label is not None:
char_num = char_ops.get_char_num() # char_num = char_ops.get_char_num()
text = char_ops.encode(label) text = char_ops.encode(label)
if len(text) == 0 or len(text) > max_text_length: if len(text) == 0 or len(text) > max_text_length:
logger.info(
"Warning in ppocr/data/rec/img_tools.py:line106: Wrong data type."
"Excepted string with length between 0 and {}, but "
"got '{}' ".format(max_text_length, label))
return None return None
else: else:
if loss_type == "ctc": if loss_type == "ctc":
......
...@@ -114,7 +114,10 @@ def merge_config(config): ...@@ -114,7 +114,10 @@ def merge_config(config):
global_config[key] = value global_config[key] = value
else: else:
sub_keys = key.split('.') sub_keys = key.split('.')
assert (sub_keys[0] in global_config), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(global_config.keys(), sub_keys[0]) assert (
sub_keys[0] in global_config
), "the sub_keys can only be one of global_config: {}, but get: {}, please check your running command".format(
global_config.keys(), sub_keys[0])
cur = global_config[sub_keys[0]] cur = global_config[sub_keys[0]]
for idx, sub_key in enumerate(sub_keys[1:]): for idx, sub_key in enumerate(sub_keys[1:]):
assert (sub_key in cur) assert (sub_key in cur)
...@@ -177,7 +180,6 @@ def build(config, main_prog, startup_prog, mode): ...@@ -177,7 +180,6 @@ def build(config, main_prog, startup_prog, mode):
optimizer.minimize(opt_loss) optimizer.minimize(opt_loss)
opt_loss_name = opt_loss.name opt_loss_name = opt_loss.name
global_lr = optimizer._global_learning_rate() global_lr = optimizer._global_learning_rate()
global_lr.persistable = True
fetch_name_list.insert(0, "lr") fetch_name_list.insert(0, "lr")
fetch_varname_list.insert(0, global_lr.name) fetch_varname_list.insert(0, global_lr.name)
return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name) return (dataloader, fetch_name_list, fetch_varname_list, opt_loss_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册