diff --git a/python/paddle/fluid/tests/unittests/test_imperative_mnist.py b/python/paddle/fluid/tests/unittests/test_imperative_mnist.py index 5ab01839fbc20bbd3c242878c4ea23a00f7b0dca..6aae2c0507dca998cfba07465b96a7feca24fb59 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_mnist.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_mnist.py @@ -117,6 +117,7 @@ class TestImperativeMnist(unittest.TestCase): train_reader = paddle.batch( paddle.dataset.mnist.train(), batch_size=128, drop_last=True) + mnist.train() dy_param_init_value = {} for epoch in range(epoch_num): for batch_id, data in enumerate(train_reader()):