From 197abda0c6ab55b847bad521b99b5e59086ffe48 Mon Sep 17 00:00:00 2001 From: yukavio <67678385+yukavio@users.noreply.github.com> Date: Wed, 6 Jan 2021 20:49:59 +0800 Subject: [PATCH] fix prune demo batchsize (#590) --- demo/prune/train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/demo/prune/train.py b/demo/prune/train.py index edb30a39..44082c95 100644 --- a/demo/prune/train.py +++ b/demo/prune/train.py @@ -146,12 +146,13 @@ def compress(args): paddle.static.load(paddle.static.default_main_program(), args.pretrained_model, exe) + batch_size_per_card = int(args.batch_size / len(places)) train_loader = paddle.io.DataLoader( train_dataset, places=places, feed_list=[image, label], drop_last=True, - batch_size=args.batch_size, + batch_size=batch_size_per_card, shuffle=True, return_list=False, use_shared_memory=True, @@ -163,7 +164,7 @@ def compress(args): drop_last=False, return_list=False, use_shared_memory=True, - batch_size=args.batch_size, + batch_size=batch_size_per_card, shuffle=False) def test(epoch, program): -- GitLab