diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md index 744cdc74a24c66873720e011ab3ea3af5b8f6abc..2216950e58654a6143fbc9955568ee66d21512fc 100644 --- a/ppstructure/vqa/README.md +++ b/ppstructure/vqa/README.md @@ -165,6 +165,7 @@ python3.7 train_ser.py \ --learning_rate 5e-5 \ --warmup_steps 50 \ --evaluate_during_training \ + --num_workers 8 \ --seed 2048 \ --resume ``` @@ -177,6 +178,7 @@ python3 eval_ser.py \ --eval_data_dir "XFUND/zh_val/image" \ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ --per_gpu_eval_batch_size 8 \ + --num_workers 8 \ --output_dir "output/ser/" \ --seed 2048 ``` @@ -234,6 +236,7 @@ python3 train_re.py \ --warmup_steps 50 \ --per_gpu_train_batch_size 8 \ --per_gpu_eval_batch_size 8 \ + --num_workers 8 \ --evaluate_during_training \ --seed 2048 @@ -257,6 +260,7 @@ python3 train_re.py \ --warmup_steps 50 \ --per_gpu_train_batch_size 8 \ --per_gpu_eval_batch_size 8 \ + --num_workers 8 \ --evaluate_during_training \ --seed 2048 \ --resume @@ -276,6 +280,7 @@ python3 eval_re.py \ --label_map_path 'labels/labels_ser.txt' \ --output_dir "output/re_test/" \ --per_gpu_eval_batch_size 8 \ + --num_workers 8 \ --seed 2048 ``` 最终会打印出`precision`, `recall`, `f1`等指标 diff --git a/ppstructure/vqa/eval_re.py b/ppstructure/vqa/eval_re.py index 45c23660474f7574e01a3551fd7cb1ee7929aec1..12bb9cabdb8b4d6482a121ca3b73089b3d0244ff 100644 --- a/ppstructure/vqa/eval_re.py +++ b/ppstructure/vqa/eval_re.py @@ -112,7 +112,7 @@ def eval(args): eval_dataloader = paddle.io.DataLoader( eval_dataset, batch_size=args.per_gpu_eval_batch_size, - num_workers=8, + num_workers=args.num_workers, shuffle=False, collate_fn=DataCollator()) diff --git a/ppstructure/vqa/eval_ser.py b/ppstructure/vqa/eval_ser.py index c9de25fb565cb2efe8b1ba9da28ff96db25b95f8..acf37452a44fe00a4421044f7ed80de69dfacbb8 100644 --- a/ppstructure/vqa/eval_ser.py +++ b/ppstructure/vqa/eval_ser.py @@ -61,7 +61,7 @@ def eval(args): eval_dataloader = paddle.io.DataLoader( eval_dataset, batch_size=args.per_gpu_eval_batch_size, - num_workers=8, + num_workers=args.num_workers, use_shared_memory=True, collate_fn=None, ) diff --git a/ppstructure/vqa/train_re.py b/ppstructure/vqa/train_re.py index c7e701c8d2e19599b10357b4e8b4b2f10e454deb..47d694678013295a1c664a7bdb6a7fe13a0b36a5 100644 --- a/ppstructure/vqa/train_re.py +++ b/ppstructure/vqa/train_re.py @@ -97,14 +97,14 @@ def train(args): train_dataloader = paddle.io.DataLoader( train_dataset, batch_sampler=train_sampler, - num_workers=8, + num_workers=args.num_workers, use_shared_memory=True, collate_fn=DataCollator()) eval_dataloader = paddle.io.DataLoader( eval_dataset, batch_size=args.per_gpu_eval_batch_size, - num_workers=8, + num_workers=args.num_workers, shuffle=False, collate_fn=DataCollator()) diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py index 58eb09918c33ce53e4248d64d486ef9d29c0d504..6791cea89adfbcccbea8861c2ea2d3c5e94fb713 100644 --- a/ppstructure/vqa/train_ser.py +++ b/ppstructure/vqa/train_ser.py @@ -94,14 +94,14 @@ def train(args): train_dataloader = paddle.io.DataLoader( train_dataset, batch_sampler=train_sampler, - num_workers=8, + num_workers=args.num_workers, use_shared_memory=True, collate_fn=None, ) eval_dataloader = paddle.io.DataLoader( eval_dataset, batch_size=args.per_gpu_eval_batch_size, - num_workers=8, + num_workers=args.num_workers, use_shared_memory=True, collate_fn=None, ) diff --git a/ppstructure/vqa/utils.py b/ppstructure/vqa/utils.py index 44a6298080a456b514a719edf82ff1f1b60fd077..9f6eebb662c8d41542f54795896f4a4f8b3f1b8e 100644 --- a/ppstructure/vqa/utils.py +++ b/ppstructure/vqa/utils.py @@ -363,6 +363,7 @@ def parse_args(): parser.add_argument("--output_dir", default=None, type=str, required=True,) parser.add_argument("--max_seq_length", default=512, type=int,) parser.add_argument("--evaluate_during_training", action="store_true",) + parser.add_argument("--num_workers", efault=8, type=int,) parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",) parser.add_argument("--per_gpu_eval_batch_size", default=8,