test_imperative_resnet.py 14.4 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import unittest
import numpy as np
import six

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
M
minqiyang 已提交
23
from paddle.fluid.layer_helper import LayerHelper
24
from paddle.fluid import Conv2D, Pool2D, BatchNorm, FC
L
lujun 已提交
25
from paddle.fluid.dygraph.base import to_variable
M
minqiyang 已提交
26
from test_imperative_base import new_program_scope
27 28
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
from paddle.fluid.dygraph.jit import TracedLayer
M
minqiyang 已提交
29

30
batch_size = 8
M
minqiyang 已提交
31 32 33 34 35 36
train_parameters = {
    "input_size": [3, 224, 224],
    "input_mean": [0.485, 0.456, 0.406],
    "input_std": [0.229, 0.224, 0.225],
    "learning_strategy": {
        "name": "piecewise_decay",
M
minqiyang 已提交
37
        "batch_size": batch_size,
M
minqiyang 已提交
38 39
        "epochs": [30, 60, 90],
        "steps": [0.1, 0.01, 0.001, 0.0001]
M
minqiyang 已提交
40
    },
M
minqiyang 已提交
41
    "batch_size": batch_size,
M
minqiyang 已提交
42 43
    "lr": 0.1,
    "total_images": 1281164,
M
minqiyang 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
}


def optimizer_setting(params):
    ls = params["learning_strategy"]
    if ls["name"] == "piecewise_decay":
        if "total_images" not in params:
            total_images = 1281167
        else:
            total_images = params["total_images"]
        batch_size = ls["batch_size"]
        step = int(total_images / batch_size + 1)

        bd = [step * e for e in ls["epochs"]]
        base_lr = params["lr"]
        lr = []
        lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
61
        optimizer = fluid.optimizer.SGD(learning_rate=0.01)
L
lujun 已提交
62
        # TODO(minqiyang): Add learning rate scheduler support to dygraph mode
M
minqiyang 已提交
63 64 65 66 67 68
        #  optimizer = fluid.optimizer.Momentum(
    #  learning_rate=params["lr"],
    #  learning_rate=fluid.layers.piecewise_decay(
    #  boundaries=bd, values=lr),
    #  momentum=0.9,
    #  regularization=fluid.regularizer.L2Decay(1e-4))
M
minqiyang 已提交
69 70 71 72

    return optimizer


73
class ConvBNLayer(fluid.Layer):
M
minqiyang 已提交
74
    def __init__(self,
X
Xin Pan 已提交
75
                 name_scope,
M
minqiyang 已提交
76 77 78 79 80
                 num_filters,
                 filter_size,
                 stride=1,
                 groups=1,
                 act=None):
X
Xin Pan 已提交
81
        super(ConvBNLayer, self).__init__(name_scope)
M
minqiyang 已提交
82 83

        self._conv = Conv2D(
X
Xin Pan 已提交
84
            self.full_name(),
M
minqiyang 已提交
85 86 87 88
            num_filters=num_filters,
            filter_size=filter_size,
            stride=stride,
            padding=(filter_size - 1) // 2,
M
minqiyang 已提交
89 90
            groups=groups,
            act=None,
H
hong 已提交
91 92
            bias_attr=None,
            use_cudnn=False)
M
minqiyang 已提交
93

X
Xin Pan 已提交
94
        self._batch_norm = BatchNorm(self.full_name(), num_filters, act=act)
M
minqiyang 已提交
95 96 97

    def forward(self, inputs):
        y = self._conv(inputs)
98
        y = self._batch_norm(y)
M
minqiyang 已提交
99 100 101 102

        return y


103
class BottleneckBlock(fluid.Layer):
104
    def __init__(self, name_scope, num_filters, stride, shortcut=True):
X
Xin Pan 已提交
105
        super(BottleneckBlock, self).__init__(name_scope)
M
minqiyang 已提交
106 107

        self.conv0 = ConvBNLayer(
X
Xin Pan 已提交
108
            self.full_name(),
M
minqiyang 已提交
109 110 111
            num_filters=num_filters,
            filter_size=1,
            act='relu')
