test_imperative_mnist.py 10.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import unittest
import numpy as np
import six

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
24
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
M
minqiyang 已提交
25
from paddle.fluid.dygraph.base import to_variable
26
from test_imperative_base import new_program_scope
27
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
J
Jiabin Yang 已提交
28
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
29 30


M
minqiyang 已提交
31
class SimpleImgConvPool(fluid.dygraph.Layer):
32

M
minqiyang 已提交
33
    def __init__(self,
34
                 num_channels,
M
minqiyang 已提交
35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
                 num_filters,
                 filter_size,
                 pool_size,
                 pool_stride,
                 pool_padding=0,
                 pool_type='max',
                 global_pooling=False,
                 conv_stride=1,
                 conv_padding=0,
                 conv_dilation=1,
                 conv_groups=1,
                 act=None,
                 use_cudnn=False,
                 param_attr=None,
                 bias_attr=None):
50
        super(SimpleImgConvPool, self).__init__()
M
minqiyang 已提交
51

52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
        self._conv2d = Conv2D(num_channels=num_channels,
                              num_filters=num_filters,
                              filter_size=filter_size,
                              stride=conv_stride,
                              padding=conv_padding,
                              dilation=conv_dilation,
                              groups=conv_groups,
                              param_attr=None,
                              bias_attr=None,
                              use_cudnn=use_cudnn)

        self._pool2d = Pool2D(pool_size=pool_size,
                              pool_type=pool_type,
                              pool_stride=pool_stride,
                              pool_padding=pool_padding,
                              global_pooling=global_pooling,
                              use_cudnn=use_cudnn)
69

M
minqiyang 已提交
70
    def forward(self, inputs):
M
minqiyang 已提交
71 72 73
        x = self._conv2d(inputs)
        x = self._pool2d(x)
        return x
74 75


M
minqiyang 已提交
76
class MNIST(fluid.dygraph.Layer):
77

78 79
    def __init__(self):
        super(MNIST, self).__init__()
80

81 82 83 84 85 86
        self._simple_img_conv_pool_1 = SimpleImgConvPool(1,
                                                         20,
                                                         5,
                                                         2,
                                                         2,
                                                         act="relu")
87

88 89 90 91 92 93
        self._simple_img_conv_pool_2 = SimpleImgConvPool(20,
                                                         50,
                                                         5,
                                                         2,
                                                         2,
                                                         act="relu")
M
minqiyang 已提交
94

95
        self.pool_2_shape = 50 * 4 * 4
M
minqiyang 已提交
96
        SIZE = 10
97
        scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
98 99 100 101 102 103
        self._fc = Linear(self.pool_2_shape,
                          10,
                          param_attr=fluid.param_attr.ParamAttr(
                              initializer=fluid.initializer.NormalInitializer(
                                  loc=0.0, scale=scale)),
                          act="softmax")
M
minqiyang 已提交
104 105 106 107

    def forward(self, inputs):
        x = self._simple_img_conv_pool_1(inputs)
        x = self._simple_img_conv_pool_2(x)
108
        x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
M
minqiyang 已提交
109 110 111 112 113
        x = self._fc(x)
        return x


class TestImperativeMnist(unittest.TestCase):
114

115
    def reader_decorator(self, reader):
116

117 118 119 120 121 122 123 124
        def _reader_imple():
            for item in reader():
                image = np.array(item[0]).reshape(1, 28, 28)
                label = np.array(item[1]).astype('int64').reshape(1)
                yield image, label

        return _reader_imple

125
    def func_test_mnist_float32(self):
126
        seed = 90
M
minqiyang 已提交
127
        epoch_num = 1
128 129 130
        batch_size = 128
        batch_num = 50

131 132
        traced_layer = None

M
minqiyang 已提交
133
        with fluid.dygraph.guard():
134 135 136
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

137
            mnist = MNIST()
138 139
            sgd = SGDOptimizer(learning_rate=1e-3,
                               parameter_list=mnist.parameters())
140 141 142

            batch_py_reader = fluid.io.PyReader(capacity=1)
            batch_py_reader.decorate_sample_list_generator(
143 144 145 146
                paddle.batch(self.reader_decorator(
                    paddle.dataset.mnist.train()),
                             batch_size=batch_size,
                             drop_last=True),
147
                places=fluid.CPUPlace())
148

M
minqiyang 已提交
149
            mnist.train()
150
            dy_param_init_value = {}
151

152 153
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
M
minqiyang 已提交
154
            for epoch in range(epoch_num):
155 156 157 158 159 160
                for batch_id, data in enumerate(batch_py_reader()):
                    if batch_id >= batch_num:
                        break
                    img = data[0]
                    dy_x_data = img.numpy()
                    label = data[1]
L
lujun 已提交
161
                    label.stop_gradient = True
