test_imperative_resnet.py 16.2 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import unittest
import numpy as np
import six

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
M
minqiyang 已提交
23
from paddle.fluid.layer_helper import LayerHelper
24
from paddle.fluid import Conv2D, Pool2D, BatchNorm, Linear
L
lujun 已提交
25
from paddle.fluid.dygraph.base import to_variable
M
minqiyang 已提交
26
from test_imperative_base import new_program_scope
27
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
28
from paddle.fluid.dygraph import TracedLayer
J
Jiabin Yang 已提交
29
from paddle.fluid.framework import _test_eager_guard, _in_legacy_dygraph
M
minqiyang 已提交
30

31 32
#NOTE(zhiqiu): run with FLAGS_cudnn_deterministic=1

33
batch_size = 8
M
minqiyang 已提交
34 35 36 37 38 39
train_parameters = {
    "input_size": [3, 224, 224],
    "input_mean": [0.485, 0.456, 0.406],
    "input_std": [0.229, 0.224, 0.225],
    "learning_strategy": {
        "name": "piecewise_decay",
M
minqiyang 已提交
40
        "batch_size": batch_size,
M
minqiyang 已提交
41 42
        "epochs": [30, 60, 90],
        "steps": [0.1, 0.01, 0.001, 0.0001]
M
minqiyang 已提交
43
    },
M
minqiyang 已提交
44
    "batch_size": batch_size,
M
minqiyang 已提交
45 46
    "lr": 0.1,
    "total_images": 1281164,
M
minqiyang 已提交
47 48 49
}


50
def optimizer_setting(params, parameter_list=None):
M
minqiyang 已提交
51 52 53 54 55 56 57 58 59 60 61 62 63
    ls = params["learning_strategy"]
    if ls["name"] == "piecewise_decay":
        if "total_images" not in params:
            total_images = 1281167
        else:
            total_images = params["total_images"]
        batch_size = ls["batch_size"]
        step = int(total_images / batch_size + 1)

        bd = [step * e for e in ls["epochs"]]
        base_lr = params["lr"]
        lr = []
        lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
J
Jiabin Yang 已提交
64
        if fluid._non_static_mode():
65 66 67 68
            optimizer = fluid.optimizer.SGD(learning_rate=0.01,
                                            parameter_list=parameter_list)
        else:
            optimizer = fluid.optimizer.SGD(learning_rate=0.01)
L
lujun 已提交
69
        # TODO(minqiyang): Add learning rate scheduler support to dygraph mode
M
minqiyang 已提交
70
        #  optimizer = fluid.optimizer.Momentum(
71 72 73 74 75
        #  learning_rate=params["lr"],
        #  learning_rate=fluid.layers.piecewise_decay(
        #  boundaries=bd, values=lr),
        #  momentum=0.9,
        #  regularization=fluid.regularizer.L2Decay(1e-4))
M
minqiyang 已提交
76 77 78 79

    return optimizer


80
class ConvBNLayer(fluid.Layer):
81

M
minqiyang 已提交
82
    def __init__(self,
83
                 num_channels,
M
minqiyang 已提交
84 85 86 87
                 num_filters,
                 filter_size,
                 stride=1,
                 groups=1,
88 89
                 act=None,
                 use_cudnn=False):
90
        super(ConvBNLayer, self).__init__()
M
minqiyang 已提交
91

92 93 94 95 96 97 98 99 100
        self._conv = Conv2D(num_channels=num_channels,
                            num_filters=num_filters,
                            filter_size=filter_size,
                            stride=stride,
                            padding=(filter_size - 1) // 2,
                            groups=groups,
                            act=None,
                            bias_attr=False,
                            use_cudnn=use_cudnn)
M
minqiyang 已提交
101

102
        self._batch_norm = BatchNorm(num_filters, act=act)
M
minqiyang 已提交
103 104 105

    def forward(self, inputs):
        y = self._conv(inputs)
106
        y = self._batch_norm(y)
M
minqiyang 已提交
107 108 109 110

        return y


111
class BottleneckBlock(fluid.Layer):
112

113 114 115 116 117 118
    def __init__(self,
                 num_channels,
                 num_filters,
                 stride,
                 shortcut=True,
                 use_cudnn=False):
119
        super(BottleneckBlock, self).__init__()
M
minqiyang 已提交
120

