test_imperative_resnet.py 15.7 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
M
minqiyang 已提交
21
from paddle.fluid.layer_helper import LayerHelper
22
from paddle.fluid import Pool2D, BatchNorm, Linear
L
lujun 已提交
23
from paddle.fluid.dygraph.base import to_variable
M
minqiyang 已提交
24
from test_imperative_base import new_program_scope
25
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
26
from paddle.fluid.dygraph import TracedLayer
J
Jiabin Yang 已提交
27
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
M
minqiyang 已提交
28

29
# NOTE(zhiqiu): run with FLAGS_cudnn_deterministic=1
30

31
batch_size = 8
M
minqiyang 已提交
32 33 34 35 36 37
train_parameters = {
    "input_size": [3, 224, 224],
    "input_mean": [0.485, 0.456, 0.406],
    "input_std": [0.229, 0.224, 0.225],
    "learning_strategy": {
        "name": "piecewise_decay",
M
minqiyang 已提交
38
        "batch_size": batch_size,
M
minqiyang 已提交
39
        "epochs": [30, 60, 90],
40
        "steps": [0.1, 0.01, 0.001, 0.0001],
M
minqiyang 已提交
41
    },
M
minqiyang 已提交
42
    "batch_size": batch_size,
M
minqiyang 已提交
43 44
    "lr": 0.1,
    "total_images": 1281164,
M
minqiyang 已提交
45 46 47
}


48
def optimizer_setting(params, parameter_list=None):
M
minqiyang 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61
    ls = params["learning_strategy"]
    if ls["name"] == "piecewise_decay":
        if "total_images" not in params:
            total_images = 1281167
        else:
            total_images = params["total_images"]
        batch_size = ls["batch_size"]
        step = int(total_images / batch_size + 1)

        bd = [step * e for e in ls["epochs"]]
        base_lr = params["lr"]
        lr = []
        lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
J
Jiabin Yang 已提交
62
        if fluid._non_static_mode():
63 64 65
            optimizer = fluid.optimizer.SGD(
                learning_rate=0.01, parameter_list=parameter_list
            )
66 67
        else:
            optimizer = fluid.optimizer.SGD(learning_rate=0.01)
L
lujun 已提交
68
        # TODO(minqiyang): Add learning rate scheduler support to dygraph mode
M
minqiyang 已提交
69
        #  optimizer = fluid.optimizer.Momentum(
70 71 72 73 74
        #  learning_rate=params["lr"],
        #  learning_rate=fluid.layers.piecewise_decay(
        #  boundaries=bd, values=lr),
        #  momentum=0.9,
        #  regularization=fluid.regularizer.L2Decay(1e-4))
M
minqiyang 已提交
75 76 77 78

    return optimizer


79
class ConvBNLayer(fluid.Layer):
80 81 82 83 84 85 86 87 88 89
    def __init__(
        self,
        num_channels,
        num_filters,
        filter_size,
        stride=1,
        groups=1,
        act=None,
        use_cudnn=False,
    ):
90
        super().__init__()
M
minqiyang 已提交
91

92 93 94 95
        self._conv = paddle.nn.Conv2D(
            in_channels=num_channels,
            out_channels=num_filters,
            kernel_size=filter_size,
96 97 98 99 100
            stride=stride,
            padding=(filter_size - 1) // 2,
            groups=groups,
            bias_attr=False,
        )
M
minqiyang 已提交
101

102
        self._batch_norm = BatchNorm(num_filters, act=act)
M
minqiyang 已提交
103 104 105

    def forward(self, inputs):
        y = self._conv(inputs)
106
        y = self._batch_norm(y)
M
minqiyang 已提交
107 108 109 110

        return y


111
class BottleneckBlock(fluid.Layer):
112 113 114
    def __init__(
        self, num_channels, num_filters, stride, shortcut=True, use_cudnn=False
    ):
115
        super().__init__()
M
minqiyang 已提交
116

117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
        self.conv0 = ConvBNLayer(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=1,
            act='relu',
            use_cudnn=use_cudnn,
        )
        self.conv1 = ConvBNLayer(
            num_channels=num_filters,
            num_filters=num_filters,
            filter_size=3,
            stride=stride,
            act='relu',
            use_cudnn=use_cudnn,
        )
        self.conv2 = ConvBNLayer(
            num_channels=num_filters,
            num_filters=num_filters * 4,
            filter_size=1,
            act=None,
            use_cudnn=use_cudnn,
        )
