From 4c0b08733d41946d4c4817878511f23d2f68feb0 Mon Sep 17 00:00:00 2001 From: wangjingyeye <1025993141@qq.com> Date: Mon, 5 Sep 2022 07:03:16 +0000 Subject: [PATCH] update pgnet --- configs/e2e/e2e_r50_vd_pg.yml | 6 +++--- ppocr/data/imaug/pg_process.py | 8 ++++---- ppocr/modeling/heads/e2e_pg_head.py | 13 +++++++++++-- ppocr/postprocess/pg_postprocess.py | 6 +++--- ppocr/utils/e2e_utils/extract_textpoint_fast.py | 12 ++++++------ ppocr/utils/e2e_utils/pgnet_pp_utils.py | 6 +++--- 6 files changed, 30 insertions(+), 21 deletions(-) diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index 5f1fde6b..4adbd2d4 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -33,7 +33,7 @@ Architecture: name: PGFPN Head: name: PGHead - tcc_channels: 37 # the length of character dict + character_dict_path: ppocr/utils/ic15_dict.txt # the same as Global:character_dict_path Loss: name: PGLoss @@ -58,7 +58,7 @@ PostProcess: name: PGPostProcess score_thresh: 0.5 mode: fast # fast or slow two ways - tcc_type: v3 # same as PGProcessTrain: tcc_type + point_gather_mode: v3 # same as PGProcessTrain: point_gather_mode Metric: name: E2EMetric @@ -85,7 +85,7 @@ Train: min_crop_size: 24 min_text_size: 4 max_text_size: 512 - tcc_type: v3 # two ways, v2 is original code, v3 is updated code + point_gather_mode: v3 # two ways, v2 is original code, v3 is updated code - KeepKeys: keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order loader: diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py index 2c8f8821..622a5a68 100644 --- a/ppocr/data/imaug/pg_process.py +++ b/ppocr/data/imaug/pg_process.py @@ -33,7 +33,7 @@ class PGProcessTrain(object): min_crop_size=24, min_text_size=4, max_text_size=512, - tcc_type='v3', + point_gather_mode='v3', **kwargs): self.tcl_len = tcl_len self.max_text_length = max_text_length @@ -45,7 +45,7 @@ class PGProcessTrain(object): self.min_text_size = min_text_size self.max_text_size = max_text_size self.use_resize = use_resize - self.tcc_type = tcc_type + self.point_gather_mode = point_gather_mode self.Lexicon_Table = self.get_dict(character_dict_path) self.pad_num = len(self.Lexicon_Table) self.img_id = 0 @@ -531,7 +531,7 @@ class PGProcessTrain(object): average_shrink_height = self.calculate_average_height( stcl_quads) - if self.tcc_type == 'v3': + if self.point_gather_mode == 'v3': self.f_direction = direction_map[:, :, :-1].copy() pos_res = self.fit_and_gather_tcl_points_v3( min_area_quad, @@ -545,7 +545,7 @@ class PGProcessTrain(object): continue pos_l, pos_m = pos_res[0], pos_res[1] - elif self.tcc_type == 'v2': + elif self.point_gather_mode == 'v2': pos_l, pos_m = self.fit_and_gather_tcl_points_v2( min_area_quad, poly, diff --git a/ppocr/modeling/heads/e2e_pg_head.py b/ppocr/modeling/heads/e2e_pg_head.py index 4bdabeb4..514962ef 100644 --- a/ppocr/modeling/heads/e2e_pg_head.py +++ b/ppocr/modeling/heads/e2e_pg_head.py @@ -66,8 +66,17 @@ class PGHead(nn.Layer): """ """ - def __init__(self, in_channels, tcc_channels=37, **kwargs): + def __init__(self, + in_channels, + character_dict_path='ppocr/utils/ic15_dict.txt', + **kwargs): super(PGHead, self).__init__() + + # get character_length + with open(character_dict_path, "rb") as fin: + lines = fin.readlines() + character_length = len(lines) + 1 + self.conv_f_score1 = ConvBNLayer( in_channels=in_channels, out_channels=64, @@ -178,7 +187,7 @@ class PGHead(nn.Layer): name="conv_f_char{}".format(5)) self.conv3 = nn.Conv2D( in_channels=256, - out_channels=tcc_channels, + out_channels=character_length, kernel_size=3, stride=1, padding=1, diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py index 7f17579b..1a52979c 100644 --- a/ppocr/postprocess/pg_postprocess.py +++ b/ppocr/postprocess/pg_postprocess.py @@ -31,12 +31,12 @@ class PGPostProcess(object): """ def __init__(self, character_dict_path, valid_set, score_thresh, mode, - tcc_type, **kwargs): + point_gather_mode, **kwargs): self.character_dict_path = character_dict_path self.valid_set = valid_set self.score_thresh = score_thresh self.mode = mode - self.tcc_type = tcc_type + self.point_gather_mode = point_gather_mode # c++ la-nms is faster, but only support python 3.5 self.is_python35 = False @@ -50,7 +50,7 @@ class PGPostProcess(object): self.score_thresh, outs_dict, shape_list, - tcc_type=self.tcc_type) + point_gather_mode=self.point_gather_mode) if self.mode == 'fast': data = post.pg_postprocess_fast() else: diff --git a/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/ppocr/utils/e2e_utils/extract_textpoint_fast.py index fee4145f..6cf3eb84 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint_fast.py +++ b/ppocr/utils/e2e_utils/extract_textpoint_fast.py @@ -91,9 +91,9 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4, - tcc_type='v3'): + point_gather_mode='v3'): _, _, C = logits_map.shape - if tcc_type == 'v3': + if point_gather_mode == 'v3': insert_num = 0 gather_info = np.array(gather_info) length = len(gather_info) - 1 @@ -130,7 +130,7 @@ def ctc_decoder_for_image(gather_info_list, logits_map, Lexicon_Table, pts_num=6, - tcc_type='v3'): + point_gather_mode='v3'): """ CTC decoder using multiple processes. """ @@ -140,7 +140,7 @@ def ctc_decoder_for_image(gather_info_list, if len(gather_info) < pts_num: continue dst_str, xys_list = instance_ctc_greedy_decoder( - gather_info, logits_map, pts_num=pts_num, tcc_type='v3') + gather_info, logits_map, pts_num=pts_num, point_gather_mode='v3') dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str]) if len(dst_str_readable) < 2: continue @@ -383,7 +383,7 @@ def generate_pivot_list_fast(p_score, f_direction, Lexicon_Table, score_thresh=0.5, - tcc_type='v3'): + point_gather_mode='v3'): """ return center point and end point of TCL instance; filter with the char maps; """ @@ -414,7 +414,7 @@ def generate_pivot_list_fast(p_score, all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table, - tcc_type='v3') + point_gather_mode='v3') return keep_yxs_list, decoded_str diff --git a/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/ppocr/utils/e2e_utils/pgnet_pp_utils.py index 605ab0e1..12f9dac5 100644 --- a/ppocr/utils/e2e_utils/pgnet_pp_utils.py +++ b/ppocr/utils/e2e_utils/pgnet_pp_utils.py @@ -34,13 +34,13 @@ class PGNet_PostProcess(object): score_thresh, outs_dict, shape_list, - tcc_type='v3'): + point_gather_mode='v3'): self.Lexicon_Table = get_dict(character_dict_path) self.valid_set = valid_set self.score_thresh = score_thresh self.outs_dict = outs_dict self.shape_list = shape_list - self.tcc_type = tcc_type + self.point_gather_mode = point_gather_mode def pg_postprocess_fast(self): p_score = self.outs_dict['f_score'] @@ -65,7 +65,7 @@ class PGNet_PostProcess(object): p_direction, self.Lexicon_Table, score_thresh=self.score_thresh, - tcc_type=self.tcc_type) + point_gather_mode=self.point_gather_mode) poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w, src_h, self.valid_set) -- GitLab