test_imperative_resnet.py 15.0 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

M
minqiyang 已提交
17
import numpy as np
18
from test_imperative_base import new_program_scope
19
from utils import DyGraphProgramDescTracerTestHelper
M
minqiyang 已提交
20 21 22

import paddle
import paddle.fluid as fluid
23
from paddle.fluid import core
L
lujun 已提交
24
from paddle.fluid.dygraph.base import to_variable
25
from paddle.fluid.layer_helper import LayerHelper
26
from paddle.nn import BatchNorm
M
minqiyang 已提交
27

28
# NOTE(zhiqiu): run with FLAGS_cudnn_deterministic=1
29

30
batch_size = 8
M
minqiyang 已提交
31 32 33 34 35 36
train_parameters = {
    "input_size": [3, 224, 224],
    "input_mean": [0.485, 0.456, 0.406],
    "input_std": [0.229, 0.224, 0.225],
    "learning_strategy": {
        "name": "piecewise_decay",
M
minqiyang 已提交
37
        "batch_size": batch_size,
M
minqiyang 已提交
38
        "epochs": [30, 60, 90],
39
        "steps": [0.1, 0.01, 0.001, 0.0001],
M
minqiyang 已提交
40
    },
M
minqiyang 已提交
41
    "batch_size": batch_size,
M
minqiyang 已提交
42 43
    "lr": 0.1,
    "total_images": 1281164,
M
minqiyang 已提交
44 45 46
}


47
def optimizer_setting(params, parameter_list=None):
M
minqiyang 已提交
48 49 50 51 52 53 54 55 56 57 58 59 60
    ls = params["learning_strategy"]
    if ls["name"] == "piecewise_decay":
        if "total_images" not in params:
            total_images = 1281167
        else:
            total_images = params["total_images"]
        batch_size = ls["batch_size"]
        step = int(total_images / batch_size + 1)

        bd = [step * e for e in ls["epochs"]]
        base_lr = params["lr"]
        lr = []
        lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
J
Jiabin Yang 已提交
61
        if fluid._non_static_mode():
62 63 64
            optimizer = fluid.optimizer.SGD(
                learning_rate=0.01, parameter_list=parameter_list
            )
65 66
        else:
            optimizer = fluid.optimizer.SGD(learning_rate=0.01)
L
lujun 已提交
67
        # TODO(minqiyang): Add learning rate scheduler support to dygraph mode
M
minqiyang 已提交
68
        #  optimizer = fluid.optimizer.Momentum(
69 70 71 72 73
        #  learning_rate=params["lr"],
        #  learning_rate=fluid.layers.piecewise_decay(
        #  boundaries=bd, values=lr),
        #  momentum=0.9,
        #  regularization=fluid.regularizer.L2Decay(1e-4))
M
minqiyang 已提交
74 75 76 77

    return optimizer


78
class ConvBNLayer(fluid.Layer):
79 80 81 82 83 84 85 86 87 88
    def __init__(
        self,
        num_channels,
        num_filters,
        filter_size,
        stride=1,
        groups=1,
        act=None,
        use_cudnn=False,
    ):
89
        super().__init__()
M
minqiyang 已提交
90

91 92 93 94
        self._conv = paddle.nn.Conv2D(
            in_channels=num_channels,
            out_channels=num_filters,
            kernel_size=filter_size,
95 96 97 98 99
            stride=stride,
            padding=(filter_size - 1) // 2,
            groups=groups,
            bias_attr=False,
        )
M
minqiyang 已提交
100

101
        self._batch_norm = BatchNorm(num_filters, act=act)
M
minqiyang 已提交
102 103 104

    def forward(self, inputs):
        y = self._conv(inputs)
105
        y = self._batch_norm(y)
M
minqiyang 已提交
106 107 108 109

        return y


110
class BottleneckBlock(fluid.Layer):
111 112 113
    def __init__(
        self, num_channels, num_filters, stride, shortcut=True, use_cudnn=False
    ):
114
        super().__init__()
M
minqiyang 已提交
115

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        self.conv0 = ConvBNLayer(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=1,
            act='relu',
            use_cudnn=use_cudnn,
        )
        self.conv1 = ConvBNLayer(
            num_channels=num_filters,
            num_filters=num_filters,
            filter_size=3,
            stride=stride,
            act='relu',
            use_cudnn=use_cudnn,
        )
        self.conv2 = ConvBNLayer(
            num_channels=num_filters,
            num_filters=num_filters * 4,
            filter_size=1,
            act=None,
            use_cudnn=use_cudnn,
        )
