diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml
index 22548a3c98a8dbe0fb07a9a0e1b721dc1ea1298c..08c485a638759e6436bd1613fff81fa14c8a6db8 100644
--- a/configs/e2e/e2e_r50_vd_pg.yml
+++ b/configs/e2e/e2e_r50_vd_pg.yml
@@ -3,7 +3,7 @@ Global:
epoch_num: 600
log_smooth_window: 20
print_batch_step: 10
- save_model_dir: ./output/pg_r50_vd_tt/
+ save_model_dir: ./output/pgnet_r50_vd_totaltext/
save_epoch_step: 10
# evaluation is run every 0 iterationss after the 1000th iteration
eval_batch_step: [ 0, 1000 ]
@@ -18,7 +18,11 @@ Global:
save_inference_dir:
use_visualdl: False
infer_img:
- save_res_path: ./output/pg_r50_vd_tt/predicts_pg.txt
+ valid_set: totaltext #two mode: totaltext valid curved words, partvgg valid non-curved words
+ save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
+ character_dict_path: ppocr/utils/pgnet_dict.txt
+ character_type: EN
+ max_text_length: 50
Architecture:
model_type: e2e
@@ -51,30 +55,26 @@ Optimizer:
PostProcess:
name: PGPostProcess
score_thresh: 0.8
- cover_thresh: 0.1
- nms_thresh: 0.2
-
Metric:
name: E2EMetric
- Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
+ character_dict_path: ppocr/utils/pgnet_dict.txt
main_indicator: f_score_e2e
Train:
dataset:
- name: PGDateSet
- label_file_list: [./train_data/total_text/train/]
+ name: PGDataSet
+ label_file_list: [.././train_data/total_text/train/]
ratio_list: [1.0]
- data_format: icdar
+ data_format: icdar #two data format: icdar/textnet
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- PGProcessTrain:
- batch_size: 14
+ batch_size: 14 # same as loader: batch_size_per_card
min_crop_size: 24
min_text_size: 4
max_text_size: 512
- Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
- KeepKeys:
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
loader:
@@ -93,10 +93,7 @@ Eval:
img_mode: BGR
channel_first: False
- E2ELabelEncode:
- Lexicon_Table: [ '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z' ]
- max_len: 50
- E2EResizeForTest:
- valid_set: totaltext
max_side_len: 768
- NormalizeImage:
scale: 1./255.
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index 7968b355ea936d465b3c173c0fcdb3e08f12f16e..40f4d8c5119fb4be72573dd6a1f99ca59aeaf7aa 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -12,7 +12,8 @@ inference 模型(`paddle.jit.save`保存的模型)
- [一、训练模型转inference模型](#训练模型转inference模型)
- [检测模型转inference模型](#检测模型转inference模型)
- [识别模型转inference模型](#识别模型转inference模型)
- - [方向分类模型转inference模型](#方向分类模型转inference模型)
+ - [方向分类模型转inference模型](#方向分类模型转inference模型)
+ - [端到端模型转inference模型](#端到端模型转inference模型)
- [二、文本检测模型推理](#文本检测模型推理)
- [1. 超轻量中文检测模型推理](#超轻量中文检测模型推理)
@@ -27,10 +28,13 @@ inference 模型(`paddle.jit.save`保存的模型)
- [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理)
- [5. 多语言模型的推理](#多语言模型的推理)
-- [四、方向分类模型推理](#方向识别模型推理)
+- [四、端到端模型推理](#端到端模型推理)
+ - [1. PGNet端到端模型推理](#SAST文本检测模型推理)
+
+- [五、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
-- [五、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
+- [六、文本检测、方向分类和文字识别串联推理](#文本检测、方向分类和文字识别串联推理)
- [1. 超轻量中文OCR模型推理](#超轻量中文OCR模型推理)
- [2. 其他模型推理](#其他模型推理)
@@ -118,6 +122,32 @@ python3 tools/export_model.py -c configs/cls/cls_mv3.yml -o Global.pretrained_mo
├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
└── inference.pdmodel # 分类inference模型的program文件
```
+
+### 端到端模型转inference模型
+
+下载端到端模型:
+```
+wget -P ./ch_lite/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_train.tar && tar xf ./ch_lite/ch_ppocr_mobile_v2.0_cls_train.tar -C ./ch_lite/
+```
+
+端到端模型转inference模型与检测的方式相同,如下:
+```
+# -c 后面设置训练算法的yml配置文件
+# -o 配置可选参数
+# Global.pretrained_model 参数设置待转换的训练模型地址,不用添加文件后缀 .pdmodel,.pdopt或.pdparams。
+# Global.load_static_weights 参数需要设置为 False。
+# Global.save_inference_dir参数设置转换的模型将保存的地址。
+
+python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./ch_lite/ch_ppocr_mobile_v2.0_cls_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e/
+```
+
+转换成功后,在目录下有三个文件:
+```
+/inference/e2e/
+ ├── inference.pdiparams # 分类inference模型的参数文件
+ ├── inference.pdiparams.info # 分类inference模型的参数信息,可忽略
+ └── inference.pdmodel # 分类inference模型的program文件
+```
## 二、文本检测模型推理
@@ -332,8 +362,45 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" -
Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
```
+
+## 四、端到端模型推理
+
+端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
+
+### 1. PGNet端到端模型推理
+#### (1). 四边形文本检测模型(ICDAR2015)
+首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
+```
+python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_icdar15_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
+```
+**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
+```
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e_pgnet_ic15/"
+```
+可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
+
+![](../imgs_results/det_res_img_10_sast.jpg)
+
+#### (2). 弯曲文本检测模型(Total-Text)
+首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
+
+```
+python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e_pgnet_tt
+```
+
+**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
+```
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_pgnet_tt/" --e2e_pgnet_polygon=True
+```
+可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
+
+![](../imgs_results/e2e_res_img623_pg.jpg)
+
+**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
+
+
-## 四、方向分类模型推理
+## 五、方向分类模型推理
下面将介绍方向分类模型推理。
@@ -358,7 +425,7 @@ Predicts of ./doc/imgs_words/ch/word_4.jpg:['0', 0.9999982]
```
-## 五、文本检测、方向分类和文字识别串联推理
+## 六、文本检测、方向分类和文字识别串联推理
### 1. 超轻量中文OCR模型推理
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 26a2a8dcfef24f48ab1331ece0b69bea9959f2ea..728b8317f54687ee76b519cba18f4d7807493821 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -73,14 +73,14 @@ def build_dataloader(config, mode, device, logger, seed=None):
else:
use_shared_memory = True
if mode == "Train":
- #Distribute data to multiple cards
+ # Distribute data to multiple cards
batch_sampler = DistributedBatchSampler(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
else:
- #Distribute data to single card
+ # Distribute data to single card
batch_sampler = BatchSampler(
dataset=dataset,
batch_size=batch_size,
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 85aa8bb28a1b9c5d945b5d8cfa290975df1d7a48..6c2fc8e4bbfcedb5150dd7baf13db267d1d74aa2 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -34,28 +34,6 @@ class ClsLabelEncode(object):
return data
-class E2ELabelEncode(object):
- def __init__(self, Lexicon_Table, max_len, **kwargs):
- self.Lexicon_Table = Lexicon_Table
- self.max_len = max_len
- self.pad_num = len(self.Lexicon_Table)
-
- def __call__(self, data):
- text_label_index_list, temp_text = [], []
- texts = data['strs']
- for text in texts:
- text = text.upper()
- temp_text = []
- for c_ in text:
- if c_ in self.Lexicon_Table:
- temp_text.append(self.Lexicon_Table.index(c_))
- temp_text = temp_text + [self.pad_num] * (self.max_len -
- len(temp_text))
- text_label_index_list.append(temp_text)
- data['strs'] = np.array(text_label_index_list)
- return data
-
-
class DetLabelEncode(object):
def __init__(self, **kwargs):
pass
@@ -209,6 +187,32 @@ class CTCLabelEncode(BaseRecLabelEncode):
return dict_character
+class E2ELabelEncode(BaseRecLabelEncode):
+ def __init__(self,
+ max_text_length,
+ character_dict_path=None,
+ character_type='EN',
+ use_space_char=False,
+ **kwargs):
+ super(E2ELabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ character_type, use_space_char)
+
+ def __call__(self, data):
+ texts = data['strs']
+ temp_texts = []
+ for text in texts:
+ text = text.upper()
+ text = self.encode(text)
+ if text is None:
+ return None
+ text = text + [36] * (self.max_text_len - len(text)
+ ) # use 36 to pad
+ temp_texts.append(text)
+ data['strs'] = np.array(temp_texts)
+ return data
+
+
class AttnLabelEncode(BaseRecLabelEncode):
""" Convert between text-label and text-index """
diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py
index a496ed43b56972c6bf2feff8abc26e620abd643c..58837d7bfa3643babf0dd66951d6b15c5e32865e 100644
--- a/ppocr/data/imaug/pg_process.py
+++ b/ppocr/data/imaug/pg_process.py
@@ -21,6 +21,7 @@ __all__ = ['PGProcessTrain']
class PGProcessTrain(object):
def __init__(self,
+ character_dict_path,
batch_size=14,
min_crop_size=24,
min_text_size=10,
@@ -30,13 +31,19 @@ class PGProcessTrain(object):
self.min_crop_size = min_crop_size
self.min_text_size = min_text_size
self.max_text_size = max_text_size
- self.Lexicon_Table = [
- '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
- 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
- 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
- ]
+ self.Lexicon_Table = self.get_dict(character_dict_path)
self.img_id = 0
+ def get_dict(self, character_dict_path):
+ character_str = ""
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ character_str += line
+ dict_character = list(character_str)
+ return dict_character
+
def quad_area(self, poly):
"""
compute area of a polygon
@@ -853,7 +860,7 @@ class PGProcessTrain(object):
for i in range(len(label_list)):
label_list[i] = np.array(label_list[i])
- if len(pos_list) <= 0 or len(pos_list) > 30:
+ if len(pos_list) <= 0 or len(pos_list) > 30: #一张图片中最多存在30行文本
return None
for __ in range(30 - len(pos_list), 0, -1):
pos_list.append(pos_list_temp)
diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py
index c6cd1db94370e966a56fb694265a83e46c5e9ee3..04b73e0c4652c263d59380a0feff1f29da6c6817 100644
--- a/ppocr/metrics/e2e_metric.py
+++ b/ppocr/metrics/e2e_metric.py
@@ -19,11 +19,15 @@ from __future__ import print_function
__all__ = ['E2EMetric']
from ppocr.utils.e2e_metric.Deteval import *
+from ppocr.utils.e2e_utils.extract_textpoint import *
class E2EMetric(object):
- def __init__(self, Lexicon_Table, main_indicator='f_score_e2e', **kwargs):
- self.label_list = Lexicon_Table
+ def __init__(self,
+ character_dict_path,
+ main_indicator='f_score_e2e',
+ **kwargs):
+ self.label_list = get_dict(character_dict_path)
self.max_index = len(self.label_list)
self.main_indicator = main_indicator
self.reset()
diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py
index 106cdfa689680d8370c9ad2b4d51e0c5a8c74ba7..a3bc39aa6711f79f172e86912c42019d92543ed4 100644
--- a/ppocr/modeling/heads/e2e_pg_head.py
+++ b/ppocr/modeling/heads/e2e_pg_head.py
@@ -228,11 +228,11 @@ class PGHead(nn.Layer):
f_score = self.conv1(f_score)
f_score = F.sigmoid(f_score)
- # f_boder
- f_boder = self.conv_f_boder1(x)
- f_boder = self.conv_f_boder2(f_boder)
- f_boder = self.conv_f_boder3(f_boder)
- f_boder = self.conv2(f_boder)
+ # f_border
+ f_border = self.conv_f_boder1(x)
+ f_border = self.conv_f_boder2(f_border)
+ f_border = self.conv_f_boder3(f_border)
+ f_border = self.conv2(f_border)
f_char = self.conv_f_char1(x)
f_char = self.conv_f_char2(f_char)
@@ -246,4 +246,9 @@ class PGHead(nn.Layer):
f_direction = self.conv_f_direc3(f_direction)
f_direction = self.conv4(f_direction)
- return f_score, f_boder, f_direction, f_char
+ predicts = {}
+ predicts['f_score'] = f_score
+ predicts['f_border'] = f_border
+ predicts['f_char'] = f_char
+ predicts['f_direction'] = f_direction
+ return predicts
diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py
index 1f1ab60e0df044a9f731bbdb3a87ff89da5bdd99..6d1b7d7a106c76fcf7104abe432f99588a4043eb 100644
--- a/ppocr/postprocess/pg_postprocess.py
+++ b/ppocr/postprocess/pg_postprocess.py
@@ -30,30 +30,14 @@ import paddle
class PGPostProcess(object):
"""
- The post process for SAST.
+ The post process for PGNet.
"""
- def __init__(self,
- score_thresh=0.5,
- nms_thresh=0.2,
- sample_pts_num=2,
- shrink_ratio_of_width=0.3,
- expand_scale=1.0,
- tcl_map_thresh=0.5,
- **kwargs):
- self.result_path = ""
- self.valid_set = 'totaltext'
- self.Lexicon_Table = [
- '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C',
- 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P',
- 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'
- ]
+ def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs):
+
+ self.Lexicon_Table = get_dict(character_dict_path)
+ self.valid_set = valid_set
self.score_thresh = score_thresh
- self.nms_thresh = nms_thresh
- self.sample_pts_num = sample_pts_num
- self.shrink_ratio_of_width = shrink_ratio_of_width
- self.expand_scale = expand_scale
- self.tcl_map_thresh = tcl_map_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
@@ -61,16 +45,23 @@ class PGPostProcess(object):
self.is_python35 = True
def __call__(self, outs_dict, shape_list):
- p_score, p_border, p_direction, p_char = outs_dict[:4]
- p_score = p_score[0].numpy()
- p_border = p_border[0].numpy()
- p_direction = p_direction[0].numpy()
- p_char = p_char[0].numpy()
- src_h, src_w, ratio_h, ratio_w = shape_list[0]
- if self.valid_set != 'totaltext':
- is_curved = False
+ p_score = outs_dict['f_score']
+ p_border = outs_dict['f_border']
+ p_char = outs_dict['f_char']
+ p_direction = outs_dict['f_direction']
+ if isinstance(p_score, paddle.Tensor):
+ p_score = p_score[0].numpy()
+ p_border = p_border[0].numpy()
+ p_direction = p_direction[0].numpy()
+ p_char = p_char[0].numpy()
else:
- is_curved = True
+ p_score = p_score[0]
+ p_border = p_border[0]
+ p_direction = p_direction[0]
+ p_char = p_char[0]
+
+ src_h, src_w, ratio_h, ratio_w = shape_list[0]
+ is_curved = self.valid_set == "totaltext"
instance_yxs_list = generate_pivot_list(
p_score,
p_char,
diff --git a/ppocr/utils/e2e_metric/polygon_fast.py b/ppocr/utils/e2e_metric/polygon_fast.py
index c78e2a1e5f81850a7e3d7f1bb3e825aba0573cc0..81c9ad70675bb37a95968283b6dc6f42f709df27 100755
--- a/ppocr/utils/e2e_metric/polygon_fast.py
+++ b/ppocr/utils/e2e_metric/polygon_fast.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/ppocr/utils/e2e_utils/extract_textpoint.py b/ppocr/utils/e2e_utils/extract_textpoint.py
index 2d793aa98ebef3835c83efa190ff9dee204771f4..5355280946c0eadc8bc097e5409f755e8a390e5a 100644
--- a/ppocr/utils/e2e_utils/extract_textpoint.py
+++ b/ppocr/utils/e2e_utils/extract_textpoint.py
@@ -24,6 +24,17 @@ from itertools import groupby
from skimage.morphology._skeletonize import thin
+def get_dict(character_dict_path):
+ character_str = ""
+ with open(character_dict_path, "rb") as fin:
+ lines = fin.readlines()
+ for line in lines:
+ line = line.decode('utf-8').strip("\n").strip("\r\n")
+ character_str += line
+ dict_character = list(character_str)
+ return dict_character
+
+
def softmax(logits):
"""
logits: N x d
@@ -164,7 +175,6 @@ def sort_and_expand_with_direction(pos_list, f_direction):
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
- # expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
@@ -207,7 +217,6 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
- # expand along
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
@@ -268,7 +277,6 @@ def generate_pivot_list_curved(p_score,
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
- # get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
@@ -279,7 +287,6 @@ def generate_pivot_list_curved(p_score,
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
- ### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
@@ -290,7 +297,6 @@ def generate_pivot_list_curved(p_score,
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
all_pos_yxs.append(pos_list_sorted)
- # use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
@@ -335,11 +341,9 @@ def generate_pivot_list_horizontal(p_score,
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
- ### FIX-ME, eliminate outlier
if len(pos_list) < 5:
continue
- # add rule here
main_direction = extract_main_direction(pos_list,
f_direction) # y x
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
@@ -370,7 +374,6 @@ def generate_pivot_list_horizontal(p_score,
f_direction)
all_pos_yxs.append(pos_list_sorted)
- # use decoder to filter backgroud points.
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
@@ -417,7 +420,6 @@ def generate_pivot_list(p_score,
image_id=image_id)
-# for refine module
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
@@ -504,14 +506,12 @@ def generate_pivot_list_tt_inference(p_score,
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
- # get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
- ### FIX-ME, eliminate outlier
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
diff --git a/ppocr/utils/e2e_utils/visual.py b/ppocr/utils/e2e_utils/visual.py
index 6f8a429ef0fd85e14413bc1429e13a6ed81fc5f4..e6e4fd0667dbf4a42dbc0fd9bf26e6fd91be0d82 100644
--- a/ppocr/utils/e2e_utils/visual.py
+++ b/ppocr/utils/e2e_utils/visual.py
@@ -28,7 +28,6 @@ def resize_image(im, max_side_len=512):
resize_w = w
resize_h = h
- # Fix the longer side
if resize_h > resize_w:
ratio = float(max_side_len) / resize_h
else:
@@ -50,13 +49,11 @@ def resize_image(im, max_side_len=512):
def resize_image_min(im, max_side_len=512):
"""
"""
- # print('--> Using resize_image_min')
h, w, _ = im.shape
resize_w = w
resize_h = h
- # Fix the longer side
if resize_h < resize_w:
ratio = float(max_side_len) / resize_h
else:
@@ -84,12 +81,7 @@ def resize_image_for_totaltext(im, max_side_len=512):
ratio = 1.25
if h * ratio > max_side_len:
ratio = float(max_side_len) / resize_h
- # Fix the longer side
- # if resize_h > resize_w:
- # ratio = float(max_side_len) / resize_h
- # else:
- # ratio = float(max_side_len) / resize_w
- ###
+
resize_h = int(resize_h * ratio)
resize_w = int(resize_w * ratio)
@@ -114,7 +106,6 @@ def point_pair2poly(point_pair_list):
pair_info = (pair_length_list.max(), pair_length_list.min(),
pair_length_list.mean())
- # constract poly
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
diff --git a/ppocr/utils/pgnet_dict.txt b/ppocr/utils/pgnet_dict.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b52d16e64f1004e1fceccac2280bc6f6eabd6af3
--- /dev/null
+++ b/ppocr/utils/pgnet_dict.txt
@@ -0,0 +1,36 @@
+0
+1
+2
+3
+4
+5
+6
+7
+8
+9
+A
+B
+C
+D
+E
+F
+G
+H
+I
+J
+K
+L
+M
+N
+O
+P
+Q
+R
+S
+T
+U
+V
+W
+X
+Y
+Z
\ No newline at end of file
diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py
new file mode 100755
index 0000000000000000000000000000000000000000..1a92c4ab518a1c01d9d282443c09f7e3a7ecf008
--- /dev/null
+++ b/tools/infer/predict_e2e.py
@@ -0,0 +1,168 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import sys
+
+__dir__ = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(__dir__)
+sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))
+
+os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
+
+import cv2
+import numpy as np
+import time
+import sys
+
+import tools.infer.utility as utility
+from ppocr.utils.logging import get_logger
+from ppocr.utils.utility import get_image_file_list, check_and_read_gif
+from ppocr.data import create_operators, transform
+from ppocr.postprocess import build_post_process
+
+logger = get_logger()
+
+
+class TextE2e(object):
+ def __init__(self, args):
+ self.args = args
+ self.e2e_algorithm = args.e2e_algorithm
+ pre_process_list = [{
+ 'E2EResizeForTest': {
+ 'max_side_len': 768,
+ 'valid_set': 'totaltext'
+ }
+ }, {
+ 'NormalizeImage': {
+ 'std': [0.229, 0.224, 0.225],
+ 'mean': [0.485, 0.456, 0.406],
+ 'scale': '1./255.',
+ 'order': 'hwc'
+ }
+ }, {
+ 'ToCHWImage': None
+ }, {
+ 'KeepKeys': {
+ 'keep_keys': ['image', 'shape']
+ }
+ }]
+ postprocess_params = {}
+ if self.e2e_algorithm == "PGNet":
+ pre_process_list[0] = {
+ 'E2EResizeForTest': {
+ 'max_side_len': args.e2e_limit_side_len,
+ 'valid_set': 'totaltext'
+ }
+ }
+ postprocess_params['name'] = 'PGPostProcess'
+ postprocess_params["score_thresh"] = args.e2e_pgnet_score_thresh
+ postprocess_params["character_dict_path"] = args.e2e_char_dict_path
+ postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
+ self.e2e_pgnet_polygon = args.e2e_pgnet_polygon
+ if self.e2e_pgnet_polygon:
+ postprocess_params["expand_scale"] = 1.2
+ postprocess_params["shrink_ratio_of_width"] = 0.2
+ else:
+ postprocess_params["expand_scale"] = 1.0
+ postprocess_params["shrink_ratio_of_width"] = 0.3
+ else:
+ logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
+ sys.exit(0)
+
+ self.preprocess_op = create_operators(pre_process_list)
+ self.postprocess_op = build_post_process(postprocess_params)
+ self.predictor, self.input_tensor, self.output_tensors = utility.create_predictor(
+ args, 'e2e', logger) # paddle.jit.load(args.det_model_dir)
+ # self.predictor.eval()
+
+ def clip_det_res(self, points, img_height, img_width):
+ for pno in range(points.shape[0]):
+ points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))
+ points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))
+ return points
+
+ def filter_tag_det_res_only_clip(self, dt_boxes, image_shape):
+ img_height, img_width = image_shape[0:2]
+ dt_boxes_new = []
+ for box in dt_boxes:
+ box = self.clip_det_res(box, img_height, img_width)
+ dt_boxes_new.append(box)
+ dt_boxes = np.array(dt_boxes_new)
+ return dt_boxes
+
+ def __call__(self, img):
+ ori_im = img.copy()
+ data = {'image': img}
+ data = transform(data, self.preprocess_op)
+ img, shape_list = data
+ if img is None:
+ return None, 0
+ img = np.expand_dims(img, axis=0)
+ print(img.shape)
+ shape_list = np.expand_dims(shape_list, axis=0)
+ img = img.copy()
+ starttime = time.time()
+
+ self.input_tensor.copy_from_cpu(img)
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+
+ preds = {}
+ if self.e2e_algorithm == 'PGNet':
+ preds['f_score'] = outputs[0]
+ preds['f_border'] = outputs[1]
+ preds['f_direction'] = outputs[2]
+ preds['f_char'] = outputs[3]
+ else:
+ raise NotImplementedError
+
+ post_result = self.postprocess_op(preds, shape_list)
+ points, strs = post_result['points'], post_result['strs']
+ dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
+ elapse = time.time() - starttime
+ return dt_boxes, strs, elapse
+
+
+if __name__ == "__main__":
+ args = utility.parse_args()
+ image_file_list = get_image_file_list(args.image_dir)
+ text_detector = TextE2e(args)
+ count = 0
+ total_time = 0
+ draw_img_save = "./inference_results"
+ if not os.path.exists(draw_img_save):
+ os.makedirs(draw_img_save)
+ for image_file in image_file_list:
+ img, flag = check_and_read_gif(image_file)
+ if not flag:
+ img = cv2.imread(image_file)
+ if img is None:
+ logger.info("error in loading image:{}".format(image_file))
+ continue
+ points, strs, elapse = text_detector(img)
+ if count > 0:
+ total_time += elapse
+ count += 1
+ logger.info("Predict time of {}: {}".format(image_file, elapse))
+ src_im = utility.draw_e2e_res(points, strs, image_file)
+ img_name_pure = os.path.split(image_file)[-1]
+ img_path = os.path.join(draw_img_save,
+ "e2e_res_{}".format(img_name_pure))
+ cv2.imwrite(img_path, src_im)
+ logger.info("The visualized image saved in {}".format(img_path))
+ if count > 1:
+ logger.info("Avg Time: {}".format(total_time / (count - 1)))
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index a4a91efdd2ec04e2a2959c77a444549ad413c13d..9aa0afed635481859cd31d461a97c451ca72acdc 100755
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -74,6 +74,21 @@ def parse_args():
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5)
+ # params for e2e
+ parser.add_argument("--e2e_algorithm", type=str, default='PGNet')
+ parser.add_argument("--e2e_model_dir", type=str)
+ parser.add_argument("--e2e_limit_side_len", type=float, default=768)
+ parser.add_argument("--e2e_limit_type", type=str, default='max')
+
+ # PGNet parmas
+ parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
+ parser.add_argument(
+ "--e2e_char_dict_path",
+ type=str,
+ default="./ppocr/utils/pgnet_dict.txt")
+ parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
+ parser.add_argument("--e2e_pgnet_polygon", type=bool, default=False)
+
# params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
parser.add_argument("--cls_model_dir", type=str)
@@ -93,8 +108,10 @@ def create_predictor(args, mode, logger):
model_dir = args.det_model_dir
elif mode == 'cls':
model_dir = args.cls_model_dir
- else:
+ elif mode == 'rec':
model_dir = args.rec_model_dir
+ else:
+ model_dir = args.e2e_model_dir
if model_dir is None:
logger.info("not find {} model file path {}".format(mode, model_dir))
@@ -147,6 +164,22 @@ def create_predictor(args, mode, logger):
return predictor, input_tensor, output_tensors
+def draw_e2e_res(dt_boxes, strs, img_path):
+ src_im = cv2.imread(img_path)
+ for box, str in zip(dt_boxes, strs):
+ box = box.astype(np.int32).reshape((-1, 1, 2))
+ cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
+ cv2.putText(
+ src_im,
+ str,
+ org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
+ fontFace=cv2.FONT_HERSHEY_COMPLEX,
+ fontScale=0.7,
+ color=(0, 255, 0),
+ thickness=1)
+ return src_im
+
+
def draw_text_det_res(dt_boxes, img_path):
src_im = cv2.imread(img_path)
for box in dt_boxes:
diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py
index c40b8e02341afd8d1204aa5335bb7b0963e5899a..b7503adb94eb797d4fb12cf47b377fa72d02158b 100755
--- a/tools/infer_e2e.py
+++ b/tools/infer_e2e.py
@@ -71,7 +71,8 @@ def main():
init_model(config, model, logger)
# build post process
- post_process_class = build_post_process(config['PostProcess'])
+ post_process_class = build_post_process(config['PostProcess'],
+ global_config)
# create data ops
transforms = []
diff --git a/train_data/total_text/train/poly/2.txt b/train_data/total_text/train/poly/2.txt
new file mode 100644
index 0000000000000000000000000000000000000000..961d9680b203a498405fb2276e03adcc9ba3ec8c
--- /dev/null
+++ b/train_data/total_text/train/poly/2.txt
@@ -0,0 +1,2 @@
+2.0,165.0,20.0,167.0,39.0,170.0,57.0,173.0,76.0,176.0,94.0,179.0,113.0,182.0,109.0,218.0,90.0,215.0,72.0,213.0,54.0,210.0,36.0,208.0,18.0,205.0,0.0,203.0 izza
+2.0,411.0,30.0,412.0,58.0,414.0,87.0,416.0,115.0,418.0,143.0,420.0,172.0,422.0,172.0,476.0,143.0,474.0,114.0,472.0,86.0,471.0,57.0,469.0,28.0,467.0,0.0,466.0 ISA
diff --git a/train_data/total_text/train/rgb/2.jpg b/train_data/total_text/train/rgb/2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..f3bc7a06e911ef87c0831e779d20e44b6d2bbea5
Binary files /dev/null and b/train_data/total_text/train/rgb/2.jpg differ