M
minqiyang 已提交
139

M
minqiyang 已提交
140
        if not shortcut:
141 142 143 144 145 146 147
            self.short = ConvBNLayer(
                num_channels=num_channels,
                num_filters=num_filters * 4,
                filter_size=1,
                stride=stride,
                use_cudnn=use_cudnn,
            )
M
minqiyang 已提交
148 149 150 151

        self.shortcut = shortcut

    def forward(self, inputs):
M
minqiyang 已提交
152 153 154
        y = self.conv0(inputs)
        conv1 = self.conv1(y)
        conv2 = self.conv2(conv1)
M
minqiyang 已提交
155 156

        if self.shortcut:
M
minqiyang 已提交
157 158 159
            short = inputs
        else:
            short = self.short(inputs)
M
minqiyang 已提交
160

M
minqiyang 已提交
161 162
        y = fluid.layers.elementwise_add(x=short, y=conv2)

X
Xin Pan 已提交
163
        layer_helper = LayerHelper(self.full_name(), act='relu')
M
minqiyang 已提交
164
        return layer_helper.append_activation(y)
M
minqiyang 已提交
165 166


167
class ResNet(fluid.Layer):
H
hong 已提交
168
    def __init__(self, layers=50, class_dim=102, use_cudnn=True):
169
        super().__init__()
M
minqiyang 已提交
170

M
minqiyang 已提交
171 172
        self.layers = layers
        supported_layers = [50, 101, 152]
173 174 175 176 177
        assert (
            layers in supported_layers
        ), "supported layers are {} but input layer is {}".format(
            supported_layers, layers
        )
M
minqiyang 已提交
178 179 180 181 182 183 184

        if layers == 50:
            depth = [3, 4, 6, 3]
        elif layers == 101:
            depth = [3, 4, 23, 3]
        elif layers == 152:
            depth = [3, 8, 36, 3]
185
        num_channels = [64, 256, 512, 1024]
M
minqiyang 已提交
186 187
        num_filters = [64, 128, 256, 512]

188 189 190 191 192 193 194 195 196 197 198
        self.conv = ConvBNLayer(
            num_channels=3,
            num_filters=64,
            filter_size=7,
            stride=2,
            act='relu',
            use_cudnn=use_cudnn,
        )
        self.pool2d_max = Pool2D(
            pool_size=3, pool_stride=2, pool_padding=1, pool_type='max'
        )
M
minqiyang 已提交
199

M
minqiyang 已提交
200 201 202 203
        self.bottleneck_block_list = []
        for block in range(len(depth)):
            shortcut = False
            for i in range(depth[block]):
X
Xin Pan 已提交
204 205
                bottleneck_block = self.add_sublayer(
                    'bb_%d_%d' % (block, i),
206 207 208 209 210 211 212 213 214 215
                    BottleneckBlock(
                        num_channels=num_channels[block]
                        if i == 0
                        else num_filters[block] * 4,
                        num_filters=num_filters[block],
                        stride=2 if i == 0 and block != 0 else 1,
                        shortcut=shortcut,
                        use_cudnn=use_cudnn,
                    ),
                )
M
minqiyang 已提交
216 217 218
                self.bottleneck_block_list.append(bottleneck_block)
                shortcut = True

219 220 221
        self.pool2d_avg = Pool2D(
            pool_size=7, pool_type='avg', global_pooling=True
        )
M
minqiyang 已提交
222

223 224
        self.pool2d_avg_output = num_filters[-1] * 4 * 1 * 1

M
minqiyang 已提交
225
        import math
226

M
minqiyang 已提交
227 228
        stdv = 1.0 / math.sqrt(2048 * 1.0)

229 230 231 232 233
        self.out = Linear(
            self.pool2d_avg_output,
            class_dim,
            act='softmax',
            param_attr=fluid.param_attr.ParamAttr(
234 235 236
                initializer=fluid.initializer.Uniform(-stdv, stdv)
            ),
        )
M
minqiyang 已提交
237 238 239 240

    def forward(self, inputs):
        y = self.conv(inputs)
        y = self.pool2d_max(y)
M
minqiyang 已提交
241 242 243
        for bottleneck_block in self.bottleneck_block_list:
            y = bottleneck_block(y)
        y = self.pool2d_avg(y)
244
        y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
M
minqiyang 已提交
245
        y = self.out(y)
M
minqiyang 已提交
246 247 248
        return y


