test_imperative_resnet.py 13.6 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import unittest
import numpy as np
import six

import paddle
import paddle.fluid as fluid
from paddle.fluid import core
M
minqiyang 已提交
23
from paddle.fluid.layer_helper import LayerHelper
24
from paddle.fluid import Conv2D, Pool2D, BatchNorm, FC
L
lujun 已提交
25
from paddle.fluid.dygraph.base import to_variable
M
minqiyang 已提交
26
from test_imperative_base import new_program_scope
27
from utils import DyGraphProgramDescTracerTestHelper
M
minqiyang 已提交
28

29
batch_size = 8
M
minqiyang 已提交
30 31 32 33 34 35
train_parameters = {
    "input_size": [3, 224, 224],
    "input_mean": [0.485, 0.456, 0.406],
    "input_std": [0.229, 0.224, 0.225],
    "learning_strategy": {
        "name": "piecewise_decay",
M
minqiyang 已提交
36
        "batch_size": batch_size,
M
minqiyang 已提交
37 38
        "epochs": [30, 60, 90],
        "steps": [0.1, 0.01, 0.001, 0.0001]
M
minqiyang 已提交
39
    },
M
minqiyang 已提交
40
    "batch_size": batch_size,
M
minqiyang 已提交
41 42
    "lr": 0.1,
    "total_images": 1281164,
M
minqiyang 已提交
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
}


def optimizer_setting(params):
    ls = params["learning_strategy"]
    if ls["name"] == "piecewise_decay":
        if "total_images" not in params:
            total_images = 1281167
        else:
            total_images = params["total_images"]
        batch_size = ls["batch_size"]
        step = int(total_images / batch_size + 1)

        bd = [step * e for e in ls["epochs"]]
        base_lr = params["lr"]
        lr = []
        lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
60
        optimizer = fluid.optimizer.SGD(learning_rate=0.01)
L
lujun 已提交
61
        # TODO(minqiyang): Add learning rate scheduler support to dygraph mode
M
minqiyang 已提交
62 63 64 65 66 67
        #  optimizer = fluid.optimizer.Momentum(
    #  learning_rate=params["lr"],
    #  learning_rate=fluid.layers.piecewise_decay(
    #  boundaries=bd, values=lr),
    #  momentum=0.9,
    #  regularization=fluid.regularizer.L2Decay(1e-4))
M
minqiyang 已提交
68 69 70 71

    return optimizer


72
class ConvBNLayer(fluid.Layer):
M
minqiyang 已提交
73
    def __init__(self,
X
Xin Pan 已提交
74
                 name_scope,
M
minqiyang 已提交
75 76 77 78 79
                 num_filters,
                 filter_size,
                 stride=1,
                 groups=1,
                 act=None):
X
Xin Pan 已提交
80
        super(ConvBNLayer, self).__init__(name_scope)
M
minqiyang 已提交
81 82

        self._conv = Conv2D(
X
Xin Pan 已提交
83
            self.full_name(),
M
minqiyang 已提交
84 85 86 87
            num_filters=num_filters,
            filter_size=filter_size,
            stride=stride,
            padding=(filter_size - 1) // 2,
M
minqiyang 已提交
88 89 90 91
            groups=groups,
            act=None,
            bias_attr=None)

X
Xin Pan 已提交
92
        self._batch_norm = BatchNorm(self.full_name(), num_filters, act=act)
M
minqiyang 已提交
93 94 95

    def forward(self, inputs):
        y = self._conv(inputs)
96
        y = self._batch_norm(y)
M
minqiyang 已提交
97 98 99 100

        return y


101
class BottleneckBlock(fluid.Layer):
102
    def __init__(self, name_scope, num_filters, stride, shortcut=True):
X
Xin Pan 已提交
103
        super(BottleneckBlock, self).__init__(name_scope)
M
minqiyang 已提交
104 105

        self.conv0 = ConvBNLayer(
X
Xin Pan 已提交
106
            self.full_name(),
M
minqiyang 已提交
107 108 109
            num_filters=num_filters,
            filter_size=1,
            act='relu')
