From 8e05d54c7f351fd638e3cfbb7319471410e4c53f Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Mon, 20 Dec 2021 09:02:40 +0000 Subject: [PATCH] fix win train bug --- ppstructure/vqa/eval_ser.py | 2 +- ppstructure/vqa/train_ser.py | 4 ++-- ppstructure/vqa/xfun.py | 44 +++++++++++++++++++++++++++--------- 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/ppstructure/vqa/eval_ser.py b/ppstructure/vqa/eval_ser.py index e0612219..c9de25fb 100644 --- a/ppstructure/vqa/eval_ser.py +++ b/ppstructure/vqa/eval_ser.py @@ -61,7 +61,7 @@ def eval(args): eval_dataloader = paddle.io.DataLoader( eval_dataset, batch_size=args.per_gpu_eval_batch_size, - num_workers=0, + num_workers=8, use_shared_memory=True, collate_fn=None, ) diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py index d6c297c4..58eb0991 100644 --- a/ppstructure/vqa/train_ser.py +++ b/ppstructure/vqa/train_ser.py @@ -94,14 +94,14 @@ def train(args): train_dataloader = paddle.io.DataLoader( train_dataset, batch_sampler=train_sampler, - num_workers=0, + num_workers=8, use_shared_memory=True, collate_fn=None, ) eval_dataloader = paddle.io.DataLoader( eval_dataset, batch_size=args.per_gpu_eval_batch_size, - num_workers=0, + num_workers=8, use_shared_memory=True, collate_fn=None, ) diff --git a/ppstructure/vqa/xfun.py b/ppstructure/vqa/xfun.py index eb9750dd..f5dbe507 100644 --- a/ppstructure/vqa/xfun.py +++ b/ppstructure/vqa/xfun.py @@ -79,14 +79,36 @@ class XFUNDataset(Dataset): self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} self.return_keys = { - 'bbox': 'np', - 'input_ids': 'np', - 'labels': 'np', - 'attention_mask': 'np', - 'image': 'np', - 'token_type_ids': 'np', - 'entities': 'dict', - 'relations': 'dict', + 'bbox': { + 'type': 'np', + 'dtype': 'int64' + }, + 'input_ids': { + 'type': 'np', + 'dtype': 'int64' + }, + 'labels': { + 'type': 'np', + 'dtype': 'int64' + }, + 'attention_mask': { + 'type': 'np', + 'dtype': 'int64' + }, + 'image': { + 'type': 'np', + 'dtype': 'float32' + }, + 'token_type_ids': { + 'type': 'np', + 'dtype': 'int64' + }, + 'entities': { + 'type': 'dict' + }, + 'relations': { + 'type': 'dict' + } } if load_mode == "all": @@ -103,7 +125,7 @@ class XFUNDataset(Dataset): return_special_tokens_mask=False): # Padding needs_to_be_padded = pad_to_max_seq_len and \ - max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len + max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len if needs_to_be_padded: difference = max_seq_len - len(encoded_inputs["input_ids"]) @@ -412,8 +434,8 @@ class XFUNDataset(Dataset): return_data = {} for k, v in data.items(): if k in self.return_keys: - if self.return_keys[k] == 'np': - v = np.array(v) + if self.return_keys[k]['type'] == 'np': + v = np.array(v, dtype=self.return_keys[k]['dtype']) return_data[k] = v return return_data -- GitLab