提交 23eb335d 编写于 作者: 文幕地方's avatar 文幕地方

add need_reset to dataset

上级 f671f133
...@@ -38,6 +38,9 @@ class LMDBDataSet(Dataset): ...@@ -38,6 +38,9 @@ class LMDBDataSet(Dataset):
np.random.shuffle(self.data_idx_order_list) np.random.shuffle(self.data_idx_order_list)
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def load_hierarchical_lmdb_dataset(self, data_dir): def load_hierarchical_lmdb_dataset(self, data_dir):
lmdb_sets = {} lmdb_sets = {}
dataset_idx = 0 dataset_idx = 0
......
...@@ -49,6 +49,8 @@ class PGDataSet(Dataset): ...@@ -49,6 +49,8 @@ class PGDataSet(Dataset):
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed) random.seed(self.seed)
......
...@@ -53,6 +53,9 @@ class PubTabDataSet(Dataset): ...@@ -53,6 +53,9 @@ class PubTabDataSet(Dataset):
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed) random.seed(self.seed)
...@@ -70,7 +73,7 @@ class PubTabDataSet(Dataset): ...@@ -70,7 +73,7 @@ class PubTabDataSet(Dataset):
prob = self.img_select_prob[file_name] prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1): if prob < random.uniform(0, 1):
select_flag = False select_flag = False
if self.table_select_type: if self.table_select_type:
structure = info['html']['structure']['tokens'].copy() structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure) structure_str = ''.join(structure)
...@@ -79,13 +82,17 @@ class PubTabDataSet(Dataset): ...@@ -79,13 +82,17 @@ class PubTabDataSet(Dataset):
table_type = "complex" table_type = "complex"
if table_type == "complex": if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1): if self.table_select_prob < random.uniform(0, 1):
select_flag = False select_flag = False
if select_flag: if select_flag:
cells = info['html']['cells'].copy() cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy() structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'cells': cells, 'structure':structure} data = {
'img_path': img_path,
'cells': cells,
'structure': structure
}
if not os.path.exists(img_path): if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path)) raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f: with open(data['img_path'], 'rb') as f:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册