From 16fec90f670669612d8b493af0a26622b3c687bd Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Tue, 22 Dec 2020 18:50:57 +0800 Subject: [PATCH] fix quant/distill demo dataloader (#565) --- demo/distillation/distill.py | 6 +++--- demo/quant/pact_quant_aware/train.py | 6 +++--- demo/quant/quant_aware/train.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/demo/distillation/distill.py b/demo/distillation/distill.py index a180c3c9..d7417470 100644 --- a/demo/distillation/distill.py +++ b/demo/distillation/distill.py @@ -123,15 +123,15 @@ def compress(args): batch_size=args.batch_size, return_list=False, shuffle=True, - use_shared_memory=False, - num_workers=1) + use_shared_memory=True, + num_workers=4) valid_loader = paddle.io.DataLoader( val_dataset, places=place, feed_list=[image, label], drop_last=False, return_list=False, - use_shared_memory=False, + use_shared_memory=True, batch_size=args.batch_size, shuffle=False) # model definition diff --git a/demo/quant/pact_quant_aware/train.py b/demo/quant/pact_quant_aware/train.py index 9be4d6f8..d6210a2e 100644 --- a/demo/quant/pact_quant_aware/train.py +++ b/demo/quant/pact_quant_aware/train.py @@ -159,9 +159,9 @@ def compress(args): drop_last=True, return_list=False, batch_size=args.batch_size, - use_shared_memory=False, + use_shared_memory=True, shuffle=True, - num_workers=1) + num_workers=4) valid_loader = paddle.io.DataLoader( val_dataset, @@ -170,7 +170,7 @@ def compress(args): drop_last=False, return_list=False, batch_size=args.batch_size, - use_shared_memory=False, + use_shared_memory=True, shuffle=False) if args.analysis: diff --git a/demo/quant/quant_aware/train.py b/demo/quant/quant_aware/train.py index a1c0ea20..af43aeff 100644 --- a/demo/quant/quant_aware/train.py +++ b/demo/quant/quant_aware/train.py @@ -169,9 +169,9 @@ def compress(args): drop_last=True, batch_size=args.batch_size, return_list=False, - use_shared_memory=False, + use_shared_memory=True, shuffle=True, - num_workers=1) + num_workers=4) valid_loader = paddle.io.DataLoader( val_dataset, places=place, @@ -179,7 +179,7 @@ def compress(args): drop_last=False, return_list=False, batch_size=args.batch_size, - use_shared_memory=False, + use_shared_memory=True, shuffle=False) def test(epoch, program): -- GitLab