提交 929b4f45 编写于 作者: W wangjingyeye

update pgnet

上级 04aaaa74
......@@ -13,6 +13,7 @@ Global:
save_inference_dir:
use_visualdl: False
infer_img:
infer_visual_type: EN # two mode: EN is for english datasets, CN is for chinese datasets
valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words
save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt
character_dict_path: ppocr/utils/ic15_dict.txt
......@@ -32,6 +33,7 @@ Architecture:
name: PGFPN
Head:
name: PGHead
tcc_channels: 37 # the length of character dict
Loss:
name: PGLoss
......@@ -45,16 +47,18 @@ Optimizer:
beta1: 0.9
beta2: 0.999
lr:
name: Cosine
learning_rate: 0.001
warmup_epoch: 50
regularizer:
name: 'L2'
factor: 0
factor: 0.0001
PostProcess:
name: PGPostProcess
score_thresh: 0.5
mode: fast # fast or slow two ways
tcc_type: v3 # same as PGProcessTrain: tcc_type
Metric:
name: E2EMetric
......@@ -76,9 +80,12 @@ Train:
- E2ELabelEncodeTrain:
- PGProcessTrain:
batch_size: 14 # same as loader: batch_size_per_card
use_resize: True
use_random_crop: False
min_crop_size: 24
min_text_size: 4
max_text_size: 512
tcc_type: v3 # two ways, v2 is original code, v3 is updated code
- KeepKeys:
keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order
loader:
......
......@@ -15,6 +15,8 @@
import math
import cv2
import numpy as np
from skimage.morphology._skeletonize import thin
from ppocr.utils.e2e_utils.extract_textpoint_fast import sort_and_expand_with_direction_v2
__all__ = ['PGProcessTrain']
......@@ -26,17 +28,24 @@ class PGProcessTrain(object):
max_text_nums,
tcl_len,
batch_size=14,
use_resize=True,
use_random_crop=False,
min_crop_size=24,
min_text_size=4,
max_text_size=512,
tcc_type='v3',
**kwargs):
self.tcl_len = tcl_len
self.max_text_length = max_text_length
self.max_text_nums = max_text_nums
self.batch_size = batch_size
if use_random_crop is True:
self.min_crop_size = min_crop_size
self.use_random_crop = use_random_crop
self.min_text_size = min_text_size
self.max_text_size = max_text_size
self.use_resize = use_resize
self.tcc_type = tcc_type
self.Lexicon_Table = self.get_dict(character_dict_path)
self.pad_num = len(self.Lexicon_Table)
self.img_id = 0
......@@ -282,6 +291,95 @@ class PGProcessTrain(object):
pos_m[:keep] = 1.0
return pos_l, pos_m
def fit_and_gather_tcl_points_v3(self,
min_area_quad,
poly,
max_h,
max_w,
fixed_point_num=64,
img_id=0,
reference_height=3):
"""
Find the center point of poly as key_points, then fit and gather.
"""
det_mask = np.zeros((int(max_h / self.ds_ratio),
int(max_w / self.ds_ratio))).astype(np.float32)
# score_big_map
cv2.fillPoly(det_mask,
np.round(poly / self.ds_ratio).astype(np.int32), 1.0)
det_mask = cv2.resize(
det_mask, dsize=None, fx=self.ds_ratio, fy=self.ds_ratio)
det_mask = np.array(det_mask > 1e-3, dtype='float32')
f_direction = self.f_direction
skeleton_map = thin(det_mask.astype(np.uint8))
instance_count, instance_label_map = cv2.connectedComponents(
skeleton_map.astype(np.uint8), connectivity=8)
ys, xs = np.where(instance_label_map == 1)
pos_list = list(zip(ys, xs))
if len(pos_list) < 3:
return None
pos_list_sorted = sort_and_expand_with_direction_v2(
pos_list, f_direction, det_mask)
pos_list_sorted = np.array(pos_list_sorted)
length = len(pos_list_sorted) - 1
insert_num = 0
for index in range(length):
stride_y = np.abs(pos_list_sorted[index + insert_num][0] -
pos_list_sorted[index + 1 + insert_num][0])
stride_x = np.abs(pos_list_sorted[index + insert_num][1] -
pos_list_sorted[index + 1 + insert_num][1])
max_points = int(max(stride_x, stride_y))
stride = (pos_list_sorted[index + insert_num] -
pos_list_sorted[index + 1 + insert_num]) / (max_points)
insert_num_temp = max_points - 1
for i in range(int(insert_num_temp)):
insert_value = pos_list_sorted[index + insert_num] - (i + 1
) * stride
insert_index = index + i + 1 + insert_num
pos_list_sorted = np.insert(
pos_list_sorted, insert_index, insert_value, axis=0)
insert_num += insert_num_temp
pos_info = np.array(pos_list_sorted).reshape(-1, 2).astype(
np.float32) # xy-> yx
point_num = len(pos_info)
if point_num > fixed_point_num:
keep_ids = [
int((point_num * 1.0 / fixed_point_num) * x)
for x in range(fixed_point_num)
]
pos_info = pos_info[keep_ids, :]
keep = int(min(len(pos_info), fixed_point_num))
reference_width = (np.abs(poly[0, 0, 0] - poly[-1, 1, 0]) +
np.abs(poly[0, 3, 0] - poly[-1, 2, 0])) // 2
if np.random.rand() < 1:
dh = (np.random.rand(keep) - 0.5) * reference_height
offset = np.random.rand() - 0.5
dw = np.array([[0, offset * reference_width * 0.2]])
random_float_h = np.array([1, 0]).reshape([1, 2]) * dh.reshape(
[keep, 1])
random_float_w = dw.repeat(keep, axis=0)
pos_info += random_float_h
pos_info += random_float_w
pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
# padding to fixed length
pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
pos_m[:keep] = 1.0
return pos_l, pos_m
def generate_direction_map(self, poly_quads, n_char, direction_map):
"""
"""
......@@ -334,6 +432,7 @@ class PGProcessTrain(object):
"""
Generate polygon.
"""
self.ds_ratio = ds_ratio
score_map_big = np.zeros(
(
h,
......@@ -384,7 +483,6 @@ class PGProcessTrain(object):
text_label = text_strs[poly_idx]
text_label = self.prepare_text_label(text_label,
self.Lexicon_Table)
text_label_index_list = [[self.Lexicon_Table.index(c_)]
for c_ in text_label
if c_ in self.Lexicon_Table]
......@@ -432,6 +530,22 @@ class PGProcessTrain(object):
# pos info
average_shrink_height = self.calculate_average_height(
stcl_quads)
if self.tcc_type == 'v3':
self.f_direction = direction_map[:, :, :-1].copy()
pos_res = self.fit_and_gather_tcl_points_v3(
min_area_quad,
stcl_quads,
max_h=h,
max_w=w,
fixed_point_num=64,
img_id=self.img_id,
reference_height=average_shrink_height)
if pos_res is None:
continue
pos_l, pos_m = pos_res[0], pos_res[1]
elif self.tcc_type == 'v2':
pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
min_area_quad,
poly,
......@@ -770,6 +884,20 @@ class PGProcessTrain(object):
text_polys[:, :, 0] *= asp_wx
text_polys[:, :, 1] *= asp_hy
if self.use_resize is True:
ori_h, ori_w, _ = im.shape
if max(ori_h, ori_w) < 200:
ratio = 200 / max(ori_h, ori_w)
im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
text_polys[:, :, 0] *= ratio
text_polys[:, :, 1] *= ratio
if max(ori_h, ori_w) > 512:
ratio = 512 / max(ori_h, ori_w)
im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
text_polys[:, :, 0] *= ratio
text_polys[:, :, 1] *= ratio
elif self.use_random_crop is True:
h, w, _ = im.shape
if max(h, w) > 2048:
rd_scale = 2048.0 / max(h, w)
......@@ -790,7 +918,7 @@ class PGProcessTrain(object):
if text_polys.shape[0] == 0:
return None
# # continue for all ignore case
# continue for all ignore case
if np.sum((text_tags * 1.0)) >= text_tags.size:
return None
new_h, new_w, _ = im.shape
......
......@@ -89,12 +89,13 @@ class PGLoss(nn.Layer):
tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
tcl_pos = paddle.cast(tcl_pos, dtype=int)
f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
f_tcl_char = paddle.reshape(f_tcl_char,
[-1, 64, 37]) # len(Lexicon_Table)+1
f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2)
f_tcl_char = paddle.reshape(
f_tcl_char, [-1, 64, self.pad_num + 1]) # len(Lexicon_Table)+1
f_tcl_char_fg, f_tcl_char_bg = paddle.split(
f_tcl_char, [self.pad_num, 1], axis=2)
f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
b, c, l = tcl_mask.shape
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l])
tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, self.pad_num * l])
tcl_mask_fg.stop_gradient = True
f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
-20.0)
......
......@@ -66,7 +66,7 @@ class PGHead(nn.Layer):
"""
"""
def __init__(self, in_channels, **kwargs):
def __init__(self, in_channels, tcc_channels=37, **kwargs):
super(PGHead, self).__init__()
self.conv_f_score1 = ConvBNLayer(
in_channels=in_channels,
......@@ -178,7 +178,7 @@ class PGHead(nn.Layer):
name="conv_f_char{}".format(5))
self.conv3 = nn.Conv2D(
in_channels=256,
out_channels=37,
out_channels=tcc_channels,
kernel_size=3,
stride=1,
padding=1,
......
......@@ -31,11 +31,12 @@ class PGPostProcess(object):
"""
def __init__(self, character_dict_path, valid_set, score_thresh, mode,
**kwargs):
tcc_type, **kwargs):
self.character_dict_path = character_dict_path
self.valid_set = valid_set
self.score_thresh = score_thresh
self.mode = mode
self.tcc_type = tcc_type
# c++ la-nms is faster, but only support python 3.5
self.is_python35 = False
......@@ -43,8 +44,13 @@ class PGPostProcess(object):
self.is_python35 = True
def __call__(self, outs_dict, shape_list):
post = PGNet_PostProcess(self.character_dict_path, self.valid_set,
self.score_thresh, outs_dict, shape_list)
post = PGNet_PostProcess(
self.character_dict_path,
self.valid_set,
self.score_thresh,
outs_dict,
shape_list,
tcc_type=self.tcc_type)
if self.mode == 'fast':
data = post.pg_postprocess_fast()
else:
......
......@@ -88,8 +88,33 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
return dst_str, keep_idx_list
def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
def instance_ctc_greedy_decoder(gather_info,
logits_map,
pts_num=4,
tcc_type='v3'):
_, _, C = logits_map.shape
if tcc_type == 'v3':
insert_num = 0
gather_info = np.array(gather_info)
length = len(gather_info) - 1
for index in range(length):
stride_y = np.abs(gather_info[index + insert_num][0] - gather_info[
index + 1 + insert_num][0])
stride_x = np.abs(gather_info[index + insert_num][1] - gather_info[
index + 1 + insert_num][1])
max_points = int(max(stride_x, stride_y))
stride = (gather_info[index + insert_num] -
gather_info[index + 1 + insert_num]) / (max_points)
insert_num_temp = max_points - 1
for i in range(int(insert_num_temp)):
insert_value = gather_info[index + insert_num] - (i + 1
) * stride
insert_index = index + i + 1 + insert_num
gather_info = np.insert(
gather_info, insert_index, insert_value, axis=0)
insert_num += insert_num_temp
gather_info = gather_info.tolist()
ys, xs = zip(*gather_info)
logits_seq = logits_map[list(ys), list(xs)]
probs_seq = logits_seq
......@@ -104,7 +129,8 @@ def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4):
def ctc_decoder_for_image(gather_info_list,
logits_map,
Lexicon_Table,
pts_num=6):
pts_num=6,
tcc_type='v3'):
"""
CTC decoder using multiple processes.
"""
......@@ -114,7 +140,7 @@ def ctc_decoder_for_image(gather_info_list,
if len(gather_info) < pts_num:
continue
dst_str, xys_list = instance_ctc_greedy_decoder(
gather_info, logits_map, pts_num=pts_num)
gather_info, logits_map, pts_num=pts_num, tcc_type='v3')
dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
if len(dst_str_readable) < 2:
continue
......@@ -356,7 +382,8 @@ def generate_pivot_list_fast(p_score,
p_char_maps,
f_direction,
Lexicon_Table,
score_thresh=0.5):
score_thresh=0.5,
tcc_type='v3'):
"""
return center point and end point of TCL instance; filter with the char maps;
"""
......@@ -384,7 +411,10 @@ def generate_pivot_list_fast(p_score,
p_char_maps = p_char_maps.transpose([1, 2, 0])
decoded_str, keep_yxs_list = ctc_decoder_for_image(
all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table)
all_pos_yxs,
logits_map=p_char_maps,
Lexicon_Table=Lexicon_Table,
tcc_type='v3')
return keep_yxs_list, decoded_str
......
......@@ -28,13 +28,19 @@ from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
class PGNet_PostProcess(object):
# two different post-process
def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict,
shape_list):
def __init__(self,
character_dict_path,
valid_set,
score_thresh,
outs_dict,
shape_list,
tcc_type='v3'):
self.Lexicon_Table = get_dict(character_dict_path)
self.valid_set = valid_set
self.score_thresh = score_thresh
self.outs_dict = outs_dict
self.shape_list = shape_list
self.tcc_type = tcc_type
def pg_postprocess_fast(self):
p_score = self.outs_dict['f_score']
......@@ -58,7 +64,8 @@ class PGNet_PostProcess(object):
p_char,
p_direction,
self.Lexicon_Table,
score_thresh=self.score_thresh)
score_thresh=self.score_thresh,
tcc_type=self.tcc_type)
poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
p_border, ratio_w, ratio_h,
src_w, src_h, self.valid_set)
......
......@@ -37,6 +37,46 @@ from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
import tools.program as program
from PIL import Image, ImageDraw, ImageFont
import math
def draw_e2e_res_for_chinese(image,
boxes,
txts,
config,
img_name,
font_path="./doc/simfang.ttf"):
h, w = image.height, image.width
img_left = image.copy()
img_right = Image.new('RGB', (w, h), (255, 255, 255))
import random
random.seed(0)
draw_left = ImageDraw.Draw(img_left)
draw_right = ImageDraw.Draw(img_right)
for idx, (box, txt) in enumerate(zip(boxes, txts)):
box = np.array(box)
box = [tuple(x) for x in box]
color = (random.randint(0, 255), random.randint(0, 255),
random.randint(0, 255))
draw_left.polygon(box, fill=color)
draw_right.polygon(box, outline=color)
font = ImageFont.truetype(font_path, 15, encoding="utf-8")
draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
img_left = Image.blend(image, img_left, 0.5)
img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
img_show.paste(img_left, (0, 0, w, h))
img_show.paste(img_right, (w, 0, w * 2, h))
save_e2e_path = os.path.dirname(config['Global'][
'save_res_path']) + "/e2e_results/"
if not os.path.exists(save_e2e_path):
os.makedirs(save_e2e_path)
save_path = os.path.join(save_e2e_path, os.path.basename(img_name))
cv2.imwrite(save_path, np.array(img_show)[:, :, ::-1])
logger.info("The e2e Image saved in {}".format(save_path))
def draw_e2e_res(dt_boxes, strs, config, img, img_name):
......@@ -113,7 +153,19 @@ def main():
otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
fout.write(otstr.encode())
src_img = cv2.imread(file)
if global_config['infer_visual_type'] == 'EN':
draw_e2e_res(points, strs, config, src_img, file)
elif global_config['infer_visual_type'] == 'CN':
src_img = Image.fromarray(
cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB))
draw_e2e_res_for_chinese(
src_img,
points,
strs,
config,
file,
font_path="./doc/fonts/simfang.ttf")
logger.info("success!")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册