M
minqiyang 已提交
138

M
minqiyang 已提交
139
        if not shortcut:
140 141 142 143 144 145 146
            self.short = ConvBNLayer(
                num_channels=num_channels,
                num_filters=num_filters * 4,
                filter_size=1,
                stride=stride,
                use_cudnn=use_cudnn,
            )
M
minqiyang 已提交
147 148 149 150

        self.shortcut = shortcut

    def forward(self, inputs):
M
minqiyang 已提交
151 152 153
        y = self.conv0(inputs)
        conv1 = self.conv1(y)
        conv2 = self.conv2(conv1)
M
minqiyang 已提交
154 155

        if self.shortcut:
M
minqiyang 已提交
156 157 158
            short = inputs
        else:
            short = self.short(inputs)
M
minqiyang 已提交
159

160
        y = paddle.add(x=short, y=conv2)
M
minqiyang 已提交
161

X
Xin Pan 已提交
162
        layer_helper = LayerHelper(self.full_name(), act='relu')
M
minqiyang 已提交
163
        return layer_helper.append_activation(y)
M
minqiyang 已提交
164 165


166
class ResNet(fluid.Layer):
H
hong 已提交
167
    def __init__(self, layers=50, class_dim=102, use_cudnn=True):
168
        super().__init__()
M
minqiyang 已提交
169

M
minqiyang 已提交
170 171
        self.layers = layers
        supported_layers = [50, 101, 152]
172 173 174 175 176
        assert (
            layers in supported_layers
        ), "supported layers are {} but input layer is {}".format(
            supported_layers, layers
        )
M
minqiyang 已提交
177 178 179 180 181 182 183

        if layers == 50:
            depth = [3, 4, 6, 3]
        elif layers == 101:
            depth = [3, 4, 23, 3]
        elif layers == 152:
            depth = [3, 8, 36, 3]
184
        num_channels = [64, 256, 512, 1024]
M
minqiyang 已提交
185 186
        num_filters = [64, 128, 256, 512]

187 188 189 190 191 192 193 194
        self.conv = ConvBNLayer(
            num_channels=3,
            num_filters=64,
            filter_size=7,
            stride=2,
            act='relu',
            use_cudnn=use_cudnn,
        )
195 196
        self.pool2d_max = paddle.nn.MaxPool2D(
            kernel_size=3, stride=2, padding=1
197
        )
M
minqiyang 已提交
198

M
minqiyang 已提交
199 200 201 202
        self.bottleneck_block_list = []
        for block in range(len(depth)):
            shortcut = False
            for i in range(depth[block]):
X
Xin Pan 已提交
203 204
                bottleneck_block = self.add_sublayer(
                    'bb_%d_%d' % (block, i),
205 206 207 208 209 210 211 212 213 214
                    BottleneckBlock(
                        num_channels=num_channels[block]
                        if i == 0
                        else num_filters[block] * 4,
                        num_filters=num_filters[block],
                        stride=2 if i == 0 and block != 0 else 1,
                        shortcut=shortcut,
                        use_cudnn=use_cudnn,
                    ),
                )
M
minqiyang 已提交
215 216
                self.bottleneck_block_list.append(bottleneck_block)
                shortcut = True
W
wangzhen38 已提交
217
        self.pool2d_avg = paddle.nn.AdaptiveAvgPool2D(1)
M
minqiyang 已提交
218

219 220
        self.pool2d_avg_output = num_filters[-1] * 4 * 1 * 1

M
minqiyang 已提交
221
        import math
222

M
minqiyang 已提交
223 224
        stdv = 1.0 / math.sqrt(2048 * 1.0)

225
        self.out = paddle.nn.Linear(
226 227
            self.pool2d_avg_output,
            class_dim,
228
            weight_attr=fluid.param_attr.ParamAttr(
229 230 231
                initializer=fluid.initializer.Uniform(-stdv, stdv)
            ),
        )
M
minqiyang 已提交
232 233 234 235

    def forward(self, inputs):
        y = self.conv(inputs)
        y = self.pool2d_max(y)
