train.py 13.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
16 17
import argparse
import ast
18
import os
19 20 21
import paddle
import paddle.fluid as fluid
from paddle.fluid.layer_helper import LayerHelper
22
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
23 24
from paddle.fluid.dygraph.base import to_variable

H
Hongyu Liu 已提交
25 26
from paddle.fluid import framework

27 28
from paddle.distributed import fleet
from paddle.distributed.fleet.base import role_maker
H
Hongyu Liu 已提交
29 30
import math
import sys
H
hysunflower 已提交
31
import time
H
Hongyu Liu 已提交
32 33 34 35 36

IMAGENET1000 = 1281167
base_lr = 0.1
momentum_rate = 0.9
l2_decay = 1e-4
37 38


39
def parse_args():
D
Divano 已提交
40
    parser = argparse.ArgumentParser("Training for Resnet.")
41 42 43 44
    parser.add_argument(
        "--use_data_parallel",
        type=ast.literal_eval,
        default=False,
C
chengduo 已提交
45 46 47 48 49 50
        help="The flag indicating whether to use data parallel mode to train the model."
    )
    parser.add_argument(
        "-e", "--epoch", default=120, type=int, help="set epoch")
    parser.add_argument(
        "-b", "--batch_size", default=32, type=int, help="set epoch")
D
Divano 已提交
51
    parser.add_argument("--ce", action="store_true", help="run ce")
H
hysunflower 已提交
52 53 54
   
    # NOTE:used in benchmark
    parser.add_argument("--max_iter", default=0, type=int, help="the max iters to train, used in benchmark")
55 56 57 58 59
    args = parser.parse_args()
    return args


args = parse_args()
D
Divano 已提交
60
batch_size = args.batch_size
61

C
chengduo 已提交
62

63
def optimizer_setting(parameter_list=None):
H
Hongyu Liu 已提交
64 65 66 67 68 69 70 71 72 73

    total_images = IMAGENET1000

    step = int(math.ceil(float(total_images) / batch_size))

    epochs = [30, 60, 90]
    bd = [step * e for e in epochs]

    lr = []
    lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
74 75 76 77 78 79 80 81 82 83 84 85 86 87
    if fluid.in_dygraph_mode():
        optimizer = fluid.optimizer.Momentum(
            learning_rate=fluid.layers.piecewise_decay(
                boundaries=bd, values=lr),
            momentum=momentum_rate,
            regularization=fluid.regularizer.L2Decay(l2_decay),
            parameter_list=parameter_list)
    else:
        optimizer = fluid.optimizer.Momentum(
            learning_rate=fluid.layers.piecewise_decay(
                boundaries=bd, values=lr),
            momentum=momentum_rate,
            regularization=fluid.regularizer.L2Decay(l2_decay))
        
H
Hongyu Liu 已提交
88 89

    return optimizer
90 91 92 93 94 95 96 97 98 99


class ConvBNLayer(fluid.dygraph.Layer):
    def __init__(self,
                 num_channels,
                 num_filters,
                 filter_size,
                 stride=1,
                 groups=1,
                 act=None):
100
        super(ConvBNLayer, self).__init__()
101 102

        self._conv = Conv2D(
103
            num_channels=num_channels,
104 105 106 107 108 109
            num_filters=num_filters,
            filter_size=filter_size,
            stride=stride,
            padding=(filter_size - 1) // 2,
            groups=groups,
            act=None,
X
xiaoting 已提交
110
            bias_attr=False)
111

112
        self._batch_norm = BatchNorm(num_filters, act=act)
113 114 115 116 117 118 119 120 121 122 123 124 125 126

    def forward(self, inputs):
        y = self._conv(inputs)
        y = self._batch_norm(y)

        return y


class BottleneckBlock(fluid.dygraph.Layer):
    def __init__(self,
                 num_channels,
                 num_filters,
                 stride,
                 shortcut=True):
127
        super(BottleneckBlock, self).__init__()
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173

        self.conv0 = ConvBNLayer(
            num_channels=num_channels,
            num_filters=num_filters,
            filter_size=1,
            act='relu')
        self.conv1 = ConvBNLayer(
            num_channels=num_filters,
            num_filters=num_filters,
            filter_size=3,
            stride=stride,
            act='relu')
        self.conv2 = ConvBNLayer(
            num_channels=num_filters,
            num_filters=num_filters * 4,
            filter_size=1,
            act=None)

        if not shortcut:
            self.short = ConvBNLayer(
                num_channels=num_channels,
                num_filters=num_filters * 4,
                filter_size=1,
                stride=stride)

        self.shortcut = shortcut

        self._num_channels_out = num_filters * 4

    def forward(self, inputs):
        y = self.conv0(inputs)
        conv1 = self.conv1(y)
        conv2 = self.conv2(conv1)

        if self.shortcut:
            short = inputs
        else:
            short = self.short(inputs)

        y = fluid.layers.elementwise_add(x=short, y=conv2)

        layer_helper = LayerHelper(self.full_name(), act='relu')
        return layer_helper.append_activation(y)


