diff --git a/demo/prune/train.py b/demo/prune/train.py index edb30a3937cbbd18c2f237c9c01ce86979823006..44082c9563f146a523a08184a51e8008cb08d8d8 100644 --- a/demo/prune/train.py +++ b/demo/prune/train.py @@ -146,12 +146,13 @@ def compress(args): paddle.static.load(paddle.static.default_main_program(), args.pretrained_model, exe) + batch_size_per_card = int(args.batch_size / len(places)) train_loader = paddle.io.DataLoader( train_dataset, places=places, feed_list=[image, label], drop_last=True, - batch_size=args.batch_size, + batch_size=batch_size_per_card, shuffle=True, return_list=False, use_shared_memory=True, @@ -163,7 +164,7 @@ def compress(args): drop_last=False, return_list=False, use_shared_memory=True, - batch_size=args.batch_size, + batch_size=batch_size_per_card, shuffle=False) def test(epoch, program):