test_imperative_mnist.py 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
22
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
23
from test_imperative_base import new_program_scope
24
from utils import DyGraphProgramDescTracerTestHelper
J
Jiabin Yang 已提交
25
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
26 27


M
minqiyang 已提交
28
class SimpleImgConvPool(fluid.dygraph.Layer):
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
    def __init__(
        self,
        num_channels,
        num_filters,
        filter_size,
        pool_size,
        pool_stride,
        pool_padding=0,
        pool_type='max',
        global_pooling=False,
        conv_stride=1,
        conv_padding=0,
        conv_dilation=1,
        conv_groups=1,
        act=None,
        use_cudnn=False,
        param_attr=None,
        bias_attr=None,
    ):
48
        super().__init__()
M
minqiyang 已提交
49

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
        self._conv2d = Conv2D(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=filter_size,
            stride=conv_stride,
            padding=conv_padding,
            dilation=conv_dilation,
            groups=conv_groups,
            param_attr=None,
            bias_attr=None,
            use_cudnn=use_cudnn,
        )

        self._pool2d = Pool2D(
            pool_size=pool_size,
            pool_type=pool_type,
            pool_stride=pool_stride,
            pool_padding=pool_padding,
            global_pooling=global_pooling,
            use_cudnn=use_cudnn,
        )
71

M
minqiyang 已提交
72
    def forward(self, inputs):
M
minqiyang 已提交
73 74 75
        x = self._conv2d(inputs)
        x = self._pool2d(x)
        return x
76 77


M
minqiyang 已提交
78
class MNIST(fluid.dygraph.Layer):
79
    def __init__(self):
80
        super().__init__()
81

82 83 84
        self._simple_img_conv_pool_1 = SimpleImgConvPool(
            1, 20, 5, 2, 2, act="relu"
        )
85

86 87 88
        self._simple_img_conv_pool_2 = SimpleImgConvPool(
            20, 50, 5, 2, 2, act="relu"
        )
M
minqiyang 已提交
89

90
        self.pool_2_shape = 50 * 4 * 4
M
minqiyang 已提交
91
        SIZE = 10
92 93 94 95 96 97 98 99 100 101 102
        scale = (2.0 / (self.pool_2_shape**2 * SIZE)) ** 0.5
        self._fc = Linear(
            self.pool_2_shape,
            10,
            param_attr=fluid.param_attr.ParamAttr(
                initializer=fluid.initializer.NormalInitializer(
                    loc=0.0, scale=scale
                )
            ),
            act="softmax",
        )
M
minqiyang 已提交
103 104 105 106

    def forward(self, inputs):
        x = self._simple_img_conv_pool_1(inputs)
        x = self._simple_img_conv_pool_2(x)
107
        x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
M
minqiyang 已提交
108 109 110 111 112
        x = self._fc(x)
        return x


class TestImperativeMnist(unittest.TestCase):
113 114 115 116 117 118 119 120 121
    def reader_decorator(self, reader):
        def _reader_imple():
            for item in reader():
                image = np.array(item[0]).reshape(1, 28, 28)
                label = np.array(item[1]).astype('int64').reshape(1)
                yield image, label

        return _reader_imple

122
    def func_test_mnist_float32(self):
123
        seed = 90
M
minqiyang 已提交
124
        epoch_num = 1
125 126 127
        batch_size = 128
        batch_num = 50

128 129
        traced_layer = None

M
minqiyang 已提交
130
        with fluid.dygraph.guard():
131 132 133
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

134
            mnist = MNIST()
135 136 137
            sgd = SGDOptimizer(
                learning_rate=1e-3, parameter_list=mnist.parameters()
            )
138 139 140

            batch_py_reader = fluid.io.PyReader(capacity=1)
            batch_py_reader.decorate_sample_list_generator(
141 142 143 144 145 146 147
                paddle.batch(
                    self.reader_decorator(paddle.dataset.mnist.train()),
                    batch_size=batch_size,
                    drop_last=True,
                ),
                places=fluid.CPUPlace(),
            )
148

M
minqiyang 已提交
149
            mnist.train()
150
            dy_param_init_value = {}
151

152 153
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
M
minqiyang 已提交
154
            for epoch in range(epoch_num):
155 156 157 158 159 160
                for batch_id, data in enumerate(batch_py_reader()):
                    if batch_id >= batch_num:
                        break
                    img = data[0]
                    dy_x_data = img.numpy()
                    label = data[1]
L
lujun 已提交
161
                    label.stop_gradient = True
M
minqiyang 已提交
162

J
Jiabin Yang 已提交
163
                    if batch_id % 10 == 0 and _in_legacy_dygraph():
164
                        cost, traced_layer = paddle.jit.TracedLayer.trace(
165 166
                            mnist, inputs=img
                        )
