提交 88964dc9 编写于 作者: J Jethong

add visual png

上级 97111112
......@@ -18,11 +18,13 @@ Global:
save_inference_dir:
use_visualdl: False
infer_img:
valid_set: totaltext #two mode: totaltext valid curved words, partvgg valid non-curved words
valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words
save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
character_dict_path: ppocr/utils/pgnet_dict.txt
character_dict_path: ppocr/utils/ic15_dict.txt
character_type: EN
max_text_length: 50
max_text_length: 50 # the max length in seq
max_text_nums: 30 # the max seq nums in a pic
tcl_len: 64
Architecture:
model_type: e2e
......@@ -33,13 +35,15 @@ Architecture:
layers: 50
Neck:
name: PGFPN
model_name: large
Head:
name: PGHead
model_name: large
Loss:
name: PGLoss
tcl_bs: 64
max_text_length: 50 # the same as Global: max_text_length
max_text_nums: 30 # the same as Global:max_text_nums
pad_num: 36 # the length of dict for pad
Optimizer:
name: Adam
......@@ -54,10 +58,10 @@ Optimizer:
PostProcess:
name: PGPostProcess
score_thresh: 0.8
score_thresh: 0.5
Metric:
name: E2EMetric
character_dict_path: ppocr/utils/pgnet_dict.txt
character_dict_path: ppocr/utils/ic15_dict.txt
main_indicator: f_score_e2e
Train:
......
......@@ -9,8 +9,10 @@
解压数据集和下载标注文件后,PaddleOCR/train_data/part_vgg_synth/train/ 有一个文件夹和一个文件,分别是:
```
/PaddleOCR/train_data/part_vgg_synth/train/
└─ image/ partvgg数据集的训练数据
└─ train_annotation_info.txt partvgg数据集的测试标注
|- image/ partvgg数据集的训练数据
|- 119_nile_110_31.png
| ...
|- train_annotation_info.txt partvgg数据集的测试标注
```
提供的标注文件格式如下,中间用"\t"分隔:
......@@ -18,7 +20,7 @@
" 图像文件名 图像标注信息--四点标注 图像标注信息--识别标注
119_nile_110_31 140.2 222.5 266.0 194.6 278.7 251.8 152.9 279.7 Path: 32.9 133.1 106.0 130.8 106.4 143.8 33.3 146.1 were 21.8 81.9 106.9 80.4 107.7 123.2 22.6 124.7 why
```
标注文件txt当中,其中每一行代表一组数据,以第一行为例。第一个代表同级目录image/下面的文件名, 后面每9个代表一组标注信息,前8个代表文本框的四个点坐标(x,y),从左上角的点开始顺时针排列。
标注文件txt当中,其中每一行代表一组数据,以第一行为例。第一个代表同级目录image/下面的文件名前缀, 后面每9个代表一组标注信息,前8个代表文本框的四个点坐标(x,y),从左上角的点开始顺时针排列。
最后一个代表文字的识别结果,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
......@@ -26,8 +28,12 @@
解压数据集和下载标注文件后,PaddleOCR/train_data/total_text/train/ 有两个文件夹,分别是:
```
/PaddleOCR/train_data/total_text/train/
└─ rgb/ total_text数据集的训练数据
└─ poly/ total_text数据集的测试标注
|- rgb/ total_text数据集的训练数据
|- gt_0.png
| ...
|-poly/ total_text数据集的测试标注
|- gt_0.txt
| ...
```
提供的标注文件格式如下,中间用"\t"分隔:
......@@ -36,7 +42,7 @@
1004.0,689.0,1019.0,698.0,1034.0,708.0,1049.0,718.0,1064.0,728.0,1079.0,738.0,1095.0,748.0,1094.0,774.0,1079.0,765.0,1065.0,756.0,1050.0,747.0,1036.0,738.0,1021.0,729.0,1007.0,721.0 EST
1102.0,755.0,1116.0,764.0,1131.0,773.0,1146.0,783.0,1161.0,792.0,1176.0,801.0,1191.0,811.0,1193.0,837.0,1178.0,828.0,1164.0,819.0,1150.0,810.0,1135.0,801.0,1121.0,792.0,1107.0,784.0 1972
```
标注文件当中,其中每一个txt文件代表一组数据,文件名同级目录rgb/下面的文件名。以第一行为例,前面28个代表文本框的十四个点坐标(x,y),从左上角的点开始顺时针排列。
标注文件当中,其中每一个txt文件代表一组数据,文件名就是同级目录rgb/下面的文件名。以第一行为例,前面28个代表文本框的十四个点坐标(x,y),从左上角的点开始顺时针排列。
最后一个代表文字的识别结果,**当其内容为“###”时,表示该文本框无效,在训练时会跳过。**
如果您想在其他数据集上训练,可以按照上述形式构建标注文件。
......
......@@ -29,7 +29,7 @@ inference 模型(`paddle.jit.save`保存的模型)
- [5. 多语言模型的推理](#多语言模型的推理)
- [四、端到端模型推理](#端到端模型推理)
- [1. PGNet端到端模型推理](#SAST文本检测模型推理)
- [1. PGNet端到端模型推理](#PGNet端到端模型推理)
- [五、方向分类模型推理](#方向识别模型推理)
- [1. 方向分类模型推理](#方向分类模型推理)
......@@ -366,7 +366,7 @@ Predicts of ./doc/imgs_words/korean/1.jpg:('바탕으로', 0.9948904)
## 四、端到端模型推理
端到端模型推理,默认使用PGNet模型的配置参数。当不使用PGNet模型时,在推理时,需要通过传入相应的参数进行算法适配,细节参考下文。
<a name="SAST文本检测模型推理"></a>
<a name="PGNet端到端模型推理"></a>
### 1. PGNet端到端模型推理
#### (1). 四边形文本检测模型(ICDAR2015)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在ICDAR2015英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_icdar15_v2.0_train.tar)),可以使用如下命令进行转换:
......@@ -375,28 +375,26 @@ python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrai
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`**,可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e_pgnet_ic15/"
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img_10.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=False
```
可视化文本检测结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/det_res_img_10_sast.jpg)
![](../imgs_results/e2e_res_img_10_pgnet.jpg)
#### (2). 弯曲文本检测模型(Total-Text)
首先将PGNet端到端训练过程中保存的模型,转换成inference model。以基于Resnet50_vd骨干网络,在Total-Text英文数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_sast_totaltext_v2.0_train.tar)),可以使用如下命令进行转换:
```
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e_pgnet_tt
python3 tools/export_model.py -c configs/e2e/e2e_r50_vd_pg.yml -o Global.pretrained_model=./det_r50_vd_sast_totaltext_v2.0_train/best_accuracy Global.load_static_weights=False Global.save_inference_dir=./inference/e2e
```
**PGNet端到端模型推理,需要设置参数`--e2e_algorithm="PGNet"`,同时,还需要增加参数`--e2e_pgnet_polygon=True`,**可以执行如下命令:
```
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e_pgnet_tt/" --e2e_pgnet_polygon=True
python3 tools/infer/predict_e2e.py --e2e_algorithm="PGNet" --image_dir="./doc/imgs_en/img623.jpg" --e2e_model_dir="./inference/e2e/" --e2e_pgnet_polygon=True
```
可视化文本端到端结果默认保存到`./inference_results`文件夹里面,结果文件的名称前缀为'e2e_res'。结果示例如下:
![](../imgs_results/e2e_res_img623_pg.jpg)
**注意**:本代码库中,SAST后处理Locality-Aware NMS有python和c++两种版本,c++版速度明显快于python版。由于c++版本nms编译版本问题,只有python3.5环境下会调用c++版nms,其他情况将调用python版nms。
![](../imgs_results/e2e_res_img623_pgnet.jpg)
<a name="方向分类模型推理"></a>
......
......@@ -197,17 +197,17 @@ class E2ELabelEncode(BaseRecLabelEncode):
super(E2ELabelEncode,
self).__init__(max_text_length, character_dict_path,
character_type, use_space_char)
self.pad_num = len(self.dict) # the length to pad
def __call__(self, data):
texts = data['strs']
temp_texts = []
for text in texts:
text = text.upper()
text = text.lower()
text = self.encode(text)
if text is None:
return None
text = text + [36] * (self.max_text_len - len(text)
) # use 36 to pad
text = text + [self.pad_num] * (self.max_text_len - len(text))
temp_texts.append(text)
data['strs'] = np.array(temp_texts)
return data
......
......@@ -22,16 +22,23 @@ __all__ = ['PGProcessTrain']
class PGProcessTrain(object):
def __init__(self,
character_dict_path,
max_text_length,
max_text_nums,
tcl_len,
batch_size=14,
min_crop_size=24,
min_text_size=10,
max_text_size=512,
**kwargs):
self.tcl_len = tcl_len
self.max_text_length = max_text_length
self.max_text_nums = max_text_nums
self.batch_size = batch_size
self.min_crop_size = min_crop_size
self.min_text_size = min_text_size
self.max_text_size = max_text_size
self.Lexicon_Table = self.get_dict(character_dict_path)
self.pad_num = len(self.Lexicon_Table)
self.img_id = 0
def get_dict(self, character_dict_path):
......@@ -290,7 +297,7 @@ class PGProcessTrain(object):
height_list.append(quad_h)
norm_width = max(sum(width_list) / n_char, 1.0)
average_height = max(sum(height_list) / len(height_list), 1.0)
k = 1
for quad in poly_quads:
direct_vector_full = (
(quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
......@@ -302,6 +309,8 @@ class PGProcessTrain(object):
cv2.fillPoly(direction_map,
quad.round().astype(np.int32)[np.newaxis, :, :],
direction_label)
cv2.imwrite("output/{}.png".format(k), direction_map * 255.0)
k += 1
return direction_map
def calculate_average_height(self, poly_quads):
......@@ -371,7 +380,6 @@ class PGProcessTrain(object):
continue
if tag:
# continue
cv2.fillPoly(training_mask,
poly.astype(np.int32)[np.newaxis, :, :], 0.15)
else:
......@@ -577,7 +585,7 @@ class PGProcessTrain(object):
Prepare text lablel by given Lexicon_Table.
"""
if len(Lexicon_Table) == 36:
return label_str.upper()
return label_str.lower()
else:
return label_str
......@@ -846,23 +854,23 @@ class PGProcessTrain(object):
return None
pos_list_temp = np.zeros([64, 3])
pos_mask_temp = np.zeros([64, 1])
label_list_temp = np.zeros([50, 1]) + 36
label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
for i, label in enumerate(label_list):
n = len(label)
if n > 50:
label_list[i] = label[:50]
if n > self.max_text_length:
label_list[i] = label[:self.max_text_length]
continue
while n < 50:
label.append([36])
while n < self.max_text_length:
label.append([self.pad_num])
n += 1
for i in range(len(label_list)):
label_list[i] = np.array(label_list[i])
if len(pos_list) <= 0 or len(pos_list) > 30: #一张图片中最多存在30行文本
if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
return None
for __ in range(30 - len(pos_list), 0, -1):
for __ in range(self.max_text_nums - len(pos_list), 0, -1):
pos_list.append(pos_list_temp)
pos_mask.append(pos_mask_temp)
label_list.append(label_list_temp)
......
......@@ -156,6 +156,7 @@ class PGDataSet(Dataset):
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
except Exception as e:
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
......
......@@ -18,102 +18,26 @@ from __future__ import print_function
from paddle import nn
import paddle
import numpy as np
import copy
from .det_basic_loss import DiceLoss
from ppocr.utils.e2e_utils.extract_batchsize import *
class PGLoss(nn.Layer):
def __init__(self, eps=1e-6, **kwargs):
def __init__(self,
tcl_bs,
max_text_length,
max_text_nums,
pad_num,
eps=1e-6,
**kwargs):
super(PGLoss, self).__init__()
self.tcl_bs = tcl_bs
self.max_text_nums = max_text_nums
self.max_text_length = max_text_length
self.pad_num = pad_num
self.dice_loss = DiceLoss(eps=eps)
def org_tcl_rois(self, batch_size, pos_lists, pos_masks, label_lists):
"""
"""
pos_lists_, pos_masks_, label_lists_ = [], [], []
img_bs = batch_size
tcl_bs = 64
ngpu = int(batch_size / img_bs)
img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
pos_lists_split, pos_masks_split, label_lists_split = [], [], []
for i in range(ngpu):
pos_lists_split.append([])
pos_masks_split.append([])
label_lists_split.append([])
for i in range(img_ids.shape[0]):
img_id = img_ids[i]
gpu_id = int(img_id / img_bs)
img_id = img_id % img_bs
pos_list = pos_lists[i].copy()
pos_list[:, 0] = img_id
pos_lists_split[gpu_id].append(pos_list)
pos_masks_split[gpu_id].append(pos_masks[i].copy())
label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
# repeat or delete
for i in range(ngpu):
vp_len = len(pos_lists_split[i])
if vp_len <= tcl_bs:
for j in range(0, tcl_bs - vp_len):
pos_list = pos_lists_split[i][j].copy()
pos_lists_split[i].append(pos_list)
pos_mask = pos_masks_split[i][j].copy()
pos_masks_split[i].append(pos_mask)
label_list = copy.deepcopy(label_lists_split[i][j])
label_lists_split[i].append(label_list)
else:
for j in range(0, vp_len - tcl_bs):
c_len = len(pos_lists_split[i])
pop_id = np.random.permutation(c_len)[0]
pos_lists_split[i].pop(pop_id)
pos_masks_split[i].pop(pop_id)
label_lists_split[i].pop(pop_id)
# merge
for i in range(ngpu):
pos_lists_.extend(pos_lists_split[i])
pos_masks_.extend(pos_masks_split[i])
label_lists_.extend(label_lists_split[i])
return pos_lists_, pos_masks_, label_lists_
def pre_process(self, label_list, pos_list, pos_mask):
max_len = 30 # the max texts in a single image
max_str_len = 50 # the max len in a single text
pad_num = 36 # padding num
label_list = label_list.numpy()
batch, _, _, _ = label_list.shape
pos_list = pos_list.numpy()
pos_mask = pos_mask.numpy()
pos_list_t = []
pos_mask_t = []
label_list_t = []
for i in range(batch):
for j in range(max_len):
if pos_mask[i, j].any():
pos_list_t.append(pos_list[i][j])
pos_mask_t.append(pos_mask[i][j])
label_list_t.append(label_list[i][j])
pos_list, pos_mask, label_list = self.org_tcl_rois(
batch, pos_list_t, pos_mask_t, label_list_t)
label = []
tt = [l.tolist() for l in label_list]
for i in range(batch):
k = 0
for j in range(max_str_len):
if tt[i][j][0] != pad_num:
k += 1
else:
break
label.append(k)
label = paddle.to_tensor(label)
label = paddle.cast(label, dtype='int64')
pos_list = paddle.to_tensor(pos_list)
pos_mask = paddle.to_tensor(pos_mask)
label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
label_list = paddle.cast(label_list, dtype='int32')
return pos_list, pos_mask, label_list, label
def border_loss(self, f_border, l_border, l_score, l_mask):
l_border_split, l_border_norm = paddle.tensor.split(
l_border, num_or_sections=[4, 1], axis=1)
......@@ -183,7 +107,7 @@ class PGLoss(nn.Layer):
labels=tcl_label,
input_lengths=input_lengths,
label_lengths=label_t,
blank=36,
blank=self.pad_num,
reduction='none')
cost = cost.mean()
return cost
......@@ -192,12 +116,14 @@ class PGLoss(nn.Layer):
images, tcl_maps, tcl_label_maps, border_maps \
, direction_maps, training_masks, label_list, pos_list, pos_mask = labels
# for all the batch_size
pos_list, pos_mask, label_list, label_t = self.pre_process(
label_list, pos_list, pos_mask)
pos_list, pos_mask, label_list, label_t = pre_process(
label_list, pos_list, pos_mask, self.max_text_length,
self.max_text_nums, self.pad_num, self.tcl_bs)
f_score, f_boder, f_direction, f_char = predicts
f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
predicts['f_char']
score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
border_loss = self.border_loss(f_boder, border_maps, tcl_maps,
border_loss = self.border_loss(f_border, border_maps, tcl_maps,
training_masks)
direction_loss = self.direction_loss(f_direction, direction_maps,
tcl_maps, training_masks)
......
......@@ -66,9 +66,8 @@ class PGHead(nn.Layer):
"""
"""
def __init__(self, in_channels, model_name, **kwargs):
def __init__(self, in_channels, **kwargs):
super(PGHead, self).__init__()
self.model_name = model_name
self.conv_f_score1 = ConvBNLayer(
in_channels=in_channels,
out_channels=64,
......
......@@ -23,8 +23,7 @@ __dir__ = os.path.dirname(__file__)
sys.path.append(__dir__)
sys.path.append(os.path.join(__dir__, '..'))
from ppocr.utils.e2e_utils.extract_textpoint import *
from ppocr.utils.e2e_utils.visual import *
from ppocr.utils.e2e_utils.extract_textpoint import get_dict, generate_pivot_list, restore_poly
import paddle
......@@ -34,16 +33,10 @@ class PGPostProcess(object):
"""
def __init__(self, character_dict_path, valid_set, score_thresh, **kwargs):
self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set
self.score_thresh = score_thresh
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
if sys.version_info.major == 3 and sys.version_info.minor == 5:
self.is_python35 = True
def __call__(self, outs_dict, shape_list):
p_score = outs_dict['f_score']
p_border = outs_dict['f_border']
......@@ -61,96 +54,15 @@ class PGPostProcess(object):
p_char = p_char[0]
src_h, src_w, ratio_h, ratio_w = shape_list[0]
is_curved = self.valid_set == "totaltext"
instance_yxs_list = generate_pivot_list(
instance_yxs_list, seq_strs = generate_pivot_list(
p_score,
p_char,
p_direction,
score_thresh=self.score_thresh,
is_backbone=True,
is_curved=is_curved)
p_char = np.expand_dims(p_char, axis=0)
p_char = paddle.to_tensor(p_char)
char_seq_idx_set = []
for i in range(len(instance_yxs_list)):
gather_info_lod = paddle.to_tensor(instance_yxs_list[i])
f_char_map = paddle.transpose(p_char, [0, 2, 3, 1])
featyre_seq = paddle.gather_nd(f_char_map, gather_info_lod)
featyre_seq = np.expand_dims(featyre_seq.numpy(), axis=0)
t = len(featyre_seq[0])
featyre_seq = paddle.to_tensor(featyre_seq)
l = np.array([[t]]).astype(np.int64)
length = paddle.to_tensor(l)
seq_pred = paddle.fluid.layers.ctc_greedy_decoder(
input=featyre_seq, blank=36, input_length=length)
seq_pred1 = seq_pred[0].numpy().tolist()[0]
seq_len = seq_pred[1].numpy()[0][0]
temp_t = []
for x in seq_pred1[:seq_len]:
temp_t.append(x)
char_seq_idx_set.append(temp_t)
seq_strs = []
for char_idx_set in char_seq_idx_set:
pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
seq_strs.append(pr_str)
poly_list = []
keep_str_list = []
all_point_list = []
all_point_pair_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(yx_center_line) == 1:
yx_center_line.append(yx_center_line[-1])
offset_expand = 1.0
if self.valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for batch_id, y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2)
if offset_expand != 1.0:
offset_length = np.linalg.norm(
offset, axis=1, keepdims=True)
expand_length = np.clip(
offset_length * (offset_expand - 1),
a_min=0.5,
a_max=3.0)
offset_detal = offset / offset_length * expand_length
offset = offset + offset_detal
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
all_point_list.append([
int(round(x * 4.0 / ratio_w)),
int(round(y * 4.0 / ratio_h))
])
all_point_pair_list.append(point_pair.round().astype(np.int32)
.tolist())
detected_poly, pair_length_info = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(
detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(
detected_poly[:, 1], a_min=0, a_max=src_h)
if len(keep_str) < 2:
continue
keep_str_list.append(keep_str)
if self.valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif self.valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
print('--> Not supported format.')
exit(-1)
self.Lexicon_Table,
score_thresh=self.score_thresh)
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
p_border, ratio_w, ratio_h,
src_w, src_h, self.valid_set)
data = {
'points': poly_list,
'strs': keep_str_list,
......
import paddle
import numpy as np
import copy
def org_tcl_rois(batch_size, pos_lists, pos_masks, label_lists, tcl_bs):
"""
"""
pos_lists_, pos_masks_, label_lists_ = [], [], []
img_bs = batch_size
ngpu = int(batch_size / img_bs)
img_ids = np.array(pos_lists, dtype=np.int32)[:, 0, 0].copy()
pos_lists_split, pos_masks_split, label_lists_split = [], [], []
for i in range(ngpu):
pos_lists_split.append([])
pos_masks_split.append([])
label_lists_split.append([])
for i in range(img_ids.shape[0]):
img_id = img_ids[i]
gpu_id = int(img_id / img_bs)
img_id = img_id % img_bs
pos_list = pos_lists[i].copy()
pos_list[:, 0] = img_id
pos_lists_split[gpu_id].append(pos_list)
pos_masks_split[gpu_id].append(pos_masks[i].copy())
label_lists_split[gpu_id].append(copy.deepcopy(label_lists[i]))
# repeat or delete
for i in range(ngpu):
vp_len = len(pos_lists_split[i])
if vp_len <= tcl_bs:
for j in range(0, tcl_bs - vp_len):
pos_list = pos_lists_split[i][j].copy()
pos_lists_split[i].append(pos_list)
pos_mask = pos_masks_split[i][j].copy()
pos_masks_split[i].append(pos_mask)
label_list = copy.deepcopy(label_lists_split[i][j])
label_lists_split[i].append(label_list)
else:
for j in range(0, vp_len - tcl_bs):
c_len = len(pos_lists_split[i])
pop_id = np.random.permutation(c_len)[0]
pos_lists_split[i].pop(pop_id)
pos_masks_split[i].pop(pop_id)
label_lists_split[i].pop(pop_id)
# merge
for i in range(ngpu):
pos_lists_.extend(pos_lists_split[i])
pos_masks_.extend(pos_masks_split[i])
label_lists_.extend(label_lists_split[i])
return pos_lists_, pos_masks_, label_lists_
def pre_process(label_list, pos_list, pos_mask, max_text_length, max_text_nums,
pad_num, tcl_bs):
label_list = label_list.numpy()
batch, _, _, _ = label_list.shape
pos_list = pos_list.numpy()
pos_mask = pos_mask.numpy()
pos_list_t = []
pos_mask_t = []
label_list_t = []
for i in range(batch):
for j in range(max_text_nums):
if pos_mask[i, j].any():
pos_list_t.append(pos_list[i][j])
pos_mask_t.append(pos_mask[i][j])
label_list_t.append(label_list[i][j])
pos_list, pos_mask, label_list = org_tcl_rois(batch, pos_list_t, pos_mask_t,
label_list_t, tcl_bs)
label = []
tt = [l.tolist() for l in label_list]
for i in range(tcl_bs):
k = 0
for j in range(max_text_length):
if tt[i][j][0] != pad_num:
k += 1
else:
break
label.append(k)
label = paddle.to_tensor(label)
label = paddle.cast(label, dtype='int64')
pos_list = paddle.to_tensor(pos_list)
pos_mask = paddle.to_tensor(pos_mask)
label_list = paddle.squeeze(paddle.to_tensor(label_list), axis=2)
label_list = paddle.cast(label_list, dtype='int32')
return pos_list, pos_mask, label_list, label
......@@ -17,11 +17,9 @@ from __future__ import division
from __future__ import print_function
import cv2
import math
import numpy as np
from itertools import groupby
from skimage.morphology._skeletonize import thin
from cv2.ximgproc import thinning as thin
def get_dict(character_dict_path):
......@@ -35,87 +33,39 @@ def get_dict(character_dict_path):
return dict_character
def softmax(logits):
"""
logits: N x d
"""
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def get_keep_pos_idxs(labels, remove_blank=None):
"""
Remove duplicate and get pos idxs of keep items.
The value of keep_blank should be [None, 95].
"""
duplicate_len_list = []
keep_pos_idx_list = []
keep_char_idx_list = []
for k, v_ in groupby(labels):
current_len = len(list(v_))
if k != remove_blank:
current_idx = int(sum(duplicate_len_list) + current_len // 2)
keep_pos_idx_list.append(current_idx)
keep_char_idx_list.append(k)
duplicate_len_list.append(current_len)
return keep_char_idx_list, keep_pos_idx_list
def remove_blank(labels, blank=0):
new_labels = [x for x in labels if x != blank]
return new_labels
def insert_blank(labels, blank=0):
new_labels = [blank]
for l in labels:
new_labels += [l, blank]
return new_labels
def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
"""
CTC greedy (best path) decoder.
"""
raw_str = np.argmax(np.array(probs_seq), axis=1)
remove_blank_in_pos = None if keep_blank_in_idxs else blank
dedup_str, keep_idx_list = get_keep_pos_idxs(
raw_str, remove_blank=remove_blank_in_pos)
dst_str = remove_blank(dedup_str, blank=blank)
return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info,
logits_map,
keep_blank_in_idxs=True):
"""
gather_info: [[x, y], [x, y] ...]
logits_map: H x W X (n_chars + 1)
"""
def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
_, _, C = logits_map.shape
ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)] # n x 96
probs_seq = softmax(logits_seq)
dst_str, keep_idx_list = ctc_greedy_decoder(
probs_seq, blank=C - 1, keep_blank_in_idxs=keep_blank_in_idxs)
logits_seq = logits_map[list(ys), list(xs)]
probs_seq = logits_seq
labels = np.argmax(probs_seq, axis=1)
dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
detal = len(gather_info) // (pts_num - 1)
keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
return dst_str, keep_gather_list
def ctc_decoder_for_image(gather_info_list, logits_map,
keep_blank_in_idxs=True):
def ctc_decoder_for_image(gather_info_list,
logits_map,
Lexicon_Table,
pts_num=6):
"""
CTC decoder using multiple processes.
"""
decoder_results = []
decoder_str = []
decoder_xys = []
for gather_info in gather_info_list:
res = instance_ctc_greedy_decoder(
gather_info, logits_map, keep_blank_in_idxs=keep_blank_in_idxs)
decoder_results.append(res)
return decoder_results
if len(gather_info) < pts_num:
continue
dst_str, xys_list = instance_ctc_greedy_decoder(
gather_info, logits_map, pts_num=pts_num)
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
if len(dst_str_readable) < 2:
continue
decoder_str.append(dst_str_readable)
decoder_xys.append(xys_list)
return decoder_str, decoder_xys
def sort_with_direction(pos_list, f_direction):
......@@ -157,57 +107,6 @@ def sort_with_direction(pos_list, f_direction):
return sorted_point, np.array(sorted_direction)
def add_id(pos_list, image_id=0):
"""
Add id for gather feature, for inference.
"""
new_list = []
for item in pos_list:
new_list.append((image_id, item[0], item[1]))
return new_list
def sort_and_expand_with_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
h, w, _ = f_direction.shape
sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
point_num = len(sorted_list)
sub_direction_len = max(point_num // 3, 2)
left_direction = point_direction[:sub_direction_len, :]
right_dirction = point_direction[point_num - sub_direction_len:, :]
left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
left_average_len = np.linalg.norm(left_average_direction)
left_start = np.array(sorted_list[0])
left_step = left_average_direction / (left_average_len + 1e-6)
right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
right_average_len = np.linalg.norm(right_average_direction)
right_step = right_average_direction / (right_average_len + 1e-6)
right_start = np.array(sorted_list[-1])
append_num = max(
int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
left_list = []
right_list = []
for i in range(append_num):
ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
'int32').tolist()
if ly < h and lx < w and (ly, lx) not in left_list:
left_list.append((ly, lx))
ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
'int32').tolist()
if ry < h and rx < w and (ry, rx) not in right_list:
right_list.append((ry, rx))
all_list = left_list[::-1] + sorted_list + right_list
return all_list
def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
"""
f_direction: h x w x 2
......@@ -260,262 +159,125 @@ def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
return all_list
def generate_pivot_list_curved(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_expand=True,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
if len(pos_list) < 3:
continue
if is_expand:
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
else:
pos_list_sorted, _ = sort_with_direction(pos_list, f_direction)
all_pos_yxs.append(pos_list_sorted)
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
def generate_pivot_list_horizontal(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
image_id=0):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map_bi = (p_score > score_thresh) * 1.0
instance_count, instance_label_map = cv2.connectedComponents(
p_tcl_map_bi.astype(np.uint8), connectivity=8)
# get TCL Instance
all_pos_yxs = []
center_pos_yxs = []
end_points_yxs = []
instance_center_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
if len(pos_list) < 5:
continue
main_direction = extract_main_direction(pos_list,
f_direction) # y x
reference_directin = np.array([0, 1]).reshape([-1, 2]) # y x
is_h_angle = abs(np.sum(
main_direction * reference_directin)) < math.cos(math.pi / 180 *
70)
point_yxs = np.array(pos_list)
max_y, max_x = np.max(point_yxs, axis=0)
min_y, min_x = np.min(point_yxs, axis=0)
is_h_len = (max_y - min_y) < 1.5 * (max_x - min_x)
pos_list_final = []
if is_h_len:
xs = np.unique(xs)
for x in xs:
ys = instance_label_map[:, x].copy().reshape((-1, ))
y = int(np.where(ys == instance_id)[0].mean())
pos_list_final.append((y, x))
else:
ys = np.unique(ys)
for y in ys:
xs = instance_label_map[y, :].copy().reshape((-1, ))
x = int(np.where(xs == instance_id)[0].mean())
pos_list_final.append((y, x))
pos_list_sorted, _ = sort_with_direction(pos_list_final,
f_direction)
all_pos_yxs.append(pos_list_sorted)
p_char_maps = p_char_maps.transpose([1, 2, 0])
decode_res = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, keep_blank_in_idxs=True)
for decoded_str, keep_yxs_list in decode_res:
if is_backbone:
keep_yxs_list_with_id = add_id(keep_yxs_list, image_id=image_id)
instance_center_pos_yxs.append(keep_yxs_list_with_id)
def point_pair2poly(point_pair_list):
"""
Transfer vertical point_pairs into poly point in clockwise.
"""
point_num = len(point_pair_list) * 2
point_list = [0] * point_num
for idx, point_pair in enumerate(point_pair_list):
point_list[idx] = point_pair[0]
point_list[point_num - 1 - idx] = point_pair[1]
return np.array(point_list).reshape(-1, 2)
def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
ratio_pair = np.array(
[[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
"""
expand poly along width.
"""
point_num = poly.shape[0]
left_quad = np.array(
[poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
(np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
right_quad = np.array(
[
poly[point_num // 2 - 2], poly[point_num // 2 - 1],
poly[point_num // 2], poly[point_num // 2 + 1]
],
dtype=np.float32)
right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
(np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
poly[0] = left_quad_expand[0]
poly[-1] = left_quad_expand[-1]
poly[point_num // 2 - 1] = right_quad_expand[1]
poly[point_num // 2] = right_quad_expand[2]
return poly
def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
src_h, valid_set):
poly_list = []
keep_str_list = []
for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
if len(keep_str) < 2:
print('--> too short, {}'.format(keep_str))
continue
offset_expand = 1.0
if valid_set == 'totaltext':
offset_expand = 1.2
point_pair_list = []
for y, x in yx_center_line:
offset = p_border[:, y, x].reshape(2, 2) * offset_expand
ori_yx = np.array([y, x], dtype=np.float32)
point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
[ratio_w, ratio_h]).reshape(-1, 2)
point_pair_list.append(point_pair)
detected_poly = point_pair2poly(point_pair_list)
detected_poly = expand_poly_along_width(
detected_poly, shrink_ratio_of_width=0.2)
detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
keep_str_list.append(keep_str)
if valid_set == 'partvgg':
middle_point = len(detected_poly) // 2
detected_poly = detected_poly[
[0, middle_point - 1, middle_point, -1], :]
poly_list.append(detected_poly)
elif valid_set == 'totaltext':
poly_list.append(detected_poly)
else:
end_points_yxs.extend((keep_yxs_list[0], keep_yxs_list[-1]))
center_pos_yxs.extend(keep_yxs_list)
if is_backbone:
return instance_center_pos_yxs
else:
return center_pos_yxs, end_points_yxs
print('--> Not supported format.')
exit(-1)
return poly_list, keep_str_list
def generate_pivot_list(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
"""
Warp all the function together.
"""
if is_curved:
return generate_pivot_list_curved(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_expand=True,
is_backbone=is_backbone,
image_id=image_id)
else:
return generate_pivot_list_horizontal(
p_score,
p_char_maps,
f_direction,
score_thresh=score_thresh,
is_backbone=is_backbone,
image_id=image_id)
def extract_main_direction(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
pos_list = np.array(pos_list)
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
average_direction = average_direction / (
np.linalg.norm(average_direction) + 1e-6)
return average_direction
def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
"""
pos_list_full = np.array(pos_list).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
return sorted_list
def sort_by_direction_with_image_id(pos_list, f_direction):
"""
f_direction: h x w x 2
pos_list: [[y, x], [y, x], [y, x] ...]
"""
def sort_part_with_direction(pos_list_full, point_direction):
pos_list_full = np.array(pos_list_full).reshape(-1, 3)
pos_list = pos_list_full[:, 1:]
point_direction = np.array(point_direction).reshape(-1, 2)
average_direction = np.mean(point_direction, axis=0, keepdims=True)
pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
return sorted_list, sorted_direction
pos_list = np.array(pos_list).reshape(-1, 3)
point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
point_direction = point_direction[:, ::-1] # x, y -> y, x
sorted_point, sorted_direction = sort_part_with_direction(pos_list,
point_direction)
point_num = len(sorted_point)
if point_num >= 16:
middle_num = point_num // 2
first_part_point = sorted_point[:middle_num]
first_point_direction = sorted_direction[:middle_num]
sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
first_part_point, first_point_direction)
last_part_point = sorted_point[middle_num:]
last_point_direction = sorted_direction[middle_num:]
sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
last_part_point, last_point_direction)
sorted_point = sorted_fist_part_point + sorted_last_part_point
sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
return sorted_point
def generate_pivot_list_tt_inference(p_score,
p_char_maps,
f_direction,
score_thresh=0.5,
is_backbone=False,
is_curved=True,
image_id=0):
Lexicon_Table,
score_thresh=0.5):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
p_score = p_score[0]
f_direction = f_direction.transpose(1, 2, 0)
p_tcl_map = (p_score > score_thresh) * 1.0
skeleton_map = thin(p_tcl_map)
ret, p_tcl_map = cv2.threshold(p_score, score_thresh, 255,
cv2.THRESH_BINARY)
skeleton_map = thin(p_tcl_map.astype('uint8'))
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
skeleton_map, connectivity=8)
# get TCL Instance
all_pos_yxs = []
if instance_count > 0:
for instance_id in range(1, instance_count):
pos_list = []
ys, xs = np.where(instance_label_map == instance_id)
pos_list = list(zip(ys, xs))
if len(pos_list) < 3:
continue
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, p_tcl_map)
pos_list_sorted_with_id = add_id(pos_list_sorted, image_id=image_id)
all_pos_yxs.append(pos_list_sorted_with_id)
return all_pos_yxs
all_pos_yxs.append(pos_list_sorted)
p_char_maps = p_char_maps.transpose([1, 2, 0])
decoded_str, keep_yxs_list = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table)
return keep_yxs_list, decoded_str
0
1
2
3
4
5
6
7
8
9
A
B
C
D
E
F
G
H
I
J
K
L
M
N
O
P
Q
R
S
T
U
V
W
X
Y
Z
\ No newline at end of file
......@@ -39,10 +39,7 @@ class TextE2e(object):
self.args = args
self.e2e_algorithm = args.e2e_algorithm
pre_process_list = [{
'E2EResizeForTest': {
'max_side_len': 768,
'valid_set': 'totaltext'
}
'E2EResizeForTest': {}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
......@@ -70,12 +67,6 @@ class TextE2e(object):
postprocess_params["character_dict_path"] = args.e2e_char_dict_path
postprocess_params["valid_set"] = args.e2e_pgnet_valid_set
self.e2e_pgnet_polygon = args.e2e_pgnet_polygon
if self.e2e_pgnet_polygon:
postprocess_params["expand_scale"] = 1.2
postprocess_params["shrink_ratio_of_width"] = 0.2
else:
postprocess_params["expand_scale"] = 1.0
postprocess_params["shrink_ratio_of_width"] = 0.3
else:
logger.info("unknown e2e_algorithm:{}".format(self.e2e_algorithm))
sys.exit(0)
......@@ -102,6 +93,7 @@ class TextE2e(object):
return dt_boxes
def __call__(self, img):
ori_im = img.copy()
data = {'image': img}
data = transform(data, self.preprocess_op)
......@@ -109,7 +101,6 @@ class TextE2e(object):
if img is None:
return None, 0
img = np.expand_dims(img, axis=0)
print(img.shape)
shape_list = np.expand_dims(shape_list, axis=0)
img = img.copy()
starttime = time.time()
......@@ -123,13 +114,12 @@ class TextE2e(object):
preds = {}
if self.e2e_algorithm == 'PGNet':
preds['f_score'] = outputs[0]
preds['f_border'] = outputs[1]
preds['f_border'] = outputs[0]
preds['f_char'] = outputs[1]
preds['f_direction'] = outputs[2]
preds['f_char'] = outputs[3]
preds['f_score'] = outputs[3]
else:
raise NotImplementedError
post_result = self.postprocess_op(preds, shape_list)
points, strs = post_result['points'], post_result['strs']
dt_boxes = self.filter_tag_det_res_only_clip(points, ori_im.shape)
......
......@@ -83,11 +83,9 @@ def parse_args():
# PGNet parmas
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5)
parser.add_argument(
"--e2e_char_dict_path",
type=str,
default="./ppocr/utils/pgnet_dict.txt")
"--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt")
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext')
parser.add_argument("--e2e_pgnet_polygon", type=bool, default=False)
parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True)
# params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册