train.py 15.6 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import cProfile
import time
import os
import traceback

import numpy as np
Y
update  
Yancey1989 已提交
22
import torchvision_reader
Y
Yancey1989 已提交
23
import torch
Y
Yancey1989 已提交
24 25 26 27 28 29 30 31 32 33
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.profiler as profiler
import paddle.fluid.transpiler.distribute_transpiler as distribute_transpiler

import sys
sys.path.append("..")
from utility import add_arguments, print_arguments
import functools
Y
Yancey1989 已提交
34
from models.fast_resnet import FastResNet, lr_decay
Y
Yancey1989 已提交
35 36 37 38 39 40 41 42 43 44 45 46
import utils

def parse_args():
    parser = argparse.ArgumentParser(description=__doc__)
    add_arg = functools.partial(add_arguments, argparser=parser)
    # yapf: disable
    add_arg('total_images',     int,   1281167,              "Training image number.")
    add_arg('num_epochs',       int,   120,                  "number of epochs.")
    add_arg('image_shape',      str,   "3,224,224",          "input image size")
    add_arg('model_save_dir',   str,   "output",             "model save directory")
    add_arg('pretrained_model', str,   None,                 "Whether to use pretrained model.")
    add_arg('checkpoint',       str,   None,                 "Whether to resume checkpoint.")
Y
Yancey1989 已提交
47
    add_arg('lr',               float, 1.0,                  "set learning rate.")
Y
Yancey1989 已提交
48 49 50 51 52 53 54
    add_arg('lr_strategy',      str,   "piecewise_decay",    "Set the learning rate decay strategy.")
    add_arg('model',            str,   "FastResNet",         "Set the network to use.")
    add_arg('data_dir',         str,   "./data/ILSVRC2012",  "The ImageNet dataset root dir.")
    add_arg('model_category',   str,   "models",             "Whether to use models_name or not, valid value:'models','models_name'" )
    add_arg('fp16',             bool,  False,                "Enable half precision training with fp16." )
    add_arg('scale_loss',       float, 1.0,                  "Scale loss for fp16." )
    # for distributed
Y
Yancey1989 已提交
55 56 57 58 59 60
    add_arg('start_test_pass',  int,    0,                  "Start test after x passes.")
    add_arg('num_threads',      int,    8,                  "Use num_threads to run the fluid program.")
    add_arg('reduce_strategy',  str,    "allreduce",        "Choose from reduce or allreduce.")
    add_arg('log_period',       int,    30,                 "Print period, defualt is 5.")
    add_arg('memory_optimize',  bool,   True,               "Whether to enable memory optimize.")
    add_arg('best_acc5',        float,  0.93,               "The best acc5, default is 93%.")
Y
Yancey1989 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74
    # yapf: enable
    args = parser.parse_args()
    return args

def get_device_num():
    import subprocess
    visible_device = os.getenv('CUDA_VISIBLE_DEVICES')
    if visible_device:
        device_num = len(visible_device.split(','))
    else:
        device_num = subprocess.check_output(
            ['nvidia-smi', '-L']).decode().count('\n')
    return device_num

Y
Yancey1989 已提交
75
DEVICE_NUM = get_device_num()
Y
Yancey1989 已提交
76

Y
Yancey1989 已提交
77
def test_parallel(exe, test_args, args, test_reader, feeder, bs):
Y
Yancey1989 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91
    acc_evaluators = []
    for i in xrange(len(test_args[2])):
        acc_evaluators.append(fluid.metrics.Accuracy())

    to_fetch = [v.name for v in test_args[2]]
    batch_id = 0
    start_ts = time.time()
    for batch_id, data in enumerate(test_reader()):
        acc_rets = exe.run(fetch_list=to_fetch, feed=feeder.feed(data))
        ret_result = [np.mean(np.array(ret)) for ret in acc_rets]
        print("Test batch: [%d], acc_rets: [%s]" % (batch_id, ret_result))
        for i, e in enumerate(acc_evaluators):
            e.update(
                value=np.array(acc_rets[i]), weight=bs)
Y
Yancey1989 已提交
92 93
    num_samples = batch_id * bs * DEVICE_NUM
    print_train_time(start_ts, time.time(), num_samples)
