From 23eb335dedeb597a56b757d6d5de1b5a0eb5dc87 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Wed, 12 Jan 2022 09:54:07 +0000 Subject: [PATCH] add need_reset to dataset --- ppocr/data/lmdb_dataset.py | 3 +++ ppocr/data/pgnet_dataset.py | 2 ++ ppocr/data/pubtab_dataset.py | 15 +++++++++++---- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index e2d6dc93..e1b49809 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -38,6 +38,9 @@ class LMDBDataSet(Dataset): np.random.shuffle(self.data_idx_order_list) self.ops = create_operators(dataset_config['transforms'], global_config) + ratio_list = dataset_config.get("ratio_list", [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + def load_hierarchical_lmdb_dataset(self, data_dir): lmdb_sets = {} dataset_idx = 0 diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py index 5adcd02c..6f80179c 100644 --- a/ppocr/data/pgnet_dataset.py +++ b/ppocr/data/pgnet_dataset.py @@ -49,6 +49,8 @@ class PGDataSet(Dataset): self.ops = create_operators(dataset_config['transforms'], global_config) + self.need_reset = True in [x < 1 for x in ratio_list] + def shuffle_data_random(self): if self.do_shuffle: random.seed(self.seed) diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py index 78b76c5a..671cda76 100644 --- a/ppocr/data/pubtab_dataset.py +++ b/ppocr/data/pubtab_dataset.py @@ -53,6 +53,9 @@ class PubTabDataSet(Dataset): self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) + ratio_list = dataset_config.get("ratio_list", [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + def shuffle_data_random(self): if self.do_shuffle: random.seed(self.seed) @@ -70,7 +73,7 @@ class PubTabDataSet(Dataset): prob = self.img_select_prob[file_name] if prob < random.uniform(0, 1): select_flag = False - + if self.table_select_type: structure = info['html']['structure']['tokens'].copy() structure_str = ''.join(structure) @@ -79,13 +82,17 @@ class PubTabDataSet(Dataset): table_type = "complex" if table_type == "complex": if self.table_select_prob < random.uniform(0, 1): - select_flag = False - + select_flag = False + if select_flag: cells = info['html']['cells'].copy() structure = info['html']['structure'].copy() img_path = os.path.join(self.data_dir, file_name) - data = {'img_path': img_path, 'cells': cells, 'structure':structure} + data = { + 'img_path': img_path, + 'cells': cells, + 'structure': structure + } if not os.path.exists(img_path): raise Exception("{} does not exist!".format(img_path)) with open(data['img_path'], 'rb') as f: -- GitLab