From 197abda0c6ab55b847bad521b99b5e59086ffe48 Mon Sep 17 00:00:00 2001
From: yukavio <67678385+yukavio@users.noreply.github.com>
Date: Wed, 6 Jan 2021 20:49:59 +0800
Subject: [PATCH] fix prune demo batchsize (#590)

---
 demo/prune/train.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/demo/prune/train.py b/demo/prune/train.py
index edb30a39..44082c95 100644
--- a/demo/prune/train.py
+++ b/demo/prune/train.py
@@ -146,12 +146,13 @@ def compress(args):
         paddle.static.load(paddle.static.default_main_program(),
                            args.pretrained_model, exe)
 
+    batch_size_per_card = int(args.batch_size / len(places))
     train_loader = paddle.io.DataLoader(
         train_dataset,
         places=places,
         feed_list=[image, label],
         drop_last=True,
-        batch_size=args.batch_size,
+        batch_size=batch_size_per_card,
         shuffle=True,
         return_list=False,
         use_shared_memory=True,
@@ -163,7 +164,7 @@ def compress(args):
         drop_last=False,
         return_list=False,
         use_shared_memory=True,
-        batch_size=args.batch_size,
+        batch_size=batch_size_per_card,
         shuffle=False)
 
     def test(epoch, program):
-- 
GitLab