diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml
index 08c485a638759e6436bd1613fff81fa14c8a6db8..be7529d71367b2ba6d0207c443a2ea55c710a8fe 100644
--- a/configs/e2e/e2e_r50_vd_pg.yml
+++ b/configs/e2e/e2e_r50_vd_pg.yml
@@ -18,11 +18,13 @@ Global:
save_inference_dir:
use_visualdl: False
infer_img:
- valid_set: totaltext #two mode: totaltext valid curved words, partvgg valid non-curved words
+ valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words
save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
- character_dict_path: ppocr/utils/pgnet_dict.txt
+ character_dict_path: ppocr/utils/ic15_dict.txt
character_type: EN
- max_text_length: 50
+ max_text_length: 50 # the max length in seq
+ max_text_nums: 30 # the max seq nums in a pic
+ tcl_len: 64
Architecture:
model_type: e2e
@@ -33,13 +35,15 @@ Architecture:
layers: 50
Neck:
name: PGFPN
- model_name: large
Head:
name: PGHead
- model_name: large
Loss:
name: PGLoss
+ tcl_bs: 64
+ max_text_length: 50 # the same as Global: max_text_length
+ max_text_nums: 30 # the same as Global:max_text_nums
+ pad_num: 36 # the length of dict for pad
Optimizer:
name: Adam
@@ -54,10 +58,10 @@ Optimizer:
PostProcess:
name: PGPostProcess
- score_thresh: 0.8
+ score_thresh: 0.5
Metric:
name: E2EMetric
- character_dict_path: ppocr/utils/pgnet_dict.txt
+ character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e
Train:
diff --git a/doc/doc_ch/e2e.md b/doc/doc_ch/e2e.md
index a0695697e39345b391a9bb37114136ee8e5743dc..3927865de982d8fbc9472f5afda322338e48c503 100644
--- a/doc/doc_ch/e2e.md
+++ b/doc/doc_ch/e2e.md
@@ -9,8 +9,10 @@
解压数据集和下载标注文件后,PaddleOCR/train_data/part_vgg_synth/train/ 有一个文件夹和一个文件,分别是:
```
/PaddleOCR/train_data/part_vgg_synth/train/
- └─ image/ partvgg数据集的训练数据
- └─ train_annotation_info.txt partvgg数据集的测试标注
+ |- image/ partvgg数据集的训练数据
+ |- 119_nile_110_31.png
+ | ...
+ |- train_annotation_info.txt partvgg数据集的测试标注
```
提供的标注文件格式如下,中间用"\t"分隔:
@@ -18,7 +20,7 @@
" 图像文件名 图像标注信息--四点标注 图像标注信息--识别标注
119_nile_110_31 140.2 222.5 266.0 194.6 278.7 251.8 152.9 279.7 Path: 32.9 133.1 106.0 130.8 106.4 143.8 33.3 146.1 were 21.8 81.9 106.9 80.4 107.7 123.2 22.6 124.7 why
```
-标注文件txt当中,其中每一行代表一组数据,以第一行为例。第一个代表同级目录image/下面的文件名, 后面每9个代表一组标注信息,前8个代表文本框的四个点坐标(x,y),从左上角的点开始顺时针排列。
+标注文件txt当中,其中每一行代表一组数据,以第一行为例。第一个代表同级目录image/下面的文件名前缀, 后面每9个代表一组标注信息,前8个代表文本框的四个点坐标(x,y),从左上角的点开始顺时针排列。
最后一个代表文字的识别结果,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
@@ -26,8 +28,12 @@
解压数据集和下载标注文件后,PaddleOCR/train_data/total_text/train/ 有两个文件夹,分别是:
```
/PaddleOCR/train_data/total_text/train/
- └─ rgb/ total_text数据集的训练数据
- └─ poly/ total_text数据集的测试标注
+ |- rgb/ total_text数据集的训练数据
+ |- gt_0.png
+ | ...
+ |-poly/ total_text数据集的测试标注
+ |- gt_0.txt
+ | ...
```
提供的标注文件格式如下,中间用"\t"分隔:
@@ -36,7 +42,7 @@
1004.0,689.0,1019.0,698.0,1034.0,708.0,1049.0,718.0,1064.0,728.0,1079.0,738.0,1095.0,748.0,1094.0,774.0,1079.0,765.0,1065.0,756.0,1050.0,747.0,1036.0,738.0,1021.0,729.0,1007.0,721.0 EST
1102.0,755.0,1116.0,764.0,1131.0,773.0,1146.0,783.0,1161.0,792.0,1176.0,801.0,1191.0,811.0,1193.0,837.0,1178.0,828.0,1164.0,819.0,1150.0,810.0,1135.0,801.0,1121.0,792.0,1107.0,784.0 1972
```
-标注文件当中,其中每一个txt文件代表一组数据,文件名同级目录rgb/下面的文件名。以第一行为例,前面28个代表文本框的十四个点坐标(x,y),从左上角的点开始顺时针排列。
+标注文件当中,其中每一个txt文件代表一组数据,文件名就是同级目录rgb/下面的文件名。以第一行为例,前面28个代表文本框的十四个点坐标(x,y),从左上角的点开始顺时针排列。
最后一个代表文字的识别结果,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
如果您想在其他数据集上训练,可以按照上述形式构建标注文件。
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index 40f4d8c5119fb4be72573dd6a1f99ca59aeaf7aa..f06524090871f55dac5b2b3ef99bdce0c0ace749 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -29,7 +29,7 @@ inference 模型(`paddle.jit.save`保存的模型)
- [5. 多语言模型的推理](#多语言模型的推理)
- [四、端到端模型推理](#端到端模型推理)
- - [1. PGNet端到端模型推理](#SAST文本检测模型推理)
+ - [1. PGNet端到端模型推理](#PGNet端到端模型推理)
- [五、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
@@ -366,7 +366,7 @@ Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
## 四、端到端模型推理
端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
-
+
### 1. PGNet端到端模型推理
#### (1). 四边形文本检测模型(ICDAR2015)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
@@ -375,28 +375,26 @@ python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrai
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
```
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e_pgnet_ic15/"
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
-![](../imgs_results/det_res_img_10_sast.jpg)
+![](../imgs_results/e2e_res_img_10_pgnet.jpg)
#### (2). 弯曲文本检测模型(Total-Text)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
```
-python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e_pgnet_tt
+python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
```
-python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_pgnet_tt/" --e2e_pgnet_polygon=True
+python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
```
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
-![](../imgs_results/e2e_res_img623_pg.jpg)
-
-**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
+![](../imgs_results/e2e_res_img623_pgnet.jpg)
diff --git a/doc/imgs_results/e2e_res_img623_pg.jpg b/doc/imgs_results/e2e_res_img623_pg.jpg
deleted file mode 100644
index 84fca124363353313750984b4cf64ce2c2cad70b..0000000000000000000000000000000000000000
Binary files a/doc/imgs_results/e2e_res_img623_pg.jpg and /dev/null differ
diff --git a/doc/imgs_results/e2e_res_img623_pgnet.jpg b/doc/imgs_results/e2e_res_img623_pgnet.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b45dc05f7bfae05bbaa338f59397c2458b80638b
Binary files /dev/null and b/doc/imgs_results/e2e_res_img623_pgnet.jpg differ
diff --git a/doc/imgs_results/e2e_res_img_10_pgnet.jpg b/doc/imgs_results/e2e_res_img_10_pgnet.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a0962993f81628da04f0838aed6240599f9eaec2
Binary files /dev/null and b/doc/imgs_results/e2e_res_img_10_pgnet.jpg differ
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index 6c2fc8e4bbfcedb5150dd7baf13db267d1d74aa2..cbb110090cfff3ebee4b30b009f88fc9aaba1617 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -197,17 +197,17 @@ class E2ELabelEncode(BaseRecLabelEncode):
super(E2ELabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
+ self.pad_num = len(self.dict) # the length to pad
def __call__(self, data):
texts = data['strs']
temp_texts = []
for text in texts:
- text = text.upper()
+ text = text.lower()
text = self.encode(text)
if text is None:
return None
- text = text + [36] * (self.max_text_len - len(text)
- ) # use 36 to pad
+ text = text + [self.pad_num] * (self.max_text_len - len(text))
temp_texts.append(text)
data['strs'] = np.array(temp_texts)
return data
diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py
index 58837d7bfa3643babf0dd66951d6b15c5e32865e..09382f7ed7e6c0c6bd9ff704cf42358a61c4165a 100644
--- a/ppocr/data/imaug/pg_process.py
+++ b/ppocr/data/imaug/pg_process.py
@@ -22,16 +22,23 @@ __all__ = ['PGProcessTrain']
class PGProcessTrain(object):
def __init__(self,
character_dict_path,
+ max_text_length,
+ max_text_nums,
+ tcl_len,
batch_size=14,
min_crop_size=24,
min_text_size=10,
max_text_size=512,
**kwargs):
+ self.tcl_len = tcl_len
+ self.max_text_length = max_text_length
+ self.max_text_nums = max_text_nums
self.batch_size = batch_size
self.min_crop_size = min_crop_size
self.min_text_size = min_text_size
self.max_text_size = max_text_size
self.Lexicon_Table = self.get_dict(character_dict_path)
+ self.pad_num = len(self.Lexicon_Table)
self.img_id = 0
def get_dict(self, character_dict_path):
@@ -290,7 +297,7 @@ class PGProcessTrain(object):
height_list.append(quad_h)
norm_width = max(sum(width_list) / n_char, 1.0)
average_height = max(sum(height_list) / len(height_list), 1.0)
-
+ k = 1
for quad in poly_quads:
direct_vector_full = (
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
@@ -302,6 +309,8 @@ class PGProcessTrain(object):
cv2.fillPoly(direction_map,
quad.round().astype(np.int32)[np.newaxis, :, :],
direction_label)
+ cv2.imwrite("output/{}.png".format(k), direction_map * 255.0)
+ k += 1
return direction_map
def calculate_average_height(self, poly_quads):
@@ -371,7 +380,6 @@ class PGProcessTrain(object):
continue
if tag:
- # continue
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
else:
@@ -577,7 +585,7 @@ class PGProcessTrain(object):
Prepare text lablel by given Lexicon_Table.
"""
if len(Lexicon_Table) == 36:
- return label_str.upper()
+ return label_str.lower()
else:
return label_str
@@ -846,23 +854,23 @@ class PGProcessTrain(object):
return None
pos_list_temp = np.zeros([64, 3])
pos_mask_temp = np.zeros([64, 1])
- label_list_temp = np.zeros([50, 1]) + 36
+ label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
for i, label in enumerate(label_list):
n = len(label)
- if n > 50:
- label_list[i] = label[:50]
+ if n > self.max_text_length:
+ label_list[i] = label[:self.max_text_length]
continue
- while n < 50:
- label.append([36])
+ while n < self.max_text_length:
+ label.append([self.pad_num])
n += 1
for i in range(len(label_list)):
label_list[i] = np.array(label_list[i])
- if len(pos_list) <= 0 or len(pos_list) > 30: #一张图片中最多存在30行文本
+ if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
return None
- for __ in range(30 - len(pos_list), 0, -1):
+ for __ in range(self.max_text_nums - len(pos_list), 0, -1):
pos_list.append(pos_list_temp)
pos_mask.append(pos_mask_temp)
label_list.append(label_list_temp)
diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py
index 2e0989fd29d7d808d8bf43c259dbb68dad0c9294..3f1e254592203e6d3d992e8ccb7025e7aa92bb57 100644
--- a/ppocr/data/pgnet_dataset.py
+++ b/ppocr/data/pgnet_dataset.py
@@ -156,6 +156,7 @@ class PGDataSet(Dataset):
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
+
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
diff --git a/ppocr/losses/e2e_pg_loss.py b/ppocr/losses/e2e_pg_loss.py
index be4614e70e41d2ac7962920844cf30327a8407a3..680ab0e60a394df0b4f86d334c616ca338ec5d93 100644
--- a/ppocr/losses/e2e_pg_loss.py
+++ b/ppocr/losses/e2e_pg_loss.py
@@ -18,102 +18,26 @@ from __future__ import print_function
from paddle import nn
import paddle
-import numpy as np
-import copy
from .det_basic_loss import DiceLoss
+from ppocr.utils.e2e_utils.extract_batchsize import *
class PGLoss(nn.Layer):
- def __init__(self, eps=1e-6, **kwargs):
+ def __init__(self,
+ tcl_bs,
+ max_text_length,
+ max_text_nums,
+ pad_num,
+ eps=1e-6,
+ **kwargs):
super(PGLoss, self).__init__()
+ self.tcl_bs = tcl_bs
+ self.max_text_nums = max_text_nums
+ self.max_text_length = max_text_length
+ self.pad_num = pad_num
self.dice_loss = DiceLoss(eps=eps)
- def org_tcl_rois(self, batch_size, pos_lists, pos_masks, label_lists):
- """
- """
- pos_lists_, pos_masks_, label_lists_ = [], [], []
- img_bs = batch_size
- tcl_bs = 64
- ngpu = int(batch_size / img_bs)
- img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
- pos_lists_split, pos_masks_split, label_lists_split = [], [], []
- for i in range(ngpu):
- pos_lists_split.append([])
- pos_masks_split.append([])
- label_lists_split.append([])
-
- for i in range(img_ids.shape[0]):
- img_id = img_ids[i]
- gpu_id = int(img_id / img_bs)
- img_id = img_id % img_bs
- pos_list = pos_lists[i].copy()
- pos_list[:, 0] = img_id
- pos_lists_split[gpu_id].append(pos_list)
- pos_masks_split[gpu_id].append(pos_masks[i].copy())
- label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
- # repeat or delete
- for i in range(ngpu):
- vp_len = len(pos_lists_split[i])
- if vp_len <= tcl_bs:
- for j in range(0, tcl_bs - vp_len):
- pos_list = pos_lists_split[i][j].copy()
- pos_lists_split[i].append(pos_list)
- pos_mask = pos_masks_split[i][j].copy()
- pos_masks_split[i].append(pos_mask)
- label_list = copy.deepcopy(label_lists_split[i][j])
- label_lists_split[i].append(label_list)
- else:
- for j in range(0, vp_len - tcl_bs):
- c_len = len(pos_lists_split[i])
- pop_id = np.random.permutation(c_len)[0]
- pos_lists_split[i].pop(pop_id)
- pos_masks_split[i].pop(pop_id)
- label_lists_split[i].pop(pop_id)
- # merge
- for i in range(ngpu):
- pos_lists_.extend(pos_lists_split[i])
- pos_masks_.extend(pos_masks_split[i])
- label_lists_.extend(label_lists_split[i])
- return pos_lists_, pos_masks_, label_lists_
-
- def pre_process(self, label_list, pos_list, pos_mask):
- max_len = 30 # the max texts in a single image
- max_str_len = 50 # the max len in a single text
- pad_num = 36 # padding num
- label_list = label_list.numpy()
- batch, _, _, _ = label_list.shape
- pos_list = pos_list.numpy()
- pos_mask = pos_mask.numpy()
- pos_list_t = []
- pos_mask_t = []
- label_list_t = []
- for i in range(batch):
- for j in range(max_len):
- if pos_mask[i, j].any():
- pos_list_t.append(pos_list[i][j])
- pos_mask_t.append(pos_mask[i][j])
- label_list_t.append(label_list[i][j])
- pos_list, pos_mask, label_list = self.org_tcl_rois(
- batch, pos_list_t, pos_mask_t, label_list_t)
- label = []
- tt = [l.tolist() for l in label_list]
- for i in range(batch):
- k = 0
- for j in range(max_str_len):
- if tt[i][j][0] != pad_num:
- k += 1
- else:
- break
- label.append(k)
- label = paddle.to_tensor(label)
- label = paddle.cast(label, dtype='int64')
- pos_list = paddle.to_tensor(pos_list)
- pos_mask = paddle.to_tensor(pos_mask)
- label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
- label_list = paddle.cast(label_list, dtype='int32')
- return pos_list, pos_mask, label_list, label
-
def border_loss(self, f_border, l_border, l_score, l_mask):
l_border_split, l_border_norm = paddle.tensor.split(
l_border, num_or_sections=[4, 1], axis=1)
@@ -183,7 +107,7 @@ class PGLoss(nn.Layer):
labels=tcl_label,
input_lengths=input_lengths,
label_lengths=label_t,
- blank=36,
+ blank=self.pad_num,
reduction='none')
cost = cost.mean()
return cost
@@ -192,12 +116,14 @@ class PGLoss(nn.Layer):
images, tcl_maps, tcl_label_maps, border_maps \
, direction_maps, training_masks, label_list, pos_list, pos_mask = labels
# for all the batch_size
- pos_list, pos_mask, label_list, label_t = self.pre_process(
- label_list, pos_list, pos_mask)
+ pos_list, pos_mask, label_list, label_t = pre_process(
+ label_list, pos_list, pos_mask, self.max_text_length,
+ self.max_text_nums, self.pad_num, self.tcl_bs)
- f_score, f_boder, f_direction, f_char = predicts
+ f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
+ predicts['f_char']
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
- border_loss = self.border_loss(f_boder, border_maps, tcl_maps,
+ border_loss = self.border_loss(f_border, border_maps, tcl_maps,
training_masks)
direction_loss = self.direction_loss(f_direction, direction_maps,
tcl_maps, training_masks)
diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py
index a3bc39aa6711f79f172e86912c42019d92543ed4..0da9de7580a0ceb473f971b2246c966497026a5d 100644
--- a/ppocr/modeling/heads/e2e_pg_head.py
+++ b/ppocr/modeling/heads/e2e_pg_head.py
@@ -66,9 +66,8 @@ class PGHead(nn.Layer):
"""
"""
- def __init__(self, in_channels, model_name, **kwargs):
+ def __init__(self, in_channels, **kwargs):
super(PGHead, self).__init__()
- self.model_name = model_name
self.conv_f_score1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py
index 6d1b7d7a106c76fcf7104abe432f99588a4043eb..2cc7dc24dc69db46fdca85a98137a9194ae1fc0b 100644
--- a/ppocr/postprocess/pg_postprocess.py
+++ b/ppocr/postprocess/pg_postprocess.py
@@ -23,8 +23,7 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
-from ppocr.utils.e2e_utils.extract_textpoint import *
-from ppocr.utils.e2e_utils.visual import *
+from ppocr.utils.e2e_utils.extract_textpoint import get_dict, generate_pivot_list, restore_poly
import paddle
@@ -34,16 +33,10 @@ class PGPostProcess(object):
"""
def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs):
-
self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set
self.score_thresh = score_thresh
- # c++ la-nms is faster, but only support python 3.5
- self.is_python35 = False
- if sys.version_info.major == 3 and sys.version_info.minor == 5:
- self.is_python35 = True
-
def __call__(self, outs_dict, shape_list):
p_score = outs_dict['f_score']
p_border = outs_dict['f_border']
@@ -61,96 +54,15 @@ class PGPostProcess(object):
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = shape_list[0]
- is_curved = self.valid_set == "totaltext"
- instance_yxs_list = generate_pivot_list(
+ instance_yxs_list, seq_strs = generate_pivot_list(
p_score,
p_char,
p_direction,
- score_thresh=self.score_thresh,
- is_backbone=True,
- is_curved=is_curved)
- p_char = np.expand_dims(p_char, axis=0)
- p_char = paddle.to_tensor(p_char)
- char_seq_idx_set = []
- for i in range(len(instance_yxs_list)):
- gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
- f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
- featyre_seq = paddle.gather_nd(f_char_map, gather_info_lod)
- featyre_seq = np.expand_dims(featyre_seq.numpy(), axis=0)
- t = len(featyre_seq[0])
- featyre_seq = paddle.to_tensor(featyre_seq)
- l = np.array([[t]]).astype(np.int64)
- length = paddle.to_tensor(l)
- seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
- input=featyre_seq, blank=36, input_length=length)
- seq_pred1 = seq_pred[0].numpy().tolist()[0]
- seq_len = seq_pred[1].numpy()[0][0]
- temp_t = []
- for x in seq_pred1[:seq_len]:
- temp_t.append(x)
- char_seq_idx_set.append(temp_t)
- seq_strs = []
- for char_idx_set in char_seq_idx_set:
- pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
- seq_strs.append(pr_str)
- poly_list = []
- keep_str_list = []
- all_point_list = []
- all_point_pair_list = []
- for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
- if len(yx_center_line) == 1:
- yx_center_line.append(yx_center_line[-1])
-
- offset_expand = 1.0
- if self.valid_set == 'totaltext':
- offset_expand = 1.2
-
- point_pair_list = []
- for batch_id, y, x in yx_center_line:
- offset = p_border[:, y, x].reshape(2, 2)
- if offset_expand != 1.0:
- offset_length = np.linalg.norm(
- offset, axis=1, keepdims=True)
- expand_length = np.clip(
- offset_length * (offset_expand - 1),
- a_min=0.5,
- a_max=3.0)
- offset_detal = offset / offset_length * expand_length
- offset = offset + offset_detal
- ori_yx = np.array([y, x], dtype=np.float32)
- point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
- [ratio_w, ratio_h]).reshape(-1, 2)
- point_pair_list.append(point_pair)
-
- all_point_list.append([
- int(round(x * 4.0 / ratio_w)),
- int(round(y * 4.0 / ratio_h))
- ])
- all_point_pair_list.append(point_pair.round().astype(np.int32)
- .tolist())
-
- detected_poly, pair_length_info = point_pair2poly(point_pair_list)
- detected_poly = expand_poly_along_width(
- detected_poly, shrink_ratio_of_width=0.2)
- detected_poly[:, 0] = np.clip(
- detected_poly[:, 0], a_min=0, a_max=src_w)
- detected_poly[:, 1] = np.clip(
- detected_poly[:, 1], a_min=0, a_max=src_h)
-
- if len(keep_str) < 2:
- continue
-
- keep_str_list.append(keep_str)
- if self.valid_set == 'partvgg':
- middle_point = len(detected_poly) // 2
- detected_poly = detected_poly[
- [0, middle_point - 1, middle_point, -1], :]
- poly_list.append(detected_poly)
- elif self.valid_set == 'totaltext':
- poly_list.append(detected_poly)
- else:
- print('--> Not supported format.')
- exit(-1)
+ self.Lexicon_Table,
+ score_thresh=self.score_thresh)
+ poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
+ p_border, ratio_w, ratio_h,
+ src_w, src_h, self.valid_set)
data = {
'points': poly_list,
'strs': keep_str_list,
diff --git a/ppocr/utils/e2e_utils/extract_batchsize.py b/ppocr/utils/e2e_utils/extract_batchsize.py
new file mode 100644
index 0000000000000000000000000000000000000000..e99a833ea76a81e02d39b16fe1a01e22f15bf3a4
--- /dev/null
+++ b/ppocr/utils/e2e_utils/extract_batchsize.py
@@ -0,0 +1,87 @@
+import paddle
+import numpy as np
+import copy
+
+
+def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
+ """
+ """
+ pos_lists_, pos_masks_, label_lists_ = [], [], []
+ img_bs = batch_size
+ ngpu = int(batch_size / img_bs)
+ img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
+ pos_lists_split, pos_masks_split, label_lists_split = [], [], []
+ for i in range(ngpu):
+ pos_lists_split.append([])
+ pos_masks_split.append([])
+ label_lists_split.append([])
+
+ for i in range(img_ids.shape[0]):
+ img_id = img_ids[i]
+ gpu_id = int(img_id / img_bs)
+ img_id = img_id % img_bs
+ pos_list = pos_lists[i].copy()
+ pos_list[:, 0] = img_id
+ pos_lists_split[gpu_id].append(pos_list)
+ pos_masks_split[gpu_id].append(pos_masks[i].copy())
+ label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
+ # repeat or delete
+ for i in range(ngpu):
+ vp_len = len(pos_lists_split[i])
+ if vp_len <= tcl_bs:
+ for j in range(0, tcl_bs - vp_len):
+ pos_list = pos_lists_split[i][j].copy()
+ pos_lists_split[i].append(pos_list)
+ pos_mask = pos_masks_split[i][j].copy()
+ pos_masks_split[i].append(pos_mask)
+ label_list = copy.deepcopy(label_lists_split[i][j])
+ label_lists_split[i].append(label_list)
+ else:
+ for j in range(0, vp_len - tcl_bs):
+ c_len = len(pos_lists_split[i])
+ pop_id = np.random.permutation(c_len)[0]
+ pos_lists_split[i].pop(pop_id)
+ pos_masks_split[i].pop(pop_id)
+ label_lists_split[i].pop(pop_id)
+ # merge
+ for i in range(ngpu):
+ pos_lists_.extend(pos_lists_split[i])
+ pos_masks_.extend(pos_masks_split[i])
+ label_lists_.extend(label_lists_split[i])
+ return pos_lists_, pos_masks_, label_lists_
+
+
+def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
+ pad_num, tcl_bs):
+ label_list = label_list.numpy()
+ batch, _, _, _ = label_list.shape
+ pos_list = pos_list.numpy()
+ pos_mask = pos_mask.numpy()
+ pos_list_t = []
+ pos_mask_t = []
+ label_list_t = []
+ for i in range(batch):
+ for j in range(max_text_nums):
+ if pos_mask[i, j].any():
+ pos_list_t.append(pos_list[i][j])
+ pos_mask_t.append(pos_mask[i][j])
+ label_list_t.append(label_list[i][j])
+ pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
+ label_list_t, tcl_bs)
+ label = []
+ tt = [l.tolist() for l in label_list]
+ for i in range(tcl_bs):
+ k = 0
+ for j in range(max_text_length):
+ if tt[i][j][0] != pad_num:
+ k += 1
+ else:
+ break
+ label.append(k)
+ label = paddle.to_tensor(label)
+ label = paddle.cast(label, dtype='int64')
+ pos_list = paddle.to_tensor(pos_list)
+ pos_mask = paddle.to_tensor(pos_mask)
+ label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
+ label_list = paddle.cast(label_list, dtype='int32')
+ return pos_list, pos_mask, label_list, label
diff --git a/ppocr/utils/e2e_utils/extract_textpoint.py b/ppocr/utils/e2e_utils/extract_textpoint.py
index 5355280946c0eadc8bc097e5409f755e8a390e5a..d64f1e838e6d2b976e42ccf1954976e0564b5aa0 100644
--- a/ppocr/utils/e2e_utils/extract_textpoint.py
+++ b/ppocr/utils/e2e_utils/extract_textpoint.py
@@ -17,11 +17,9 @@ from __future__ import division
from __future__ import print_function
import cv2
-import math
-
import numpy as np
from itertools import groupby
-from skimage.morphology._skeletonize import thin
+from cv2.ximgproc import thinning as thin
def get_dict(character_dict_path):
@@ -35,87 +33,39 @@ def get_dict(character_dict_path):
return dict_character
-def softmax(logits):
- """
- logits: N x d
- """
- max_value = np.max(logits, axis=1, keepdims=True)
- exp = np.exp(logits - max_value)
- exp_sum = np.sum(exp, axis=1, keepdims=True)
- dist = exp / exp_sum
- return dist
-
-
-def get_keep_pos_idxs(labels, remove_blank=None):
- """
- Remove duplicate and get pos idxs of keep items.
- The value of keep_blank should be [None, 95].
- """
- duplicate_len_list = []
- keep_pos_idx_list = []
- keep_char_idx_list = []
- for k, v_ in groupby(labels):
- current_len = len(list(v_))
- if k != remove_blank:
- current_idx = int(sum(duplicate_len_list) + current_len // 2)
- keep_pos_idx_list.append(current_idx)
- keep_char_idx_list.append(k)
- duplicate_len_list.append(current_len)
- return keep_char_idx_list, keep_pos_idx_list
-
-
-def remove_blank(labels, blank=0):
- new_labels = [x for x in labels if x != blank]
- return new_labels
-
-
-def insert_blank(labels, blank=0):
- new_labels = [blank]
- for l in labels:
- new_labels += [l, blank]
- return new_labels
-
-
-def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
- """
- CTC greedy (best path) decoder.
- """
- raw_str = np.argmax(np.array(probs_seq), axis=1)
- remove_blank_in_pos = None if keep_blank_in_idxs else blank
- dedup_str, keep_idx_list = get_keep_pos_idxs(
- raw_str, remove_blank=remove_blank_in_pos)
- dst_str = remove_blank(dedup_str, blank=blank)
- return dst_str, keep_idx_list
-
-
-def instance_ctc_greedy_decoder(gather_info,
- logits_map,
- keep_blank_in_idxs=True):
- """
- gather_info: [[x, y], [x, y] ...]
- logits_map: H x W X (n_chars + 1)
- """
+def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
_, _, C = logits_map.shape
ys, xs = zip(*gather_info)
- logits_seq = logits_map[list(ys), list(xs)] # n x 96
- probs_seq = softmax(logits_seq)
- dst_str, keep_idx_list = ctc_greedy_decoder(
- probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
+ logits_seq = logits_map[list(ys), list(xs)]
+ probs_seq = logits_seq
+ labels = np.argmax(probs_seq, axis=1)
+ dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
+ detal = len(gather_info) // (pts_num - 1)
+ keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list
-def ctc_decoder_for_image(gather_info_list, logits_map,
- keep_blank_in_idxs=True):
+def ctc_decoder_for_image(gather_info_list,
+ logits_map,
+ Lexicon_Table,
+ pts_num=6):
"""
CTC decoder using multiple processes.
"""
- decoder_results = []
+ decoder_str = []
+ decoder_xys = []
for gather_info in gather_info_list:
- res = instance_ctc_greedy_decoder(
- gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
- decoder_results.append(res)
- return decoder_results
+ if len(gather_info) < pts_num:
+ continue
+ dst_str, xys_list = instance_ctc_greedy_decoder(
+ gather_info, logits_map, pts_num=pts_num)
+ dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
+ if len(dst_str_readable) < 2:
+ continue
+ decoder_str.append(dst_str_readable)
+ decoder_xys.append(xys_list)
+ return decoder_str, decoder_xys
def sort_with_direction(pos_list, f_direction):
@@ -157,57 +107,6 @@ def sort_with_direction(pos_list, f_direction):
return sorted_point, np.array(sorted_direction)
-def add_id(pos_list, image_id=0):
- """
- Add id for gather feature, for inference.
- """
- new_list = []
- for item in pos_list:
- new_list.append((image_id, item[0], item[1]))
- return new_list
-
-
-def sort_and_expand_with_direction(pos_list, f_direction):
- """
- f_direction: h x w x 2
- pos_list: [[y, x], [y, x], [y, x] ...]
- """
- h, w, _ = f_direction.shape
- sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
-
- point_num = len(sorted_list)
- sub_direction_len = max(point_num // 3, 2)
- left_direction = point_direction[:sub_direction_len, :]
- right_dirction = point_direction[point_num - sub_direction_len:, :]
-
- left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
- left_average_len = np.linalg.norm(left_average_direction)
- left_start = np.array(sorted_list[0])
- left_step = left_average_direction / (left_average_len + 1e-6)
-
- right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
- right_average_len = np.linalg.norm(right_average_direction)
- right_step = right_average_direction / (right_average_len + 1e-6)
- right_start = np.array(sorted_list[-1])
-
- append_num = max(
- int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
- left_list = []
- right_list = []
- for i in range(append_num):
- ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
- 'int32').tolist()
- if ly < h and lx < w and (ly, lx) not in left_list:
- left_list.append((ly, lx))
- ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
- 'int32').tolist()
- if ry < h and rx < w and (ry, rx) not in right_list:
- right_list.append((ry, rx))
-
- all_list = left_list[::-1] + sorted_list + right_list
- return all_list
-
-
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
"""
f_direction: h x w x 2
@@ -260,262 +159,125 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
return all_list
-def generate_pivot_list_curved(p_score,
- p_char_maps,
- f_direction,
- score_thresh=0.5,
- is_expand=True,
- is_backbone=False,
- image_id=0):
- """
- return center point and end point of TCL instance; filter with the char maps;
- """
- p_score = p_score[0]
- f_direction = f_direction.transpose(1, 2, 0)
- p_tcl_map = (p_score > score_thresh) * 1.0
- skeleton_map = thin(p_tcl_map)
- instance_count, instance_label_map = cv2.connectedComponents(
- skeleton_map.astype(np.uint8), connectivity=8)
-
- all_pos_yxs = []
- center_pos_yxs = []
- end_points_yxs = []
- instance_center_pos_yxs = []
- if instance_count > 0:
- for instance_id in range(1, instance_count):
- pos_list = []
- ys, xs = np.where(instance_label_map == instance_id)
- pos_list = list(zip(ys, xs))
-
- if len(pos_list) < 3:
- continue
-
- if is_expand:
- pos_list_sorted = sort_and_expand_with_direction_v2(
- pos_list, f_direction, p_tcl_map)
- else:
- pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
- all_pos_yxs.append(pos_list_sorted)
-
- p_char_maps = p_char_maps.transpose([1, 2, 0])
- decode_res = ctc_decoder_for_image(
- all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
- for decoded_str, keep_yxs_list in decode_res:
- if is_backbone:
- keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
- instance_center_pos_yxs.append(keep_yxs_list_with_id)
- else:
- end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
- center_pos_yxs.extend(keep_yxs_list)
-
- if is_backbone:
- return instance_center_pos_yxs
- else:
- return center_pos_yxs, end_points_yxs
-
-
-def generate_pivot_list_horizontal(p_score,
- p_char_maps,
- f_direction,
- score_thresh=0.5,
- is_backbone=False,
- image_id=0):
- """
- return center point and end point of TCL instance; filter with the char maps;
- """
- p_score = p_score[0]
- f_direction = f_direction.transpose(1, 2, 0)
- p_tcl_map_bi = (p_score > score_thresh) * 1.0
- instance_count, instance_label_map = cv2.connectedComponents(
- p_tcl_map_bi.astype(np.uint8), connectivity=8)
-
- # get TCL Instance
- all_pos_yxs = []
- center_pos_yxs = []
- end_points_yxs = []
- instance_center_pos_yxs = []
-
- if instance_count > 0:
- for instance_id in range(1, instance_count):
- pos_list = []
- ys, xs = np.where(instance_label_map == instance_id)
- pos_list = list(zip(ys, xs))
-
- if len(pos_list) < 5:
- continue
-
- main_direction = extract_main_direction(pos_list,
- f_direction) # y x
- reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
- is_h_angle = abs(np.sum(
- main_direction * reference_directin)) < math.cos(math.pi / 180 *
- 70)
-
- point_yxs = np.array(pos_list)
- max_y, max_x = np.max(point_yxs, axis=0)
- min_y, min_x = np.min(point_yxs, axis=0)
- is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
-
- pos_list_final = []
- if is_h_len:
- xs = np.unique(xs)
- for x in xs:
- ys = instance_label_map[:, x].copy().reshape((-1, ))
- y = int(np.where(ys == instance_id)[0].mean())
- pos_list_final.append((y, x))
- else:
- ys = np.unique(ys)
- for y in ys:
- xs = instance_label_map[y, :].copy().reshape((-1, ))
- x = int(np.where(xs == instance_id)[0].mean())
- pos_list_final.append((y, x))
-
- pos_list_sorted, _ = sort_with_direction(pos_list_final,
- f_direction)
- all_pos_yxs.append(pos_list_sorted)
-
- p_char_maps = p_char_maps.transpose([1, 2, 0])
- decode_res = ctc_decoder_for_image(
- all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
- for decoded_str, keep_yxs_list in decode_res:
- if is_backbone:
- keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
- instance_center_pos_yxs.append(keep_yxs_list_with_id)
+def point_pair2poly(point_pair_list):
+ """
+ Transfer vertical point_pairs into poly point in clockwise.
+ """
+ point_num = len(point_pair_list) * 2
+ point_list = [0] * point_num
+ for idx, point_pair in enumerate(point_pair_list):
+ point_list[idx] = point_pair[0]
+ point_list[point_num - 1 - idx] = point_pair[1]
+ return np.array(point_list).reshape(-1, 2)
+
+
+def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
+ ratio_pair = np.array(
+ [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
+ p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
+ p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
+ return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
+
+
+def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
+ """
+ expand poly along width.
+ """
+ point_num = poly.shape[0]
+ left_quad = np.array(
+ [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
+ left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
+ (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
+ left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
+ right_quad = np.array(
+ [
+ poly[point_num // 2 - 2], poly[point_num // 2 - 1],
+ poly[point_num // 2], poly[point_num // 2 + 1]
+ ],
+ dtype=np.float32)
+ right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
+ (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
+ right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
+ poly[0] = left_quad_expand[0]
+ poly[-1] = left_quad_expand[-1]
+ poly[point_num // 2 - 1] = right_quad_expand[1]
+ poly[point_num // 2] = right_quad_expand[2]
+ return poly
+
+
+def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
+ src_h, valid_set):
+ poly_list = []
+ keep_str_list = []
+ for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
+ if len(keep_str) < 2:
+ print('--> too short, {}'.format(keep_str))
+ continue
+
+ offset_expand = 1.0
+ if valid_set == 'totaltext':
+ offset_expand = 1.2
+
+ point_pair_list = []
+ for y, x in yx_center_line:
+ offset = p_border[:, y, x].reshape(2, 2) * offset_expand
+ ori_yx = np.array([y, x], dtype=np.float32)
+ point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
+ [ratio_w, ratio_h]).reshape(-1, 2)
+ point_pair_list.append(point_pair)
+
+ detected_poly = point_pair2poly(point_pair_list)
+ detected_poly = expand_poly_along_width(
+ detected_poly, shrink_ratio_of_width=0.2)
+ detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
+ detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
+
+ keep_str_list.append(keep_str)
+ if valid_set == 'partvgg':
+ middle_point = len(detected_poly) // 2
+ detected_poly = detected_poly[
+ [0, middle_point - 1, middle_point, -1], :]
+ poly_list.append(detected_poly)
+ elif valid_set == 'totaltext':
+ poly_list.append(detected_poly)
else:
- end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
- center_pos_yxs.extend(keep_yxs_list)
-
- if is_backbone:
- return instance_center_pos_yxs
- else:
- return center_pos_yxs, end_points_yxs
+ print('--> Not supported format.')
+ exit(-1)
+ return poly_list, keep_str_list
def generate_pivot_list(p_score,
p_char_maps,
f_direction,
- score_thresh=0.5,
- is_backbone=False,
- is_curved=True,
- image_id=0):
- """
- Warp all the function together.
- """
- if is_curved:
- return generate_pivot_list_curved(
- p_score,
- p_char_maps,
- f_direction,
- score_thresh=score_thresh,
- is_expand=True,
- is_backbone=is_backbone,
- image_id=image_id)
- else:
- return generate_pivot_list_horizontal(
- p_score,
- p_char_maps,
- f_direction,
- score_thresh=score_thresh,
- is_backbone=is_backbone,
- image_id=image_id)
-
-
-def extract_main_direction(pos_list, f_direction):
- """
- f_direction: h x w x 2
- pos_list: [[y, x], [y, x], [y, x] ...]
- """
- pos_list = np.array(pos_list)
- point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
- point_direction = point_direction[:, ::-1] # x, y -> y, x
- average_direction = np.mean(point_direction, axis=0, keepdims=True)
- average_direction = average_direction / (
- np.linalg.norm(average_direction) + 1e-6)
- return average_direction
-
-
-def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
- """
- f_direction: h x w x 2
- pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
- """
- pos_list_full = np.array(pos_list).reshape(-1, 3)
- pos_list = pos_list_full[:, 1:]
- point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
- point_direction = point_direction[:, ::-1] # x, y -> y, x
- average_direction = np.mean(point_direction, axis=0, keepdims=True)
- pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
- sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
- return sorted_list
-
-
-def sort_by_direction_with_image_id(pos_list, f_direction):
- """
- f_direction: h x w x 2
- pos_list: [[y, x], [y, x], [y, x] ...]
- """
-
- def sort_part_with_direction(pos_list_full, point_direction):
- pos_list_full = np.array(pos_list_full).reshape(-1, 3)
- pos_list = pos_list_full[:, 1:]
- point_direction = np.array(point_direction).reshape(-1, 2)
- average_direction = np.mean(point_direction, axis=0, keepdims=True)
- pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
- sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
- sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
- return sorted_list, sorted_direction
-
- pos_list = np.array(pos_list).reshape(-1, 3)
- point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
- point_direction = point_direction[:, ::-1] # x, y -> y, x
- sorted_point, sorted_direction = sort_part_with_direction(pos_list,
- point_direction)
-
- point_num = len(sorted_point)
- if point_num >= 16:
- middle_num = point_num // 2
- first_part_point = sorted_point[:middle_num]
- first_point_direction = sorted_direction[:middle_num]
- sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
- first_part_point, first_point_direction)
-
- last_part_point = sorted_point[middle_num:]
- last_point_direction = sorted_direction[middle_num:]
- sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
- last_part_point, last_point_direction)
- sorted_point = sorted_fist_part_point + sorted_last_part_point
- sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
-
- return sorted_point
-
-
-def generate_pivot_list_tt_inference(p_score,
- p_char_maps,
- f_direction,
- score_thresh=0.5,
- is_backbone=False,
- is_curved=True,
- image_id=0):
+ Lexicon_Table,
+ score_thresh=0.5):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
- p_tcl_map = (p_score > score_thresh) * 1.0
- skeleton_map = thin(p_tcl_map)
+ ret, p_tcl_map = cv2.threshold(p_score, score_thresh, 255,
+ cv2.THRESH_BINARY)
+ skeleton_map = thin(p_tcl_map.astype('uint8'))
instance_count, instance_label_map = cv2.connectedComponents(
- skeleton_map.astype(np.uint8), connectivity=8)
+ skeleton_map, connectivity=8)
+ # get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
+
if len(pos_list) < 3:
continue
+
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
- pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
- all_pos_yxs.append(pos_list_sorted_with_id)
- return all_pos_yxs
+ all_pos_yxs.append(pos_list_sorted)
+
+ p_char_maps = p_char_maps.transpose([1, 2, 0])
+ decoded_str, keep_yxs_list = ctc_decoder_for_image(
+ all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table)
+ return keep_yxs_list, decoded_str
diff --git a/ppocr/utils/pgnet_dict.txt b/ppocr/utils/pgnet_dict.txt
deleted file mode 100644
index b52d16e64f1004e1fceccac2280bc6f6eabd6af3..0000000000000000000000000000000000000000
--- a/ppocr/utils/pgnet_dict.txt
+++ /dev/null
@@ -1,36 +0,0 @@
-0
-1
-2
-3
-4
-5
-6
-7
-8
-9
-A
-B
-C
-D
-E
-F
-G
-H
-I
-J
-K
-L
-M
-N
-O
-P
-Q
-R
-S
-T
-U
-V
-W
-X
-Y
-Z
\ No newline at end of file
diff --git a/tools/infer/predict_e2e.py b/tools/infer/predict_e2e.py
index 1a92c4ab518a1c01d9d282443c09f7e3a7ecf008..6744e7e20c64379c8b482b826066dffe64f1923a 100755
--- a/tools/infer/predict_e2e.py
+++ b/tools/infer/predict_e2e.py
@@ -39,10 +39,7 @@ class TextE2e(object):
self.args = args
self.e2e_algorithm = args.e2e_algorithm
pre_process_list = [{
- 'E2EResizeForTest': {
- 'max_side_len': 768,
- 'valid_set': 'totaltext'
- }
+ 'E2EResizeForTest': {}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
@@ -70,12 +67,6 @@ class TextE2e(object):
postprocess_params["character_dict_path"] = args.e2e_char_dict_path
postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
self.e2e_pgnet_polygon = args.e2e_pgnet_polygon
- if self.e2e_pgnet_polygon:
- postprocess_params["expand_scale"] = 1.2
- postprocess_params["shrink_ratio_of_width"] = 0.2
- else:
- postprocess_params["expand_scale"] = 1.0
- postprocess_params["shrink_ratio_of_width"] = 0.3
else:
logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
sys.exit(0)
@@ -102,6 +93,7 @@ class TextE2e(object):
return dt_boxes
def __call__(self, img):
+
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
@@ -109,7 +101,6 @@ class TextE2e(object):
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
- print(img.shape)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
starttime = time.time()
@@ -123,13 +114,12 @@ class TextE2e(object):
preds = {}
if self.e2e_algorithm == 'PGNet':
- preds['f_score'] = outputs[0]
- preds['f_border'] = outputs[1]
+ preds['f_border'] = outputs[0]
+ preds['f_char'] = outputs[1]
preds['f_direction'] = outputs[2]
- preds['f_char'] = outputs[3]
+ preds['f_score'] = outputs[3]
else:
raise NotImplementedError
-
post_result = self.postprocess_op(preds, shape_list)
points, strs = post_result['points'], post_result['strs']
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 9aa0afed635481859cd31d461a97c451ca72acdc..6cb075e8be639fffbbc5376b2fd8c6ce3597e4aa 100755
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -83,11 +83,9 @@ def parse_args():
# PGNet parmas
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
parser.add_argument(
- "--e2e_char_dict_path",
- type=str,
- default="./ppocr/utils/pgnet_dict.txt")
+ "--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
- parser.add_argument("--e2e_pgnet_polygon", type=bool, default=False)
+ parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
# params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)