Y
Yancey1989 已提交
94 95 96

    return [e.eval() for e in acc_evaluators]

Y
Yancey1989 已提交
97 98 99 100

def build_program(args, is_train, main_prog, startup_prog, py_reader_startup_prog, sz, trn_dir, bs, min_scale, rect_val=False):

    dshape=[3, sz, sz]
Y
Yancey1989 已提交
101
    class_dim=1000
Y
Yancey1989 已提交
102 103 104 105 106 107 108
    pyreader = None
    with fluid.program_guard(main_prog, startup_prog):
        with fluid.unique_name.guard():
            if is_train:
                with fluid.program_guard(main_prog, py_reader_startup_prog):
                    with fluid.unique_name.guard():
                        pyreader = fluid.layers.py_reader(
Y
Yancey1989 已提交
109
                            capacity=bs * DEVICE_NUM,
Y
Yancey1989 已提交
110 111
                            shapes=([-1] + dshape, (-1, 1)),
                            dtypes=('uint8', 'int64'),
Y
Yancey1989 已提交
112
                            name="train_reader_" + str(sz) if is_train else "test_reader_" + str(sz),
Y
Yancey1989 已提交
113 114 115 116 117 118 119 120 121
                            use_double_buffer=True)
                input, label = fluid.layers.read_file(pyreader)
            else:
                input = fluid.layers.data(name="image", shape=[3, 244, 244], dtype="uint8")
                label = fluid.layers.data(name="label", shape=[1], dtype="int64")
            cast_img_type = "float16" if args.fp16 else "float32"
            cast = fluid.layers.cast(input, cast_img_type)
            img_mean = fluid.layers.create_global_var([3, 1, 1], 0.0, cast_img_type, name="img_mean", persistable=True)
            img_std = fluid.layers.create_global_var([3, 1, 1], 0.0, cast_img_type, name="img_std", persistable=True)
Y
Yancey1989 已提交
122
            # image = (image - (mean * 255.0)) / (std * 255.0)
Y
Yancey1989 已提交
123 124 125
            t1 = fluid.layers.elementwise_sub(cast, img_mean, axis=1)
            t2 = fluid.layers.elementwise_div(t1, img_std, axis=1)

Y
Yancey1989 已提交
126 127
            model = FastResNet(is_train=is_train)
            predict = model.net(t2, class_dim=class_dim, img_size=sz)
Y
Yancey1989 已提交
128 129 130 131 132 133 134 135 136 137 138 139
            cost, pred = fluid.layers.softmax_with_cross_entropy(predict, label, return_softmax=True)
            if args.scale_loss > 1:
                avg_cost = fluid.layers.mean(x=cost) * float(args.scale_loss)
            else:
                avg_cost = fluid.layers.mean(x=cost)

            batch_acc1 = fluid.layers.accuracy(input=pred, label=label, k=1)
            batch_acc5 = fluid.layers.accuracy(input=pred, label=label, k=5)

            # configure optimize
            optimizer = None
            if is_train:
Y
Yancey1989 已提交
140 141 142
                total_images = args.total_images
                lr = args.lr

Y
Yancey1989 已提交
143
                epochs = [(0,7), (7,13), (13, 22), (22, 25), (25, 28)]
Y
Yancey1989 已提交
144 145 146 147 148
                bs_epoch = [bs*DEVICE_NUM for bs in [224, 224, 96, 96, 50]]
                bs_scale = [bs*1.0 / bs_epoch[0] for bs in bs_epoch]
                lrs = [(lr, lr*2), (lr*2, lr/4), (lr*bs_scale[2], lr/10*bs_scale[2]), (lr/10*bs_scale[2], lr/100*bs_scale[2]), (lr/100*bs_scale[4], lr/1000*bs_scale[4]), lr/1000*bs_scale[4]]

                boundaries, values = lr_decay(lrs, epochs, bs_epoch, total_images)
Y
Yancey1989 已提交
149

