test_imperative_mnist.py 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import numpy as np
18
from test_imperative_base import new_program_scope
19
from utils import DyGraphProgramDescTracerTestHelper
20 21 22 23 24

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.optimizer import SGDOptimizer
25
from paddle.nn import Linear
26 27


M
minqiyang 已提交
28
class SimpleImgConvPool(fluid.dygraph.Layer):
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
    def __init__(
        self,
        num_channels,
        num_filters,
        filter_size,
        pool_size,
        pool_stride,
        pool_padding=0,
        pool_type='max',
        global_pooling=False,
        conv_stride=1,
        conv_padding=0,
        conv_dilation=1,
        conv_groups=1,
        act=None,
        use_cudnn=False,
        param_attr=None,
        bias_attr=None,
    ):
48
        super().__init__()
M
minqiyang 已提交
49

50 51 52 53
        self._conv2d = paddle.nn.Conv2D(
            in_channels=num_channels,
            out_channels=num_filters,
            kernel_size=filter_size,
54 55 56 57
            stride=conv_stride,
            padding=conv_padding,
            dilation=conv_dilation,
            groups=conv_groups,
58
            weight_attr=None,
59 60
            bias_attr=None,
        )
W
wangzhen38 已提交
61 62 63 64
        self._pool2d = paddle.nn.MaxPool2D(
            kernel_size=pool_size,
            stride=pool_stride,
            padding=pool_padding,
65
        )
66

M
minqiyang 已提交
67
    def forward(self, inputs):
M
minqiyang 已提交
68 69 70
        x = self._conv2d(inputs)
        x = self._pool2d(x)
        return x
71 72


M
minqiyang 已提交
73
class MNIST(fluid.dygraph.Layer):
74
    def __init__(self):
75
        super().__init__()
76

77 78 79
        self._simple_img_conv_pool_1 = SimpleImgConvPool(
            1, 20, 5, 2, 2, act="relu"
        )
80

81 82 83
        self._simple_img_conv_pool_2 = SimpleImgConvPool(
            20, 50, 5, 2, 2, act="relu"
        )
M
minqiyang 已提交
84

85
        self.pool_2_shape = 50 * 4 * 4
M
minqiyang 已提交
86
        SIZE = 10
87 88 89 90
        scale = (2.0 / (self.pool_2_shape**2 * SIZE)) ** 0.5
        self._fc = Linear(
            self.pool_2_shape,
            10,
91 92
            weight_attr=paddle.ParamAttr(
                initializer=paddle.nn.initializer.Normal(mean=0.0, std=scale)
93 94
            ),
        )
M
minqiyang 已提交
95 96 97 98

    def forward(self, inputs):
        x = self._simple_img_conv_pool_1(inputs)
        x = self._simple_img_conv_pool_2(x)
99
        x = paddle.reshape(x, shape=[-1, self.pool_2_shape])
M
minqiyang 已提交
100
        x = self._fc(x)
101
        x = paddle.nn.functional.softmax(x)
M
minqiyang 已提交
102 103 104 105
        return x


class TestImperativeMnist(unittest.TestCase):
106 107 108 109 110 111 112 113 114
    def reader_decorator(self, reader):
        def _reader_imple():
            for item in reader():
                image = np.array(item[0]).reshape(1, 28, 28)
                label = np.array(item[1]).astype('int64').reshape(1)
                yield image, label

        return _reader_imple

115
    def test_mnist_float32(self):
116
        seed = 90
M
minqiyang 已提交
117
        epoch_num = 1
118 119 120
        batch_size = 128
        batch_num = 50

121 122
        traced_layer = None

M
minqiyang 已提交
123
        with fluid.dygraph.guard():
124 125 126
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

127
            mnist = MNIST()
128 129 130
            sgd = SGDOptimizer(
                learning_rate=1e-3, parameter_list=mnist.parameters()
            )
131 132 133

            batch_py_reader = fluid.io.PyReader(capacity=1)
            batch_py_reader.decorate_sample_list_generator(
134 135 136 137 138 139 140
                paddle.batch(
                    self.reader_decorator(paddle.dataset.mnist.train()),
                    batch_size=batch_size,
                    drop_last=True,
                ),
                places=fluid.CPUPlace(),
            )
141

M
minqiyang 已提交
142
            mnist.train()
143
            dy_param_init_value = {}
144

145 146
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
M
minqiyang 已提交
147
            for epoch in range(epoch_num):
148 149 150 151 152 153
                for batch_id, data in enumerate(batch_py_reader()):
                    if batch_id >= batch_num:
                        break
                    img = data[0]
                    dy_x_data = img.numpy()
                    label = data[1]