M
minqiyang 已提交
112
        self.conv1 = ConvBNLayer(
X
Xin Pan 已提交
113
            self.full_name(),
M
minqiyang 已提交
114 115 116 117
            num_filters=num_filters,
            filter_size=3,
            stride=stride,
            act='relu')
M
minqiyang 已提交
118
        self.conv2 = ConvBNLayer(
X
Xin Pan 已提交
119
            self.full_name(),
M
minqiyang 已提交
120 121 122
            num_filters=num_filters * 4,
            filter_size=1,
            act=None)
M
minqiyang 已提交
123

M
minqiyang 已提交
124
        if not shortcut:
M
minqiyang 已提交
125
            self.short = ConvBNLayer(
X
Xin Pan 已提交
126
                self.full_name(),
M
minqiyang 已提交
127 128 129
                num_filters=num_filters * 4,
                filter_size=1,
                stride=stride)
M
minqiyang 已提交
130 131 132 133

        self.shortcut = shortcut

    def forward(self, inputs):
M
minqiyang 已提交
134 135 136
        y = self.conv0(inputs)
        conv1 = self.conv1(y)
        conv2 = self.conv2(conv1)
M
minqiyang 已提交
137 138

        if self.shortcut:
M
minqiyang 已提交
139 140 141
            short = inputs
        else:
            short = self.short(inputs)
M
minqiyang 已提交
142

M
minqiyang 已提交
143 144
        y = fluid.layers.elementwise_add(x=short, y=conv2)

X
Xin Pan 已提交
145
        layer_helper = LayerHelper(self.full_name(), act='relu')
M
minqiyang 已提交
146
        return layer_helper.append_activation(y)
M
minqiyang 已提交
147 148


149
class ResNet(fluid.Layer):
X
Xin Pan 已提交
150 151
    def __init__(self, name_scope, layers=50, class_dim=102):
        super(ResNet, self).__init__(name_scope)
M
minqiyang 已提交
152

M
minqiyang 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166
        self.layers = layers
        supported_layers = [50, 101, 152]
        assert layers in supported_layers, \
            "supported layers are {} but input layer is {}".format(supported_layers, layers)

        if layers == 50:
            depth = [3, 4, 6, 3]
        elif layers == 101:
            depth = [3, 4, 23, 3]
        elif layers == 152:
            depth = [3, 8, 36, 3]
        num_filters = [64, 128, 256, 512]

        self.conv = ConvBNLayer(
X
Xin Pan 已提交
167 168 169 170 171
            self.full_name(),
            num_filters=64,
            filter_size=7,
            stride=2,
            act='relu')
M
minqiyang 已提交
172
        self.pool2d_max = Pool2D(
X
Xin Pan 已提交
173 174 175 176 177
            self.full_name(),
            pool_size=3,
            pool_stride=2,
            pool_padding=1,
            pool_type='max')
M
minqiyang 已提交
178

M
minqiyang 已提交
179 180 181 182
        self.bottleneck_block_list = []
        for block in range(len(depth)):
            shortcut = False
            for i in range(depth[block]):
X
Xin Pan 已提交
183 184 185
                bottleneck_block = self.add_sublayer(
                    'bb_%d_%d' % (block, i),
                    BottleneckBlock(
X
Xin Pan 已提交
186
                        self.full_name(),
X
Xin Pan 已提交
187 188 189
                        num_filters=num_filters[block],
                        stride=2 if i == 0 and block != 0 else 1,
                        shortcut=shortcut))
M
minqiyang 已提交
190 191 192 193
                self.bottleneck_block_list.append(bottleneck_block)
                shortcut = True

        self.pool2d_avg = Pool2D(
X
Xin Pan 已提交
194
            self.full_name(), pool_size=7, pool_type='avg', global_pooling=True)
M
minqiyang 已提交
195 196 197 198

        import math
        stdv = 1.0 / math.sqrt(2048 * 1.0)

X
Xin Pan 已提交
199 200
        self.out = FC(self.full_name(),
                      size=class_dim,
M
minqiyang 已提交
201 202 203 204 205 206 207
                      act='softmax',
                      param_attr=fluid.param_attr.ParamAttr(
                          initializer=fluid.initializer.Uniform(-stdv, stdv)))

    def forward(self, inputs):
        y = self.conv(inputs)
        y = self.pool2d_max(y)