Y
Yancey1989 已提交
150
                optimizer = fluid.optimizer.Momentum(
Y
Yancey1989 已提交
151
                    learning_rate=fluid.layers.piecewise_decay(boundaries=boundaries, values=values),
Y
Yancey1989 已提交
152
                    momentum=0.9)
Y
Yancey1989 已提交
153 154 155 156 157 158 159 160 161
                if args.fp16:
                    params_grads = optimizer.backward(avg_cost)
                    master_params_grads = utils.create_master_params_grads(
                        params_grads, main_prog, startup_prog, args.scale_loss)
                    optimizer.apply_gradients(master_params_grads)
                    utils.master_param_to_train_param(master_params_grads, params_grads, main_prog)
                else:
                    optimizer.minimize(avg_cost)

Y
Yancey1989 已提交
162 163
                if args.memory_optimize:
                    fluid.memory_optimize(main_prog, skip_grads=True)
Y
Yancey1989 已提交
164

Y
Yancey1989 已提交
165
    return avg_cost, optimizer, [batch_acc1, batch_acc5], pyreader
Y
Yancey1989 已提交
166 167


Y
Yancey1989 已提交
168
def refresh_program(args, epoch, sz, trn_dir, bs, val_bs, need_update_start_prog=False, min_scale=0.08, rect_val=False):
Y
Yancey1989 已提交
169
    print('refresh program: epoch: [%d], image size: [%d], trn_dir: [%s], batch_size:[%d]' % (epoch, sz, trn_dir, bs))
Y
Yancey1989 已提交
170 171 172 173 174
    train_prog = fluid.Program()
    test_prog = fluid.Program()
    startup_prog = fluid.Program()
    py_reader_startup_prog = fluid.Program()

Y
Yancey1989 已提交
175 176 177 178
    train_args = build_program(args, True, train_prog, startup_prog, py_reader_startup_prog, sz, trn_dir, bs, min_scale)
    test_args = build_program(args, False, test_prog, startup_prog, py_reader_startup_prog, sz, trn_dir, val_bs, min_scale, rect_val=rect_val)

    place = core.CUDAPlace(0)
Y
Yancey1989 已提交
179 180 181 182 183
    startup_exe = fluid.Executor(place)
    startup_exe.run(py_reader_startup_prog)

    if need_update_start_prog:
        startup_exe.run(startup_prog)
Y
update  
Yancey1989 已提交
184 185 186 187 188 189 190 191 192 193
        conv2d_w_vars = [var for var in startup_prog.global_block().vars.values() if var.name.startswith('conv2d_')]
        for var in conv2d_w_vars:
            torch_w = torch.empty(var.shape)
            kaiming_np = torch.nn.init.kaiming_normal_(torch_w, mode='fan_out', nonlinearity='relu').numpy()
            tensor = fluid.global_scope().find_var(var.name).get_tensor()
            if args.fp16:
                tensor.set(np.array(kaiming_np, dtype="float16").view(np.uint16), place)
            else:
                tensor.set(np.array(kaiming_np, dtype="float32"), place)

Y
Yancey1989 已提交
194 195 196 197 198 199 200 201 202 203
        np_tensors = {}
        np_tensors["img_mean"] = np.array([0.485 * 255.0, 0.456 * 255.0, 0.406 * 255.0]).astype("float16" if args.fp16 else "float32").reshape((3, 1, 1))
        np_tensors["img_std"] = np.array([0.229 * 255.0, 0.224 * 255.0, 0.225 * 255.0]).astype("float16" if args.fp16 else "float32").reshape((3, 1, 1))
        for vname, np_tensor in np_tensors.items():
            var = fluid.global_scope().find_var(vname)
            if args.fp16:
                var.get_tensor().set(np_tensor.view(np.uint16), place)
            else:
                var.get_tensor().set(np_tensor, place)

Y
Yancey1989 已提交
204

Y
Yancey1989 已提交
205 206 207
    strategy = fluid.ExecutionStrategy()
    strategy.num_threads = args.num_threads
    strategy.allow_op_delay = False
Y
Yancey1989 已提交
208
    strategy.num_iteration_per_drop_scope = 1
Y
Yancey1989 已提交
209 210
    build_strategy = fluid.BuildStrategy()
    build_strategy.reduce_strategy = fluid.BuildStrategy().ReduceStrategy.AllReduce