121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        self.conv0 = ConvBNLayer(num_channels=num_channels,
                                 num_filters=num_filters,
                                 filter_size=1,
                                 act='relu',
                                 use_cudnn=use_cudnn)
        self.conv1 = ConvBNLayer(num_channels=num_filters,
                                 num_filters=num_filters,
                                 filter_size=3,
                                 stride=stride,
                                 act='relu',
                                 use_cudnn=use_cudnn)
        self.conv2 = ConvBNLayer(num_channels=num_filters,
                                 num_filters=num_filters * 4,
                                 filter_size=1,
                                 act=None,
                                 use_cudnn=use_cudnn)
M
minqiyang 已提交
137

M
minqiyang 已提交
138
        if not shortcut:
139 140 141 142 143
            self.short = ConvBNLayer(num_channels=num_channels,
                                     num_filters=num_filters * 4,
                                     filter_size=1,
                                     stride=stride,
                                     use_cudnn=use_cudnn)
M
minqiyang 已提交
144 145 146 147

        self.shortcut = shortcut

    def forward(self, inputs):
M
minqiyang 已提交
148 149 150
        y = self.conv0(inputs)
        conv1 = self.conv1(y)
        conv2 = self.conv2(conv1)
M
minqiyang 已提交
151 152

        if self.shortcut:
M
minqiyang 已提交
153 154 155
            short = inputs
        else:
            short = self.short(inputs)
M
minqiyang 已提交
156

M
minqiyang 已提交
157 158
        y = fluid.layers.elementwise_add(x=short, y=conv2)

X
Xin Pan 已提交
159
        layer_helper = LayerHelper(self.full_name(), act='relu')
M
minqiyang 已提交
160
        return layer_helper.append_activation(y)
M
minqiyang 已提交
161 162


163
class ResNet(fluid.Layer):
164

H
hong 已提交
165
    def __init__(self, layers=50, class_dim=102, use_cudnn=True):
166
        super(ResNet, self).__init__()
M
minqiyang 已提交
167

M
minqiyang 已提交
168 169 170 171 172 173 174 175 176 177 178
        self.layers = layers
        supported_layers = [50, 101, 152]
        assert layers in supported_layers, \
            "supported layers are {} but input layer is {}".format(supported_layers, layers)

        if layers == 50:
            depth = [3, 4, 6, 3]
        elif layers == 101:
            depth = [3, 4, 23, 3]
        elif layers == 152:
            depth = [3, 8, 36, 3]
179
        num_channels = [64, 256, 512, 1024]
M
minqiyang 已提交
180 181
        num_filters = [64, 128, 256, 512]

182 183 184 185 186 187 188 189 190 191
        self.conv = ConvBNLayer(num_channels=3,
                                num_filters=64,
                                filter_size=7,
                                stride=2,
                                act='relu',
                                use_cudnn=use_cudnn)
        self.pool2d_max = Pool2D(pool_size=3,
                                 pool_stride=2,
                                 pool_padding=1,
                                 pool_type='max')
M
minqiyang 已提交
192

M
minqiyang 已提交
193 194 195 196
        self.bottleneck_block_list = []
        for block in range(len(depth)):
            shortcut = False
            for i in range(depth[block]):
X
Xin Pan 已提交
197 198
                bottleneck_block = self.add_sublayer(
                    'bb_%d_%d' % (block, i),
199 200 201 202 203 204
                    BottleneckBlock(num_channels=num_channels[block]
                                    if i == 0 else num_filters[block] * 4,
                                    num_filters=num_filters[block],
                                    stride=2 if i == 0 and block != 0 else 1,
                                    shortcut=shortcut,
                                    use_cudnn=use_cudnn))
M
minqiyang 已提交
205 206 207
                self.bottleneck_block_list.append(bottleneck_block)
                shortcut = True

208 209 210
        self.pool2d_avg = Pool2D(pool_size=7,
                                 pool_type='avg',
                                 global_pooling=True)
M
minqiyang 已提交
211

212 213
        self.pool2d_avg_output = num_filters[-1] * 4 * 1 * 1

M
minqiyang 已提交
214 215 216
        import math
        stdv = 1.0 / math.sqrt(2048 * 1.0)

217 218 219 220 221 222
        self.out = Linear(
            self.pool2d_avg_output,
            class_dim,
            act='softmax',
            param_attr=fluid.param_attr.ParamAttr(
                initializer=fluid.initializer.Uniform(-stdv, stdv)))
M
minqiyang 已提交
223 224 225 226

    def forward(self, inputs):
        y = self.conv(inputs)
        y = self.pool2d_max(y)
