未验证 提交 16fec90f 编写于 作者: B Bai Yifan 提交者: GitHub

fix quant/distill demo dataloader (#565)

上级 51578a2b
...@@ -123,15 +123,15 @@ def compress(args): ...@@ -123,15 +123,15 @@ def compress(args):
batch_size=args.batch_size, batch_size=args.batch_size,
return_list=False, return_list=False,
shuffle=True, shuffle=True,
use_shared_memory=False, use_shared_memory=True,
num_workers=1) num_workers=4)
valid_loader = paddle.io.DataLoader( valid_loader = paddle.io.DataLoader(
val_dataset, val_dataset,
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
drop_last=False, drop_last=False,
return_list=False, return_list=False,
use_shared_memory=False, use_shared_memory=True,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=False) shuffle=False)
# model definition # model definition
......
...@@ -159,9 +159,9 @@ def compress(args): ...@@ -159,9 +159,9 @@ def compress(args):
drop_last=True, drop_last=True,
return_list=False, return_list=False,
batch_size=args.batch_size, batch_size=args.batch_size,
use_shared_memory=False, use_shared_memory=True,
shuffle=True, shuffle=True,
num_workers=1) num_workers=4)
valid_loader = paddle.io.DataLoader( valid_loader = paddle.io.DataLoader(
val_dataset, val_dataset,
...@@ -170,7 +170,7 @@ def compress(args): ...@@ -170,7 +170,7 @@ def compress(args):
drop_last=False, drop_last=False,
return_list=False, return_list=False,
batch_size=args.batch_size, batch_size=args.batch_size,
use_shared_memory=False, use_shared_memory=True,
shuffle=False) shuffle=False)
if args.analysis: if args.analysis:
......
...@@ -169,9 +169,9 @@ def compress(args): ...@@ -169,9 +169,9 @@ def compress(args):
drop_last=True, drop_last=True,
batch_size=args.batch_size, batch_size=args.batch_size,
return_list=False, return_list=False,
use_shared_memory=False, use_shared_memory=True,
shuffle=True, shuffle=True,
num_workers=1) num_workers=4)
valid_loader = paddle.io.DataLoader( valid_loader = paddle.io.DataLoader(
val_dataset, val_dataset,
places=place, places=place,
...@@ -179,7 +179,7 @@ def compress(args): ...@@ -179,7 +179,7 @@ def compress(args):
drop_last=False, drop_last=False,
return_list=False, return_list=False,
batch_size=args.batch_size, batch_size=args.batch_size,
use_shared_memory=False, use_shared_memory=True,
shuffle=False) shuffle=False)
def test(epoch, program): def test(epoch, program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册