From 310d399b8335aec29f967d2b2c7609506b8c0780 Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Mon, 8 Mar 2021 15:44:07 +0800 Subject: [PATCH] ADD PGnet_v3 --- ppocr/data/pgnet_dataset.py | 6 ++++-- ppocr/postprocess/pg_postprocess.py | 19 ------------------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py index 82c580ce..ed970d7e 100644 --- a/ppocr/data/pgnet_dataset.py +++ b/ppocr/data/pgnet_dataset.py @@ -19,10 +19,11 @@ import random class PGDateSet(Dataset): - def __init__(self, config, mode, logger): + def __init__(self, config, mode, logger, seed=None): super(PGDateSet, self).__init__() self.logger = logger + self.seed = seed global_config = config['Global'] dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] @@ -36,7 +37,6 @@ class PGDateSet(Dataset): assert len( ratio_list ) == data_source_num, "The length of ratio_list should be the same as the file_list." - # self.data_dir = dataset_config['data_dir'] self.do_shuffle = loader_config['shuffle'] logger.info("Initialize indexs of datasets:%s" % label_file_list) @@ -50,6 +50,7 @@ class PGDateSet(Dataset): def shuffle_data_random(self): if self.do_shuffle: + random.seed(self.seed) random.shuffle(self.data_lines) return @@ -122,6 +123,7 @@ class PGDateSet(Dataset): else: print("Unrecognized data format...") exit(-1) + random.seed(self.seed) image_files = random.sample( image_files, round(len(image_files) * ratio_list[idx])) data_lines.extend(image_files) diff --git a/ppocr/postprocess/pg_postprocess.py b/ppocr/postprocess/pg_postprocess.py index 90031a83..1b340b42 100644 --- a/ppocr/postprocess/pg_postprocess.py +++ b/ppocr/postprocess/pg_postprocess.py @@ -113,7 +113,6 @@ class PGPostProcess(object): all_point_pair_list = [] for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs): if len(yx_center_line) == 1: - print('the length of tcl point is less than 2, repeat') yx_center_line.append(yx_center_line[-1]) # expand corresponding offset for total-text. @@ -148,7 +147,6 @@ class PGPostProcess(object): # ndarry: (x, 2) detected_poly, pair_length_info = point_pair2poly(point_pair_list) - print('expand along width. {}'.format(detected_poly.shape)) detected_poly = expand_poly_along_width( detected_poly, shrink_ratio_of_width=0.2) detected_poly[:, 0] = np.clip( @@ -157,7 +155,6 @@ class PGPostProcess(object): detected_poly[:, 1], a_min=0, a_max=src_h) if len(keep_str) < 2: - print('--> too short, {}'.format(keep_str)) continue keep_str_list.append(keep_str) @@ -175,20 +172,4 @@ class PGPostProcess(object): 'points': poly_list, 'strs': keep_str_list, } - # visualization - # if self.save_visualization: - # visualize_e2e_result(im_fn, poly_list, keep_str_list, src_im) - # visualize_point_result(im_fn, all_point_list, all_point_pair_list, src_im) - - # save detected boxes - # txt_dir = (result_path[:-1] if result_path.endswith('/') else result_path) + '_txt_anno' - # if not os.path.exists(txt_dir): - # os.makedirs(txt_dir) - # res_file = os.path.join(txt_dir, '{}.txt'.format(im_prefix)) - # with open(res_file, 'w') as f: - # for i_box, box in enumerate(poly_list): - # seq_str = keep_str_list[i_box] - # box = np.round(box).astype('int32') - # box_str = ','.join(str(s) for s in (box.flatten().tolist())) - # f.write('{}\t{}\r\n'.format(box_str, seq_str)) return data -- GitLab