未验证 提交 4cac91eb 编写于 作者: D dyning 提交者: GitHub

Merge pull request #132 from tink2123/add_rec_score

Add rec score
...@@ -36,6 +36,9 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力 ...@@ -36,6 +36,9 @@ PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力
#### 2.inference模型下载 #### 2.inference模型下载
*windows 环境下如果没有安装wget,下载模型时可将链接复制到浏览器中下载,并解压放置在相应目录下*
#### (1)超轻量级中文OCR模型下载 #### (1)超轻量级中文OCR模型下载
``` ```
mkdir inference && cd inference mkdir inference && cd inference
...@@ -63,6 +66,9 @@ cd .. ...@@ -63,6 +66,9 @@ cd ..
# 设置PYTHONPATH环境变量 # 设置PYTHONPATH环境变量
export PYTHONPATH=. export PYTHONPATH=.
# windows下设置环境变量
SET PYTHONPATH=.
# 预测image_dir指定的单张图像 # 预测image_dir指定的单张图像
python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/ch_det_mv3_db/" --rec_model_dir="./inference/ch_rec_mv3_crnn/" python3 tools/infer/predict_system.py --image_dir="./doc/imgs/11.jpg" --det_model_dir="./inference/ch_det_mv3_db/" --rec_model_dir="./inference/ch_rec_mv3_crnn/"
......
...@@ -10,4 +10,3 @@ EvalReader: ...@@ -10,4 +10,3 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,LMDBReader reader_function: ppocr.data.rec.dataset_traversal,LMDBReader
lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/ lmdb_sets_dir: ./train_data/data_lmdb_release/evaluation/
infer_img: ./infer_img
...@@ -15,9 +15,11 @@ Global: ...@@ -15,9 +15,11 @@ Global:
character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt character_dict_path: ./ppocr/utils/ppocr_keys_v1.txt
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_chinese_reader.yml reader_yml: ./configs/rec/rec_chinese_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -11,4 +11,3 @@ EvalReader: ...@@ -11,4 +11,3 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img
...@@ -11,4 +11,3 @@ EvalReader: ...@@ -11,4 +11,3 @@ EvalReader:
TestReader: TestReader:
reader_function: ppocr.data.rec.dataset_traversal,SimpleReader reader_function: ppocr.data.rec.dataset_traversal,SimpleReader
infer_img: ./infer_img
...@@ -14,9 +14,11 @@ Global: ...@@ -14,9 +14,11 @@ Global:
character_type: en character_type: en
loss_type: ctc loss_type: ctc
reader_yml: ./configs/rec/rec_icdar15_reader.yml reader_yml: ./configs/rec/rec_icdar15_reader.yml
pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy pretrain_weights: ./pretrain_models/rec_mv3_none_bilstm_ctc/best_accuracy
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -17,6 +17,7 @@ Global: ...@@ -17,6 +17,7 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -17,6 +17,7 @@ Global: ...@@ -17,6 +17,7 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -13,11 +13,14 @@ Global: ...@@ -13,11 +13,14 @@ Global:
max_text_length: 25 max_text_length: 25
character_type: en character_type: en
loss_type: attention loss_type: attention
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -13,10 +13,12 @@ Global: ...@@ -13,10 +13,12 @@ Global:
max_text_length: 25 max_text_length: 25
character_type: en character_type: en
loss_type: ctc loss_type: ctc
tps: true
reader_yml: ./configs/rec/rec_benchmark_reader.yml reader_yml: ./configs/rec/rec_benchmark_reader.yml
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
......
...@@ -17,7 +17,9 @@ Global: ...@@ -17,7 +17,9 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -17,6 +17,7 @@ Global: ...@@ -17,6 +17,7 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -17,6 +17,8 @@ Global: ...@@ -17,6 +17,8 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -17,6 +17,8 @@ Global: ...@@ -17,6 +17,8 @@ Global:
pretrain_weights: pretrain_weights:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
infer_img:
Architecture: Architecture:
function: ppocr.modeling.architectures.rec_model,RecModel function: ppocr.modeling.architectures.rec_model,RecModel
......
...@@ -46,6 +46,9 @@ wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/Res ...@@ -46,6 +46,9 @@ wget -P ./pretrain_models/ https://paddle-imagenet-models-name.bj.bcebos.com/Res
``` ```
**启动训练** **启动训练**
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
``` ```
python3 tools/train.py -c configs/det/det_mv3_db.yml python3 tools/train.py -c configs/det/det_mv3_db.yml
``` ```
......
...@@ -165,6 +165,16 @@ STAR-Net文本识别模型推理,可以执行如下命令: ...@@ -165,6 +165,16 @@ STAR-Net文本识别模型推理,可以执行如下命令:
``` ```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en" python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/starnet/" --rec_image_shape="3, 32, 100" --rec_char_type="en"
``` ```
### 3.基于Attention损失的识别模型推理
基于Attention损失的识别模型与ctc不同,需要额外设置识别算法参数 --rec_algorithm="RARE"
RARE 文本识别模型推理,可以执行如下命令:
```
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" --rec_model_dir="./inference/sare/" --rec_image_shape="3, 32, 100" --rec_char_type="en" --rec_algorithm="RARE"
```
![](imgs_words_en/word_336.png) ![](imgs_words_en/word_336.png)
执行命令后,上面图像的识别结果如下: 执行命令后,上面图像的识别结果如下:
......
...@@ -8,6 +8,8 @@ PaddleOCR 工作环境 ...@@ -8,6 +8,8 @@ PaddleOCR 工作环境
建议使用我们提供的docker运行PaddleOCR,有关docker使用请参考[链接](https://docs.docker.com/get-started/) 建议使用我们提供的docker运行PaddleOCR,有关docker使用请参考[链接](https://docs.docker.com/get-started/)
*如您希望使用 mac 或 windows直接运行预测代码,可以从第2步开始执行。*
1. (建议)准备docker环境。第一次使用这个镜像,会自动下载该镜像,请耐心等待。 1. (建议)准备docker环境。第一次使用这个镜像,会自动下载该镜像,请耐心等待。
``` ```
# 切换到工作目录下 # 切换到工作目录下
...@@ -54,6 +56,10 @@ python3 -m pip install paddlepaddle-gpu==1.7.2.post97 -i https://pypi.tuna.tsing ...@@ -54,6 +56,10 @@ python3 -m pip install paddlepaddle-gpu==1.7.2.post97 -i https://pypi.tuna.tsing
如果您的机器安装的是CUDA10,请运行以下命令安装 如果您的机器安装的是CUDA10,请运行以下命令安装
python3 -m pip install paddlepaddle-gpu==1.7.2.post107 -i https://pypi.tuna.tsinghua.edu.cn/simple python3 -m pip install paddlepaddle-gpu==1.7.2.post107 -i https://pypi.tuna.tsinghua.edu.cn/simple
如果您的机器是CPU,请运行以下命令安装
python3 -m pip install paddlepaddle==1.7.2 -i https://pypi.tuna.tsinghua.edu.cn/simple
更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。 更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
``` ```
......
...@@ -41,6 +41,8 @@ PaddleOCR 提供了一份用于训练 icdar2015 数据集的标签文件,通 ...@@ -41,6 +41,8 @@ PaddleOCR 提供了一份用于训练 icdar2015 数据集的标签文件,通
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_train.txt wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_train.txt
# 测试集标签 # 测试集标签
wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_test.txt wget -P ./train_data/ic15_data https://paddleocr.bj.bcebos.com/dataset/rec_gt_test.txt
``` ```
最终训练集应有如下文件结构: 最终训练集应有如下文件结构:
...@@ -111,6 +113,8 @@ tar -xf rec_mv3_none_bilstm_ctc.tar && rm -rf rec_mv3_none_bilstm_ctc.tar ...@@ -111,6 +113,8 @@ tar -xf rec_mv3_none_bilstm_ctc.tar && rm -rf rec_mv3_none_bilstm_ctc.tar
开始训练: 开始训练:
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
``` ```
# 设置PYTHONPATH路径 # 设置PYTHONPATH路径
export PYTHONPATH=$PYTHONPATH:. export PYTHONPATH=$PYTHONPATH:.
...@@ -168,10 +172,11 @@ Global: ...@@ -168,10 +172,11 @@ Global:
评估数据集可以通过 `configs/rec/rec_icdar15_reader.yml` 修改EvalReader中的 `label_file_path` 设置。 评估数据集可以通过 `configs/rec/rec_icdar15_reader.yml` 修改EvalReader中的 `label_file_path` 设置。
*注意* 评估时必须确保配置文件中 infer_img 字段为空
``` ```
export CUDA_VISIBLE_DEVICES=0 export CUDA_VISIBLE_DEVICES=0
# GPU 评估, Global.checkpoints 为待测权重 # GPU 评估, Global.checkpoints 为待测权重
python3 tools/eval.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy python3 tools/eval.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy
``` ```
### 预测 ### 预测
...@@ -184,7 +189,7 @@ python3 tools/eval.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkp ...@@ -184,7 +189,7 @@ python3 tools/eval.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkp
``` ```
# 预测英文结果 # 预测英文结果
python3 tools/infer_rec.py -c configs/rec/rec_chinese_lite_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy TestReader.infer_img=doc/imgs_words/en/word_1.jpg python3 tools/infer_rec.py -c configs/rec/rec_icdar15_train.yml -o Global.checkpoints={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words/en/word_1.png
``` ```
预测图片: 预测图片:
......
...@@ -61,8 +61,6 @@ class TrainReader(object): ...@@ -61,8 +61,6 @@ class TrainReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if len(batch_outs) != 0:
yield batch_outs
return batch_iter_reader return batch_iter_reader
......
...@@ -17,6 +17,8 @@ import cv2 ...@@ -17,6 +17,8 @@ import cv2
import numpy as np import numpy as np
import json import json
import sys import sys
from ppocr.utils.utility import initial_logger
logger = initial_logger()
from .data_augment import AugmentData from .data_augment import AugmentData
from .random_crop_data import RandomCropData from .random_crop_data import RandomCropData
...@@ -100,6 +102,7 @@ class DBProcessTrain(object): ...@@ -100,6 +102,7 @@ class DBProcessTrain(object):
img_path, gt_label = self.convert_label_infor(label_infor) img_path, gt_label = self.convert_label_infor(label_infor)
imgvalue = cv2.imread(img_path) imgvalue = cv2.imread(img_path)
if imgvalue is None: if imgvalue is None:
logger.info("{} does not exist!".format(img_path))
return None return None
data = self.make_data_dict(imgvalue, gt_label) data = self.make_data_dict(imgvalue, gt_label)
data = AugmentData(data) data = AugmentData(data)
......
...@@ -41,13 +41,18 @@ class LMDBReader(object): ...@@ -41,13 +41,18 @@ class LMDBReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
self.drop_last = False
self.use_tps = False
if "tps" in params:
self.ues_tps = True
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == "eval": self.drop_last = True
else:
self.batch_size = params['test_batch_size_per_card'] self.batch_size = params['test_batch_size_per_card']
elif params['mode'] == "test": self.drop_last = False
self.batch_size = 1 self.infer_img = params['infer_img']
self.infer_img = params["infer_img"]
def load_hierarchical_lmdb_dataset(self): def load_hierarchical_lmdb_dataset(self):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
...@@ -100,13 +105,18 @@ class LMDBReader(object): ...@@ -100,13 +105,18 @@ class LMDBReader(object):
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img) image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img yield norm_img
else: else:
lmdb_sets = self.load_hierarchical_lmdb_dataset() lmdb_sets = self.load_hierarchical_lmdb_dataset()
...@@ -126,9 +136,13 @@ class LMDBReader(object): ...@@ -126,9 +136,13 @@ class LMDBReader(object):
if sample_info is None: if sample_info is None:
continue continue
img, label = sample_info img, label = sample_info
outs = process_image(img, self.image_shape, label, outs = process_image(
self.char_ops, self.loss_type, img=img,
self.max_text_length) image_shape=self.image_shape,
label=label,
char_ops=self.char_ops,
loss_type=self.loss_type,
max_text_length=self.max_text_length)
if outs is None: if outs is None:
continue continue
yield outs yield outs
...@@ -136,6 +150,7 @@ class LMDBReader(object): ...@@ -136,6 +150,7 @@ class LMDBReader(object):
if finish_read_num == len(lmdb_sets): if finish_read_num == len(lmdb_sets):
break break
self.close_lmdb_dataset(lmdb_sets) self.close_lmdb_dataset(lmdb_sets)
def batch_iter_reader(): def batch_iter_reader():
batch_outs = [] batch_outs = []
for outs in sample_iter_reader(): for outs in sample_iter_reader():
...@@ -143,10 +158,11 @@ class LMDBReader(object): ...@@ -143,10 +158,11 @@ class LMDBReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if len(batch_outs) != 0: if not self.drop_last:
yield batch_outs if len(batch_outs) != 0:
yield batch_outs
if self.mode != 'test': if self.infer_img is None:
return batch_iter_reader return batch_iter_reader
return sample_iter_reader return sample_iter_reader
...@@ -165,26 +181,34 @@ class SimpleReader(object): ...@@ -165,26 +181,34 @@ class SimpleReader(object):
self.loss_type = params['loss_type'] self.loss_type = params['loss_type']
self.max_text_length = params['max_text_length'] self.max_text_length = params['max_text_length']
self.mode = params['mode'] self.mode = params['mode']
self.infer_img = params['infer_img']
self.use_tps = False
if "tps" in params:
self.ues_tps = True
if params['mode'] == 'train': if params['mode'] == 'train':
self.batch_size = params['train_batch_size_per_card'] self.batch_size = params['train_batch_size_per_card']
elif params['mode'] == 'eval': self.drop_last = True
self.batch_size = params['test_batch_size_per_card']
else: else:
self.batch_size = 1 self.batch_size = params['test_batch_size_per_card']
self.infer_img = params['infer_img'] self.drop_last = False
def __call__(self, process_id): def __call__(self, process_id):
if self.mode != 'train': if self.mode != 'train':
process_id = 0 process_id = 0
def sample_iter_reader(): def sample_iter_reader():
if self.mode == 'test': if self.mode != 'train' and self.infer_img is not None:
image_file_list = get_image_file_list(self.infer_img) image_file_list = get_image_file_list(self.infer_img)
for single_img in image_file_list: for single_img in image_file_list:
img = cv2.imread(single_img) img = cv2.imread(single_img)
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
norm_img = process_image(img, self.image_shape) norm_img = process_image(
img=img,
image_shape=self.image_shape,
char_ops=self.char_ops,
tps=self.use_tps,
infer_mode=True)
yield norm_img yield norm_img
else: else:
with open(self.label_file_path, "rb") as fin: with open(self.label_file_path, "rb") as fin:
...@@ -192,7 +216,7 @@ class SimpleReader(object): ...@@ -192,7 +216,7 @@ class SimpleReader(object):
img_num = len(label_infor_list) img_num = len(label_infor_list)
img_id_list = list(range(img_num)) img_id_list = list(range(img_num))
random.shuffle(img_id_list) random.shuffle(img_id_list)
if sys.platform=="win32": if sys.platform == "win32":
print("multiprocess is not fully compatible with Windows." print("multiprocess is not fully compatible with Windows."
"num_workers will be 1.") "num_workers will be 1.")
self.num_workers = 1 self.num_workers = 1
...@@ -204,7 +228,7 @@ class SimpleReader(object): ...@@ -204,7 +228,7 @@ class SimpleReader(object):
if img is None: if img is None:
logger.info("{} does not exist!".format(img_path)) logger.info("{} does not exist!".format(img_path))
continue continue
if img.shape[-1]==1 or len(list(img.shape))==2: if img.shape[-1] == 1 or len(list(img.shape)) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
label = substr[1] label = substr[1]
...@@ -222,9 +246,10 @@ class SimpleReader(object): ...@@ -222,9 +246,10 @@ class SimpleReader(object):
if len(batch_outs) == self.batch_size: if len(batch_outs) == self.batch_size:
yield batch_outs yield batch_outs
batch_outs = [] batch_outs = []
if len(batch_outs) != 0: if not self.drop_last:
yield batch_outs if len(batch_outs) != 0:
yield batch_outs
if self.mode != 'test': if self.infer_img is None:
return batch_iter_reader return batch_iter_reader
return sample_iter_reader return sample_iter_reader
...@@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape): ...@@ -48,6 +48,32 @@ def resize_norm_img(img, image_shape):
return padding_im return padding_im
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
max_wh_ratio = 0
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
imgW = int(32 * max_wh_ratio)
if math.ceil(imgH * ratio) > imgW:
resized_w = imgW
else:
resized_w = int(math.ceil(imgH * ratio))
resized_image = cv2.resize(img, (resized_w, imgH))
resized_image = resized_image.astype('float32')
if image_shape[0] == 1:
resized_image = resized_image / 255
resized_image = resized_image[np.newaxis, :]
else:
resized_image = resized_image.transpose((2, 0, 1)) / 255
resized_image -= 0.5
resized_image /= 0.5
padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
def get_img_data(value): def get_img_data(value):
"""get_img_data""" """get_img_data"""
if not value: if not value:
...@@ -66,8 +92,13 @@ def process_image(img, ...@@ -66,8 +92,13 @@ def process_image(img,
label=None, label=None,
char_ops=None, char_ops=None,
loss_type=None, loss_type=None,
max_text_length=None): max_text_length=None,
norm_img = resize_norm_img(img, image_shape) tps=None,
infer_mode=False):
if not infer_mode or char_ops.character_type == "en" or tps != None:
norm_img = resize_norm_img(img, image_shape)
else:
norm_img = resize_norm_img_chinese(img, image_shape)
norm_img = norm_img[np.newaxis, :] norm_img = norm_img[np.newaxis, :]
if label is not None: if label is not None:
char_num = char_ops.get_char_num() char_num = char_ops.get_char_num()
......
...@@ -30,6 +30,8 @@ class RecModel(object): ...@@ -30,6 +30,8 @@ class RecModel(object):
global_params = params['Global'] global_params = params['Global']
char_num = global_params['char_ops'].get_char_num() char_num = global_params['char_ops'].get_char_num()
global_params['char_num'] = char_num global_params['char_num'] = char_num
self.char_type = global_params['character_type']
self.infer_img = global_params['infer_img']
if "TPS" in params: if "TPS" in params:
tps_params = deepcopy(params["TPS"]) tps_params = deepcopy(params["TPS"])
tps_params.update(global_params) tps_params.update(global_params)
...@@ -60,8 +62,8 @@ class RecModel(object): ...@@ -60,8 +62,8 @@ class RecModel(object):
def create_feed(self, mode): def create_feed(self, mode):
image_shape = deepcopy(self.image_shape) image_shape = deepcopy(self.image_shape)
image_shape.insert(0, -1) image_shape.insert(0, -1)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if mode == "train": if mode == "train":
image = fluid.data(name='image', shape=image_shape, dtype='float32')
if self.loss_type == "attention": if self.loss_type == "attention":
label_in = fluid.data( label_in = fluid.data(
name='label_in', name='label_in',
...@@ -86,6 +88,16 @@ class RecModel(object): ...@@ -86,6 +88,16 @@ class RecModel(object):
use_double_buffer=True, use_double_buffer=True,
iterable=False) iterable=False)
else: else:
if self.char_type == "ch" and self.infer_img:
image_shape[-1] = -1
if self.tps != None:
logger.info(
"WARNRNG!!!\n"
"TPS does not support variable shape in chinese!"
"We set img_shape to be the same , it may affect the inference effect"
)
image_shape = deepcopy(self.image_shape)
image = fluid.data(name='image', shape=image_shape, dtype='float32')
labels = None labels = None
loader = None loader = None
return image, labels, loader return image, labels, loader
...@@ -110,7 +122,11 @@ class RecModel(object): ...@@ -110,7 +122,11 @@ class RecModel(object):
return loader, outputs return loader, outputs
elif mode == "export": elif mode == "export":
predict = predicts['predict'] predict = predicts['predict']
predict = fluid.layers.softmax(predict) if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
return [image, {'decoded_out': decoded_out, 'predicts': predict}] return [image, {'decoded_out': decoded_out, 'predicts': predict}]
else: else:
return loader, {'decoded_out': decoded_out} predict = predicts['predict']
if self.loss_type == "ctc":
predict = fluid.layers.softmax(predict)
return loader, {'decoded_out': decoded_out, 'predicts': predict}
...@@ -123,6 +123,8 @@ class AttentionPredict(object): ...@@ -123,6 +123,8 @@ class AttentionPredict(object):
full_ids = fluid.layers.fill_constant_batch_size_like( full_ids = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='int64', value=1) input=init_state, shape=[-1, 1], dtype='int64', value=1)
full_scores = fluid.layers.fill_constant_batch_size_like(
input=init_state, shape=[-1, 1], dtype='float32', value=1)
cond = layers.less_than(x=counter, y=array_len) cond = layers.less_than(x=counter, y=array_len)
while_op = layers.While(cond=cond) while_op = layers.While(cond=cond)
...@@ -171,6 +173,9 @@ class AttentionPredict(object): ...@@ -171,6 +173,9 @@ class AttentionPredict(object):
new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1) new_ids = fluid.layers.concat([full_ids, topk_indices], axis=1)
fluid.layers.assign(new_ids, full_ids) fluid.layers.assign(new_ids, full_ids)
new_scores = fluid.layers.concat([full_scores, topk_scores], axis=1)
fluid.layers.assign(new_scores, full_scores)
layers.increment(x=counter, value=1, in_place=True) layers.increment(x=counter, value=1, in_place=True)
# update the memories # update the memories
...@@ -184,7 +189,7 @@ class AttentionPredict(object): ...@@ -184,7 +189,7 @@ class AttentionPredict(object):
length_cond = layers.less_than(x=counter, y=array_len) length_cond = layers.less_than(x=counter, y=array_len)
finish_cond = layers.logical_not(layers.is_empty(x=topk_indices)) finish_cond = layers.logical_not(layers.is_empty(x=topk_indices))
layers.logical_and(x=length_cond, y=finish_cond, out=cond) layers.logical_and(x=length_cond, y=finish_cond, out=cond)
return full_ids return full_ids, full_scores
def __call__(self, inputs, labels=None, mode=None): def __call__(self, inputs, labels=None, mode=None):
encoder_features = self.encoder(inputs) encoder_features = self.encoder(inputs)
...@@ -223,10 +228,10 @@ class AttentionPredict(object): ...@@ -223,10 +228,10 @@ class AttentionPredict(object):
decoder_size, char_num) decoder_size, char_num)
_, decoded_out = layers.topk(input=predict, k=1) _, decoded_out = layers.topk(input=predict, k=1)
decoded_out = layers.lod_reset(decoded_out, y=label_out) decoded_out = layers.lod_reset(decoded_out, y=label_out)
predicts = {'predict': predict, 'decoded_out': decoded_out} predicts = {'predict':predict, 'decoded_out':decoded_out}
else: else:
ids = self.gru_attention_infer( ids, predict = self.gru_attention_infer(
decoder_boot, self.max_length, char_num, word_vector_dim, decoder_boot, self.max_length, char_num, word_vector_dim,
encoded_vector, encoded_proj, decoder_size) encoded_vector, encoded_proj, decoder_size)
predicts = {'decoded_out': ids} predicts = {'predict':predict, 'decoded_out':ids}
return predicts return predicts
...@@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode): ...@@ -48,7 +48,7 @@ def eval_rec_run(exe, config, eval_info_dict, mode):
total_sample_num = 0 total_sample_num = 0
total_acc_num = 0 total_acc_num = 0
total_batch_num = 0 total_batch_num = 0
if mode == "test": if mode == "eval":
is_remove_duplicate = False is_remove_duplicate = False
else: else:
is_remove_duplicate = True is_remove_duplicate = True
...@@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict): ...@@ -91,11 +91,11 @@ def test_rec_benchmark(exe, config, eval_info_dict):
total_correct_number = 0 total_correct_number = 0
eval_data_acc_info = {} eval_data_acc_info = {}
for eval_data in eval_data_list: for eval_data in eval_data_list:
config['EvalReader']['lmdb_sets_dir'] = \ config['TestReader']['lmdb_sets_dir'] = \
eval_data_dir + "/" + eval_data eval_data_dir + "/" + eval_data
eval_reader = reader_main(config=config, mode="eval") eval_reader = reader_main(config=config, mode="test")
eval_info_dict['reader'] = eval_reader eval_info_dict['reader'] = eval_reader
metrics = eval_rec_run(exe, config, eval_info_dict, "eval") metrics = eval_rec_run(exe, config, eval_info_dict, "test")
total_evaluation_data_number += metrics['total_sample_num'] total_evaluation_data_number += metrics['total_sample_num']
total_correct_number += metrics['total_acc_num'] total_correct_number += metrics['total_acc_num']
eval_data_acc_info[eval_data] = metrics eval_data_acc_info[eval_data] = metrics
......
...@@ -32,10 +32,16 @@ class TextRecognizer(object): ...@@ -32,10 +32,16 @@ class TextRecognizer(object):
self.rec_image_shape = image_shape self.rec_image_shape = image_shape
self.character_type = args.rec_char_type self.character_type = args.rec_char_type
self.rec_batch_num = args.rec_batch_num self.rec_batch_num = args.rec_batch_num
self.rec_algorithm = args.rec_algorithm
char_ops_params = {} char_ops_params = {}
char_ops_params["character_type"] = args.rec_char_type char_ops_params["character_type"] = args.rec_char_type
char_ops_params["character_dict_path"] = args.rec_char_dict_path char_ops_params["character_dict_path"] = args.rec_char_dict_path
char_ops_params['loss_type'] = 'ctc' if self.rec_algorithm != "RARE":
char_ops_params['loss_type'] = 'ctc'
self.loss_type = 'ctc'
else:
char_ops_params['loss_type'] = 'attention'
self.loss_type = 'attention'
self.char_ops = CharacterOps(char_ops_params) self.char_ops = CharacterOps(char_ops_params)
def resize_norm_img(self, img, max_wh_ratio): def resize_norm_img(self, img, max_wh_ratio):
...@@ -80,26 +86,43 @@ class TextRecognizer(object): ...@@ -80,26 +86,43 @@ class TextRecognizer(object):
starttime = time.time() starttime = time.time()
self.input_tensor.copy_from_cpu(norm_img_batch) self.input_tensor.copy_from_cpu(norm_img_batch)
self.predictor.zero_copy_run() self.predictor.zero_copy_run()
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
rec_idx_lod = self.output_tensors[0].lod()[0] if self.loss_type == "ctc":
predict_batch = self.output_tensors[1].copy_to_cpu() rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_lod = self.output_tensors[1].lod()[0] rec_idx_lod = self.output_tensors[0].lod()[0]
elapse = time.time() - starttime predict_batch = self.output_tensors[1].copy_to_cpu()
predict_time += elapse predict_lod = self.output_tensors[1].lod()[0]
starttime = time.time() elapse = time.time() - starttime
for rno in range(len(rec_idx_lod) - 1): predict_time += elapse
beg = rec_idx_lod[rno] for rno in range(len(rec_idx_lod) - 1):
end = rec_idx_lod[rno + 1] beg = rec_idx_lod[rno]
rec_idx_tmp = rec_idx_batch[beg:end, 0] end = rec_idx_lod[rno + 1]
preds_text = self.char_ops.decode(rec_idx_tmp) rec_idx_tmp = rec_idx_batch[beg:end, 0]
beg = predict_lod[rno] preds_text = self.char_ops.decode(rec_idx_tmp)
end = predict_lod[rno + 1] beg = predict_lod[rno]
probs = predict_batch[beg:end, :] end = predict_lod[rno + 1]
ind = np.argmax(probs, axis=1) probs = predict_batch[beg:end, :]
blank = probs.shape[1] ind = np.argmax(probs, axis=1)
valid_ind = np.where(ind != (blank - 1))[0] blank = probs.shape[1]
score = np.mean(probs[valid_ind, ind[valid_ind]]) valid_ind = np.where(ind != (blank - 1))[0]
rec_res.append([preds_text, score]) score = np.mean(probs[valid_ind, ind[valid_ind]])
rec_res.append([preds_text, score])
else:
rec_idx_batch = self.output_tensors[0].copy_to_cpu()
predict_batch = self.output_tensors[1].copy_to_cpu()
elapse = time.time() - starttime
predict_time += elapse
for rno in range(len(rec_idx_batch)):
end_pos = np.where(rec_idx_batch[rno, :] == 1)[0]
if len(end_pos) <= 1:
preds = rec_idx_batch[rno, 1:]
score = np.mean(predict_batch[rno, 1:])
else:
preds = rec_idx_batch[rno, 1:end_pos[1]]
score = np.mean(predict_batch[rno, 1:end_pos[1]])
preds_text = self.char_ops.decode(preds)
rec_res.append([preds_text, score])
return rec_res, predict_time return rec_res, predict_time
...@@ -116,7 +139,17 @@ if __name__ == "__main__": ...@@ -116,7 +139,17 @@ if __name__ == "__main__":
continue continue
valid_image_file_list.append(image_file) valid_image_file_list.append(image_file)
img_list.append(img) img_list.append(img)
rec_res, predict_time = text_recognizer(img_list) try:
rec_res, predict_time = text_recognizer(img_list)
except Exception as e:
print(e)
logger.info(
"ERROR!!!! \n"
"Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n"
"If your model has tps module: "
"TPS does not support variable shape.\n"
"Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ")
exit()
for ino in range(len(img_list)): for ino in range(len(img_list)):
print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino])) print("Predicts of %s:%s" % (valid_image_file_list[ino], rec_res[ino]))
print("Total predict time for %d images:%.3f" % print("Total predict time for %d images:%.3f" %
......
...@@ -21,6 +21,7 @@ import time ...@@ -21,6 +21,7 @@ import time
import multiprocessing import multiprocessing
import numpy as np import numpy as np
def set_paddle_flags(**kwargs): def set_paddle_flags(**kwargs):
for key, value in kwargs.items(): for key, value in kwargs.items():
if os.environ.get(key, None) is None: if os.environ.get(key, None) is None:
...@@ -54,6 +55,7 @@ def main(): ...@@ -54,6 +55,7 @@ def main():
program.merge_config(FLAGS.opt) program.merge_config(FLAGS.opt)
logger.info(config) logger.info(config)
char_ops = CharacterOps(config['Global']) char_ops = CharacterOps(config['Global'])
loss_type = config['Global']['loss_type']
config['Global']['char_ops'] = char_ops config['Global']['char_ops'] = char_ops
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
...@@ -78,35 +80,44 @@ def main(): ...@@ -78,35 +80,44 @@ def main():
init_model(config, eval_prog, exe) init_model(config, eval_prog, exe)
blobs = reader_main(config, 'test')() blobs = reader_main(config, 'test')()
infer_img = config['TestReader']['infer_img'] infer_img = config['Global']['infer_img']
infer_list = get_image_file_list(infer_img) infer_list = get_image_file_list(infer_img)
max_img_num = len(infer_list) max_img_num = len(infer_list)
if len(infer_list) == 0: if len(infer_list) == 0:
logger.info("Can not find img in infer_img dir.") logger.info("Can not find img in infer_img dir.")
for i in range(max_img_num): for i in range(max_img_num):
print("infer_img:",infer_list[i]) print("infer_img:%s" % infer_list[i])
img = next(blobs) img = next(blobs)
predict = exe.run(program=eval_prog, predict = exe.run(program=eval_prog,
feed={"image": img}, feed={"image": img},
fetch_list=fetch_varname_list, fetch_list=fetch_varname_list,
return_numpy=False) return_numpy=False)
if loss_type == "ctc":
preds = np.array(predict[0]) preds = np.array(predict[0])
if preds.shape[1] == 1:
preds = preds.reshape(-1) preds = preds.reshape(-1)
preds_lod = predict[0].lod()[0] preds_lod = predict[0].lod()[0]
preds_text = char_ops.decode(preds) preds_text = char_ops.decode(preds)
else: probs = np.array(predict[1])
ind = np.argmax(probs, axis=1)
blank = probs.shape[1]
valid_ind = np.where(ind != (blank - 1))[0]
score = np.mean(probs[valid_ind, ind[valid_ind]])
elif loss_type == "attention":
preds = np.array(predict[0])
probs = np.array(predict[1])
end_pos = np.where(preds[0, :] == 1)[0] end_pos = np.where(preds[0, :] == 1)[0]
if len(end_pos) <= 1: if len(end_pos) <= 1:
preds_text = preds[0, 1:] preds = preds[0, 1:]
score = np.mean(probs[0, 1:])
else: else:
preds_text = preds[0, 1:end_pos[1]] preds = preds[0, 1:end_pos[1]]
preds_text = preds_text.reshape(-1) score = np.mean(probs[0, 1:end_pos[1]])
preds_text = char_ops.decode(preds_text) preds = preds.reshape(-1)
preds_text = char_ops.decode(preds)
print("\t index:",preds) print("\t index:", preds)
print("\t word :",preds_text) print("\t word :", preds_text)
print("\t score :", score)
# save for inference model # save for inference model
target_var = [] target_var = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册