M
minqiyang 已提交
110
        self.conv1 = ConvBNLayer(
X
Xin Pan 已提交
111
            self.full_name(),
M
minqiyang 已提交
112 113 114 115
            num_filters=num_filters,
            filter_size=3,
            stride=stride,
            act='relu')
M
minqiyang 已提交
116
        self.conv2 = ConvBNLayer(
X
Xin Pan 已提交
117
            self.full_name(),
M
minqiyang 已提交
118 119 120
            num_filters=num_filters * 4,
            filter_size=1,
            act=None)
M
minqiyang 已提交
121

M
minqiyang 已提交
122
        if not shortcut:
M
minqiyang 已提交
123
            self.short = ConvBNLayer(
X
Xin Pan 已提交
124
                self.full_name(),
M
minqiyang 已提交
125 126 127
                num_filters=num_filters * 4,
                filter_size=1,
                stride=stride)
M
minqiyang 已提交
128 129 130 131

        self.shortcut = shortcut

    def forward(self, inputs):
M
minqiyang 已提交
132 133 134
        y = self.conv0(inputs)
        conv1 = self.conv1(y)
        conv2 = self.conv2(conv1)
M
minqiyang 已提交
135 136

        if self.shortcut:
M
minqiyang 已提交
137 138 139
            short = inputs
        else:
            short = self.short(inputs)
M
minqiyang 已提交
140

M
minqiyang 已提交
141 142
        y = fluid.layers.elementwise_add(x=short, y=conv2)

X
Xin Pan 已提交
143
        layer_helper = LayerHelper(self.full_name(), act='relu')
M
minqiyang 已提交
144
        return layer_helper.append_activation(y)
M
minqiyang 已提交
145 146


147
class ResNet(fluid.Layer):
X
Xin Pan 已提交
148 149
    def __init__(self, name_scope, layers=50, class_dim=102):
        super(ResNet, self).__init__(name_scope)
M
minqiyang 已提交
150

M
minqiyang 已提交
151 152 153 154 155 156 157 158 159 160 161 162 163 164
        self.layers = layers
        supported_layers = [50, 101, 152]
        assert layers in supported_layers, \
            "supported layers are {} but input layer is {}".format(supported_layers, layers)

        if layers == 50:
            depth = [3, 4, 6, 3]
        elif layers == 101:
            depth = [3, 4, 23, 3]
        elif layers == 152:
            depth = [3, 8, 36, 3]
        num_filters = [64, 128, 256, 512]

        self.conv = ConvBNLayer(
X
Xin Pan 已提交
165 166 167 168 169
            self.full_name(),
            num_filters=64,
            filter_size=7,
            stride=2,
            act='relu')
M
minqiyang 已提交
170
        self.pool2d_max = Pool2D(
X
Xin Pan 已提交
171 172 173 174 175
            self.full_name(),
            pool_size=3,
            pool_stride=2,
            pool_padding=1,
            pool_type='max')
M
minqiyang 已提交
176

M
minqiyang 已提交
177 178 179 180
        self.bottleneck_block_list = []
        for block in range(len(depth)):
            shortcut = False
            for i in range(depth[block]):
X
Xin Pan 已提交
181 182 183
                bottleneck_block = self.add_sublayer(
                    'bb_%d_%d' % (block, i),
                    BottleneckBlock(
X
Xin Pan 已提交
184
                        self.full_name(),
X
Xin Pan 已提交
185 186 187
                        num_filters=num_filters[block],
                        stride=2 if i == 0 and block != 0 else 1,
                        shortcut=shortcut))
M
minqiyang 已提交
188 189 190 191
                self.bottleneck_block_list.append(bottleneck_block)
                shortcut = True

        self.pool2d_avg = Pool2D(
X
Xin Pan 已提交
192
            self.full_name(), pool_size=7, pool_type='avg', global_pooling=True)
M
minqiyang 已提交
193 194 195 196

        import math
        stdv = 1.0 / math.sqrt(2048 * 1.0)

