From c640267b00a16f74e3a355abdc7408e4c6135e1b Mon Sep 17 00:00:00 2001
From: Bai Yifan <me@ethanbai.com>
Date: Fri, 25 Dec 2020 14:48:15 +0800
Subject: [PATCH] fix quant/distill demo dataloader (#565) (#568)

---
 demo/distillation/distill.py         | 6 +++---
 demo/quant/pact_quant_aware/train.py | 6 +++---
 demo/quant/quant_aware/train.py      | 6 +++---
 3 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/demo/distillation/distill.py b/demo/distillation/distill.py
index a180c3c9..d7417470 100644
--- a/demo/distillation/distill.py
+++ b/demo/distillation/distill.py
@@ -123,15 +123,15 @@ def compress(args):
                 batch_size=args.batch_size,
                 return_list=False,
                 shuffle=True,
-                use_shared_memory=False,
-                num_workers=1)
+                use_shared_memory=True,
+                num_workers=4)
             valid_loader = paddle.io.DataLoader(
                 val_dataset,
                 places=place,
                 feed_list=[image, label],
                 drop_last=False,
                 return_list=False,
-                use_shared_memory=False,
+                use_shared_memory=True,
                 batch_size=args.batch_size,
                 shuffle=False)
             # model definition
diff --git a/demo/quant/pact_quant_aware/train.py b/demo/quant/pact_quant_aware/train.py
index 9be4d6f8..d6210a2e 100644
--- a/demo/quant/pact_quant_aware/train.py
+++ b/demo/quant/pact_quant_aware/train.py
@@ -159,9 +159,9 @@ def compress(args):
         drop_last=True,
         return_list=False,
         batch_size=args.batch_size,
-        use_shared_memory=False,
+        use_shared_memory=True,
         shuffle=True,
-        num_workers=1)
+        num_workers=4)
 
     valid_loader = paddle.io.DataLoader(
         val_dataset,
@@ -170,7 +170,7 @@ def compress(args):
         drop_last=False,
         return_list=False,
         batch_size=args.batch_size,
-        use_shared_memory=False,
+        use_shared_memory=True,
         shuffle=False)
 
     if args.analysis:
diff --git a/demo/quant/quant_aware/train.py b/demo/quant/quant_aware/train.py
index a1c0ea20..af43aeff 100644
--- a/demo/quant/quant_aware/train.py
+++ b/demo/quant/quant_aware/train.py
@@ -169,9 +169,9 @@ def compress(args):
         drop_last=True,
         batch_size=args.batch_size,
         return_list=False,
-        use_shared_memory=False,
+        use_shared_memory=True,
         shuffle=True,
-        num_workers=1)
+        num_workers=4)
     valid_loader = paddle.io.DataLoader(
         val_dataset,
         places=place,
@@ -179,7 +179,7 @@ def compress(args):
         drop_last=False,
         return_list=False,
         batch_size=args.batch_size,
-        use_shared_memory=False,
+        use_shared_memory=True,
         shuffle=False)
 
     def test(epoch, program):
-- 
GitLab