L
lujun 已提交
154
                    label.stop_gradient = True
155
                    cost = mnist(img)
156

157 158 159 160
                    if traced_layer is not None:
                        cost_static = traced_layer([img])
                        helper.assertEachVar(cost, cost_static)

161 162 163
                    loss = paddle.nn.functional.cross_entropy(
                        cost, label, reduction='none', use_softmax=False
                    )
164
                    avg_loss = paddle.mean(loss)
M
minqiyang 已提交
165

L
lujun 已提交
166
                    dy_out = avg_loss.numpy()
M
minqiyang 已提交
167 168 169

                    if epoch == 0 and batch_id == 0:
                        for param in mnist.parameters():
L
lujun 已提交
170
                            dy_param_init_value[param.name] = param.numpy()
M
minqiyang 已提交
171

L
lujun 已提交
172
                    avg_loss.backward()
M
minqiyang 已提交
173 174 175 176 177
                    sgd.minimize(avg_loss)
                    mnist.clear_gradients()

                    dy_param_value = {}
                    for param in mnist.parameters():
L
lujun 已提交
178
                        dy_param_value[param.name] = param.numpy()
179 180 181 182 183

        with new_program_scope():
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

184 185 186 187 188
            exe = fluid.Executor(
                fluid.CPUPlace()
                if not core.is_compiled_with_cuda()
                else fluid.CUDAPlace(0)
            )
189

190
            mnist = MNIST()
M
minqiyang 已提交
191
            sgd = SGDOptimizer(learning_rate=1e-3)
192 193 194 195 196 197
            train_reader = paddle.batch(
                paddle.dataset.mnist.train(),
                batch_size=batch_size,
                drop_last=True,
            )

G
GGBond8488 已提交
198 199 200 201 202
            img = paddle.static.data(
                name='pixel', shape=[-1, 1, 28, 28], dtype='float32'
            )
            label = paddle.static.data(
                name='label', shape=[-1, 1], dtype='int64'
203
            )
204
            cost = mnist(img)
205 206 207
            loss = paddle.nn.functional.cross_entropy(
                cost, label, reduction='none', use_softmax=False
            )
208
            avg_loss = paddle.mean(loss)
M
minqiyang 已提交
209
            sgd.minimize(avg_loss)
210 211 212 213

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
M
minqiyang 已提交
214
            for param in mnist.parameters():
215 216
                static_param_name_list.append(param.name)

217 218 219 220
            out = exe.run(
                fluid.default_startup_program(),
                fetch_list=static_param_name_list,
            )
221 222 223 224

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

M
minqiyang 已提交
225 226
            for epoch in range(epoch_num):
                for batch_id, data in enumerate(train_reader()):
227 228
                    if batch_id >= batch_num:
                        break
229 230 231 232 233 234 235 236
                    static_x_data = np.array(
                        [x[0].reshape(1, 28, 28) for x in data]
                    ).astype('float32')
                    y_data = (
                        np.array([x[1] for x in data])
                        .astype('int64')
                        .reshape([batch_size, 1])
                    )
M
minqiyang 已提交
237 238 239

                    fetch_list = [avg_loss.name]
                    fetch_list.extend(static_param_name_list)
240 241 242 243

                    if traced_layer is not None:
                        traced_layer([static_x_data])

244 245 246 247 248
                    out = exe.run(
                        fluid.default_main_program(),
                        feed={"pixel": static_x_data, "label": y_data},
                        fetch_list=fetch_list,
                    )
M
minqiyang 已提交
249 250 251 252

                    static_param_value = {}
                    static_out = out[0]
                    for i in range(1, len(out)):
253 254 255
                        static_param_value[static_param_name_list[i - 1]] = out[
                            i
                        ]
M
minqiyang 已提交
256

257 258 259
        np.testing.assert_allclose(
            dy_x_data.all(), static_x_data.all(), rtol=1e-05
        )
260

261
        for key, value in static_param_init_value.items():
262 263 264
            np.testing.assert_allclose(
                value, dy_param_init_value[key], rtol=1e-05
            )
M
minqiyang 已提交
265

266
        np.testing.assert_allclose(static_out, dy_out, rtol=1e-05)
M
minqiyang 已提交
267

268
        for key, value in static_param_value.items():
269 270 271
            np.testing.assert_allclose(
                value, dy_param_value[key], rtol=1e-05, atol=1e-05
            )
272 273 274


if __name__ == '__main__':
H
hong 已提交
275
    paddle.enable_static()
276
    unittest.main()