From 2059a5a281c743425951a8c24435896fde01fe13 Mon Sep 17 00:00:00 2001 From: WenmuZhou <572459439@qq.com> Date: Mon, 20 Dec 2021 09:14:27 +0000 Subject: [PATCH] add num_workers to args --- ppstructure/vqa/README.md | 5 +++++ ppstructure/vqa/eval_re.py | 2 +- ppstructure/vqa/eval_ser.py | 2 +- ppstructure/vqa/train_re.py | 4 ++-- ppstructure/vqa/train_ser.py | 4 ++-- ppstructure/vqa/utils.py | 1 + 6 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ppstructure/vqa/README.md b/ppstructure/vqa/README.md index 744cdc74..2216950e 100644 --- a/ppstructure/vqa/README.md +++ b/ppstructure/vqa/README.md @@ -165,6 +165,7 @@ python3.7 train_ser.py \ --learning_rate 5e-5 \ --warmup_steps 50 \ --evaluate_during_training \ + --num_workers 8 \ --seed 2048 \ --resume ``` @@ -177,6 +178,7 @@ python3 eval_ser.py \ --eval_data_dir "XFUND/zh_val/image" \ --eval_label_path "XFUND/zh_val/xfun_normalize_val.json" \ --per_gpu_eval_batch_size 8 \ + --num_workers 8 \ --output_dir "output/ser/" \ --seed 2048 ``` @@ -234,6 +236,7 @@ python3 train_re.py \ --warmup_steps 50 \ --per_gpu_train_batch_size 8 \ --per_gpu_eval_batch_size 8 \ + --num_workers 8 \ --evaluate_during_training \ --seed 2048 @@ -257,6 +260,7 @@ python3 train_re.py \ --warmup_steps 50 \ --per_gpu_train_batch_size 8 \ --per_gpu_eval_batch_size 8 \ + --num_workers 8 \ --evaluate_during_training \ --seed 2048 \ --resume @@ -276,6 +280,7 @@ python3 eval_re.py \ --label_map_path 'labels/labels_ser.txt' \ --output_dir "output/re_test/" \ --per_gpu_eval_batch_size 8 \ + --num_workers 8 \ --seed 2048 ``` 最终会打印出`precision`, `recall`, `f1`等指标 diff --git a/ppstructure/vqa/eval_re.py b/ppstructure/vqa/eval_re.py index 45c23660..12bb9cab 100644 --- a/ppstructure/vqa/eval_re.py +++ b/ppstructure/vqa/eval_re.py @@ -112,7 +112,7 @@ def eval(args): eval_dataloader = paddle.io.DataLoader( eval_dataset, batch_size=args.per_gpu_eval_batch_size, - num_workers=8, + num_workers=args.num_workers, shuffle=False, collate_fn=DataCollator()) diff --git a/ppstructure/vqa/eval_ser.py b/ppstructure/vqa/eval_ser.py index c9de25fb..acf37452 100644 --- a/ppstructure/vqa/eval_ser.py +++ b/ppstructure/vqa/eval_ser.py @@ -61,7 +61,7 @@ def eval(args): eval_dataloader = paddle.io.DataLoader( eval_dataset, batch_size=args.per_gpu_eval_batch_size, - num_workers=8, + num_workers=args.num_workers, use_shared_memory=True, collate_fn=None, ) diff --git a/ppstructure/vqa/train_re.py b/ppstructure/vqa/train_re.py index c7e701c8..47d69467 100644 --- a/ppstructure/vqa/train_re.py +++ b/ppstructure/vqa/train_re.py @@ -97,14 +97,14 @@ def train(args): train_dataloader = paddle.io.DataLoader( train_dataset, batch_sampler=train_sampler, - num_workers=8, + num_workers=args.num_workers, use_shared_memory=True, collate_fn=DataCollator()) eval_dataloader = paddle.io.DataLoader( eval_dataset, batch_size=args.per_gpu_eval_batch_size, - num_workers=8, + num_workers=args.num_workers, shuffle=False, collate_fn=DataCollator()) diff --git a/ppstructure/vqa/train_ser.py b/ppstructure/vqa/train_ser.py index 58eb0991..6791cea8 100644 --- a/ppstructure/vqa/train_ser.py +++ b/ppstructure/vqa/train_ser.py @@ -94,14 +94,14 @@ def train(args): train_dataloader = paddle.io.DataLoader( train_dataset, batch_sampler=train_sampler, - num_workers=8, + num_workers=args.num_workers, use_shared_memory=True, collate_fn=None, ) eval_dataloader = paddle.io.DataLoader( eval_dataset, batch_size=args.per_gpu_eval_batch_size, - num_workers=8, + num_workers=args.num_workers, use_shared_memory=True, collate_fn=None, ) diff --git a/ppstructure/vqa/utils.py b/ppstructure/vqa/utils.py index 44a62980..9f6eebb6 100644 --- a/ppstructure/vqa/utils.py +++ b/ppstructure/vqa/utils.py @@ -363,6 +363,7 @@ def parse_args(): parser.add_argument("--output_dir", default=None, type=str, required=True,) parser.add_argument("--max_seq_length", default=512, type=int,) parser.add_argument("--evaluate_during_training", action="store_true",) + parser.add_argument("--num_workers", efault=8, type=int,) parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",) parser.add_argument("--per_gpu_eval_batch_size", default=8, -- GitLab