提交 4aea89fa 编写于 作者: L lujun

fix vgg-test. test=develop

上级 919b58d1
......@@ -88,3 +88,19 @@ def train10(batch_size=None):
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch',
batch_size=batch_size)
def test10(batch_size=None):
"""
CIFAR-10 test set creator.
It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9].
:return: Test reader creator.
:rtype: callable
"""
return reader_creator(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch',
batch_size=batch_size)
......@@ -89,9 +89,11 @@ def train(use_cuda, train_program, parallel, params_dirname):
cifar10_small_test_set.train10(batch_size=10), buf_size=128 * 10),
batch_size=BATCH_SIZE,
drop_last=False)
# Use only part of the test set data validation program
test_reader = paddle.batch(
paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE, drop_last=False)
cifar10_small_test_set.test10(BATCH_SIZE),
batch_size=BATCH_SIZE,
drop_last=False)
def event_handler(event):
if isinstance(event, EndStepEvent):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册