M
minqiyang 已提交
227 228 229
        for bottleneck_block in self.bottleneck_block_list:
            y = bottleneck_block(y)
        y = self.pool2d_avg(y)
230
        y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
M
minqiyang 已提交
231
        y = self.out(y)
M
minqiyang 已提交
232 233 234
        return y


L
lujun 已提交
235
class TestDygraphResnet(unittest.TestCase):
236

237
    def reader_decorator(self, reader):
238

239 240 241 242 243 244 245 246
        def _reader_imple():
            for item in reader():
                doc = np.array(item[0]).reshape(3, 224, 224)
                label = np.array(item[1]).astype('int64').reshape(1)
                yield doc, label

        return _reader_imple

247
    def func_test_resnet_float32(self):
M
minqiyang 已提交
248 249
        seed = 90

250
        batch_size = train_parameters["batch_size"]
251 252
        batch_num = 10

253 254
        traced_layer = None

L
lujun 已提交
255
        with fluid.dygraph.guard():
C
cnn 已提交
256
            paddle.seed(seed)
L
Leo Chen 已提交
257
            paddle.framework.random._manual_program_seed(seed)
258

259
            resnet = ResNet()
260 261
            optimizer = optimizer_setting(train_parameters,
                                          parameter_list=resnet.parameters())
262
            np.random.seed(seed)
263

264 265 266
            train_reader = paddle.batch(
                paddle.dataset.flowers.train(use_xmap=False),
                batch_size=batch_size)
267 268

            dy_param_init_value = {}
M
minqiyang 已提交
269
            for param in resnet.parameters():
270
                dy_param_init_value[param.name] = param.numpy()
271

272 273
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
274

275
            for batch_id, data in enumerate(train_reader()):
M
minqiyang 已提交
276
                if batch_id >= batch_num:
277 278
                    break

279 280 281 282
                dy_x_data = np.array([x[0].reshape(3, 224, 224)
                                      for x in data]).astype('float32')
                y_data = np.array([x[1] for x in data
                                   ]).astype('int64').reshape(batch_size, 1)
283 284 285

                img = to_variable(dy_x_data)
                label = to_variable(y_data)
286
                label.stop_gradient = True
287

288
                out = None
J
Jiabin Yang 已提交
289
                if batch_id % 5 == 0 and _in_legacy_dygraph():
290 291 292 293 294 295 296 297 298
                    out, traced_layer = TracedLayer.trace(resnet, img)
                    if program is not None:
                        self.assertTrue(
                            is_equal_program(program, traced_layer.program))

                    traced_layer.save_inference_model(
                        './infer_imperative_resnet')

                    program = traced_layer.program
299 300 301
                else:
                    out = resnet(img)

302 303 304
                if traced_layer is not None:
                    resnet.eval()
                    traced_layer._switch(is_test=True)
305
                    out_dygraph = resnet(img)
306 307 308 309 310
                    out_static = traced_layer([img])
                    traced_layer._switch(is_test=False)
                    helper.assertEachVar(out_dygraph, out_static)
                    resnet.train()

311
                loss = fluid.layers.cross_entropy(input=out, label=label)
312
                avg_loss = paddle.mean(x=loss)
313

314
                dy_out = avg_loss.numpy()
315 316

                if batch_id == 0:
M
minqiyang 已提交
317
                    for param in resnet.parameters():
318
                        if param.name not in dy_param_init_value:
319
                            dy_param_init_value[param.name] = param.numpy()
320

L
lujun 已提交
321
                avg_loss.backward()
322 323

                dy_grad_value = {}
M
minqiyang 已提交
324
                for param in resnet.parameters():
325
                    if param.trainable:
326 327 328 329
                        np_array = np.array(
                            param._grad_ivar().value().get_tensor())
                        dy_grad_value[param.name +
                                      core.grad_var_suffix()] = np_array
330 331

                optimizer.minimize(avg_loss)
M
minqiyang 已提交
332
                resnet.clear_gradients()
333 334

                dy_param_value = {}
M
minqiyang 已提交
335
                for param in resnet.parameters():
336
                    dy_param_value[param.name] = param.numpy()
M
minqiyang 已提交
337 338

        with new_program_scope():
C
cnn 已提交
339
            paddle.seed(seed)
L
Leo Chen 已提交
340
            paddle.framework.random._manual_program_seed(seed)
M
minqiyang 已提交
341

M
minqiyang 已提交
342 343
            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
