未验证 提交 197abda0 编写于 作者: Y yukavio 提交者: GitHub

fix prune demo batchsize (#590)

...@@ -146,12 +146,13 @@ def compress(args): ...@@ -146,12 +146,13 @@ def compress(args):
paddle.static.load(paddle.static.default_main_program(), paddle.static.load(paddle.static.default_main_program(),
args.pretrained_model, exe) args.pretrained_model, exe)
batch_size_per_card = int(args.batch_size / len(places))
train_loader = paddle.io.DataLoader( train_loader = paddle.io.DataLoader(
train_dataset, train_dataset,
places=places, places=places,
feed_list=[image, label], feed_list=[image, label],
drop_last=True, drop_last=True,
batch_size=args.batch_size, batch_size=batch_size_per_card,
shuffle=True, shuffle=True,
return_list=False, return_list=False,
use_shared_memory=True, use_shared_memory=True,
...@@ -163,7 +164,7 @@ def compress(args): ...@@ -163,7 +164,7 @@ def compress(args):
drop_last=False, drop_last=False,
return_list=False, return_list=False,
use_shared_memory=True, use_shared_memory=True,
batch_size=args.batch_size, batch_size=batch_size_per_card,
shuffle=False) shuffle=False)
def test(epoch, program): def test(epoch, program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部