L
lujun 已提交
249
class TestDygraphResnet(unittest.TestCase):
250 251 252 253 254 255 256 257 258
    def reader_decorator(self, reader):
        def _reader_imple():
            for item in reader():
                doc = np.array(item[0]).reshape(3, 224, 224)
                label = np.array(item[1]).astype('int64').reshape(1)
                yield doc, label

        return _reader_imple

259
    def func_test_resnet_float32(self):
M
minqiyang 已提交
260 261
        seed = 90

262
        batch_size = train_parameters["batch_size"]
263 264
        batch_num = 10

265 266
        traced_layer = None

L
lujun 已提交
267
        with fluid.dygraph.guard():
C
cnn 已提交
268
            paddle.seed(seed)
L
Leo Chen 已提交
269
            paddle.framework.random._manual_program_seed(seed)
270

271
            resnet = ResNet()
272 273 274
            optimizer = optimizer_setting(
                train_parameters, parameter_list=resnet.parameters()
            )
275
            np.random.seed(seed)
276

277 278
            train_reader = paddle.batch(
                paddle.dataset.flowers.train(use_xmap=False),
279 280
                batch_size=batch_size,
            )
281 282

            dy_param_init_value = {}
M
minqiyang 已提交
283
            for param in resnet.parameters():
284
                dy_param_init_value[param.name] = param.numpy()
285

286 287
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
288

289
            for batch_id, data in enumerate(train_reader()):
M
minqiyang 已提交
290
                if batch_id >= batch_num:
291 292
                    break

293 294 295 296 297 298 299 300
                dy_x_data = np.array(
                    [x[0].reshape(3, 224, 224) for x in data]
                ).astype('float32')
                y_data = (
                    np.array([x[1] for x in data])
                    .astype('int64')
                    .reshape(batch_size, 1)
                )
301 302 303

                img = to_variable(dy_x_data)
                label = to_variable(y_data)
304
                label.stop_gradient = True
305

306
                out = None
J
Jiabin Yang 已提交
307
                if batch_id % 5 == 0 and _in_legacy_dygraph():
308 309 310
                    out, traced_layer = TracedLayer.trace(resnet, img)
                    if program is not None:
                        self.assertTrue(
311 312
                            is_equal_program(program, traced_layer.program)
                        )
313 314

                    traced_layer.save_inference_model(
315 316
                        './infer_imperative_resnet'
                    )
317 318

                    program = traced_layer.program
319 320 321
                else:
                    out = resnet(img)

322 323 324
                if traced_layer is not None:
                    resnet.eval()
                    traced_layer._switch(is_test=True)
325
                    out_dygraph = resnet(img)
326 327 328 329 330
                    out_static = traced_layer([img])
                    traced_layer._switch(is_test=False)
                    helper.assertEachVar(out_dygraph, out_static)
                    resnet.train()

331
                loss = fluid.layers.cross_entropy(input=out, label=label)
332
                avg_loss = paddle.mean(x=loss)
333

334
                dy_out = avg_loss.numpy()
335 336

                if batch_id == 0:
M
minqiyang 已提交
337
                    for param in resnet.parameters():
338
                        if param.name not in dy_param_init_value:
339
                            dy_param_init_value[param.name] = param.numpy()
340

L
lujun 已提交
341
                avg_loss.backward()
342 343

                dy_grad_value = {}
M
minqiyang 已提交
344
                for param in resnet.parameters():
345
                    if param.trainable:
346
                        np_array = np.array(
347 348 349 350 351
                            param._grad_ivar().value().get_tensor()
                        )
                        dy_grad_value[
                            param.name + core.grad_var_suffix()
                        ] = np_array
352 353

                optimizer.minimize(avg_loss)
M
minqiyang 已提交
354
                resnet.clear_gradients()
355 356

                dy_param_value = {}
M
minqiyang 已提交
357
                for param in resnet.parameters():
358
                    dy_param_value[param.name] = param.numpy()
M
minqiyang 已提交
359 360

        with new_program_scope():
C
cnn 已提交
361
            paddle.seed(seed)
L
Leo Chen 已提交
362
            paddle.framework.random._manual_program_seed(seed)
M
minqiyang 已提交
363

364 365 366 367 368
            exe = fluid.Executor(
                fluid.CPUPlace()
                if not core.is_compiled_with_cuda()
                else fluid.CUDAPlace(0)
            )
369

370
            resnet = ResNet()
