未验证 提交 3718a146 编写于 作者: B Bai Yifan 提交者: GitHub

fix quant_post ut (#528)

上级 6fa05ff3
...@@ -55,13 +55,15 @@ class TestQuantAwareCase1(StaticCase): ...@@ -55,13 +55,15 @@ class TestQuantAwareCase1(StaticCase):
mode='train', backend='cv2', transform=transform) mode='train', backend='cv2', transform=transform)
test_dataset = paddle.vision.datasets.MNIST( test_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform) mode='test', backend='cv2', transform=transform)
train_loader = paddle.io.DataLoader( train_loader = paddle.io.DataLoader(
train_dataset, train_dataset,
places=place, places=place,
feed_list=[image, label], feed_list=[image, label],
drop_last=True, drop_last=True,
return_list=False, batch_size=64,
batch_size=64) return_list=False)
valid_loader = paddle.io.DataLoader( valid_loader = paddle.io.DataLoader(
test_dataset, test_dataset,
places=place, places=place,
...@@ -69,6 +71,14 @@ class TestQuantAwareCase1(StaticCase): ...@@ -69,6 +71,14 @@ class TestQuantAwareCase1(StaticCase):
batch_size=64, batch_size=64,
return_list=False) return_list=False)
def sample_generator_creator():
def __reader__():
for data in test_dataset:
image, label = data
yield image, label
return __reader__
def train(program): def train(program):
iter = 0 iter = 0
for data in train_loader(): for data in train_loader():
...@@ -115,7 +125,7 @@ class TestQuantAwareCase1(StaticCase): ...@@ -115,7 +125,7 @@ class TestQuantAwareCase1(StaticCase):
exe, exe,
'./test_quant_post', './test_quant_post',
'./test_quant_post_inference', './test_quant_post_inference',
sample_generator=paddle.dataset.mnist.test(), sample_generator=sample_generator_creator(),
model_filename='model', model_filename='model',
params_filename='params', params_filename='params',
batch_nums=10) batch_nums=10)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册