From 82f6ef8aa3bb027384bf3d4e425151b258848d21 Mon Sep 17 00:00:00 2001 From: yukavio <67678385+yukavio@users.noreply.github.com> Date: Fri, 8 Jan 2021 17:05:33 +0800 Subject: [PATCH] fix prune demo batchsize (#591) * fix prune doc * fix prune demo batchsize Co-authored-by: ceci3 --- demo/prune/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/demo/prune/train.py b/demo/prune/train.py index edb30a39..44082c95 100644 --- a/demo/prune/train.py +++ b/demo/prune/train.py @@ -146,12 +146,13 @@ def compress(args): paddle.static.load(paddle.static.default_main_program(), args.pretrained_model, exe) + batch_size_per_card = int(args.batch_size / len(places)) train_loader = paddle.io.DataLoader( train_dataset, places=places, feed_list=[image, label], drop_last=True, - batch_size=args.batch_size, + batch_size=batch_size_per_card, shuffle=True, return_list=False, use_shared_memory=True, @@ -163,7 +164,7 @@ def compress(args): drop_last=False, return_list=False, use_shared_memory=True, - batch_size=args.batch_size, + batch_size=batch_size_per_card, shuffle=False) def test(epoch, program): -- GitLab