class ResNet(fluid.dygraph.Layer):
174 175
    def __init__(self, layers=50, class_dim=102):
        super(ResNet, self).__init__()
176 177 178 179 180 181 182 183 184 185 186 187

        self.layers = layers
        supported_layers = [50, 101, 152]
        assert layers in supported_layers, \
            "supported layers are {} but input layer is {}".format(supported_layers, layers)

        if layers == 50:
            depth = [3, 4, 6, 3]
        elif layers == 101:
            depth = [3, 4, 23, 3]
        elif layers == 152:
            depth = [3, 8, 36, 3]
188
        num_channels = [64, 256, 512, 1024]
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
        num_filters = [64, 128, 256, 512]

        self.conv = ConvBNLayer(
            num_channels=3,
            num_filters=64,
            filter_size=7,
            stride=2,
            act='relu')
        self.pool2d_max = Pool2D(
            pool_size=3,
            pool_stride=2,
            pool_padding=1,
            pool_type='max')

        self.bottleneck_block_list = []
        for block in range(len(depth)):
            shortcut = False
            for i in range(depth[block]):
                bottleneck_block = self.add_sublayer(
                    'bb_%d_%d' % (block, i),
                    BottleneckBlock(
210 211
                        num_channels=num_channels[block]
                        if i == 0 else num_filters[block] * 4,
212 213 214 215 216 217 218
                        num_filters=num_filters[block],
                        stride=2 if i == 0 and block != 0 else 1,
                        shortcut=shortcut))
                self.bottleneck_block_list.append(bottleneck_block)
                shortcut = True

        self.pool2d_avg = Pool2D(
219 220 221
            pool_size=7, pool_type='avg', global_pooling=True)

        self.pool2d_avg_output = num_filters[len(num_filters) - 1] * 4 * 1 * 1
222 223 224 225

        import math
        stdv = 1.0 / math.sqrt(2048 * 1.0)

226 227
        self.out = Linear(self.pool2d_avg_output,
                      class_dim,
228 229 230 231 232 233 234 235 236 237
                      act='softmax',
                      param_attr=fluid.param_attr.ParamAttr(
                          initializer=fluid.initializer.Uniform(-stdv, stdv)))

    def forward(self, inputs):
        y = self.conv(inputs)
        y = self.pool2d_max(y)
        for bottleneck_block in self.bottleneck_block_list:
            y = bottleneck_block(y)
        y = self.pool2d_avg(y)
238
        y = fluid.layers.reshape(y, shape=[-1, self.pool2d_avg_output])
239 240 241 242
        y = self.out(y)
        return y


H
Hongyu Liu 已提交
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
def eval(model, data):

    model.eval()
    total_loss = 0.0
    total_acc1 = 0.0
    total_acc5 = 0.0
    total_sample = 0
    for batch_id, data in enumerate(data()):
        dy_x_data = np.array(
            [x[0].reshape(3, 224, 224) for x in data]).astype('float32')
        if len(np.array([x[1] for x in data]).astype('int64')) != batch_size:
            continue
        y_data = np.array([x[1] for x in data]).astype('int64').reshape(
            batch_size, 1)

        img = to_variable(dy_x_data)
        label = to_variable(y_data)
260
        label.stop_gradient = True
H
Hongyu Liu 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277

        out = model(img)
        #loss = fluid.layers.cross_entropy(input=out, label=label)
        #avg_loss = fluid.layers.mean(x=loss)

        acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
        acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)

        #dy_out = avg_loss.numpy()

        #total_loss += dy_out
        total_acc1 += acc_top1.numpy()
        total_acc5 += acc_top5.numpy()
        total_sample += 1

        # print("epoch id: %d, batch step: %d, loss: %f" % (eop, batch_id, dy_out))
        if batch_id % 10 == 0:
278 279
            print("test | batch step %d, acc1 %0.3f acc5 %0.3f" % \
                  ( batch_id, total_acc1 / total_sample, total_acc5 / total_sample))
D
Divano 已提交
280 281 282
    if args.ce:
        print("kpis\ttest_acc1\t%0.3f" % (total_acc1 / total_sample))
        print("kpis\ttest_acc5\t%0.3f" % (total_acc5 / total_sample))
283 284
    print("final eval acc1 %0.3f acc5 %0.3f" % \
          (total_acc1 / total_sample, total_acc5 / total_sample))
H
Hongyu Liu 已提交
285 286


287
def train_resnet():
D
Divano 已提交
288
    epoch = args.epoch
289 290 291 292 293
    if args.use_data_parallel:
        place_idx = int(os.environ['FLAGS_selected_gpus'])
        place = fluid.CUDAPlace(place_idx)
    else:
        place = fluid.CUDAPlace(0)
294
    with fluid.dygraph.guard(place):
D
Divano 已提交
295 296 297 298 299 300 301
        if args.ce:
            print("ce mode")
            seed = 33
            np.random.seed(seed)
            fluid.default_startup_program().random_seed = seed
            fluid.default_main_program().random_seed = seed

