diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py index eefd5709a3bd7a34c3aaec370690158372562328..e878bfa4e9b2be7ff356e1e3861da2afc56063d7 100644 --- a/python/paddle/v2/framework/tests/mnist.py +++ b/python/paddle/v2/framework/tests/mnist.py @@ -205,7 +205,8 @@ train_reader = paddle.batch( def test(cost_name): - test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=BATCH_SIZE) cost = [] error = [] for data in test_reader():