X
Xin Pan 已提交
197 198
        self.out = FC(self.full_name(),
                      size=class_dim,
M
minqiyang 已提交
199 200 201 202 203 204 205
                      act='softmax',
                      param_attr=fluid.param_attr.ParamAttr(
                          initializer=fluid.initializer.Uniform(-stdv, stdv)))

    def forward(self, inputs):
        y = self.conv(inputs)
        y = self.pool2d_max(y)
M
minqiyang 已提交
206 207 208
        for bottleneck_block in self.bottleneck_block_list:
            y = bottleneck_block(y)
        y = self.pool2d_avg(y)
M
minqiyang 已提交
209
        y = self.out(y)
M
minqiyang 已提交
210 211 212
        return y


L
lujun 已提交
213
class TestDygraphResnet(unittest.TestCase):
214 215 216 217 218 219 220 221 222
    def reader_decorator(self, reader):
        def _reader_imple():
            for item in reader():
                doc = np.array(item[0]).reshape(3, 224, 224)
                label = np.array(item[1]).astype('int64').reshape(1)
                yield doc, label

        return _reader_imple

M
minqiyang 已提交
223
    def test_resnet_float32(self):
M
minqiyang 已提交
224 225
        seed = 90

226
        batch_size = train_parameters["batch_size"]
227 228
        batch_num = 10

L
lujun 已提交
229
        with fluid.dygraph.guard():
230 231 232
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

X
Xin Pan 已提交
233
            resnet = ResNet("resnet")
234 235 236 237
            optimizer = optimizer_setting(train_parameters)
            np.random.seed(seed)
            import random
            random.seed = seed
238 239 240 241 242 243 244 245 246

            batch_py_reader = fluid.io.PyReader(capacity=1)
            batch_py_reader.decorate_sample_list_generator(
                paddle.batch(
                    self.reader_decorator(
                        paddle.dataset.flowers.train(use_xmap=False)),
                    batch_size=batch_size,
                    drop_last=True),
                places=fluid.CPUPlace())
247 248

            dy_param_init_value = {}
M
minqiyang 已提交
249
            for param in resnet.parameters():
250
                dy_param_init_value[param.name] = param.numpy()
251

252 253
            helper = DyGraphProgramDescTracerTestHelper(resnet, self)

254
            for batch_id, data in enumerate(batch_py_reader()):
M
minqiyang 已提交
255
                if batch_id >= batch_num:
256 257
                    break

258 259
                img = data[0]
                label = data[1]
260
                label.stop_gradient = True
261

262 263 264 265 266 267 268 269
                if batch_id % 5 == 0:
                    out, out_static = helper.run(img,
                                                 feed_names=['image'],
                                                 fetch_names=['logits'])
                    helper.assertEachVar(out, out_static)
                else:
                    out = resnet(img)

270 271 272
                loss = fluid.layers.cross_entropy(input=out, label=label)
                avg_loss = fluid.layers.mean(x=loss)

273
                dy_out = avg_loss.numpy()
274 275

                if batch_id == 0:
M
minqiyang 已提交
276
                    for param in resnet.parameters():
277
                        if param.name not in dy_param_init_value:
278
                            dy_param_init_value[param.name] = param.numpy()
279

L
lujun 已提交
280
                avg_loss.backward()
281 282

                dy_grad_value = {}
M
minqiyang 已提交
283
                for param in resnet.parameters():
284
                    if param.trainable:
285 286 287 288 289 290
                        np_array = np.array(param._ivar._grad_ivar().value()
                                            .get_tensor())
                        dy_grad_value[param.name + core.grad_var_suffix(
                        )] = np_array

                optimizer.minimize(avg_loss)
M
minqiyang 已提交
291
                resnet.clear_gradients()
292 293

                dy_param_value = {}
M
minqiyang 已提交
294
                for param in resnet.parameters():
295
                    dy_param_value[param.name] = param.numpy()
M
minqiyang 已提交
296 297

        with new_program_scope():
M
minqiyang 已提交
298 299 300
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