302 303
        resnet = ResNet()
        optimizer = optimizer_setting(parameter_list=resnet.parameters())
304 305

        if args.use_data_parallel:
306 307 308 309 310 311
            role = role_maker.PaddleCloudRoleMaker(is_collective=True)
            fleet.init(role)
            dist_strategy = fleet.DistributedStrategy()
            optimizer = fleet.distributed_optimizer(optimizer, dist_strategy)
            # call after distributed_optimizer so as to apply dist_strategy
            resnet = fleet.build_distributed_model(resnet)
312

313 314
        train_reader = paddle.batch(
            paddle.dataset.flowers.train(use_xmap=False), batch_size=batch_size)
315
        if args.use_data_parallel:
316 317
            train_reader = fluid.contrib.reader.distributed_batch_reader(
                train_reader)
H
Hongyu Liu 已提交
318 319 320 321 322 323

        test_reader = paddle.batch(
            paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size)

        #file_name = './model/epoch_0.npz'
        #model_data = np.load( file_name )
324

H
hysunflower 已提交
325 326 327
        #NOTE: used in benchmark 
        total_batch_num = 0

328
        for eop in range(epoch):
H
Hongyu Liu 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341

            resnet.train()
            total_loss = 0.0
            total_acc1 = 0.0
            total_acc5 = 0.0
            total_sample = 0

            #dict_state = resnet.state_dict()

            #resnet.load_dict( model_data )

            print("load finished")

342
            for batch_id, data in enumerate(train_reader()):
H
hysunflower 已提交
343 344 345 346 347 348

                #NOTE: used in benchmark
                if args.max_iter and total_batch_num == args.max_iter:
                    return
                batch_start = time.time()

349 350
                dy_x_data = np.array(
                    [x[0].reshape(3, 224, 224) for x in data]).astype('float32')
H
Hongyu Liu 已提交
351 352
                if len(np.array([x[1]
                                 for x in data]).astype('int64')) != batch_size:
353 354
                    continue
                y_data = np.array([x[1] for x in data]).astype('int64').reshape(
355
                    -1, 1)
356

357 358
                img = to_variable(dy_x_data)
                label = to_variable(y_data)
359
                label.stop_gradient = True
360

361 362 363
                out = resnet(img)
                loss = fluid.layers.cross_entropy(input=out, label=label)
                avg_loss = fluid.layers.mean(x=loss)
364

H
Hongyu Liu 已提交
365 366 367
                acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
                acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)

368
                dy_out = avg_loss.numpy()
369 370 371 372 373 374 375

                if args.use_data_parallel:
                    avg_loss = resnet.scale_loss(avg_loss)
                    avg_loss.backward()
                    resnet.apply_collective_grads()
                else:
                    avg_loss.backward()
376

377 378
                optimizer.minimize(avg_loss)
                resnet.clear_gradients()
379

H
hysunflower 已提交
380 381
                batch_end = time.time()
                train_batch_cost = batch_end - batch_start
H
Hongyu Liu 已提交
382 383 384 385
                total_loss += dy_out
                total_acc1 += acc_top1.numpy()
                total_acc5 += acc_top5.numpy()
                total_sample += 1
H
hysunflower 已提交
386
                total_batch_num = total_batch_num + 1 #this is for benchmark
H
Hongyu Liu 已提交
387 388
                #print("epoch id: %d, batch step: %d, loss: %f" % (eop, batch_id, dy_out))
                if batch_id % 10 == 0:
H
hysunflower 已提交
389
                    print( "epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f, batch cost: %.5f" % \
H
Hongyu Liu 已提交
390
                           ( eop, batch_id, total_loss / total_sample, \
H
hysunflower 已提交
391
                             total_acc1 / total_sample, total_acc5 / total_sample, train_batch_cost))
H
Hongyu Liu 已提交
392

D
Divano 已提交
393 394 395 396
            if args.ce:
                print("kpis\ttrain_acc1\t%0.3f" % (total_acc1 / total_sample))
                print("kpis\ttrain_acc5\t%0.3f" % (total_acc5 / total_sample))
                print("kpis\ttrain_loss\t%0.3f" % (total_loss / total_sample))
H
Hongyu Liu 已提交
397 398 399 400 401
            print("epoch %d | batch step %d, loss %0.3f acc1 %0.3f acc5 %0.3f" % \
                  (eop, batch_id, total_loss / total_sample, \
                   total_acc1 / total_sample, total_acc5 / total_sample))
            resnet.eval()
            eval(resnet, test_reader)
C
chengduo 已提交
402 403 404 405 406

            save_parameters = (not args.use_data_parallel) or (
                args.use_data_parallel and
                fluid.dygraph.parallel.Env().local_rank == 0)
            if save_parameters:
407
                fluid.save_dygraph(resnet.state_dict(),
C
chengduo 已提交
408
                                                'resnet_params')
409 410 411


if __name__ == '__main__':
412

413
    train_resnet()