M
minqiyang 已提交
208 209 210
        for bottleneck_block in self.bottleneck_block_list:
            y = bottleneck_block(y)
        y = self.pool2d_avg(y)
M
minqiyang 已提交
211
        y = self.out(y)
M
minqiyang 已提交
212 213 214
        return y


L
lujun 已提交
215
class TestDygraphResnet(unittest.TestCase):
216 217 218 219 220 221 222 223 224
    def reader_decorator(self, reader):
        def _reader_imple():
            for item in reader():
                doc = np.array(item[0]).reshape(3, 224, 224)
                label = np.array(item[1]).astype('int64').reshape(1)
                yield doc, label

        return _reader_imple

M
minqiyang 已提交
225
    def test_resnet_float32(self):
M
minqiyang 已提交
226 227
        seed = 90

228
        batch_size = train_parameters["batch_size"]
229 230
        batch_num = 10

231 232
        traced_layer = None

L
lujun 已提交
233
        with fluid.dygraph.guard():
234 235 236
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

X
Xin Pan 已提交
237
            resnet = ResNet("resnet")
238 239 240 241
            optimizer = optimizer_setting(train_parameters)
            np.random.seed(seed)
            import random
            random.seed = seed
242 243 244 245 246 247 248 249 250

            batch_py_reader = fluid.io.PyReader(capacity=1)
            batch_py_reader.decorate_sample_list_generator(
                paddle.batch(
                    self.reader_decorator(
                        paddle.dataset.flowers.train(use_xmap=False)),
                    batch_size=batch_size,
                    drop_last=True),
                places=fluid.CPUPlace())
251 252

            dy_param_init_value = {}
M
minqiyang 已提交
253
            for param in resnet.parameters():
254
                dy_param_init_value[param.name] = param.numpy()
255

256 257
            helper = DyGraphProgramDescTracerTestHelper(self)
            program = None
258

259
            for batch_id, data in enumerate(batch_py_reader()):
M
minqiyang 已提交
260
                if batch_id >= batch_num:
261 262
                    break

263 264
                img = data[0]
                label = data[1]
265
                label.stop_gradient = True
266

267
                out = None
268
                if batch_id % 5 == 0:
269 270 271 272 273 274 275 276 277
                    out, traced_layer = TracedLayer.trace(resnet, img)
                    if program is not None:
                        self.assertTrue(
                            is_equal_program(program, traced_layer.program))

                    traced_layer.save_inference_model(
                        './infer_imperative_resnet')

                    program = traced_layer.program
278 279 280
                else:
                    out = resnet(img)

281 282 283 284 285 286 287 288 289
                if traced_layer is not None:
                    resnet.eval()
                    traced_layer._switch(is_test=True)
                    out_dygraph = resnet([img])
                    out_static = traced_layer([img])
                    traced_layer._switch(is_test=False)
                    helper.assertEachVar(out_dygraph, out_static)
                    resnet.train()

290 291 292
                loss = fluid.layers.cross_entropy(input=out, label=label)
                avg_loss = fluid.layers.mean(x=loss)

293
                dy_out = avg_loss.numpy()
294 295

                if batch_id == 0:
M
minqiyang 已提交
296
                    for param in resnet.parameters():
297
                        if param.name not in dy_param_init_value:
298
                            dy_param_init_value[param.name] = param.numpy()
299

L
lujun 已提交
300
                avg_loss.backward()
301 302

                dy_grad_value = {}
M
minqiyang 已提交
303
                for param in resnet.parameters():
304
                    if param.trainable:
305 306 307 308 309 310
                        np_array = np.array(param._ivar._grad_ivar().value()
                                            .get_tensor())
                        dy_grad_value[param.name + core.grad_var_suffix(
                        )] = np_array

                optimizer.minimize(avg_loss)
M
minqiyang 已提交
311
                resnet.clear_gradients()
312 313

                dy_param_value = {}
M
minqiyang 已提交
314
                for param in resnet.parameters():
315
                    dy_param_value[param.name] = param.numpy()
M
minqiyang 已提交
316 317

        with new_program_scope():
M
minqiyang 已提交
318 319 320
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

M
minqiyang 已提交
321 322
            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
323

