未验证 提交 3718a146 编写于 作者: B Bai Yifan 提交者: GitHub

fix quant_post ut (#528)

上级 6fa05ff3
......@@ -55,13 +55,15 @@ class TestQuantAwareCase1(StaticCase):
mode='train', backend='cv2', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
train_loader = paddle.io.DataLoader(
train_dataset,
places=place,
feed_list=[image, label],
drop_last=True,
return_list=False,
batch_size=64)
batch_size=64,
return_list=False)
valid_loader = paddle.io.DataLoader(
test_dataset,
places=place,
......@@ -69,6 +71,14 @@ class TestQuantAwareCase1(StaticCase):
batch_size=64,
return_list=False)
def sample_generator_creator():
def __reader__():
for data in test_dataset:
image, label = data
yield image, label
return __reader__
def train(program):
iter = 0
for data in train_loader():
......@@ -115,7 +125,7 @@ class TestQuantAwareCase1(StaticCase):
exe,
'./test_quant_post',
'./test_quant_post_inference',
sample_generator=paddle.dataset.mnist.test(),
sample_generator=sample_generator_creator(),
model_filename='model',
params_filename='params',
batch_nums=10)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册