M
minqiyang 已提交
236 237 238
        for bottleneck_block in self.bottleneck_block_list:
            y = bottleneck_block(y)
        y = self.pool2d_avg(y)
239
        y = paddle.reshape(y, shape=[-1, self.pool2d_avg_output])
M
minqiyang 已提交
240
        y = self.out(y)
241
        y = paddle.nn.functional.softmax(y)
M
minqiyang 已提交
242 243 244
        return y


L
lujun 已提交
245
class TestDygraphResnet(unittest.TestCase):
246 247 248 249 250 251 252 253 254
    def reader_decorator(self, reader):
        def _reader_imple():
            for item in reader():
                doc = np.array(item[0]).reshape(3, 224, 224)
                label = np.array(item[1]).astype('int64').reshape(1)
                yield doc, label

        return _reader_imple

255
    def test_resnet_float32(self):
M
minqiyang 已提交
256 257
        seed = 90

258
        batch_size = train_parameters["batch_size"]
259 260
        batch_num = 10

261 262
        traced_layer = None

L
lujun 已提交
263
        with fluid.dygraph.guard():
C
cnn 已提交
264
            paddle.seed(seed)
L
Leo Chen 已提交
265
            paddle.framework.random._manual_program_seed(seed)
266

267
            resnet = ResNet()
268 269 270
            optimizer = optimizer_setting(
                train_parameters, parameter_list=resnet.parameters()
            )
271
            np.random.seed(seed)
272

273 274
            train_reader = paddle.batch(
                paddle.dataset.flowers.train(use_xmap=False),
275 276
                batch_size=batch_size,
            )
277 278

            dy_param_init_value = {}
M
minqiyang 已提交
279
            for param in resnet.parameters():
280
                dy_param_init_value[param.name] = param.numpy()
281

282 283
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
284

285
            for batch_id, data in enumerate(train_reader()):
M
minqiyang 已提交
286
                if batch_id >= batch_num:
287 288
                    break

289 290 291 292 293 294 295 296
                dy_x_data = np.array(
                    [x[0].reshape(3, 224, 224) for x in data]
                ).astype('float32')
                y_data = (
                    np.array([x[1] for x in data])
                    .astype('int64')
                    .reshape(batch_size, 1)
                )
297 298 299

                img = to_variable(dy_x_data)
                label = to_variable(y_data)
300
                label.stop_gradient = True
301

302
                out = None
303
                out = resnet(img)
304

305 306 307
                if traced_layer is not None:
                    resnet.eval()
                    traced_layer._switch(is_test=True)
308
                    out_dygraph = resnet(img)
309 310 311 312 313
                    out_static = traced_layer([img])
                    traced_layer._switch(is_test=False)
                    helper.assertEachVar(out_dygraph, out_static)
                    resnet.train()

314 315 316
                loss = paddle.nn.functional.cross_entropy(
                    input=out, label=label, reduction='none', use_softmax=False
                )
317
                avg_loss = paddle.mean(x=loss)
318

319
                dy_out = avg_loss.numpy()
320 321

                if batch_id == 0:
M
minqiyang 已提交
322
                    for param in resnet.parameters():
323
                        if param.name not in dy_param_init_value:
324
                            dy_param_init_value[param.name] = param.numpy()
325

L
lujun 已提交
326
                avg_loss.backward()
327 328

                dy_grad_value = {}
M
minqiyang 已提交
329
                for param in resnet.parameters():
330
                    if param.trainable:
331
                        np_array = np.array(
332 333 334 335 336
                            param._grad_ivar().value().get_tensor()
                        )
                        dy_grad_value[
                            param.name + core.grad_var_suffix()
                        ] = np_array
337 338

                optimizer.minimize(avg_loss)
M
minqiyang 已提交
339
                resnet.clear_gradients()
340 341

                dy_param_value = {}
M
minqiyang 已提交
342
                for param in resnet.parameters():
343
                    dy_param_value[param.name] = param.numpy()
M
minqiyang 已提交
344 345

        with new_program_scope():
C
cnn 已提交
346
            paddle.seed(seed)
L
Leo Chen 已提交
347
            paddle.framework.random._manual_program_seed(seed)
M
minqiyang 已提交
348

349 350 351 352 353
            exe = fluid.Executor(
                fluid.CPUPlace()
                if not core.is_compiled_with_cuda()
                else fluid.CUDAPlace(0)
            )
354