Y
Yancey1989 已提交
211

Y
Yancey1989 已提交
212 213 214 215 216 217
    avg_loss = train_args[0]
    train_exe = fluid.ParallelExecutor(
        True,
        avg_loss.name,
        main_program=train_prog,
        exec_strategy=strategy,
Y
Yancey1989 已提交
218
        build_strategy=build_strategy)
Y
Yancey1989 已提交
219
    test_exe = fluid.ParallelExecutor(
Y
update  
Yancey1989 已提交
220
        True, main_program=test_prog, share_vars_from=train_exe)
Y
Yancey1989 已提交
221 222 223

    return train_args, test_args, test_prog, train_exe, test_exe

Y
Yancey1989 已提交
224 225 226 227 228 229 230 231 232 233 234 235
def prepare_reader(epoch_id, train_py_reader, train_bs, val_bs, trn_dir, img_dim, min_scale, rect_val):
    train_reader = torchvision_reader.train(
                traindir="/data/imagenet/%strain" % trn_dir, sz=img_dim, min_scale=min_scale, shuffle_seed=epoch_id+1)
    train_py_reader.decorate_paddle_reader(paddle.batch(train_reader, batch_size=train_bs))

    test_reader = torchvision_reader.test(
                valdir="/data/imagenet/%svalidation" % trn_dir, bs=val_bs*DEVICE_NUM, sz=img_dim, rect_val=rect_val)
    test_batched_reader = paddle.batch(test_reader, batch_size=val_bs * DEVICE_NUM)

    return test_batched_reader


Y
Yancey1989 已提交
236 237 238 239 240 241 242 243 244
# NOTE: only need to benchmark using parallelexe
def train_parallel(args):
    over_all_start = time.time()
    test_prog = fluid.Program()

    exe = None
    test_exe = None
    train_args = None
    test_args = None
Y
Yancey1989 已提交
245
    ## dynamic batch size, image size...
Y
Yancey1989 已提交
246 247
    bs = 224
    val_bs = 64
Y
Yancey1989 已提交
248 249 250 251
    trn_dir = "sz/160/"
    img_dim=128
    min_scale=0.08
    rect_val=False
Y
Yancey1989 已提交
252
    for epoch_id in range(args.num_epochs):
Y
Yancey1989 已提交
253
        # refresh program
Y
Yancey1989 已提交
254
        if epoch_id == 0:
Y
Yancey1989 已提交
255
            train_args, test_args, test_prog, exe, test_exe = refresh_program(args, epoch_id, sz=img_dim, trn_dir=trn_dir, bs=bs, val_bs=val_bs, need_update_start_prog=True)
Y
Yancey1989 已提交
256
        elif epoch_id == 13: #13
Y
Yancey1989 已提交
257
            bs = 96
Y
Yancey1989 已提交
258 259 260 261
            trn_dir="sz/352/"
            img_dim=224
            min_scale=0.087
            train_args, test_args, test_prog, exe, test_exe = refresh_program(args, epoch_id, sz=img_dim, trn_dir=trn_dir, bs=bs, val_bs=val_bs, min_scale=min_scale)
Y
Yancey1989 已提交
262
        elif epoch_id == 25: #25
Y
Yancey1989 已提交
263
            bs = 50
Y
Yancey1989 已提交
264
            val_bs=8
Y
Yancey1989 已提交
265 266 267 268 269
            trn_dir=""
            img_dim=288
            min_scale=0.5
            rect_val=True
            train_args, test_args, test_prog, exe, test_exe = refresh_program(args, epoch_id, sz=img_dim, trn_dir=trn_dir, bs=bs, val_bs=val_bs, min_scale=min_scale, rect_val=rect_val)
Y
Yancey1989 已提交
270 271 272 273 274 275 276
        else:
            pass

        avg_loss = train_args[0]
        num_samples = 0
        iters = 0
        start_time = time.time()
Y
Yancey1989 已提交
277 278 279
        train_py_reader = train_args[3]
        test_reader = prepare_reader(epoch_id, train_py_reader, bs, val_bs, trn_dir, img_dim=img_dim, min_scale=min_scale, rect_val=rect_val)
        train_py_reader.start() # start pyreader