344

345
            resnet = ResNet()
346
            optimizer = optimizer_setting(train_parameters)
M
minqiyang 已提交
347 348

            np.random.seed(seed)
349
            train_reader = paddle.batch(
M
minqiyang 已提交
350 351
                paddle.dataset.flowers.train(use_xmap=False),
                batch_size=batch_size)
352

353 354 355
            img = fluid.layers.data(name='pixel',
                                    shape=[3, 224, 224],
                                    dtype='float32')
356 357 358
            label = fluid.layers.data(name='label', shape=[1], dtype='int64')
            out = resnet(img)
            loss = fluid.layers.cross_entropy(input=out, label=label)
359
            avg_loss = paddle.mean(x=loss)
360 361 362 363 364
            optimizer.minimize(avg_loss)

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
M
minqiyang 已提交
365
            static_grad_name_list = []
M
minqiyang 已提交
366
            for param in resnet.parameters():
367
                static_param_name_list.append(param.name)
M
minqiyang 已提交
368
            for param in resnet.parameters():
369
                if param.trainable:
M
minqiyang 已提交
370 371
                    static_grad_name_list.append(param.name +
                                                 core.grad_var_suffix())
372 373 374 375 376 377 378 379

            out = exe.run(fluid.default_startup_program(),
                          fetch_list=static_param_name_list)

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

            for batch_id, data in enumerate(train_reader()):
M
minqiyang 已提交
380
                if batch_id >= batch_num:
381 382
                    break

M
minqiyang 已提交
383
                static_x_data = np.array(
384
                    [x[0].reshape(3, 224, 224) for x in data]).astype('float32')
385 386
                y_data = np.array([x[1] for x in data
                                   ]).astype('int64').reshape([batch_size, 1])
387

388 389 390
                if traced_layer is not None:
                    traced_layer([static_x_data])

M
minqiyang 已提交
391
                fetch_list = [avg_loss.name]
392
                fetch_list.extend(static_param_name_list)
M
minqiyang 已提交
393
                fetch_list.extend(static_grad_name_list)
394
                out = exe.run(fluid.default_main_program(),
395 396 397 398
                              feed={
                                  "pixel": static_x_data,
                                  "label": y_data
                              },
399 400 401
                              fetch_list=fetch_list)

                static_param_value = {}
M
minqiyang 已提交
402
                static_grad_value = {}
403
                static_out = out[0]
M
minqiyang 已提交
404 405 406 407 408 409 410 411 412 413 414
                param_start_pos = 1
                grad_start_pos = len(static_param_name_list) + param_start_pos
                for i in range(param_start_pos,
                               len(static_param_name_list) + param_start_pos):
                    static_param_value[static_param_name_list[
                        i - param_start_pos]] = out[i]
                for i in range(grad_start_pos,
                               len(static_grad_name_list) + grad_start_pos):
                    static_grad_value[static_grad_name_list[
                        i - grad_start_pos]] = out[i]

H
hong 已提交
415 416
        print("static", static_out)
        print("dygraph", dy_out)
M
minqiyang 已提交
417 418 419
        self.assertTrue(np.allclose(static_out, dy_out))

        self.assertEqual(len(dy_param_init_value), len(static_param_init_value))
X
Xin Pan 已提交
420

M
minqiyang 已提交
421 422
        for key, value in six.iteritems(static_param_init_value):
            self.assertTrue(np.allclose(value, dy_param_init_value[key]))
423 424
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
425

M
minqiyang 已提交
426
        self.assertEqual(len(dy_grad_value), len(static_grad_value))
M
minqiyang 已提交
427
        for key, value in six.iteritems(static_grad_value):
M
minqiyang 已提交
428
            self.assertTrue(np.allclose(value, dy_grad_value[key]))
429 430
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
431

M
minqiyang 已提交
432
        self.assertEqual(len(dy_param_value), len(static_param_value))
M
minqiyang 已提交
433
        for key, value in six.iteritems(static_param_value):
434 435 436
            self.assertTrue(np.allclose(value, dy_param_value[key]))
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
M
minqiyang 已提交
437

438 439 440 441 442
    def test_resnet_float32(self):
        with _test_eager_guard():
            self.func_test_resnet_float32()
        self.func_test_resnet_float32()

M
minqiyang 已提交
443 444

if __name__ == '__main__':
H
hong 已提交
445
    paddle.enable_static()
M
minqiyang 已提交
446
    unittest.main()