diff --git a/python/paddle/fluid/tests/book/high-level-api/cifar10_small_test_set.py b/python/paddle/fluid/tests/book/high-level-api/cifar10_small_test_set.py index 48c0f3d3611547308b5d4460748d3aab765f5805..6f24ec45aa6f27814e489b8dce49fe69f62d4f10 100644 --- a/python/paddle/fluid/tests/book/high-level-api/cifar10_small_test_set.py +++ b/python/paddle/fluid/tests/book/high-level-api/cifar10_small_test_set.py @@ -88,3 +88,19 @@ def train10(batch_size=None): paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), 'data_batch', batch_size=batch_size) + + +def test10(batch_size=None): + """ + CIFAR-10 test set creator. + + It returns a reader creator, each sample in the reader is image pixels in + [0, 1] and label in [0, 9]. + + :return: Test reader creator. + :rtype: callable + """ + return reader_creator( + paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), + 'test_batch', + batch_size=batch_size) diff --git a/python/paddle/fluid/tests/book/high-level-api/test_image_classification_vgg_new_api.py b/python/paddle/fluid/tests/book/high-level-api/test_image_classification_vgg_new_api.py index 82294d4b26fe64e6cddc81f9ba3480caf5b51620..0a27aa0fcfece36f1a8ae5ad0477d75a15fd88da 100644 --- a/python/paddle/fluid/tests/book/high-level-api/test_image_classification_vgg_new_api.py +++ b/python/paddle/fluid/tests/book/high-level-api/test_image_classification_vgg_new_api.py @@ -89,9 +89,11 @@ def train(use_cuda, train_program, parallel, params_dirname): cifar10_small_test_set.train10(batch_size=10), buf_size=128 * 10), batch_size=BATCH_SIZE, drop_last=False) - + # Use only part of the test set data validation program test_reader = paddle.batch( - paddle.dataset.cifar.test10(), batch_size=BATCH_SIZE, drop_last=False) + cifar10_small_test_set.test10(BATCH_SIZE), + batch_size=BATCH_SIZE, + drop_last=False) def event_handler(event): if isinstance(event, EndStepEvent):