test_imperative_mnist.py 9.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

M
minqiyang 已提交
15 16
from __future__ import print_function

17 18 19 20 21 22 23 24 25
import contextlib
import unittest
import numpy as np
import six

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
26
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
M
minqiyang 已提交
27
from paddle.fluid.dygraph.base import to_variable
28
from test_imperative_base import new_program_scope
29
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
J
Jiabin Yang 已提交
30
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
31 32


M
minqiyang 已提交
33
class SimpleImgConvPool(fluid.dygraph.Layer):
M
minqiyang 已提交
34
    def __init__(self,
35
                 num_channels,
M
minqiyang 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
                 num_filters,
                 filter_size,
                 pool_size,
                 pool_stride,
                 pool_padding=0,
                 pool_type='max',
                 global_pooling=False,
                 conv_stride=1,
                 conv_padding=0,
                 conv_dilation=1,
                 conv_groups=1,
                 act=None,
                 use_cudnn=False,
                 param_attr=None,
                 bias_attr=None):
51
        super(SimpleImgConvPool, self).__init__()
M
minqiyang 已提交
52 53

        self._conv2d = Conv2D(
54
            num_channels=num_channels,
M
minqiyang 已提交
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
            num_filters=num_filters,
            filter_size=filter_size,
            stride=conv_stride,
            padding=conv_padding,
            dilation=conv_dilation,
            groups=conv_groups,
            param_attr=None,
            bias_attr=None,
            use_cudnn=use_cudnn)

        self._pool2d = Pool2D(
            pool_size=pool_size,
            pool_type=pool_type,
            pool_stride=pool_stride,
            pool_padding=pool_padding,
            global_pooling=global_pooling,
            use_cudnn=use_cudnn)
72

M
minqiyang 已提交
73
    def forward(self, inputs):
M
minqiyang 已提交
74 75 76
        x = self._conv2d(inputs)
        x = self._pool2d(x)
        return x
77 78


M
minqiyang 已提交
79
class MNIST(fluid.dygraph.Layer):
80 81
    def __init__(self):
        super(MNIST, self).__init__()
82

M
minqiyang 已提交
83
        self._simple_img_conv_pool_1 = SimpleImgConvPool(
84
            1, 20, 5, 2, 2, act="relu")
85

M
minqiyang 已提交
86
        self._simple_img_conv_pool_2 = SimpleImgConvPool(
87
            20, 50, 5, 2, 2, act="relu")
M
minqiyang 已提交
88

89
        self.pool_2_shape = 50 * 4 * 4
M
minqiyang 已提交
90
        SIZE = 10
91 92 93 94 95 96 97 98
        scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
        self._fc = Linear(
            self.pool_2_shape,
            10,
            param_attr=fluid.param_attr.ParamAttr(
                initializer=fluid.initializer.NormalInitializer(
                    loc=0.0, scale=scale)),
            act="softmax")
M
minqiyang 已提交
99 100 101 102

    def forward(self, inputs):
        x = self._simple_img_conv_pool_1(inputs)
        x = self._simple_img_conv_pool_2(x)
103
        x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
M
minqiyang 已提交
104 105 106 107 108
        x = self._fc(x)
        return x


class TestImperativeMnist(unittest.TestCase):
109 110 111 112 113 114 115 116 117
    def reader_decorator(self, reader):
        def _reader_imple():
            for item in reader():
                image = np.array(item[0]).reshape(1, 28, 28)
                label = np.array(item[1]).astype('int64').reshape(1)
                yield image, label

        return _reader_imple

118
    def func_test_mnist_float32(self):
119
        seed = 90
M
minqiyang 已提交
120
        epoch_num = 1
121 122 123
        batch_size = 128
        batch_num = 50

124 125
        traced_layer = None

M
minqiyang 已提交
126
        with fluid.dygraph.guard():
127 128 129
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

130 131 132
            mnist = MNIST()
            sgd = SGDOptimizer(
                learning_rate=1e-3, parameter_list=mnist.parameters())
133 134 135 136 137 138 139 140

            batch_py_reader = fluid.io.PyReader(capacity=1)
            batch_py_reader.decorate_sample_list_generator(
                paddle.batch(
                    self.reader_decorator(paddle.dataset.mnist.train()),
                    batch_size=batch_size,
                    drop_last=True),
                places=fluid.CPUPlace())
141

M
minqiyang 已提交
142
            mnist.train()
143
            dy_param_init_value = {}
144

145 146
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
M
minqiyang 已提交
147
            for epoch in range(epoch_num):
148 149 150 151 152 153
                for batch_id, data in enumerate(batch_py_reader()):
                    if batch_id >= batch_num:
                        break
                    img = data[0]
                    dy_x_data = img.numpy()
                    label = data[1]