M
minqiyang 已提交
162

J
Jiabin Yang 已提交
163
                    if batch_id % 10 == 0 and _in_legacy_dygraph():
164
                        cost, traced_layer = paddle.jit.TracedLayer.trace(
165 166 167 168 169 170
                            mnist, inputs=img)
                        if program is not None:
                            self.assertTrue(program, traced_layer.program)
                        program = traced_layer.program
                        traced_layer.save_inference_model(
                            './infer_imperative_mnist')
171 172 173
                    else:
                        cost = mnist(img)

174 175 176 177
                    if traced_layer is not None:
                        cost_static = traced_layer([img])
                        helper.assertEachVar(cost, cost_static)

M
minqiyang 已提交
178
                    loss = fluid.layers.cross_entropy(cost, label)
179
                    avg_loss = paddle.mean(loss)
M
minqiyang 已提交
180

L
lujun 已提交
181
                    dy_out = avg_loss.numpy()
M
minqiyang 已提交
182 183 184

                    if epoch == 0 and batch_id == 0:
                        for param in mnist.parameters():
L
lujun 已提交
185
                            dy_param_init_value[param.name] = param.numpy()
M
minqiyang 已提交
186

L
lujun 已提交
187
                    avg_loss.backward()
M
minqiyang 已提交
188 189 190 191 192
                    sgd.minimize(avg_loss)
                    mnist.clear_gradients()

                    dy_param_value = {}
                    for param in mnist.parameters():
L
lujun 已提交
193
                        dy_param_value[param.name] = param.numpy()
194 195 196 197 198 199 200 201

        with new_program_scope():
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))

202
            mnist = MNIST()
M
minqiyang 已提交
203
            sgd = SGDOptimizer(learning_rate=1e-3)
204 205 206
            train_reader = paddle.batch(paddle.dataset.mnist.train(),
                                        batch_size=batch_size,
                                        drop_last=True)
207

208 209 210
            img = fluid.layers.data(name='pixel',
                                    shape=[1, 28, 28],
                                    dtype='float32')
211 212
            label = fluid.layers.data(name='label', shape=[1], dtype='int64')
            cost = mnist(img)
M
minqiyang 已提交
213
            loss = fluid.layers.cross_entropy(cost, label)
214
            avg_loss = paddle.mean(loss)
M
minqiyang 已提交
215
            sgd.minimize(avg_loss)
216 217 218 219

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
M
minqiyang 已提交
220
            for param in mnist.parameters():
221 222 223 224 225 226 227 228
                static_param_name_list.append(param.name)

            out = exe.run(fluid.default_startup_program(),
                          fetch_list=static_param_name_list)

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

M
minqiyang 已提交
229 230
            for epoch in range(epoch_num):
                for batch_id, data in enumerate(train_reader()):
231 232
                    if batch_id >= batch_num:
                        break
233 234 235 236 237 238
                    static_x_data = np.array([
                        x[0].reshape(1, 28, 28) for x in data
                    ]).astype('float32')
                    y_data = np.array([x[1]
                                       for x in data]).astype('int64').reshape(
                                           [batch_size, 1])
M
minqiyang 已提交
239 240 241

                    fetch_list = [avg_loss.name]
                    fetch_list.extend(static_param_name_list)
242 243 244 245

                    if traced_layer is not None:
                        traced_layer([static_x_data])

246 247 248 249 250 251
                    out = exe.run(fluid.default_main_program(),
                                  feed={
                                      "pixel": static_x_data,
                                      "label": y_data
                                  },
                                  fetch_list=fetch_list)
M
minqiyang 已提交
252 253 254 255

                    static_param_value = {}
                    static_out = out[0]
                    for i in range(1, len(out)):
256 257
                        static_param_value[static_param_name_list[i -
                                                                  1]] = out[i]
M
minqiyang 已提交
258

259 260 261
        np.testing.assert_allclose(dy_x_data.all(),
                                   static_x_data.all(),
                                   rtol=1e-05)
262 263

        for key, value in six.iteritems(static_param_init_value):
264 265 266
            np.testing.assert_allclose(value,
                                       dy_param_init_value[key],
                                       rtol=1e-05)
M
minqiyang 已提交
267

268
        np.testing.assert_allclose(static_out, dy_out, rtol=1e-05)
M
minqiyang 已提交
269

270
        for key, value in six.iteritems(static_param_value):
271 272 273 274
            np.testing.assert_allclose(value,
                                       dy_param_value[key],
                                       rtol=1e-05,
                                       atol=1e-05)
275

276 277 278 279 280
    def test_mnist_float32(self):
        with _test_eager_guard():
            self.func_test_mnist_float32()
        self.func_test_mnist_float32()

281 282

if __name__ == '__main__':
H
hong 已提交
283
    paddle.enable_static()
284
    unittest.main()