From 4aea89faa2875015be421eee0247d9516677972b Mon Sep 17 00:00:00 2001 From: lujun Date: Mon, 15 Apr 2019 12:56:38 +0800 Subject: [PATCH] fix vgg-test. test=develop --- .../high-level-api/cifar10_small_test_set.py | 16 ++++++++++++++++ .../test_image_classification_vgg_new_api.py | 6 ++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/book/high-level-api/cifar10_small_test_set.py b/python/paddle/fluid/tests/book/high-level-api/cifar10_small_test_set.py index 48c0f3d36..6f24ec45a 100644 --- a/python/paddle/fluid/tests/book/high-level-api/cifar10_small_test_set.py +++ b/python/paddle/fluid/tests/book/high-level-api/cifar10_small_test_set.py @@ -88,3 +88,19 @@ def train10(batch_size=None): paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'data_batch', batch_size=batch_size) + + +def test10(batch_size=None): + """ + CIFAR-10 test set creator. + + It returns a reader creator, each sample in the reader is image pixels in + [0, 1] and label in [0, 9]. + + :return: Test reader creator. + :rtype: callable + """ + return reader_creator( + paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), + 'test_batch', + batch_size=batch_size) diff --git a/python/paddle/fluid/tests/book/high-level-api/test_image_classification_vgg_new_api.py b/python/paddle/fluid/tests/book/high-level-api/test_image_classification_vgg_new_api.py index 82294d4b2..0a27aa0fc 100644 --- a/python/paddle/fluid/tests/book/high-level-api/test_image_classification_vgg_new_api.py +++ b/python/paddle/fluid/tests/book/high-level-api/test_image_classification_vgg_new_api.py @@ -89,9 +89,11 @@ def train(use_cuda, train_program, parallel, params_dirname): cifar10_small_test_set.train10(batch_size=10), buf_size=128 * 10), batch_size=BATCH_SIZE, drop_last=False) - + # Use only part of the test set data validation program test_reader = paddle.batch( - paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE, drop_last=False) + cifar10_small_test_set.test10(BATCH_SIZE), + batch_size=BATCH_SIZE, + drop_last=False) def event_handler(event): if isinstance(event, EndStepEvent): -- GitLab