diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index c4c5226e796a42db723ce78ef65473e357c25dc6..4642f544868f720d413f7f5242740705bc9fd0a5 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -13,6 +13,7 @@ Global: save_inference_dir: use_visualdl: False infer_img: + infer_visual_type: EN # two mode: EN is for english datasets, CN is for chinese datasets valid_set: totaltext # two mode: totaltext valid curved words, partvgg valid non-curved words save_res_path: ./output/pgnet_r50_vd_totaltext/predicts_pgnet.txt character_dict_path: ppocr/utils/ic15_dict.txt @@ -32,6 +33,7 @@ Architecture: name: PGFPN Head: name: PGHead + character_dict_path: ppocr/utils/ic15_dict.txt # the same as Global:character_dict_path Loss: name: PGLoss @@ -45,16 +47,18 @@ Optimizer: beta1: 0.9 beta2: 0.999 lr: + name: Cosine learning_rate: 0.001 + warmup_epoch: 50 regularizer: name: 'L2' - factor: 0 - + factor: 0.0001 PostProcess: name: PGPostProcess score_thresh: 0.5 mode: fast # fast or slow two ways + point_gather_mode: align # same as PGProcessTrain: point_gather_mode Metric: name: E2EMetric @@ -76,9 +80,12 @@ Train: - E2ELabelEncodeTrain: - PGProcessTrain: batch_size: 14 # same as loader: batch_size_per_card + use_resize: True + use_random_crop: False min_crop_size: 24 min_text_size: 4 max_text_size: 512 + point_gather_mode: align # two mode: align and none, align mode is better than none mode - KeepKeys: keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order loader: diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py index 53031064c019ddce00c7546f898ac67a7f0459f9..f1e5f912b7a55dc3b9e883a9f4f8c5de482dcd5a 100644 --- a/ppocr/data/imaug/pg_process.py +++ b/ppocr/data/imaug/pg_process.py @@ -15,6 +15,8 @@ import math import cv2 import numpy as np +from skimage.morphology._skeletonize import thin +from ppocr.utils.e2e_utils.extract_textpoint_fast import sort_and_expand_with_direction_v2 __all__ = ['PGProcessTrain'] @@ -26,17 +28,24 @@ class PGProcessTrain(object): max_text_nums, tcl_len, batch_size=14, + use_resize=True, + use_random_crop=False, min_crop_size=24, min_text_size=4, max_text_size=512, + point_gather_mode=None, **kwargs): self.tcl_len = tcl_len self.max_text_length = max_text_length self.max_text_nums = max_text_nums self.batch_size = batch_size - self.min_crop_size = min_crop_size + if use_random_crop is True: + self.min_crop_size = min_crop_size + self.use_random_crop = use_random_crop self.min_text_size = min_text_size self.max_text_size = max_text_size + self.use_resize = use_resize + self.point_gather_mode = point_gather_mode self.Lexicon_Table = self.get_dict(character_dict_path) self.pad_num = len(self.Lexicon_Table) self.img_id = 0 @@ -282,6 +291,95 @@ class PGProcessTrain(object): pos_m[:keep] = 1.0 return pos_l, pos_m + def fit_and_gather_tcl_points_v3(self, + min_area_quad, + poly, + max_h, + max_w, + fixed_point_num=64, + img_id=0, + reference_height=3): + """ + Find the center point of poly as key_points, then fit and gather. + """ + det_mask = np.zeros((int(max_h / self.ds_ratio), + int(max_w / self.ds_ratio))).astype(np.float32) + + # score_big_map + cv2.fillPoly(det_mask, + np.round(poly / self.ds_ratio).astype(np.int32), 1.0) + det_mask = cv2.resize( + det_mask, dsize=None, fx=self.ds_ratio, fy=self.ds_ratio) + det_mask = np.array(det_mask > 1e-3, dtype='float32') + + f_direction = self.f_direction + skeleton_map = thin(det_mask.astype(np.uint8)) + instance_count, instance_label_map = cv2.connectedComponents( + skeleton_map.astype(np.uint8), connectivity=8) + + ys, xs = np.where(instance_label_map == 1) + pos_list = list(zip(ys, xs)) + if len(pos_list) < 3: + return None + pos_list_sorted = sort_and_expand_with_direction_v2( + pos_list, f_direction, det_mask) + + pos_list_sorted = np.array(pos_list_sorted) + length = len(pos_list_sorted) - 1 + insert_num = 0 + for index in range(length): + stride_y = np.abs(pos_list_sorted[index + insert_num][0] - + pos_list_sorted[index + 1 + insert_num][0]) + stride_x = np.abs(pos_list_sorted[index + insert_num][1] - + pos_list_sorted[index + 1 + insert_num][1]) + max_points = int(max(stride_x, stride_y)) + + stride = (pos_list_sorted[index + insert_num] - + pos_list_sorted[index + 1 + insert_num]) / (max_points) + insert_num_temp = max_points - 1 + + for i in range(int(insert_num_temp)): + insert_value = pos_list_sorted[index + insert_num] - (i + 1 + ) * stride + insert_index = index + i + 1 + insert_num + pos_list_sorted = np.insert( + pos_list_sorted, insert_index, insert_value, axis=0) + insert_num += insert_num_temp + + pos_info = np.array(pos_list_sorted).reshape(-1, 2).astype( + np.float32) # xy-> yx + + point_num = len(pos_info) + if point_num > fixed_point_num: + keep_ids = [ + int((point_num * 1.0 / fixed_point_num) * x) + for x in range(fixed_point_num) + ] + pos_info = pos_info[keep_ids, :] + + keep = int(min(len(pos_info), fixed_point_num)) + reference_width = (np.abs(poly[0, 0, 0] - poly[-1, 1, 0]) + + np.abs(poly[0, 3, 0] - poly[-1, 2, 0])) // 2 + if np.random.rand() < 1: + dh = (np.random.rand(keep) - 0.5) * reference_height + offset = np.random.rand() - 0.5 + dw = np.array([[0, offset * reference_width * 0.2]]) + random_float_h = np.array([1, 0]).reshape([1, 2]) * dh.reshape( + [keep, 1]) + random_float_w = dw.repeat(keep, axis=0) + pos_info += random_float_h + pos_info += random_float_w + pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1) + pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1) + + # padding to fixed length + pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32) + pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id + pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32) + pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32) + pos_m[:keep] = 1.0 + return pos_l, pos_m + def generate_direction_map(self, poly_quads, n_char, direction_map): """ """ @@ -334,6 +432,7 @@ class PGProcessTrain(object): """ Generate polygon. """ + self.ds_ratio = ds_ratio score_map_big = np.zeros( ( h, @@ -384,7 +483,6 @@ class PGProcessTrain(object): text_label = text_strs[poly_idx] text_label = self.prepare_text_label(text_label, self.Lexicon_Table) - text_label_index_list = [[self.Lexicon_Table.index(c_)] for c_ in text_label if c_ in self.Lexicon_Table] @@ -432,14 +530,30 @@ class PGProcessTrain(object): # pos info average_shrink_height = self.calculate_average_height( stcl_quads) - pos_l, pos_m = self.fit_and_gather_tcl_points_v2( - min_area_quad, - poly, - max_h=h, - max_w=w, - fixed_point_num=64, - img_id=self.img_id, - reference_height=average_shrink_height) + + if self.point_gather_mode == 'align': + self.f_direction = direction_map[:, :, :-1].copy() + pos_res = self.fit_and_gather_tcl_points_v3( + min_area_quad, + stcl_quads, + max_h=h, + max_w=w, + fixed_point_num=64, + img_id=self.img_id, + reference_height=average_shrink_height) + if pos_res is None: + continue + pos_l, pos_m = pos_res[0], pos_res[1] + + else: + pos_l, pos_m = self.fit_and_gather_tcl_points_v2( + min_area_quad, + poly, + max_h=h, + max_w=w, + fixed_point_num=64, + img_id=self.img_id, + reference_height=average_shrink_height) label_l = text_label_index_list if len(text_label_index_list) < 2: @@ -770,27 +884,41 @@ class PGProcessTrain(object): text_polys[:, :, 0] *= asp_wx text_polys[:, :, 1] *= asp_hy - h, w, _ = im.shape - if max(h, w) > 2048: - rd_scale = 2048.0 / max(h, w) - im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) - text_polys *= rd_scale - h, w, _ = im.shape - if min(h, w) < 16: - return None - - # no background - im, text_polys, text_tags, hv_tags, text_strs = self.crop_area( - im, - text_polys, - text_tags, - hv_tags, - text_strs, - crop_background=False) + if self.use_resize is True: + ori_h, ori_w, _ = im.shape + if max(ori_h, ori_w) < 200: + ratio = 200 / max(ori_h, ori_w) + im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio))) + text_polys[:, :, 0] *= ratio + text_polys[:, :, 1] *= ratio + + if max(ori_h, ori_w) > 512: + ratio = 512 / max(ori_h, ori_w) + im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio))) + text_polys[:, :, 0] *= ratio + text_polys[:, :, 1] *= ratio + elif self.use_random_crop is True: + h, w, _ = im.shape + if max(h, w) > 2048: + rd_scale = 2048.0 / max(h, w) + im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale) + text_polys *= rd_scale + h, w, _ = im.shape + if min(h, w) < 16: + return None + + # no background + im, text_polys, text_tags, hv_tags, text_strs = self.crop_area( + im, + text_polys, + text_tags, + hv_tags, + text_strs, + crop_background=False) if text_polys.shape[0] == 0: return None - # # continue for all ignore case + # continue for all ignore case if np.sum((text_tags * 1.0)) >= text_tags.size: return None new_h, new_w, _ = im.shape diff --git a/ppocr/losses/e2e_pg_loss.py b/ppocr/losses/e2e_pg_loss.py index 10a8ed0aa907123b155976ba498426604f23c2b0..aff67b7ce3c208bf9c7b1371e095eac8c70ce9df 100644 --- a/ppocr/losses/e2e_pg_loss.py +++ b/ppocr/losses/e2e_pg_loss.py @@ -89,12 +89,13 @@ class PGLoss(nn.Layer): tcl_pos = paddle.reshape(tcl_pos, [-1, 3]) tcl_pos = paddle.cast(tcl_pos, dtype=int) f_tcl_char = paddle.gather_nd(f_char, tcl_pos) - f_tcl_char = paddle.reshape(f_tcl_char, - [-1, 64, 37]) # len(Lexicon_Table)+1 - f_tcl_char_fg, f_tcl_char_bg = paddle.split(f_tcl_char, [36, 1], axis=2) + f_tcl_char = paddle.reshape( + f_tcl_char, [-1, 64, self.pad_num + 1]) # len(Lexicon_Table)+1 + f_tcl_char_fg, f_tcl_char_bg = paddle.split( + f_tcl_char, [self.pad_num, 1], axis=2) f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0 b, c, l = tcl_mask.shape - tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, 36 * l]) + tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, self.pad_num * l]) tcl_mask_fg.stop_gradient = True f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * ( -20.0) diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py index 274e1cdac5172f45590c9f7d7b50522c74db6750..514962ef97e503d331b6351c6d314070dfd8b15f 100644 --- a/ppocr/modeling/heads/e2e_pg_head.py +++ b/ppocr/modeling/heads/e2e_pg_head.py @@ -66,8 +66,17 @@ class PGHead(nn.Layer): """ """ - def __init__(self, in_channels, **kwargs): + def __init__(self, + in_channels, + character_dict_path='ppocr/utils/ic15_dict.txt', + **kwargs): super(PGHead, self).__init__() + + # get character_length + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + character_length = len(lines) + 1 + self.conv_f_score1 = ConvBNLayer( in_channels=in_channels, out_channels=64, @@ -178,7 +187,7 @@ class PGHead(nn.Layer): name="conv_f_char{}".format(5)) self.conv3 = nn.Conv2D( in_channels=256, - out_channels=37, + out_channels=character_length, kernel_size=3, stride=1, padding=1, diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py index 0b1455181fddb0adb5347406bb2eb3093ee6fb30..058cf8b907de296094d3ed2fc7e6981939ced328 100644 --- a/ppocr/postprocess/pg_postprocess.py +++ b/ppocr/postprocess/pg_postprocess.py @@ -30,12 +30,18 @@ class PGPostProcess(object): The post process for PGNet. """ - def __init__(self, character_dict_path, valid_set, score_thresh, mode, + def __init__(self, + character_dict_path, + valid_set, + score_thresh, + mode, + point_gather_mode=None, **kwargs): self.character_dict_path = character_dict_path self.valid_set = valid_set self.score_thresh = score_thresh self.mode = mode + self.point_gather_mode = point_gather_mode # c++ la-nms is faster, but only support python 3.5 self.is_python35 = False @@ -43,8 +49,13 @@ class PGPostProcess(object): self.is_python35 = True def __call__(self, outs_dict, shape_list): - post = PGNet_PostProcess(self.character_dict_path, self.valid_set, - self.score_thresh, outs_dict, shape_list) + post = PGNet_PostProcess( + self.character_dict_path, + self.valid_set, + self.score_thresh, + outs_dict, + shape_list, + point_gather_mode=self.point_gather_mode) if self.mode == 'fast': data = post.pg_postprocess_fast() else: diff --git a/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/ppocr/utils/e2e_utils/extract_textpoint_fast.py index 787cd3017fafa6fc554bead0cc05b5bfe682df42..a85b8e78ead00e64630b57400b9e5141eb0181a8 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint_fast.py +++ b/ppocr/utils/e2e_utils/extract_textpoint_fast.py @@ -88,8 +88,35 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): return dst_str, keep_idx_list -def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): +def instance_ctc_greedy_decoder(gather_info, + logits_map, + pts_num=4, + point_gather_mode=None): _, _, C = logits_map.shape + if point_gather_mode == 'align': + insert_num = 0 + gather_info = np.array(gather_info) + length = len(gather_info) - 1 + for index in range(length): + stride_y = np.abs(gather_info[index + insert_num][0] - gather_info[ + index + 1 + insert_num][0]) + stride_x = np.abs(gather_info[index + insert_num][1] - gather_info[ + index + 1 + insert_num][1]) + max_points = int(max(stride_x, stride_y)) + stride = (gather_info[index + insert_num] - + gather_info[index + 1 + insert_num]) / (max_points) + insert_num_temp = max_points - 1 + + for i in range(int(insert_num_temp)): + insert_value = gather_info[index + insert_num] - (i + 1 + ) * stride + insert_index = index + i + 1 + insert_num + gather_info = np.insert( + gather_info, insert_index, insert_value, axis=0) + insert_num += insert_num_temp + gather_info = gather_info.tolist() + else: + pass ys, xs = zip(*gather_info) logits_seq = logits_map[list(ys), list(xs)] probs_seq = logits_seq @@ -104,7 +131,8 @@ def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4): def ctc_decoder_for_image(gather_info_list, logits_map, Lexicon_Table, - pts_num=6): + pts_num=6, + point_gather_mode=None): """ CTC decoder using multiple processes. """ @@ -114,7 +142,10 @@ def ctc_decoder_for_image(gather_info_list, if len(gather_info) < pts_num: continue dst_str, xys_list = instance_ctc_greedy_decoder( - gather_info, logits_map, pts_num=pts_num) + gather_info, + logits_map, + pts_num=pts_num, + point_gather_mode=point_gather_mode) dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str]) if len(dst_str_readable) < 2: continue @@ -356,7 +387,8 @@ def generate_pivot_list_fast(p_score, p_char_maps, f_direction, Lexicon_Table, - score_thresh=0.5): + score_thresh=0.5, + point_gather_mode=None): """ return center point and end point of TCL instance; filter with the char maps; """ @@ -384,7 +416,10 @@ def generate_pivot_list_fast(p_score, p_char_maps = p_char_maps.transpose([1, 2, 0]) decoded_str, keep_yxs_list = ctc_decoder_for_image( - all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table) + all_pos_yxs, + logits_map=p_char_maps, + Lexicon_Table=Lexicon_Table, + point_gather_mode=point_gather_mode) return keep_yxs_list, decoded_str diff --git a/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/ppocr/utils/e2e_utils/pgnet_pp_utils.py index a15503c0a88f735cc5f5eef924b0d022e5684eed..06a766b0e714e2792c0b0d3069963de998eb9eb7 100644 --- a/ppocr/utils/e2e_utils/pgnet_pp_utils.py +++ b/ppocr/utils/e2e_utils/pgnet_pp_utils.py @@ -28,13 +28,19 @@ from extract_textpoint_fast import generate_pivot_list_fast, restore_poly class PGNet_PostProcess(object): # two different post-process - def __init__(self, character_dict_path, valid_set, score_thresh, outs_dict, - shape_list): + def __init__(self, + character_dict_path, + valid_set, + score_thresh, + outs_dict, + shape_list, + point_gather_mode=None): self.Lexicon_Table = get_dict(character_dict_path) self.valid_set = valid_set self.score_thresh = score_thresh self.outs_dict = outs_dict self.shape_list = shape_list + self.point_gather_mode = point_gather_mode def pg_postprocess_fast(self): p_score = self.outs_dict['f_score'] @@ -58,7 +64,8 @@ class PGNet_PostProcess(object): p_char, p_direction, self.Lexicon_Table, - score_thresh=self.score_thresh) + score_thresh=self.score_thresh, + point_gather_mode=self.point_gather_mode) poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, src_h, self.valid_set) diff --git a/tools/infer_e2e.py b/tools/infer_e2e.py index d3e6b28fca0a3ff32ea940747712d6c71aa290fd..37fdcbaadc2984c9cf4fb105b7122db31b99be30 100755 --- a/tools/infer_e2e.py +++ b/tools/infer_e2e.py @@ -37,6 +37,46 @@ from ppocr.postprocess import build_post_process from ppocr.utils.save_load import load_model from ppocr.utils.utility import get_image_file_list import tools.program as program +from PIL import Image, ImageDraw, ImageFont +import math + + +def draw_e2e_res_for_chinese(image, + boxes, + txts, + config, + img_name, + font_path="./doc/simfang.ttf"): + h, w = image.height, image.width + img_left = image.copy() + img_right = Image.new('RGB', (w, h), (255, 255, 255)) + + import random + + random.seed(0) + draw_left = ImageDraw.Draw(img_left) + draw_right = ImageDraw.Draw(img_right) + for idx, (box, txt) in enumerate(zip(boxes, txts)): + box = np.array(box) + box = [tuple(x) for x in box] + color = (random.randint(0, 255), random.randint(0, 255), + random.randint(0, 255)) + draw_left.polygon(box, fill=color) + draw_right.polygon(box, outline=color) + font = ImageFont.truetype(font_path, 15, encoding="utf-8") + draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) + img_left = Image.blend(image, img_left, 0.5) + img_show = Image.new('RGB', (w * 2, h), (255, 255, 255)) + img_show.paste(img_left, (0, 0, w, h)) + img_show.paste(img_right, (w, 0, w * 2, h)) + + save_e2e_path = os.path.dirname(config['Global'][ + 'save_res_path']) + "/e2e_results/" + if not os.path.exists(save_e2e_path): + os.makedirs(save_e2e_path) + save_path = os.path.join(save_e2e_path, os.path.basename(img_name)) + cv2.imwrite(save_path, np.array(img_show)[:, :, ::-1]) + logger.info("The e2e Image saved in {}".format(save_path)) def draw_e2e_res(dt_boxes, strs, config, img, img_name): @@ -113,7 +153,19 @@ def main(): otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n" fout.write(otstr.encode()) src_img = cv2.imread(file) - draw_e2e_res(points, strs, config, src_img, file) + if global_config['infer_visual_type'] == 'EN': + draw_e2e_res(points, strs, config, src_img, file) + elif global_config['infer_visual_type'] == 'CN': + src_img = Image.fromarray( + cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)) + draw_e2e_res_for_chinese( + src_img, + points, + strs, + config, + file, + font_path="./doc/fonts/simfang.ttf") + logger.info("success!")