371
            optimizer = optimizer_setting(train_parameters)
M
minqiyang 已提交
372 373

            np.random.seed(seed)
374
            train_reader = paddle.batch(
M
minqiyang 已提交
375
                paddle.dataset.flowers.train(use_xmap=False),
376 377
                batch_size=batch_size,
            )
378

379 380 381
            img = fluid.layers.data(
                name='pixel', shape=[3, 224, 224], dtype='float32'
            )
382 383 384
            label = fluid.layers.data(name='label', shape=[1], dtype='int64')
            out = resnet(img)
            loss = fluid.layers.cross_entropy(input=out, label=label)
385
            avg_loss = paddle.mean(x=loss)
386 387 388 389 390
            optimizer.minimize(avg_loss)

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
M
minqiyang 已提交
391
            static_grad_name_list = []
M
minqiyang 已提交
392
            for param in resnet.parameters():
393
                static_param_name_list.append(param.name)
M
minqiyang 已提交
394
            for param in resnet.parameters():
395
                if param.trainable:
396 397 398
                    static_grad_name_list.append(
                        param.name + core.grad_var_suffix()
                    )
399

400 401 402 403
            out = exe.run(
                fluid.default_startup_program(),
                fetch_list=static_param_name_list,
            )
404 405 406 407 408

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

            for batch_id, data in enumerate(train_reader()):
M
minqiyang 已提交
409
                if batch_id >= batch_num:
410 411
                    break

M
minqiyang 已提交
412
                static_x_data = np.array(
413 414 415 416 417 418 419
                    [x[0].reshape(3, 224, 224) for x in data]
                ).astype('float32')
                y_data = (
                    np.array([x[1] for x in data])
                    .astype('int64')
                    .reshape([batch_size, 1])
                )
420

421 422 423
                if traced_layer is not None:
                    traced_layer([static_x_data])

M
minqiyang 已提交
424
                fetch_list = [avg_loss.name]
425
                fetch_list.extend(static_param_name_list)
M
minqiyang 已提交
426
                fetch_list.extend(static_grad_name_list)
427 428 429 430 431
                out = exe.run(
                    fluid.default_main_program(),
                    feed={"pixel": static_x_data, "label": y_data},
                    fetch_list=fetch_list,
                )
432 433

                static_param_value = {}
M
minqiyang 已提交
434
                static_grad_value = {}
435
                static_out = out[0]
M
minqiyang 已提交
436 437
                param_start_pos = 1
                grad_start_pos = len(static_param_name_list) + param_start_pos
438 439 440 441 442 443 444 445 446 447 448 449 450
                for i in range(
                    param_start_pos,
                    len(static_param_name_list) + param_start_pos,
                ):
                    static_param_value[
                        static_param_name_list[i - param_start_pos]
                    ] = out[i]
                for i in range(
                    grad_start_pos, len(static_grad_name_list) + grad_start_pos
                ):
                    static_grad_value[
                        static_grad_name_list[i - grad_start_pos]
                    ] = out[i]
M
minqiyang 已提交
451

H
hong 已提交
452 453
        print("static", static_out)
        print("dygraph", dy_out)
454
        np.testing.assert_allclose(static_out, dy_out, rtol=1e-05)
M
minqiyang 已提交
455 456

        self.assertEqual(len(dy_param_init_value), len(static_param_init_value))
X
Xin Pan 已提交
457

458
        for key, value in static_param_init_value.items():
459 460 461
            np.testing.assert_allclose(
                value, dy_param_init_value[key], rtol=1e-05
            )
462 463
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
464

M
minqiyang 已提交
465
        self.assertEqual(len(dy_grad_value), len(static_grad_value))
466
        for key, value in static_grad_value.items():
467
            np.testing.assert_allclose(value, dy_grad_value[key], rtol=1e-05)
468 469
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
470

M
minqiyang 已提交
471
        self.assertEqual(len(dy_param_value), len(static_param_value))
472
        for key, value in static_param_value.items():
473
            np.testing.assert_allclose(value, dy_param_value[key], rtol=1e-05)
474 475
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
M
minqiyang 已提交
476

477 478 479 480 481
    def test_resnet_float32(self):
        with _test_eager_guard():
            self.func_test_resnet_float32()
        self.func_test_resnet_float32()

M
minqiyang 已提交
482 483

if __name__ == '__main__':
H
hong 已提交
484
    paddle.enable_static()
M
minqiyang 已提交
485
    unittest.main()