diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index e2d6dc9327bf3725d2fb6c32d18c0b71bd6ac408..e1b49809d199096ad06b90c4562aa5dbfa634db1 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -38,6 +38,9 @@ class LMDBDataSet(Dataset): np.random.shuffle(self.data_idx_order_list) self.ops = create_operators(dataset_config['transforms'], global_config) + ratio_list = dataset_config.get("ratio_list", [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + def load_hierarchical_lmdb_dataset(self, data_dir): lmdb_sets = {} dataset_idx = 0 diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py index 5adcd02c4a24074c0252a8590fd89f015a6ff152..6f80179c4eb971ace360edb5368f6a2acd5a6322 100644 --- a/ppocr/data/pgnet_dataset.py +++ b/ppocr/data/pgnet_dataset.py @@ -49,6 +49,8 @@ class PGDataSet(Dataset): self.ops = create_operators(dataset_config['transforms'], global_config) + self.need_reset = True in [x < 1 for x in ratio_list] + def shuffle_data_random(self): if self.do_shuffle: random.seed(self.seed) diff --git a/ppocr/data/pubtab_dataset.py b/ppocr/data/pubtab_dataset.py index 78b76c5afb8c96bc96730c7b8ad76b4bafa31c67..671cda76fb4c36f3ac6bcc7da5a7fc4de241c0e2 100644 --- a/ppocr/data/pubtab_dataset.py +++ b/ppocr/data/pubtab_dataset.py @@ -53,6 +53,9 @@ class PubTabDataSet(Dataset): self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) + ratio_list = dataset_config.get("ratio_list", [1.0]) + self.need_reset = True in [x < 1 for x in ratio_list] + def shuffle_data_random(self): if self.do_shuffle: random.seed(self.seed) @@ -70,7 +73,7 @@ class PubTabDataSet(Dataset): prob = self.img_select_prob[file_name] if prob < random.uniform(0, 1): select_flag = False - + if self.table_select_type: structure = info['html']['structure']['tokens'].copy() structure_str = ''.join(structure) @@ -79,13 +82,17 @@ class PubTabDataSet(Dataset): table_type = "complex" if table_type == "complex": if self.table_select_prob < random.uniform(0, 1): - select_flag = False - + select_flag = False + if select_flag: cells = info['html']['cells'].copy() structure = info['html']['structure'].copy() img_path = os.path.join(self.data_dir, file_name) - data = {'img_path': img_path, 'cells': cells, 'structure':structure} + data = { + 'img_path': img_path, + 'cells': cells, + 'structure': structure + } if not os.path.exists(img_path): raise Exception("{} does not exist!".format(img_path)) with open(data['img_path'], 'rb') as f: