test_imperative_resnet.py 14.8 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import unittest
import numpy as np
import six

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
M
minqiyang 已提交
23
from paddle.fluid.layer_helper import LayerHelper
24
from paddle.fluid import Conv2D, Pool2D, BatchNorm, Linear
L
lujun 已提交
25
from paddle.fluid.dygraph.base import to_variable
M
minqiyang 已提交
26
from test_imperative_base import new_program_scope
27
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
28
from paddle.fluid.dygraph import TracedLayer
M
minqiyang 已提交
29

30 31
#NOTE(zhiqiu): run with FLAGS_cudnn_deterministic=1

32
batch_size = 8
M
minqiyang 已提交
33 34 35 36 37 38
train_parameters = {
    "input_size": [3, 224, 224],
    "input_mean": [0.485, 0.456, 0.406],
    "input_std": [0.229, 0.224, 0.225],
    "learning_strategy": {
        "name": "piecewise_decay",
M
minqiyang 已提交
39
        "batch_size": batch_size,
M
minqiyang 已提交
40 41
        "epochs": [30, 60, 90],
        "steps": [0.1, 0.01, 0.001, 0.0001]
M
minqiyang 已提交
42
    },
M
minqiyang 已提交
43
    "batch_size": batch_size,
M
minqiyang 已提交
44 45
    "lr": 0.1,
    "total_images": 1281164,
M
minqiyang 已提交
46 47 48
}


49
def optimizer_setting(params, parameter_list=None):
M
minqiyang 已提交
50 51 52 53 54 55 56 57 58 59 60 61 62
    ls = params["learning_strategy"]
    if ls["name"] == "piecewise_decay":
        if "total_images" not in params:
            total_images = 1281167
        else:
            total_images = params["total_images"]
        batch_size = ls["batch_size"]
        step = int(total_images / batch_size + 1)

        bd = [step * e for e in ls["epochs"]]
        base_lr = params["lr"]
        lr = []
        lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
63 64 65 66 67
        if fluid.in_dygraph_mode():
            optimizer = fluid.optimizer.SGD(learning_rate=0.01,
                                            parameter_list=parameter_list)
        else:
            optimizer = fluid.optimizer.SGD(learning_rate=0.01)
L
lujun 已提交
68
        # TODO(minqiyang): Add learning rate scheduler support to dygraph mode
M
minqiyang 已提交
69
        #  optimizer = fluid.optimizer.Momentum(
70 71 72 73 74
        #  learning_rate=params["lr"],
        #  learning_rate=fluid.layers.piecewise_decay(
        #  boundaries=bd, values=lr),
        #  momentum=0.9,
        #  regularization=fluid.regularizer.L2Decay(1e-4))
M
minqiyang 已提交
75 76 77 78

    return optimizer


79
class ConvBNLayer(fluid.Layer):
M
minqiyang 已提交
80
    def __init__(self,
81
                 num_channels,
M
minqiyang 已提交
82 83 84 85 86
                 num_filters,
                 filter_size,
                 stride=1,
                 groups=1,
                 act=None):
87
        super(ConvBNLayer, self).__init__()
M
minqiyang 已提交
88 89

        self._conv = Conv2D(
90
            num_channels=num_channels,
M
minqiyang 已提交
91 92 93 94
            num_filters=num_filters,
            filter_size=filter_size,
            stride=stride,
            padding=(filter_size - 1) // 2,
M
minqiyang 已提交
95 96
            groups=groups,
            act=None,
H
hong 已提交
97 98
            bias_attr=None,
            use_cudnn=False)
M
minqiyang 已提交
99

100
        self._batch_norm = BatchNorm(num_filters, act=act)
M
minqiyang 已提交
101 102 103

    def forward(self, inputs):
        y = self._conv(inputs)
104
        y = self._batch_norm(y)
M
minqiyang 已提交
105 106 107 108

        return y


109
class BottleneckBlock(fluid.Layer):
110 111
    def __init__(self, num_channels, num_filters, stride, shortcut=True):
        super(BottleneckBlock, self).__init__()
M
minqiyang 已提交
112 113

        self.conv0 = ConvBNLayer(
114
            num_channels=num_channels,
M
minqiyang 已提交
115 116 117
            num_filters=num_filters,
            filter_size=1,
            act='relu')
M
minqiyang 已提交
118
        self.conv1 = ConvBNLayer(
119
            num_channels=num_filters,
M
minqiyang 已提交
120 121 122 123
            num_filters=num_filters,
            filter_size=3,
            stride=stride,
            act='relu')
M
minqiyang 已提交
124
        self.conv2 = ConvBNLayer(
125
            num_channels=num_filters,
M
minqiyang 已提交
126 127 128
            num_filters=num_filters * 4,
            filter_size=1,
            act=None)
M
minqiyang 已提交
129

M
minqiyang 已提交
130
        if not shortcut:
M
minqiyang 已提交
131
            self.short = ConvBNLayer(
132
                num_channels=num_channels,
M
minqiyang 已提交
133 134 135
                num_filters=num_filters * 4,
                filter_size=1,
                stride=stride)
M
minqiyang 已提交
136 137 138 139

        self.shortcut = shortcut

    def forward(self, inputs):
M
minqiyang 已提交
140 141 142
        y = self.conv0(inputs)
        conv1 = self.conv1(y)
        conv2 = self.conv2(conv1)
M
minqiyang 已提交
143 144

        if self.shortcut:
M
minqiyang 已提交
145 146 147
            short = inputs
        else:
            short = self.short(inputs)
M
minqiyang 已提交
148

M
minqiyang 已提交
149 150
        y = fluid.layers.elementwise_add(x=short, y=conv2)

X
Xin Pan 已提交
151
        layer_helper = LayerHelper(self.full_name(), act='relu')
M
minqiyang 已提交
152
        return layer_helper.append_activation(y)
M
minqiyang 已提交
153 154


155
class ResNet(fluid.Layer):
156 157
    def __init__(self, layers=50, class_dim=102):
        super(ResNet, self).__init__()
M
minqiyang 已提交
158

M
minqiyang 已提交
159 160 161 162 163 164 165 166 167 168 169
        self.layers = layers
        supported_layers = [50, 101, 152]
        assert layers in supported_layers, \
            "supported layers are {} but input layer is {}".format(supported_layers, layers)

        if layers == 50:
            depth = [3, 4, 6, 3]
        elif layers == 101:
            depth = [3, 4, 23, 3]
        elif layers == 152:
            depth = [3, 8, 36, 3]
170
        num_channels = [64, 256, 512, 1024]
M
minqiyang 已提交
171 172 173
        num_filters = [64, 128, 256, 512]

        self.conv = ConvBNLayer(
174
            num_channels=3, num_filters=64, filter_size=7, stride=2, act='relu')
M
minqiyang 已提交
175
        self.pool2d_max = Pool2D(
176
            pool_size=3, pool_stride=2, pool_padding=1, pool_type='max')
M
minqiyang 已提交
177

M
minqiyang 已提交
178 179 180 181
        self.bottleneck_block_list = []
        for block in range(len(depth)):
            shortcut = False
            for i in range(depth[block]):
X
Xin Pan 已提交
182 183 184
                bottleneck_block = self.add_sublayer(
                    'bb_%d_%d' % (block, i),
                    BottleneckBlock(
185 186
                        num_channels=num_channels[block]
                        if i == 0 else num_filters[block] * 4,
X
Xin Pan 已提交
187 188 189
                        num_filters=num_filters[block],
                        stride=2 if i == 0 and block != 0 else 1,
                        shortcut=shortcut))
M
minqiyang 已提交
190 191 192 193
                self.bottleneck_block_list.append(bottleneck_block)
                shortcut = True

        self.pool2d_avg = Pool2D(
194
            pool_size=7, pool_type='avg', global_pooling=True)
M
minqiyang 已提交
195

196 197
        self.pool2d_avg_output = num_filters[-1] * 4 * 1 * 1

M
minqiyang 已提交
198 199 200
        import math
        stdv = 1.0 / math.sqrt(2048 * 1.0)

201 202 203 204 205 206
        self.out = Linear(
            self.pool2d_avg_output,
            class_dim,
            act='softmax',
            param_attr=fluid.param_attr.ParamAttr(
                initializer=fluid.initializer.Uniform(-stdv, stdv)))
M
minqiyang 已提交
207 208 209 210

    def forward(self, inputs):
        y = self.conv(inputs)
        y = self.pool2d_max(y)
M
minqiyang 已提交
211 212 213
        for bottleneck_block in self.bottleneck_block_list:
            y = bottleneck_block(y)
        y = self.pool2d_avg(y)
214
        y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
M
minqiyang 已提交
215
        y = self.out(y)
M
minqiyang 已提交
216 217 218
        return y


L
lujun 已提交
219
class TestDygraphResnet(unittest.TestCase):
220 221 222 223 224 225 226 227 228
    def reader_decorator(self, reader):
        def _reader_imple():
            for item in reader():
                doc = np.array(item[0]).reshape(3, 224, 224)
                label = np.array(item[1]).astype('int64').reshape(1)
                yield doc, label

        return _reader_imple

M
minqiyang 已提交
229
    def test_resnet_float32(self):
M
minqiyang 已提交
230 231
        seed = 90

232
        batch_size = train_parameters["batch_size"]
233 234
        batch_num = 10

235 236
        traced_layer = None

L
lujun 已提交
237
        with fluid.dygraph.guard():
238 239 240
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

241 242 243
            resnet = ResNet()
            optimizer = optimizer_setting(
                train_parameters, parameter_list=resnet.parameters())
244 245 246
            np.random.seed(seed)
            import random
            random.seed = seed
247 248 249 250 251 252 253 254 255

            batch_py_reader = fluid.io.PyReader(capacity=1)
            batch_py_reader.decorate_sample_list_generator(
                paddle.batch(
                    self.reader_decorator(
                        paddle.dataset.flowers.train(use_xmap=False)),
                    batch_size=batch_size,
                    drop_last=True),
                places=fluid.CPUPlace())
256 257

            dy_param_init_value = {}
M
minqiyang 已提交
258
            for param in resnet.parameters():
259
                dy_param_init_value[param.name] = param.numpy()
260

261 262
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
263

264
            for batch_id, data in enumerate(batch_py_reader()):
M
minqiyang 已提交
265
                if batch_id >= batch_num:
266 267
                    break

268 269
                img = data[0]
                label = data[1]
270
                label.stop_gradient = True
271

272
                out = None
273
                if batch_id % 5 == 0:
274 275 276 277 278 279 280 281 282
                    out, traced_layer = TracedLayer.trace(resnet, img)
                    if program is not None:
                        self.assertTrue(
                            is_equal_program(program, traced_layer.program))

                    traced_layer.save_inference_model(
                        './infer_imperative_resnet')

                    program = traced_layer.program
283 284 285
                else:
                    out = resnet(img)

286 287 288
                if traced_layer is not None:
                    resnet.eval()
                    traced_layer._switch(is_test=True)
289
                    out_dygraph = resnet(img)
290 291 292 293 294
                    out_static = traced_layer([img])
                    traced_layer._switch(is_test=False)
                    helper.assertEachVar(out_dygraph, out_static)
                    resnet.train()

295 296 297
                loss = fluid.layers.cross_entropy(input=out, label=label)
                avg_loss = fluid.layers.mean(x=loss)

298
                dy_out = avg_loss.numpy()
299 300

                if batch_id == 0:
M
minqiyang 已提交
301
                    for param in resnet.parameters():
302
                        if param.name not in dy_param_init_value:
303
                            dy_param_init_value[param.name] = param.numpy()
304

L
lujun 已提交
305
                avg_loss.backward()
306 307

                dy_grad_value = {}
M
minqiyang 已提交
308
                for param in resnet.parameters():
309
                    if param.trainable:
310
                        np_array = np.array(param._grad_ivar().value()
311 312 313 314 315
                                            .get_tensor())
                        dy_grad_value[param.name + core.grad_var_suffix(
                        )] = np_array

                optimizer.minimize(avg_loss)
M
minqiyang 已提交
316
                resnet.clear_gradients()
317 318

                dy_param_value = {}
M
minqiyang 已提交
319
                for param in resnet.parameters():
320
                    dy_param_value[param.name] = param.numpy()
M
minqiyang 已提交
321 322

        with new_program_scope():
M
minqiyang 已提交
323 324 325
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