Y
Yancey1989 已提交
280
        batch_start_time = time.time()
Y
Yancey1989 已提交
281 282 283 284 285
        while True:
            fetch_list = [avg_loss.name]
            acc_name_list = [v.name for v in train_args[2]]
            fetch_list.extend(acc_name_list)
            fetch_list.append("learning_rate")
Y
Yancey1989 已提交
286
            if iters % args.log_period == 0:
Y
Yancey1989 已提交
287 288 289 290 291 292 293 294 295 296 297 298
                should_print = True
            else:
                should_print = False

            fetch_ret = []
            try:
                if should_print:
                    fetch_ret = exe.run(fetch_list)
                else:
                    exe.run([])
            except fluid.core.EOFException as eof:
                print("Finish current epoch, will reset pyreader...")
Y
Yancey1989 已提交
299
                train_py_reader.reset()
Y
Yancey1989 已提交
300 301 302 303
                break
            except fluid.core.EnforceNotMet as ex:
                traceback.print_exc()
                exit(1)
Y
Yancey1989 已提交
304 305

            num_samples += bs * DEVICE_NUM
Y
Yancey1989 已提交
306 307 308

            if should_print:
                fetched_data = [np.mean(np.array(d)) for d in fetch_ret]
Y
Yancey1989 已提交
309
                print("Epoch %d, batch %d, loss %s, accucacys: %s, learning_rate %s, py_reader queue_size: %d, avg batch time: %0.4f secs" %
Y
Yancey1989 已提交
310
                      (epoch_id, iters, fetched_data[0], fetched_data[1:-1], fetched_data[-1], train_py_reader.queue.size(), (time.time() - batch_start_time)*1.0/args.log_period))
Y
Yancey1989 已提交
311
                batch_start_time = time.time()
Y
Yancey1989 已提交
312 313
            iters += 1

Y
Yancey1989 已提交
314
        print_train_time(start_time, time.time(), num_samples)
Y
Yancey1989 已提交
315
        feed_list = [test_prog.global_block().var(varname) for varname in ("image", "label")]
Y
Yancey1989 已提交
316
        test_feeder = fluid.DataFeeder(feed_list=feed_list, place=fluid.CUDAPlace(0))
Y
Yancey1989 已提交
317
        test_ret = test_parallel(test_exe, test_args, args, test_reader, test_feeder, val_bs)
Y
Yancey1989 已提交
318
        test_acc1, test_acc5 = [np.mean(np.array(v)) for v in test_ret]
Y
Yancey1989 已提交
319
        print("Epoch: %d, Test Accuracy: %s, Spend %.2f hours\n" %
Y
Yancey1989 已提交
320 321 322 323
            (epoch_id, [test_acc1, test_acc5], (time.time() - over_all_start) / 3600))
        if np.mean(np.array(test_ret[1])) > args.best_acc5:
            print("Achieve the best top-1 acc %f, top-5 acc: %f" % (test_acc1, test_acc5))
            break
Y
Yancey1989 已提交
324 325 326

    print("total train time: ", time.time() - over_all_start)

Y
Yancey1989 已提交
327 328

def print_train_time(start_time, end_time, num_samples):
Y
Yancey1989 已提交
329 330
    train_elapsed = end_time - start_time
    examples_per_sec = num_samples / train_elapsed
Y
Yancey1989 已提交
331 332
    print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
          (num_samples, train_elapsed, examples_per_sec))
Y
Yancey1989 已提交
333 334 335 336


def print_paddle_envs():
    print('----------- Configuration envs -----------')
Y
Yancey1989 已提交
337
    print("DEVICE_NUM: %d" % DEVICE_NUM)
Y
Yancey1989 已提交
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
    for k in os.environ:
        if "PADDLE_" in k:
            print "ENV %s:%s" % (k, os.environ[k])
    print('------------------------------------------------')


def main():
    args = parse_args()
    print_arguments(args)
    print_paddle_envs()
    train_parallel(args)


if __name__ == "__main__":
    main()