提交 8e05d54c 编写于 作者: 文幕地方's avatar 文幕地方

fix win train bug

上级 dc51469b
......@@ -61,7 +61,7 @@ def eval(args):
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=0,
num_workers=8,
use_shared_memory=True,
collate_fn=None, )
......
......@@ -94,14 +94,14 @@ def train(args):
train_dataloader = paddle.io.DataLoader(
train_dataset,
batch_sampler=train_sampler,
num_workers=0,
num_workers=8,
use_shared_memory=True,
collate_fn=None, )
eval_dataloader = paddle.io.DataLoader(
eval_dataset,
batch_size=args.per_gpu_eval_batch_size,
num_workers=0,
num_workers=8,
use_shared_memory=True,
collate_fn=None, )
......
......@@ -79,14 +79,36 @@ class XFUNDataset(Dataset):
self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
self.return_keys = {
'bbox': 'np',
'input_ids': 'np',
'labels': 'np',
'attention_mask': 'np',
'image': 'np',
'token_type_ids': 'np',
'entities': 'dict',
'relations': 'dict',
'bbox': {
'type': 'np',
'dtype': 'int64'
},
'input_ids': {
'type': 'np',
'dtype': 'int64'
},
'labels': {
'type': 'np',
'dtype': 'int64'
},
'attention_mask': {
'type': 'np',
'dtype': 'int64'
},
'image': {
'type': 'np',
'dtype': 'float32'
},
'token_type_ids': {
'type': 'np',
'dtype': 'int64'
},
'entities': {
'type': 'dict'
},
'relations': {
'type': 'dict'
}
}
if load_mode == "all":
......@@ -103,7 +125,7 @@ class XFUNDataset(Dataset):
return_special_tokens_mask=False):
# Padding
needs_to_be_padded = pad_to_max_seq_len and \
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
max_seq_len and len(encoded_inputs["input_ids"]) < max_seq_len
if needs_to_be_padded:
difference = max_seq_len - len(encoded_inputs["input_ids"])
......@@ -412,8 +434,8 @@ class XFUNDataset(Dataset):
return_data = {}
for k, v in data.items():
if k in self.return_keys:
if self.return_keys[k] == 'np':
v = np.array(v)
if self.return_keys[k]['type'] == 'np':
v = np.array(v, dtype=self.return_keys[k]['dtype'])
return_data[k] = v
return return_data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册