提交 97111112 编写于 作者: J Jethong

fix error

上级 051fe64a
......@@ -3,7 +3,7 @@ Global:
epoch_num: 600
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/pg_r50_vd_tt/
save_model_dir: ./output/pgnet_r50_vd_totaltext/
save_epoch_step: 10
# evaluation is run every 0 iterationss after the 1000th iteration
eval_batch_step: [ 0, 1000 ]
......@@ -18,7 +18,11 @@ Global:
save_inference_dir:
use_visualdl: False
infer_img:
save_res_path: ./output/pg_r50_vd_tt/predicts_pg.txt
valid_set: totaltext #two mode: totaltext valid curved words, partvgg valid non-curved words
save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
character_dict_path: ppocr/utils/pgnet_dict.txt
character_type: EN
max_text_length: 50
Architecture:
model_type: e2e
......@@ -51,30 +55,26 @@ Optimizer:
PostProcess:
name: PGPostProcess
score_thresh: 0.8
cover_thresh: 0.1
nms_thresh: 0.2
Metric:
name: E2EMetric
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
character_dict_path: ppocr/utils/pgnet_dict.txt
main_indicator: f_score_e2e
Train:
dataset:
name: PGDateSet
label_file_list: [./train_data/total_text/train/]
name: PGDataSet
label_file_list: [.././train_data/total_text/train/]
ratio_list: [1.0]
data_format: icdar
data_format: icdar #two data format: icdar/textnet
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- PGProcessTrain:
batch_size: 14
batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24
min_text_size: 4
max_text_size: 512
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
- KeepKeys:
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
loader:
......@@ -93,10 +93,7 @@ Eval:
img_mode: BGR
channel_first: False
- E2ELabelEncode:
Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
max_len: 50
- E2EResizeForTest:
valid_set: totaltext
max_side_len: 768
- NormalizeImage:
scale: 1./255.
......
......@@ -12,7 +12,8 @@ inference 模型(`paddle.jit.save`保存的模型)
- [一、训练模型转inference模型](#训练模型转inference模型)
- [检测模型转inference模型](#检测模型转inference模型)
- [识别模型转inference模型](#识别模型转inference模型)
- [方向分类模型转inference模型](#方向分类模型转inference模型)
- [方向分类模型转inference模型](#方向分类模型转inference模型)
- [端到端模型转inference模型](#端到端模型转inference模型)
- [二、文本检测模型推理](#文本检测模型推理)
- [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理)
......@@ -27,10 +28,13 @@ inference 模型(`paddle.jit.save`保存的模型)
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- [5. 多语言模型的推理](#多语言模型的推理)
- [四、方向分类模型推理](#方向识别模型推理)
- [四、端到端模型推理](#端到端模型推理)
- [1. PGNet端到端模型推理](#SAST文本检测模型推理)
- [五、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
- [、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
- [、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
- [2. 其他模型推理](#其他模型推理)
......@@ -118,6 +122,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
└── inference.pdmodel # 分类inference模型的program文件
```
<a name="端到端模型转inference模型"></a>
### 端到端模型转inference模型
下载端到端模型:
```
wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar && tar xf ./ch_lite/ch_ppocr_mobile_v2.0_cls_train.tar -C ./ch_lite/
```
端到端模型转inference模型与检测的方式相同,如下:
```
# -c 后面设置训练算法的yml配置文件
# -o 配置可选参数
# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
# Global.load_static_weights 参数需要设置为 False。
# Global.save_inference_dir参数设置转换的模型将保存的地址。
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e/
```
转换成功后,在目录下有三个文件:
```
/inference/e2e/
├── inference.pdiparams # 分类inference模型的参数文件
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
└── inference.pdmodel # 分类inference模型的program文件
```
<a name="文本检测模型推理"></a>
## 二、文本检测模型推理
......@@ -332,8 +362,45 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" -
Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
```
<a name="端到端模型推理"></a>
## 四、端到端模型推理
端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
<a name="SAST文本检测模型推理"></a>
### 1. PGNet端到端模型推理
#### (1). 四边形文本检测模型(ICDAR2015)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e_pgnet_ic15/"
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/det_res_img_10_sast.jpg)
#### (2). 弯曲文本检测模型(Total-Text)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e_pgnet_tt
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_pgnet_tt/" --e2e_pgnet_polygon=True
```
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pg.jpg)
**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
<a name="方向分类模型推理"></a>
## 、方向分类模型推理
## 、方向分类模型推理
下面将介绍方向分类模型推理。
......@@ -358,7 +425,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982]
```
<a name="文本检测、方向分类和文字识别串联推理"></a>
## 、文本检测、方向分类和文字识别串联推理
## 、文本检测、方向分类和文字识别串联推理
<a name="超轻量中文OCR模型推理"></a>
### 1. 超轻量中文OCR模型推理
......
......@@ -73,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None):
else:
use_shared_memory = True
if mode == "Train":
#Distribute data to multiple cards
# Distribute data to multiple cards
batch_sampler = DistributedBatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
else:
#Distribute data to single card
# Distribute data to single card
batch_sampler = BatchSampler(
dataset=dataset,
batch_size=batch_size,
......
......@@ -34,28 +34,6 @@ class ClsLabelEncode(object):
return data
class E2ELabelEncode(object):
def __init__(self, Lexicon_Table, max_len, **kwargs):
self.Lexicon_Table = Lexicon_Table
self.max_len = max_len
self.pad_num = len(self.Lexicon_Table)
def __call__(self, data):
text_label_index_list, temp_text = [], []
texts = data['strs']
for text in texts:
text = text.upper()
temp_text = []
for c_ in text:
if c_ in self.Lexicon_Table:
temp_text.append(self.Lexicon_Table.index(c_))
temp_text = temp_text + [self.pad_num] * (self.max_len -
len(temp_text))
text_label_index_list.append(temp_text)
data['strs'] = np.array(text_label_index_list)
return data
class DetLabelEncode(object):
def __init__(self, **kwargs):
pass
......@@ -209,6 +187,32 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character
class E2ELabelEncode(BaseRecLabelEncode):
def __init__(self,
max_text_length,
character_dict_path=None,
character_type='EN',
use_space_char=False,
**kwargs):
super(E2ELabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
def __call__(self, data):
texts = data['strs']
temp_texts = []
for text in texts:
text = text.upper()
text = self.encode(text)
if text is None:
return None
text = text + [36] * (self.max_text_len - len(text)
) # use 36 to pad
temp_texts.append(text)
data['strs'] = np.array(temp_texts)
return data
class AttnLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
......
......@@ -21,6 +21,7 @@ __all__ = ['PGProcessTrain']
class PGProcessTrain(object):
def __init__(self,
character_dict_path,
batch_size=14,
min_crop_size=24,
min_text_size=10,
......@@ -30,13 +31,19 @@ class PGProcessTrain(object):
self.min_crop_size = min_crop_size
self.min_text_size = min_text_size
self.max_text_size = max_text_size
self.Lexicon_Table = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
]
self.Lexicon_Table = self.get_dict(character_dict_path)
self.img_id = 0
def get_dict(self, character_dict_path):
character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
def quad_area(self, poly):
"""
compute area of a polygon
......@@ -853,7 +860,7 @@ class PGProcessTrain(object):
for i in range(len(label_list)):
label_list[i] = np.array(label_list[i])
if len(pos_list) <= 0 or len(pos_list) > 30:
if len(pos_list) <= 0 or len(pos_list) > 30: #一张图片中最多存在30行文本
return None
for __ in range(30 - len(pos_list), 0, -1):
pos_list.append(pos_list_temp)
......
......@@ -19,11 +19,15 @@ from __future__ import print_function
__all__ = ['E2EMetric']
from ppocr.utils.e2e_metric.Deteval import *
from ppocr.utils.e2e_utils.extract_textpoint import *
class E2EMetric(object):
def __init__(self, Lexicon_Table, main_indicator='f_score_e2e', **kwargs):
self.label_list = Lexicon_Table
def __init__(self,
character_dict_path,
main_indicator='f_score_e2e',
**kwargs):
self.label_list = get_dict(character_dict_path)
self.max_index = len(self.label_list)
self.main_indicator = main_indicator
self.reset()
......
......@@ -228,11 +228,11 @@ class PGHead(nn.Layer):
f_score = self.conv1(f_score)
f_score = F.sigmoid(f_score)
# f_boder
f_boder = self.conv_f_boder1(x)
f_boder = self.conv_f_boder2(f_boder)
f_boder = self.conv_f_boder3(f_boder)
f_boder = self.conv2(f_boder)
# f_border
f_border = self.conv_f_boder1(x)
f_border = self.conv_f_boder2(f_border)
f_border = self.conv_f_boder3(f_border)
f_border = self.conv2(f_border)
f_char = self.conv_f_char1(x)
f_char = self.conv_f_char2(f_char)
......@@ -246,4 +246,9 @@ class PGHead(nn.Layer):
f_direction = self.conv_f_direc3(f_direction)
f_direction = self.conv4(f_direction)
return f_score, f_boder, f_direction, f_char
predicts = {}
predicts['f_score'] = f_score
predicts['f_border'] = f_border
predicts['f_char'] = f_char
predicts['f_direction'] = f_direction
return predicts
......@@ -30,30 +30,14 @@ import paddle
class PGPostProcess(object):
"""
The post process for SAST.
The post process for PGNet.
"""
def __init__(self,
score_thresh=0.5,
nms_thresh=0.2,
sample_pts_num=2,
shrink_ratio_of_width=0.3,
expand_scale=1.0,
tcl_map_thresh=0.5,
**kwargs):
self.result_path = ""
self.valid_set = 'totaltext'
self.Lexicon_Table = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
]
def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs):
self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.sample_pts_num = sample_pts_num
self.shrink_ratio_of_width = shrink_ratio_of_width
self.expand_scale = expand_scale
self.tcl_map_thresh = tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
......@@ -61,16 +45,23 @@ class PGPostProcess(object):
self.is_python35 = True
def __call__(self, outs_dict, shape_list):
p_score, p_border, p_direction, p_char = outs_dict[:4]
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
src_h, src_w, ratio_h, ratio_w = shape_list[0]
if self.valid_set != 'totaltext':
is_curved = False
p_score = outs_dict['f_score']
p_border = outs_dict['f_border']
p_char = outs_dict['f_char']
p_direction = outs_dict['f_direction']
if isinstance(p_score, paddle.Tensor):
p_score = p_score[0].numpy()
p_border = p_border[0].numpy()
p_direction = p_direction[0].numpy()
p_char = p_char[0].numpy()
else:
is_curved = True
p_score = p_score[0]
p_border = p_border[0]
p_direction = p_direction[0]
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = shape_list[0]
is_curved = self.valid_set == "totaltext"
instance_yxs_list = generate_pivot_list(
p_score,
p_char,
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
......@@ -24,6 +24,17 @@ from itertools import groupby
from skimage.morphology._skeletonize import thin
def get_dict(character_dict_path):
character_str = ""
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
character_str += line
dict_character = list(character_str)
return dict_character
def softmax(logits):
"""
logits: N x d
......@@ -164,7 +175,6 @@ def sort_and_expand_with_direction(pos_list, f_direction):
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
......@@ -207,7 +217,6 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
# expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
......@@ -268,7 +277,6 @@ def generate_pivot_list_curved(p_score,
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
......@@ -279,7 +287,6 @@ def generate_pivot_list_curved(p_score,
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
......@@ -290,7 +297,6 @@ def generate_pivot_list_curved(p_score,
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
......@@ -335,11 +341,9 @@ def generate_pivot_list_horizontal(p_score,
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 5:
continue
# add rule here
main_direction = extract_main_direction(pos_list,
f_direction) # y x
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
......@@ -370,7 +374,6 @@ def generate_pivot_list_horizontal(p_score,
f_direction)
all_pos_yxs.append(pos_list_sorted)
# use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
......@@ -417,7 +420,6 @@ def generate_pivot_list(p_score,
image_id=image_id)
# for refine module
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
......@@ -504,14 +506,12 @@ def generate_pivot_list_tt_inference(p_score,
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
......
......@@ -28,7 +28,6 @@ def resize_image(im, max_side_len=512):
resize_w = w
resize_h = h
# Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
......@@ -50,13 +49,11 @@ def resize_image(im, max_side_len=512):
def resize_image_min(im, max_side_len=512):
"""
"""
# print('--> Using resize_image_min')
h, w, _ = im.shape
resize_w = w
resize_h = h
# Fix the longer side
if resize_h < resize_w:
ratio = float(max_side_len) / resize_h
else:
......@@ -84,12 +81,7 @@ def resize_image_for_totaltext(im, max_side_len=512):
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
# Fix the longer side
# if resize_h > resize_w:
# ratio = float(max_side_len) / resize_h
# else:
# ratio = float(max_side_len) / resize_w
###
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
......@@ -114,7 +106,6 @@ def point_pair2poly(point_pair_list):
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
# constract poly
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
......
0
1
2
3
4
5
6
7
8
9
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import time
import sys
import tools.infer.utility as utility
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
logger = get_logger()
class TextE2e(object):
def __init__(self, args):
self.args = args
self.e2e_algorithm = args.e2e_algorithm
pre_process_list = [{
'E2EResizeForTest': {
'max_side_len': 768,
'valid_set': 'totaltext'
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
}
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image', 'shape']
}
}]
postprocess_params = {}
if self.e2e_algorithm == "PGNet":
pre_process_list[0] = {
'E2EResizeForTest': {
'max_side_len': args.e2e_limit_side_len,
'valid_set': 'totaltext'
}
}
postprocess_params['name'] = 'PGPostProcess'
postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh
postprocess_params["character_dict_path"] = args.e2e_char_dict_path
postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
self.e2e_pgnet_polygon = args.e2e_pgnet_polygon
if self.e2e_pgnet_polygon:
postprocess_params["expand_scale"] = 1.2
postprocess_params["shrink_ratio_of_width"] = 0.2
else:
postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3
else:
logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
sys.exit(0)
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
# self.predictor.eval()
def clip_det_res(self, points, img_height, img_width):
for pno in range(points.shape[0]):
points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
return points
def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
img_height, img_width = image_shape[0:2]
dt_boxes_new = []
for box in dt_boxes:
box = self.clip_det_res(box, img_height, img_width)
dt_boxes_new.append(box)
dt_boxes = np.array(dt_boxes_new)
return dt_boxes
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
img, shape_list = data
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
print(img.shape)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
starttime = time.time()
self.input_tensor.copy_from_cpu(img)
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
preds = {}
if self.e2e_algorithm == 'PGNet':
preds['f_score'] = outputs[0]
preds['f_border'] = outputs[1]
preds['f_direction'] = outputs[2]
preds['f_char'] = outputs[3]
else:
raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list)
points, strs = post_result['points'], post_result['strs']
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
elapse = time.time() - starttime
return dt_boxes, strs, elapse
if __name__ == "__main__":
args = utility.parse_args()
image_file_list = get_image_file_list(args.image_dir)
text_detector = TextE2e(args)
count = 0
total_time = 0
draw_img_save = "./inference_results"
if not os.path.exists(draw_img_save):
os.makedirs(draw_img_save)
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
points, strs, elapse = text_detector(img)
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
src_im = utility.draw_e2e_res(points, strs, image_file)
img_name_pure = os.path.split(image_file)[-1]
img_path = os.path.join(draw_img_save,
"e2e_res_{}".format(img_name_pure))
cv2.imwrite(img_path, src_im)
logger.info("The visualized image saved in {}".format(img_path))
if count > 1:
logger.info("Avg Time: {}".format(total_time / (count - 1)))
......@@ -74,6 +74,21 @@ def parse_args():
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5)
# params for e2e
parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
parser.add_argument("--e2e_model_dir", type=str)
parser.add_argument("--e2e_limit_side_len", type=float, default=768)
parser.add_argument("--e2e_limit_type", type=str, default='max')
# PGNet parmas
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
parser.add_argument(
"--e2e_char_dict_path",
type=str,
default="./ppocr/utils/pgnet_dict.txt")
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
parser.add_argument("--e2e_pgnet_polygon", type=bool, default=False)
# params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str)
......@@ -93,8 +108,10 @@ def create_predictor(args, mode, logger):
model_dir = args.det_model_dir
elif mode == 'cls':
model_dir = args.cls_model_dir
else:
elif mode == 'rec':
model_dir = args.rec_model_dir
else:
model_dir = args.e2e_model_dir
if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir))
......@@ -147,6 +164,22 @@ def create_predictor(args, mode, logger):
return predictor, input_tensor, output_tensors
def draw_e2e_res(dt_boxes, strs, img_path):
src_im = cv2.imread(img_path)
for box, str in zip(dt_boxes, strs):
box = box.astype(np.int32).reshape((-1, 1, 2))
cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
cv2.putText(
src_im,
str,
org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
fontFace=cv2.FONT_HERSHEY_COMPLEX,
fontScale=0.7,
color=(0, 255, 0),
thickness=1)
return src_im
def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path)
for box in dt_boxes:
......
......@@ -71,7 +71,8 @@ def main():
init_model(config, model, logger)
# build post process
post_process_class = build_post_process(config['PostProcess'])
post_process_class = build_post_process(config['PostProcess'],
global_config)
# create data ops
transforms = []
......
2.0,165.0,20.0,167.0,39.0,170.0,57.0,173.0,76.0,176.0,94.0,179.0,113.0,182.0,109.0,218.0,90.0,215.0,72.0,213.0,54.0,210.0,36.0,208.0,18.0,205.0,0.0,203.0 izza
2.0,411.0,30.0,412.0,58.0,414.0,87.0,416.0,115.0,418.0,143.0,420.0,172.0,422.0,172.0,476.0,143.0,474.0,114.0,472.0,86.0,471.0,57.0,469.0,28.0,467.0,0.0,466.0 ISA
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册