X
Xin Pan 已提交
324
            resnet = ResNet("resnet")
325
            optimizer = optimizer_setting(train_parameters)
M
minqiyang 已提交
326 327 328 329

            np.random.seed(seed)
            import random
            random.seed = seed
330
            train_reader = paddle.batch(
M
minqiyang 已提交
331 332
                paddle.dataset.flowers.train(use_xmap=False),
                batch_size=batch_size)
333 334 335 336 337 338 339 340 341 342 343 344

            img = fluid.layers.data(
                name='pixel', shape=[3, 224, 224], dtype='float32')
            label = fluid.layers.data(name='label', shape=[1], dtype='int64')
            out = resnet(img)
            loss = fluid.layers.cross_entropy(input=out, label=label)
            avg_loss = fluid.layers.mean(x=loss)
            optimizer.minimize(avg_loss)

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
M
minqiyang 已提交
345
            static_grad_name_list = []
M
minqiyang 已提交
346
            for param in resnet.parameters():
347
                static_param_name_list.append(param.name)
M
minqiyang 已提交
348
            for param in resnet.parameters():
349
                if param.trainable:
M
minqiyang 已提交
350 351
                    static_grad_name_list.append(param.name +
                                                 core.grad_var_suffix())
352 353 354 355 356 357 358 359

            out = exe.run(fluid.default_startup_program(),
                          fetch_list=static_param_name_list)

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

            for batch_id, data in enumerate(train_reader()):
M
minqiyang 已提交
360
                if batch_id >= batch_num:
361 362
                    break

M
minqiyang 已提交
363
                static_x_data = np.array(
364 365 366 367
                    [x[0].reshape(3, 224, 224) for x in data]).astype('float32')
                y_data = np.array([x[1] for x in data]).astype('int64').reshape(
                    [batch_size, 1])

368 369 370
                if traced_layer is not None:
                    traced_layer([static_x_data])

M
minqiyang 已提交
371
                fetch_list = [avg_loss.name]
372
                fetch_list.extend(static_param_name_list)
M
minqiyang 已提交
373
                fetch_list.extend(static_grad_name_list)
374
                out = exe.run(fluid.default_main_program(),
M
minqiyang 已提交
375
                              feed={"pixel": static_x_data,
376 377 378 379
                                    "label": y_data},
                              fetch_list=fetch_list)

                static_param_value = {}
M
minqiyang 已提交
380
                static_grad_value = {}
381
                static_out = out[0]
M
minqiyang 已提交
382 383 384 385 386 387 388 389 390 391 392
                param_start_pos = 1
                grad_start_pos = len(static_param_name_list) + param_start_pos
                for i in range(param_start_pos,
                               len(static_param_name_list) + param_start_pos):
                    static_param_value[static_param_name_list[
                        i - param_start_pos]] = out[i]
                for i in range(grad_start_pos,
                               len(static_grad_name_list) + grad_start_pos):
                    static_grad_value[static_grad_name_list[
                        i - grad_start_pos]] = out[i]

H
hong 已提交
393 394
        print("static", static_out)
        print("dygraph", dy_out)
M
minqiyang 已提交
395 396 397
        self.assertTrue(np.allclose(static_out, dy_out))

        self.assertEqual(len(dy_param_init_value), len(static_param_init_value))
X
Xin Pan 已提交
398

M
minqiyang 已提交
399 400
        for key, value in six.iteritems(static_param_init_value):
            self.assertTrue(np.allclose(value, dy_param_init_value[key]))
401 402
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
403

M
minqiyang 已提交
404
        self.assertEqual(len(dy_grad_value), len(static_grad_value))
M
minqiyang 已提交
405
        for key, value in six.iteritems(static_grad_value):
M
minqiyang 已提交
406
            self.assertTrue(np.allclose(value, dy_grad_value[key]))
407 408
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
409

M
minqiyang 已提交
410
        self.assertEqual(len(dy_param_value), len(static_param_value))
M
minqiyang 已提交
411
        for key, value in six.iteritems(static_param_value):
412 413 414
            self.assertTrue(np.allclose(value, dy_param_value[key]))
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
M
minqiyang 已提交
415 416 417 418


if __name__ == '__main__':
    unittest.main()