L
lujun 已提交
154
                    label.stop_gradient = True
M
minqiyang 已提交
155

J
Jiabin Yang 已提交
156
                    if batch_id % 10 == 0 and _in_legacy_dygraph():
157
                        cost, traced_layer = paddle.jit.TracedLayer.trace(
158 159 160 161 162 163
                            mnist, inputs=img)
                        if program is not None:
                            self.assertTrue(program, traced_layer.program)
                        program = traced_layer.program
                        traced_layer.save_inference_model(
                            './infer_imperative_mnist')
164 165 166
                    else:
                        cost = mnist(img)

167 168 169 170
                    if traced_layer is not None:
                        cost_static = traced_layer([img])
                        helper.assertEachVar(cost, cost_static)

M
minqiyang 已提交
171 172 173
                    loss = fluid.layers.cross_entropy(cost, label)
                    avg_loss = fluid.layers.mean(loss)

L
lujun 已提交
174
                    dy_out = avg_loss.numpy()
M
minqiyang 已提交
175 176 177

                    if epoch == 0 and batch_id == 0:
                        for param in mnist.parameters():
L
lujun 已提交
178
                            dy_param_init_value[param.name] = param.numpy()
M
minqiyang 已提交
179

L
lujun 已提交
180
                    avg_loss.backward()
M
minqiyang 已提交
181 182 183 184 185
                    sgd.minimize(avg_loss)
                    mnist.clear_gradients()

                    dy_param_value = {}
                    for param in mnist.parameters():
L
lujun 已提交
186
                        dy_param_value[param.name] = param.numpy()
187 188 189 190 191 192 193 194

        with new_program_scope():
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))

195
            mnist = MNIST()
M
minqiyang 已提交
196
            sgd = SGDOptimizer(learning_rate=1e-3)
197
            train_reader = paddle.batch(
198 199 200
                paddle.dataset.mnist.train(),
                batch_size=batch_size,
                drop_last=True)
201 202 203 204 205

            img = fluid.layers.data(
                name='pixel', shape=[1, 28, 28], dtype='float32')
            label = fluid.layers.data(name='label', shape=[1], dtype='int64')
            cost = mnist(img)
M
minqiyang 已提交
206 207 208
            loss = fluid.layers.cross_entropy(cost, label)
            avg_loss = fluid.layers.mean(loss)
            sgd.minimize(avg_loss)
209 210 211 212

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
M
minqiyang 已提交
213
            for param in mnist.parameters():
214 215 216 217 218 219 220 221
                static_param_name_list.append(param.name)

            out = exe.run(fluid.default_startup_program(),
                          fetch_list=static_param_name_list)

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

M
minqiyang 已提交
222 223
            for epoch in range(epoch_num):
                for batch_id, data in enumerate(train_reader()):
224 225
                    if batch_id >= batch_num:
                        break
M
minqiyang 已提交
226 227 228 229
                    static_x_data = np.array(
                        [x[0].reshape(1, 28, 28)
                         for x in data]).astype('float32')
                    y_data = np.array(
230 231
                        [x[1] for x in data]).astype('int64').reshape(
                            [batch_size, 1])
M
minqiyang 已提交
232 233 234

                    fetch_list = [avg_loss.name]
                    fetch_list.extend(static_param_name_list)
235 236 237 238

                    if traced_layer is not None:
                        traced_layer([static_x_data])

M
minqiyang 已提交
239 240 241 242 243 244 245 246 247 248 249 250 251
                    out = exe.run(
                        fluid.default_main_program(),
                        feed={"pixel": static_x_data,
                              "label": y_data},
                        fetch_list=fetch_list)

                    static_param_value = {}
                    static_out = out[0]
                    for i in range(1, len(out)):
                        static_param_value[static_param_name_list[i - 1]] = out[
                            i]

        self.assertTrue(np.allclose(dy_x_data.all(), static_x_data.all()))
252 253

        for key, value in six.iteritems(static_param_init_value):
M
minqiyang 已提交
254 255 256 257
            self.assertTrue(np.allclose(value, dy_param_init_value[key]))

        self.assertTrue(np.allclose(static_out, dy_out))

258
        for key, value in six.iteritems(static_param_value):
M
minqiyang 已提交
259
            self.assertTrue(np.allclose(value, dy_param_value[key], atol=1e-5))
260

261 262 263 264 265
    def test_mnist_float32(self):
        with _test_eager_guard():
            self.func_test_mnist_float32()
        self.func_test_mnist_float32()

266 267

if __name__ == '__main__':
H
hong 已提交
268
    paddle.enable_static()
269
    unittest.main()