167 168 169 170
                        if program is not None:
                            self.assertTrue(program, traced_layer.program)
                        program = traced_layer.program
                        traced_layer.save_inference_model(
171 172
                            './infer_imperative_mnist'
                        )
173 174 175
                    else:
                        cost = mnist(img)

176 177 178 179
                    if traced_layer is not None:
                        cost_static = traced_layer([img])
                        helper.assertEachVar(cost, cost_static)

M
minqiyang 已提交
180
                    loss = fluid.layers.cross_entropy(cost, label)
181
                    avg_loss = paddle.mean(loss)
M
minqiyang 已提交
182

L
lujun 已提交
183
                    dy_out = avg_loss.numpy()
M
minqiyang 已提交
184 185 186

                    if epoch == 0 and batch_id == 0:
                        for param in mnist.parameters():
L
lujun 已提交
187
                            dy_param_init_value[param.name] = param.numpy()
M
minqiyang 已提交
188

L
lujun 已提交
189
                    avg_loss.backward()
M
minqiyang 已提交
190 191 192 193 194
                    sgd.minimize(avg_loss)
                    mnist.clear_gradients()

                    dy_param_value = {}
                    for param in mnist.parameters():
L
lujun 已提交
195
                        dy_param_value[param.name] = param.numpy()
196 197 198 199 200

        with new_program_scope():
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

201 202 203 204 205
            exe = fluid.Executor(
                fluid.CPUPlace()
                if not core.is_compiled_with_cuda()
                else fluid.CUDAPlace(0)
            )
206

207
            mnist = MNIST()
M
minqiyang 已提交
208
            sgd = SGDOptimizer(learning_rate=1e-3)
209 210 211 212 213 214 215 216 217
            train_reader = paddle.batch(
                paddle.dataset.mnist.train(),
                batch_size=batch_size,
                drop_last=True,
            )

            img = fluid.layers.data(
                name='pixel', shape=[1, 28, 28], dtype='float32'
            )
218 219
            label = fluid.layers.data(name='label', shape=[1], dtype='int64')
            cost = mnist(img)
M
minqiyang 已提交
220
            loss = fluid.layers.cross_entropy(cost, label)
221
            avg_loss = paddle.mean(loss)
M
minqiyang 已提交
222
            sgd.minimize(avg_loss)
223 224 225 226

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
M
minqiyang 已提交
227
            for param in mnist.parameters():
228 229
                static_param_name_list.append(param.name)

230 231 232 233
            out = exe.run(
                fluid.default_startup_program(),
                fetch_list=static_param_name_list,
            )
234 235 236 237

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

M
minqiyang 已提交
238 239
            for epoch in range(epoch_num):
                for batch_id, data in enumerate(train_reader()):
240 241
                    if batch_id >= batch_num:
                        break
242 243 244 245 246 247 248 249
                    static_x_data = np.array(
                        [x[0].reshape(1, 28, 28) for x in data]
                    ).astype('float32')
                    y_data = (
                        np.array([x[1] for x in data])
                        .astype('int64')
                        .reshape([batch_size, 1])
                    )
M
minqiyang 已提交
250 251 252

                    fetch_list = [avg_loss.name]
                    fetch_list.extend(static_param_name_list)
253 254 255 256

                    if traced_layer is not None:
                        traced_layer([static_x_data])

257 258 259 260 261
                    out = exe.run(
                        fluid.default_main_program(),
                        feed={"pixel": static_x_data, "label": y_data},
                        fetch_list=fetch_list,
                    )
M
minqiyang 已提交
262 263 264 265

                    static_param_value = {}
                    static_out = out[0]
                    for i in range(1, len(out)):
266 267 268
                        static_param_value[static_param_name_list[i - 1]] = out[
                            i
                        ]
M
minqiyang 已提交
269

270 271 272
        np.testing.assert_allclose(
            dy_x_data.all(), static_x_data.all(), rtol=1e-05
        )
273

274
        for key, value in static_param_init_value.items():
275 276 277
            np.testing.assert_allclose(
                value, dy_param_init_value[key], rtol=1e-05
            )
M
minqiyang 已提交
278

279
        np.testing.assert_allclose(static_out, dy_out, rtol=1e-05)
M
minqiyang 已提交
280

281
        for key, value in static_param_value.items():
282 283 284
            np.testing.assert_allclose(
                value, dy_param_value[key], rtol=1e-05, atol=1e-05
            )
285

286 287 288 289 290
    def test_mnist_float32(self):
        with _test_eager_guard():
            self.func_test_mnist_float32()
        self.func_test_mnist_float32()

291 292

if __name__ == '__main__':
H
hong 已提交
293
    paddle.enable_static()
294
    unittest.main()