diff --git a/demo/prune/train.py b/demo/prune/train.py
index edb30a3937cbbd18c2f237c9c01ce86979823006..44082c9563f146a523a08184a51e8008cb08d8d8 100644
--- a/demo/prune/train.py
+++ b/demo/prune/train.py
@@ -146,12 +146,13 @@ def compress(args):
         paddle.static.load(paddle.static.default_main_program(),
                            args.pretrained_model, exe)
 
+    batch_size_per_card = int(args.batch_size / len(places))
     train_loader = paddle.io.DataLoader(
         train_dataset,
         places=places,
         feed_list=[image, label],
         drop_last=True,
-        batch_size=args.batch_size,
+        batch_size=batch_size_per_card,
         shuffle=True,
         return_list=False,
         use_shared_memory=True,
@@ -163,7 +164,7 @@ def compress(args):
         drop_last=False,
         return_list=False,
         use_shared_memory=True,
-        batch_size=args.batch_size,
+        batch_size=batch_size_per_card,
         shuffle=False)
 
     def test(epoch, program):