未验证 提交 8c80a251 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #4977 from WenmuZhou/fix_vqa

fix win train bug
...@@ -165,6 +165,7 @@ python3.7 train_ser.py \ ...@@ -165,6 +165,7 @@ python3.7 train_ser.py \
--learning_rate 5e-5 \ --learning_rate 5e-5 \
--warmup_steps 50 \ --warmup_steps 50 \
--evaluate_during_training \ --evaluate_during_training \
--num_workers 8 \
--seed 2048 \ --seed 2048 \
--resume --resume
``` ```
...@@ -177,6 +178,7 @@ python3 eval_ser.py \ ...@@ -177,6 +178,7 @@ python3 eval_ser.py \
--eval_data_dir "XFUND/zh_val/image" \ --eval_data_dir "XFUND/zh_val/image" \
--eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \
--per_gpu_eval_batch_size 8 \ --per_gpu_eval_batch_size 8 \
--num_workers 8 \
--output_dir "output/ser/" \ --output_dir "output/ser/" \
--seed 2048 --seed 2048
``` ```
...@@ -234,6 +236,7 @@ python3 train_re.py \ ...@@ -234,6 +236,7 @@ python3 train_re.py \
--warmup_steps 50 \ --warmup_steps 50 \
--per_gpu_train_batch_size 8 \ --per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 8 \ --per_gpu_eval_batch_size 8 \
--num_workers 8 \
--evaluate_during_training \ --evaluate_during_training \
--seed 2048 --seed 2048
...@@ -257,6 +260,7 @@ python3 train_re.py \ ...@@ -257,6 +260,7 @@ python3 train_re.py \
--warmup_steps 50 \ --warmup_steps 50 \
--per_gpu_train_batch_size 8 \ --per_gpu_train_batch_size 8 \
--per_gpu_eval_batch_size 8 \ --per_gpu_eval_batch_size 8 \
--num_workers 8 \
--evaluate_during_training \ --evaluate_during_training \
--seed 2048 \ --seed 2048 \
--resume --resume
...@@ -276,6 +280,7 @@ python3 eval_re.py \ ...@@ -276,6 +280,7 @@ python3 eval_re.py \
--label_map_path 'labels/labels_ser.txt' \ --label_map_path 'labels/labels_ser.txt' \
--output_dir "output/re_test/" \ --output_dir "output/re_test/" \
--per_gpu_eval_batch_size 8 \ --per_gpu_eval_batch_size 8 \
--num_workers 8 \
--seed 2048 --seed 2048
``` ```
最终会打印出`precision`, `recall`, `f1`等指标 最终会打印出`precision`, `recall`, `f1`等指标
......
...@@ -112,7 +112,7 @@ def eval(args): ...@@ -112,7 +112,7 @@ def eval(args):
eval_dataloader = paddle.io.DataLoader( eval_dataloader = paddle.io.DataLoader(
eval_dataset, eval_dataset,
batch_size=args.per_gpu_eval_batch_size, batch_size=args.per_gpu_eval_batch_size,
num_workers=8, num_workers=args.num_workers,
shuffle=False, shuffle=False,
collate_fn=DataCollator()) collate_fn=DataCollator())
......
...@@ -61,7 +61,7 @@ def eval(args): ...@@ -61,7 +61,7 @@ def eval(args):
eval_dataloader = paddle.io.DataLoader( eval_dataloader = paddle.io.DataLoader(
eval_dataset, eval_dataset,
batch_size=args.per_gpu_eval_batch_size, batch_size=args.per_gpu_eval_batch_size,
num_workers=0, num_workers=args.num_workers,
use_shared_memory=True, use_shared_memory=True,
collate_fn=None, ) collate_fn=None, )
......
...@@ -97,14 +97,14 @@ def train(args): ...@@ -97,14 +97,14 @@ def train(args):
train_dataloader = paddle.io.DataLoader( train_dataloader = paddle.io.DataLoader(
train_dataset, train_dataset,
batch_sampler=train_sampler, batch_sampler=train_sampler,
num_workers=8, num_workers=args.num_workers,
use_shared_memory=True, use_shared_memory=True,
collate_fn=DataCollator()) collate_fn=DataCollator())
eval_dataloader = paddle.io.DataLoader( eval_dataloader = paddle.io.DataLoader(
eval_dataset, eval_dataset,
batch_size=args.per_gpu_eval_batch_size, batch_size=args.per_gpu_eval_batch_size,
num_workers=8, num_workers=args.num_workers,
shuffle=False, shuffle=False,
collate_fn=DataCollator()) collate_fn=DataCollator())
......
...@@ -94,14 +94,14 @@ def train(args): ...@@ -94,14 +94,14 @@ def train(args):
train_dataloader = paddle.io.DataLoader( train_dataloader = paddle.io.DataLoader(
train_dataset, train_dataset,
batch_sampler=train_sampler, batch_sampler=train_sampler,
num_workers=0, num_workers=args.num_workers,
use_shared_memory=True, use_shared_memory=True,
collate_fn=None, ) collate_fn=None, )
eval_dataloader = paddle.io.DataLoader( eval_dataloader = paddle.io.DataLoader(
eval_dataset, eval_dataset,
batch_size=args.per_gpu_eval_batch_size, batch_size=args.per_gpu_eval_batch_size,
num_workers=0, num_workers=args.num_workers,
use_shared_memory=True, use_shared_memory=True,
collate_fn=None, ) collate_fn=None, )
......
...@@ -363,6 +363,7 @@ def parse_args(): ...@@ -363,6 +363,7 @@ def parse_args():
parser.add_argument("--output_dir", default=None, type=str, required=True,) parser.add_argument("--output_dir", default=None, type=str, required=True,)
parser.add_argument("--max_seq_length", default=512, type=int,) parser.add_argument("--max_seq_length", default=512, type=int,)
parser.add_argument("--evaluate_during_training", action="store_true",) parser.add_argument("--evaluate_during_training", action="store_true",)
parser.add_argument("--num_workers", default=8, type=int,)
parser.add_argument("--per_gpu_train_batch_size", default=8, parser.add_argument("--per_gpu_train_batch_size", default=8,
type=int, help="Batch size per GPU/CPU for training.",) type=int, help="Batch size per GPU/CPU for training.",)
parser.add_argument("--per_gpu_eval_batch_size", default=8, parser.add_argument("--per_gpu_eval_batch_size", default=8,
......
...@@ -79,14 +79,36 @@ class XFUNDataset(Dataset): ...@@ -79,14 +79,36 @@ class XFUNDataset(Dataset):
self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2} self.entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
self.return_keys = { self.return_keys = {
'bbox': 'np', 'bbox': {
'input_ids': 'np', 'type': 'np',
'labels': 'np', 'dtype': 'int64'
'attention_mask': 'np', },
'image': 'np', 'input_ids': {
'token_type_ids': 'np', 'type': 'np',
'entities': 'dict', 'dtype': 'int64'
'relations': 'dict', },
'labels': {
'type': 'np',
'dtype': 'int64'
},
'attention_mask': {
'type': 'np',
'dtype': 'int64'
},
'image': {
'type': 'np',
'dtype': 'float32'
},
'token_type_ids': {
'type': 'np',
'dtype': 'int64'
},
'entities': {
'type': 'dict'
},
'relations': {
'type': 'dict'
}
} }
if load_mode == "all": if load_mode == "all":
...@@ -412,8 +434,8 @@ class XFUNDataset(Dataset): ...@@ -412,8 +434,8 @@ class XFUNDataset(Dataset):
return_data = {} return_data = {}
for k, v in data.items(): for k, v in data.items():
if k in self.return_keys: if k in self.return_keys:
if self.return_keys[k] == 'np': if self.return_keys[k]['type'] == 'np':
v = np.array(v) v = np.array(v, dtype=self.return_keys[k]['dtype'])
return_data[k] = v return_data[k] = v
return return_data return return_data
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册