diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt
index a1cf5fad138f068c9eac5fe8d681c9f08b192270..562866cf60ce3ad297aaaeb1db707bdbe482e97a 100644
--- a/python/paddle/fluid/tests/unittests/CMakeLists.txt
+++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt
@@ -76,7 +76,7 @@ list(REMOVE_ITEM TEST_OPS test_image_classification_resnet)
 list(REMOVE_ITEM TEST_OPS test_bilinear_interp_op)
 list(REMOVE_ITEM TEST_OPS test_nearest_interp_op)
 list(REMOVE_ITEM TEST_OPS test_imperative_resnet)
-list(REMOVE_ITEM TEST_OPS test_imperative_optimizer)
+list(REMOVE_ITEM TEST_OPS test_imperative_mnist)
 list(REMOVE_ITEM TEST_OPS test_ir_memory_optimize_transformer)
 foreach(TEST_OP ${TEST_OPS})
     py_test_modules(${TEST_OP} MODULES ${TEST_OP})
@@ -87,7 +87,7 @@ py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op SERIAL)
 py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op SERIAL)
 py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS
   FLAGS_cudnn_deterministic=1)
-py_test_modules(test_imperative_optimizer MODULES test_imperative_optimizer ENVS
+py_test_modules(test_imperative_mnist MODULES test_imperative_mnist ENVS
   FLAGS_cudnn_deterministic=1)
 if(WITH_DISTRIBUTE)
     py_test_modules(test_dist_train MODULES test_dist_train SERIAL)
diff --git a/python/paddle/fluid/tests/unittests/test_imperative_mnist.py b/python/paddle/fluid/tests/unittests/test_imperative_mnist.py
index d0a5a883174cb33a035b344f9489b2ba02ba99f1..d821324364c1feb85f138a23aff552d9b8c9f297 100644
--- a/python/paddle/fluid/tests/unittests/test_imperative_mnist.py
+++ b/python/paddle/fluid/tests/unittests/test_imperative_mnist.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from __future__ import print_function
+
 import contextlib
 import unittest
 import numpy as np
@@ -21,112 +23,56 @@ import paddle
 import paddle.fluid as fluid
 from paddle.fluid import core
 from paddle.fluid.optimizer import SGDOptimizer
-from paddle.fluid.imperative.nn import Conv2D, Pool2D, FC
+from paddle.fluid.imperative.nn import FC
 from paddle.fluid.imperative.base import to_variable
 from test_imperative_base import new_program_scope
 
 
-class SimpleImgConvPool(fluid.imperative.Layer):
-    def __init__(self,
-                 num_channels,
-                 num_filters,
-                 filter_size,
-                 pool_size,
-                 pool_stride,
-                 pool_padding=0,
-                 pool_type='max',
-                 global_pooling=False,
-                 conv_stride=1,
-                 conv_padding=0,
-                 conv_dilation=1,
-                 conv_groups=1,
-                 act=None,
-                 use_cudnn=False,
-                 param_attr=None,
-                 bias_attr=None):
-        super(SimpleImgConvPool, self).__init__()
-
-        self._conv2d = Conv2D(
-            num_channels=num_channels,
-            num_filters=num_filters,
-            filter_size=filter_size,
-            stride=conv_stride,
-            padding=conv_padding,
-            dilation=conv_dilation,
-            groups=conv_groups,
-            param_attr=None,
-            bias_attr=None,
-            use_cudnn=use_cudnn)
-
-        self._pool2d = Pool2D(
-            pool_size=pool_size,
-            pool_type=pool_type,
-            pool_stride=pool_stride,
-            pool_padding=pool_padding,
-            global_pooling=global_pooling,
-            use_cudnn=use_cudnn)
-
-    def forward(self, inputs):
-        x = self._conv2d(inputs)
-        x = self._pool2d(x)
-        return x
-
-
-class MNIST(fluid.imperative.Layer):
+class MLP(fluid.imperative.Layer):
     def __init__(self, param_attr=None, bias_attr=None):
-        super(MNIST, self).__init__()
-
-        self._simple_img_conv_pool_1 = SimpleImgConvPool(
-            1, 20, 5, 2, 2, act="relu")
+        self._fc1 = FC(10)
+        self._fc2 = FC(10)
 
-        self._simple_img_conv_pool_2 = SimpleImgConvPool(
-            20, 50, 5, 2, 2, act="relu")
+    def forward(self, inputs):
+        y = self._fc1(inputs)
+        y = self._fc2(y)
+        return y
 
-        pool_2_shape = 50 * 8 * 8
-        SIZE = 10
-        scale = (2.0 / (pool_2_shape**2 * SIZE))**0.5
-        self._fc = FC(10,
-                      param_attr=fluid.param_attr.ParamAttr(
-                          initializer=fluid.initializer.NormalInitializer(
-                              loc=0.0, scale=scale)))
 
-    def forward(self, inputs):
-        x = self._simple_img_conv_pool_1(inputs)
-        x = self._simple_img_conv_pool_2(x)
-        x = self._fc(x)
-        return x
+class TestImperativeOptimizerBase(unittest.TestCase):
+    def setUp(self):
+        self.batch_num = 2
 
+    def get_optimizer(self):
+        self.optimizer = SGDOptimizer(learning_rate=1e-3)
 
-class TestImperativeMnist(unittest.TestCase):
-    def test_mnist_cpu_float32(self):
+    def test_optimizer_float32(self):
         seed = 90
-
         with fluid.imperative.guard():
             fluid.default_startup_program().random_seed = seed
             fluid.default_main_program().random_seed = seed
 
-            mnist = MNIST()
-            sgd = SGDOptimizer(learning_rate=1e-3)
+            mlp = MLP()
+            self.get_optimizer()
             train_reader = paddle.batch(
-                paddle.dataset.mnist.train(), batch_size=128)
+                paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
 
             dy_param_init_value = {}
             for batch_id, data in enumerate(train_reader()):
-                if batch_id >= 2:
+                if batch_id >= self.batch_num:
                     break
 
-                x_data = np.array(
+                dy_x_data = np.array(
                     [x[0].reshape(1, 28, 28) for x in data]).astype('float32')
                 y_data = np.array([x[1] for x in data]).astype('int64').reshape(
                     128, 1)
 
-                img = to_variable(x_data)
+                img = to_variable(dy_x_data)
                 label = to_variable(y_data)
                 label._stop_gradient = True
 
-                cost = mnist(img)
-                loss = fluid.layers.cross_entropy(cost, label)
-                avg_loss = fluid.layers.mean(loss)
+                cost = mlp(img)
+                avg_loss = fluid.layers.reduce_mean(cost)
                 dy_out = avg_loss._numpy()
 
                 if batch_id == 0:
@@ -135,7 +81,8 @@ class TestImperativeMnist(unittest.TestCase):
                         dy_param_init_value[param.name] = param._numpy()
 
                 avg_loss._backward()
-                sgd.minimize(avg_loss)
+                self.optimizer.minimize(avg_loss)
+                mlp.clear_gradients()
                 dy_param_value = {}
                 for param in fluid.default_main_program().global_block(
                 ).all_parameters():
@@ -149,23 +96,21 @@ class TestImperativeMnist(unittest.TestCase):
             ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
 
             mnist = MNIST()
-            sgd = SGDOptimizer(learning_rate=1e-3)
+            self.get_optimizer()
             train_reader = paddle.batch(
-                paddle.dataset.mnist.train(), batch_size=128)
+                paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
 
             img = fluid.layers.data(
                 name='pixel', shape=[1, 28, 28], dtype='float32')
             label = fluid.layers.data(name='label', shape=[1], dtype='int64')
             cost = mnist(img)
-            loss = fluid.layers.cross_entropy(cost, label)
-            avg_loss = fluid.layers.mean(loss)
-            sgd.minimize(avg_loss)
+            avg_loss = fluid.layers.reduce_mean(cost)
+            self.optimizer.minimize(avg_loss)
 
             # initialize params and fetch them
             static_param_init_value = {}
             static_param_name_list = []
-            for param in fluid.default_startup_program().global_block(
-            ).all_parameters():
+            for param in mnist.parameters():
                 static_param_name_list.append(param.name)
 
             out = exe.run(fluid.default_startup_program(),
@@ -175,10 +120,10 @@ class TestImperativeMnist(unittest.TestCase):
                 static_param_init_value[static_param_name_list[i]] = out[i]
 
             for batch_id, data in enumerate(train_reader()):
-                if batch_id >= 2:
+                if batch_id >= self.batch_num:
                     break
 
-                x_data = np.array(
+                static_x_data = np.array(
                     [x[0].reshape(1, 28, 28) for x in data]).astype('float32')
                 y_data = np.array([x[1] for x in data]).astype('int64').reshape(
                     [128, 1])
@@ -186,7 +131,7 @@ class TestImperativeMnist(unittest.TestCase):
                 fetch_list = [avg_loss.name]
                 fetch_list.extend(static_param_name_list)
                 out = exe.run(fluid.default_main_program(),
-                              feed={"pixel": x_data,
+                              feed={"pixel": static_x_data,
                                     "label": y_data},
                               fetch_list=fetch_list)
 
@@ -196,11 +141,12 @@ class TestImperativeMnist(unittest.TestCase):
                     static_param_value[static_param_name_list[i - 1]] = out[i]
 
         for key, value in six.iteritems(static_param_init_value):
-            self.assertTrue(
-                np.allclose(value.all(), dy_param_init_value[key].all()))
-        self.assertTrue(np.allclose(static_out.all(), dy_out.all()))
+            self.assertTrue(np.allclose(value, dy_param_init_value[key]))
+
+        self.assertTrue(np.allclose(static_out, dy_out))
+
         for key, value in six.iteritems(static_param_value):
-            self.assertTrue(np.allclose(value.all(), dy_param_value[key].all()))
+            self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
 
 
 if __name__ == '__main__':