diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index 4adbd2d4302d3098c99bad6e4286f8fcf9664513..4642f544868f720d413f7f5242740705bc9fd0a5 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -58,7 +58,7 @@ PostProcess: name: PGPostProcess score_thresh: 0.5 mode: fast # fast or slow two ways - point_gather_mode: v3 # same as PGProcessTrain: point_gather_mode + point_gather_mode: align # same as PGProcessTrain: point_gather_mode Metric: name: E2EMetric @@ -85,7 +85,7 @@ Train: min_crop_size: 24 min_text_size: 4 max_text_size: 512 - point_gather_mode: v3 # two ways, v2 is original code, v3 is updated code + point_gather_mode: align # two mode: align and none, align mode is better than none mode - KeepKeys: keep_keys: [ 'images', 'tcl_maps', 'tcl_label_maps', 'border_maps','direction_maps', 'training_masks', 'label_list', 'pos_list', 'pos_mask' ] # dataloader will return list in this order loader: diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py index 622a5a68f176040063d2b619ec678893e3c3ad1d..f1e5f912b7a55dc3b9e883a9f4f8c5de482dcd5a 100644 --- a/ppocr/data/imaug/pg_process.py +++ b/ppocr/data/imaug/pg_process.py @@ -33,7 +33,7 @@ class PGProcessTrain(object): min_crop_size=24, min_text_size=4, max_text_size=512, - point_gather_mode='v3', + point_gather_mode=None, **kwargs): self.tcl_len = tcl_len self.max_text_length = max_text_length @@ -531,7 +531,7 @@ class PGProcessTrain(object): average_shrink_height = self.calculate_average_height( stcl_quads) - if self.point_gather_mode == 'v3': + if self.point_gather_mode == 'align': self.f_direction = direction_map[:, :, :-1].copy() pos_res = self.fit_and_gather_tcl_points_v3( min_area_quad, @@ -545,7 +545,7 @@ class PGProcessTrain(object): continue pos_l, pos_m = pos_res[0], pos_res[1] - elif self.point_gather_mode == 'v2': + else: pos_l, pos_m = self.fit_and_gather_tcl_points_v2( min_area_quad, poly, diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py index 1a52979c1487b1d439e07a4bb972fd9e1f39901f..058cf8b907de296094d3ed2fc7e6981939ced328 100644 --- a/ppocr/postprocess/pg_postprocess.py +++ b/ppocr/postprocess/pg_postprocess.py @@ -30,8 +30,13 @@ class PGPostProcess(object): The post process for PGNet. """ - def __init__(self, character_dict_path, valid_set, score_thresh, mode, - point_gather_mode, **kwargs): + def __init__(self, + character_dict_path, + valid_set, + score_thresh, + mode, + point_gather_mode=None, + **kwargs): self.character_dict_path = character_dict_path self.valid_set = valid_set self.score_thresh = score_thresh diff --git a/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/ppocr/utils/e2e_utils/extract_textpoint_fast.py index 6cf3eb84539e9d00ae942e25dc43ff88d7ae323a..a85b8e78ead00e64630b57400b9e5141eb0181a8 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint_fast.py +++ b/ppocr/utils/e2e_utils/extract_textpoint_fast.py @@ -91,9 +91,9 @@ def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True): def instance_ctc_greedy_decoder(gather_info, logits_map, pts_num=4, - point_gather_mode='v3'): + point_gather_mode=None): _, _, C = logits_map.shape - if point_gather_mode == 'v3': + if point_gather_mode == 'align': insert_num = 0 gather_info = np.array(gather_info) length = len(gather_info) - 1 @@ -115,6 +115,8 @@ def instance_ctc_greedy_decoder(gather_info, gather_info, insert_index, insert_value, axis=0) insert_num += insert_num_temp gather_info = gather_info.tolist() + else: + pass ys, xs = zip(*gather_info) logits_seq = logits_map[list(ys), list(xs)] probs_seq = logits_seq @@ -130,7 +132,7 @@ def ctc_decoder_for_image(gather_info_list, logits_map, Lexicon_Table, pts_num=6, - point_gather_mode='v3'): + point_gather_mode=None): """ CTC decoder using multiple processes. """ @@ -140,7 +142,10 @@ def ctc_decoder_for_image(gather_info_list, if len(gather_info) < pts_num: continue dst_str, xys_list = instance_ctc_greedy_decoder( - gather_info, logits_map, pts_num=pts_num, point_gather_mode='v3') + gather_info, + logits_map, + pts_num=pts_num, + point_gather_mode=point_gather_mode) dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str]) if len(dst_str_readable) < 2: continue @@ -383,7 +388,7 @@ def generate_pivot_list_fast(p_score, f_direction, Lexicon_Table, score_thresh=0.5, - point_gather_mode='v3'): + point_gather_mode=None): """ return center point and end point of TCL instance; filter with the char maps; """ @@ -414,7 +419,7 @@ def generate_pivot_list_fast(p_score, all_pos_yxs, logits_map=p_char_maps, Lexicon_Table=Lexicon_Table, - point_gather_mode='v3') + point_gather_mode=point_gather_mode) return keep_yxs_list, decoded_str diff --git a/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/ppocr/utils/e2e_utils/pgnet_pp_utils.py index 12f9dac5f336c4f3f09258842500f2331ddbd8b3..06a766b0e714e2792c0b0d3069963de998eb9eb7 100644 --- a/ppocr/utils/e2e_utils/pgnet_pp_utils.py +++ b/ppocr/utils/e2e_utils/pgnet_pp_utils.py @@ -34,7 +34,7 @@ class PGNet_PostProcess(object): score_thresh, outs_dict, shape_list, - point_gather_mode='v3'): + point_gather_mode=None): self.Lexicon_Table = get_dict(character_dict_path) self.valid_set = valid_set self.score_thresh = score_thresh