M
minqiyang 已提交
326 327
            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
328

329
            resnet = ResNet()
330
            optimizer = optimizer_setting(train_parameters)
M
minqiyang 已提交
331 332 333 334

            np.random.seed(seed)
            import random
            random.seed = seed
335
            train_reader = paddle.batch(
M
minqiyang 已提交
336 337
                paddle.dataset.flowers.train(use_xmap=False),
                batch_size=batch_size)
338 339 340 341 342 343 344 345 346 347 348 349

            img = fluid.layers.data(
                name='pixel', shape=[3, 224, 224], dtype='float32')
            label = fluid.layers.data(name='label', shape=[1], dtype='int64')
            out = resnet(img)
            loss = fluid.layers.cross_entropy(input=out, label=label)
            avg_loss = fluid.layers.mean(x=loss)
            optimizer.minimize(avg_loss)

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
M
minqiyang 已提交
350
            static_grad_name_list = []
M
minqiyang 已提交
351
            for param in resnet.parameters():
352
                static_param_name_list.append(param.name)
M
minqiyang 已提交
353
            for param in resnet.parameters():
354
                if param.trainable:
M
minqiyang 已提交
355 356
                    static_grad_name_list.append(param.name +
                                                 core.grad_var_suffix())
357 358 359 360 361 362 363 364

            out = exe.run(fluid.default_startup_program(),
                          fetch_list=static_param_name_list)

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

            for batch_id, data in enumerate(train_reader()):
M
minqiyang 已提交
365
                if batch_id >= batch_num:
366 367
                    break

M
minqiyang 已提交
368
                static_x_data = np.array(
369 370 371 372
                    [x[0].reshape(3, 224, 224) for x in data]).astype('float32')
                y_data = np.array([x[1] for x in data]).astype('int64').reshape(
                    [batch_size, 1])

373 374 375
                if traced_layer is not None:
                    traced_layer([static_x_data])

M
minqiyang 已提交
376
                fetch_list = [avg_loss.name]
377
                fetch_list.extend(static_param_name_list)
M
minqiyang 已提交
378
                fetch_list.extend(static_grad_name_list)
379
                out = exe.run(fluid.default_main_program(),
M
minqiyang 已提交
380
                              feed={"pixel": static_x_data,
381 382 383 384
                                    "label": y_data},
                              fetch_list=fetch_list)

                static_param_value = {}
M
minqiyang 已提交
385
                static_grad_value = {}
386
                static_out = out[0]
M
minqiyang 已提交
387 388 389 390 391 392 393 394 395 396 397
                param_start_pos = 1
                grad_start_pos = len(static_param_name_list) + param_start_pos
                for i in range(param_start_pos,
                               len(static_param_name_list) + param_start_pos):
                    static_param_value[static_param_name_list[
                        i - param_start_pos]] = out[i]
                for i in range(grad_start_pos,
                               len(static_grad_name_list) + grad_start_pos):
                    static_grad_value[static_grad_name_list[
                        i - grad_start_pos]] = out[i]

H
hong 已提交
398 399
        print("static", static_out)
        print("dygraph", dy_out)
M
minqiyang 已提交
400 401 402
        self.assertTrue(np.allclose(static_out, dy_out))

        self.assertEqual(len(dy_param_init_value), len(static_param_init_value))
X
Xin Pan 已提交
403

M
minqiyang 已提交
404 405
        for key, value in six.iteritems(static_param_init_value):
            self.assertTrue(np.allclose(value, dy_param_init_value[key]))
406 407
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
408

M
minqiyang 已提交
409
        self.assertEqual(len(dy_grad_value), len(static_grad_value))
M
minqiyang 已提交
410
        for key, value in six.iteritems(static_grad_value):
M
minqiyang 已提交
411
            self.assertTrue(np.allclose(value, dy_grad_value[key]))
412 413
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
414

M
minqiyang 已提交
415
        self.assertEqual(len(dy_param_value), len(static_param_value))
M
minqiyang 已提交
416
        for key, value in six.iteritems(static_param_value):
417 418 419
            self.assertTrue(np.allclose(value, dy_param_value[key]))
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
M
minqiyang 已提交
420 421 422 423


if __name__ == '__main__':
    unittest.main()