355
            resnet = ResNet()
356
            optimizer = optimizer_setting(train_parameters)
M
minqiyang 已提交
357 358

            np.random.seed(seed)
359
            train_reader = paddle.batch(
M
minqiyang 已提交
360
                paddle.dataset.flowers.train(use_xmap=False),
361 362
                batch_size=batch_size,
            )
363

364 365 366
            img = fluid.layers.data(
                name='pixel', shape=[3, 224, 224], dtype='float32'
            )
367 368
            label = fluid.layers.data(name='label', shape=[1], dtype='int64')
            out = resnet(img)
369 370 371
            loss = paddle.nn.functional.cross_entropy(
                input=out, label=label, reduction='none', use_softmax=False
            )
372
            avg_loss = paddle.mean(x=loss)
373 374 375 376 377
            optimizer.minimize(avg_loss)

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
M
minqiyang 已提交
378
            static_grad_name_list = []
M
minqiyang 已提交
379
            for param in resnet.parameters():
380
                static_param_name_list.append(param.name)
M
minqiyang 已提交
381
            for param in resnet.parameters():
382
                if param.trainable:
383 384 385
                    static_grad_name_list.append(
                        param.name + core.grad_var_suffix()
                    )
386

387 388 389 390
            out = exe.run(
                fluid.default_startup_program(),
                fetch_list=static_param_name_list,
            )
391 392 393 394 395

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

            for batch_id, data in enumerate(train_reader()):
M
minqiyang 已提交
396
                if batch_id >= batch_num:
397 398
                    break

M
minqiyang 已提交
399
                static_x_data = np.array(
400 401 402 403 404 405 406
                    [x[0].reshape(3, 224, 224) for x in data]
                ).astype('float32')
                y_data = (
                    np.array([x[1] for x in data])
                    .astype('int64')
                    .reshape([batch_size, 1])
                )
407

408 409 410
                if traced_layer is not None:
                    traced_layer([static_x_data])

M
minqiyang 已提交
411
                fetch_list = [avg_loss.name]
412
                fetch_list.extend(static_param_name_list)
M
minqiyang 已提交
413
                fetch_list.extend(static_grad_name_list)
414 415 416 417 418
                out = exe.run(
                    fluid.default_main_program(),
                    feed={"pixel": static_x_data, "label": y_data},
                    fetch_list=fetch_list,
                )
419 420

                static_param_value = {}
M
minqiyang 已提交
421
                static_grad_value = {}
422
                static_out = out[0]
M
minqiyang 已提交
423 424
                param_start_pos = 1
                grad_start_pos = len(static_param_name_list) + param_start_pos
425 426 427 428 429 430 431 432 433 434 435 436 437
                for i in range(
                    param_start_pos,
                    len(static_param_name_list) + param_start_pos,
                ):
                    static_param_value[
                        static_param_name_list[i - param_start_pos]
                    ] = out[i]
                for i in range(
                    grad_start_pos, len(static_grad_name_list) + grad_start_pos
                ):
                    static_grad_value[
                        static_grad_name_list[i - grad_start_pos]
                    ] = out[i]
M
minqiyang 已提交
438

H
hong 已提交
439 440
        print("static", static_out)
        print("dygraph", dy_out)
441
        np.testing.assert_allclose(static_out, dy_out, rtol=1e-05)
M
minqiyang 已提交
442 443

        self.assertEqual(len(dy_param_init_value), len(static_param_init_value))
X
Xin Pan 已提交
444

445
        for key, value in static_param_init_value.items():
446 447 448
            np.testing.assert_allclose(
                value, dy_param_init_value[key], rtol=1e-05
            )
449 450
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
451

M
minqiyang 已提交
452
        self.assertEqual(len(dy_grad_value), len(static_grad_value))
453
        for key, value in static_grad_value.items():
454
            np.testing.assert_allclose(value, dy_grad_value[key], rtol=1e-05)
455 456
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
457

M
minqiyang 已提交
458
        self.assertEqual(len(dy_param_value), len(static_param_value))
459
        for key, value in static_param_value.items():
460
            np.testing.assert_allclose(value, dy_param_value[key], rtol=1e-05)
461 462
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
M
minqiyang 已提交
463 464 465


if __name__ == '__main__':
H
hong 已提交
466
    paddle.enable_static()
M
minqiyang 已提交
467
    unittest.main()