M
minqiyang 已提交
301 302
            exe = fluid.Executor(fluid.CPUPlace(
            ) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
303

X
Xin Pan 已提交
304
            resnet = ResNet("resnet")
305
            optimizer = optimizer_setting(train_parameters)
M
minqiyang 已提交
306 307 308 309

            np.random.seed(seed)
            import random
            random.seed = seed
310
            train_reader = paddle.batch(
M
minqiyang 已提交
311 312
                paddle.dataset.flowers.train(use_xmap=False),
                batch_size=batch_size)
313 314 315 316 317 318 319 320 321 322 323 324

            img = fluid.layers.data(
                name='pixel', shape=[3, 224, 224], dtype='float32')
            label = fluid.layers.data(name='label', shape=[1], dtype='int64')
            out = resnet(img)
            loss = fluid.layers.cross_entropy(input=out, label=label)
            avg_loss = fluid.layers.mean(x=loss)
            optimizer.minimize(avg_loss)

            # initialize params and fetch them
            static_param_init_value = {}
            static_param_name_list = []
M
minqiyang 已提交
325
            static_grad_name_list = []
M
minqiyang 已提交
326
            for param in resnet.parameters():
327
                static_param_name_list.append(param.name)
M
minqiyang 已提交
328
            for param in resnet.parameters():
329
                if param.trainable:
M
minqiyang 已提交
330 331
                    static_grad_name_list.append(param.name +
                                                 core.grad_var_suffix())
332 333 334 335 336 337 338 339

            out = exe.run(fluid.default_startup_program(),
                          fetch_list=static_param_name_list)

            for i in range(len(static_param_name_list)):
                static_param_init_value[static_param_name_list[i]] = out[i]

            for batch_id, data in enumerate(train_reader()):
M
minqiyang 已提交
340
                if batch_id >= batch_num:
341 342
                    break

M
minqiyang 已提交
343
                static_x_data = np.array(
344 345 346 347
                    [x[0].reshape(3, 224, 224) for x in data]).astype('float32')
                y_data = np.array([x[1] for x in data]).astype('int64').reshape(
                    [batch_size, 1])

M
minqiyang 已提交
348
                fetch_list = [avg_loss.name]
349
                fetch_list.extend(static_param_name_list)
M
minqiyang 已提交
350
                fetch_list.extend(static_grad_name_list)
351
                out = exe.run(fluid.default_main_program(),
M
minqiyang 已提交
352
                              feed={"pixel": static_x_data,
353 354 355 356
                                    "label": y_data},
                              fetch_list=fetch_list)

                static_param_value = {}
M
minqiyang 已提交
357
                static_grad_value = {}
358
                static_out = out[0]
M
minqiyang 已提交
359 360 361 362 363 364 365 366 367 368 369 370 371 372
                param_start_pos = 1
                grad_start_pos = len(static_param_name_list) + param_start_pos
                for i in range(param_start_pos,
                               len(static_param_name_list) + param_start_pos):
                    static_param_value[static_param_name_list[
                        i - param_start_pos]] = out[i]
                for i in range(grad_start_pos,
                               len(static_grad_name_list) + grad_start_pos):
                    static_grad_value[static_grad_name_list[
                        i - grad_start_pos]] = out[i]

        self.assertTrue(np.allclose(static_out, dy_out))

        self.assertEqual(len(dy_param_init_value), len(static_param_init_value))
X
Xin Pan 已提交
373

M
minqiyang 已提交
374 375
        for key, value in six.iteritems(static_param_init_value):
            self.assertTrue(np.allclose(value, dy_param_init_value[key]))
376 377
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
378

M
minqiyang 已提交
379
        self.assertEqual(len(dy_grad_value), len(static_grad_value))
M
minqiyang 已提交
380
        for key, value in six.iteritems(static_grad_value):
M
minqiyang 已提交
381
            self.assertTrue(np.allclose(value, dy_grad_value[key]))
382 383
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
384

M
minqiyang 已提交
385
        self.assertEqual(len(dy_param_value), len(static_param_value))
M
minqiyang 已提交
386
        for key, value in six.iteritems(static_param_value):
387 388 389
            self.assertTrue(np.allclose(value, dy_param_value[key]))
            self.assertTrue(np.isfinite(value.all()))
            self.assertFalse(np.isnan(value.any()))
M
minqiyang 已提交
390 391 392 393


if __name__ == '__main__':
    unittest.main()