dist_train.py 14.1 KB
Newer Older
Y
Yancey1989 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import time
import os
import traceback

import numpy as np

import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
25
import six
Y
Yancey1989 已提交
26 27 28
import sys
sys.path.append("..")
import models
Y
Yancey1989 已提交
29
from reader import train, val
Y
Yancey1989 已提交
30

T
typhoonzero 已提交
31 32 33 34 35
def parse_args():
    parser = argparse.ArgumentParser('Distributed Image Classification Training.')
    parser.add_argument(
        '--model',
        type=str,
T
fix  
typhoonzero 已提交
36
        default='DistResNet',
T
typhoonzero 已提交
37 38 39
        help='The model to run.')
    parser.add_argument(
        '--batch_size', type=int, default=32, help='The minibatch size per device.')
T
fix  
typhoonzero 已提交
40 41
    parser.add_argument(
        '--multi_batch_repeat', type=int, default=1, help='Batch merge repeats.')
T
typhoonzero 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    parser.add_argument(
        '--learning_rate', type=float, default=0.1, help='The learning rate.')
    parser.add_argument(
        '--pass_num', type=int, default=90, help='The number of passes.')
    parser.add_argument(
        '--data_format',
        type=str,
        default='NCHW',
        choices=['NCHW', 'NHWC'],
        help='The data data_format, now only support NCHW.')
    parser.add_argument(
        '--device',
        type=str,
        default='GPU',
        choices=['CPU', 'GPU'],
        help='The device type.')
    parser.add_argument(
        '--gpus',
        type=int,
        default=1,
        help='If gpus > 1, will use ParallelExecutor to run, else use Executor.')
    parser.add_argument(
        '--cpus',
        type=int,
        default=1,
        help='If cpus > 1, will set ParallelExecutor to use multiple threads.')
    parser.add_argument(
        '--no_test',
        action='store_true',
        help='If set, do not test the testset during training.')
    parser.add_argument(
        '--memory_optimize',
        action='store_true',
        help='If set, optimize runtime memory before start.')
    parser.add_argument(
        '--update_method',
        type=str,
        default='local',
        choices=['local', 'pserver', 'nccl2'],
        help='Choose parameter update method, can be local, pserver, nccl2.')
    parser.add_argument(
        '--no_split_var',
        action='store_true',
        default=False,
        help='Whether split variables into blocks when update_method is pserver')
    parser.add_argument(
        '--async_mode',
        action='store_true',
        default=False,
        help='Whether start pserver in async mode to support ASGD')
    parser.add_argument(
        '--reduce_strategy',
        type=str,
        choices=['reduce', 'all_reduce'],
        default='all_reduce',
        help='Specify the reduce strategy, can be reduce, all_reduce')
    parser.add_argument(
        '--data_dir',
        type=str,
        default="../data/ILSVRC2012",
        help="The ImageNet dataset root dir."
    )
    args = parser.parse_args()
    return args

Y
Yancey1989 已提交
107 108 109 110 111 112 113 114
def get_model(args, is_train, main_prog, startup_prog):
    pyreader = None
    class_dim = 1000
    if args.data_format == 'NCHW':
        dshape = [3, 224, 224]
    else:
        dshape = [224, 224, 3]
    if is_train:
Y
Yancey1989 已提交
115
        reader = train(data_dir=args.data_dir)
Y
Yancey1989 已提交
116
    else:
Y
Yancey1989 已提交
117
        reader = val(data_dir=args.data_dir)
Y
Yancey1989 已提交
118 119 120 121 122 123 124 125 126 127 128

    trainer_count = int(os.getenv("PADDLE_TRAINERS", "1"))
    with fluid.program_guard(main_prog, startup_prog):
        with fluid.unique_name.guard():
            pyreader = fluid.layers.py_reader(
                capacity=args.batch_size * args.gpus,
                shapes=([-1] + dshape, (-1, 1)),
                dtypes=('float32', 'int64'),
                name="train_reader" if is_train else "test_reader",
                use_double_buffer=True)
            input, label = fluid.layers.read_file(pyreader)
T
fix  
typhoonzero 已提交
129
            model_def = models.__dict__[args.model](layers=50, is_train=is_train)
Y
Yancey1989 已提交
130 131 132 133 134 135 136 137 138 139
            predict = model_def.net(input, class_dim=class_dim)

            cost = fluid.layers.cross_entropy(input=predict, label=label)
            avg_cost = fluid.layers.mean(x=cost)

            batch_acc1 = fluid.layers.accuracy(input=predict, label=label, k=1)
            batch_acc5 = fluid.layers.accuracy(input=predict, label=label, k=5)

            optimizer = None
            if is_train:
T
typhoonzero 已提交
140 141 142
                start_lr = args.learning_rate
                # n * worker * repeat
                end_lr = args.learning_rate * trainer_count * args.multi_batch_repeat
Y
Yancey1989 已提交
143
                total_images = 1281167 / trainer_count
T
typhoonzero 已提交
144 145 146
                step = int(total_images / (args.batch_size * args.gpus * args.multi_batch_repeat) + 1)
                warmup_steps = step * 5  # warmup 5 passes
                epochs = [30, 60, 80]
Y
Yancey1989 已提交
147
                bd = [step * e for e in epochs]
T
typhoonzero 已提交
148
                base_lr = end_lr
Y
Yancey1989 已提交
149 150
                lr = []
                lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
T
typhoonzero 已提交
151

Y
Yancey1989 已提交
152
                optimizer = fluid.optimizer.Momentum(
T
typhoonzero 已提交
153 154 155 156
                    learning_rate=models.learning_rate.lr_warmup(
                        fluid.layers.piecewise_decay(
                            boundaries=bd, values=lr),
                        warmup_steps, start_lr, end_lr),
Y
Yancey1989 已提交
157 158 159 160 161 162 163
                    momentum=0.9,
                    regularization=fluid.regularizer.L2Decay(1e-4))
                optimizer.minimize(avg_cost)

    batched_reader = None
    pyreader.decorate_paddle_reader(
        paddle.batch(
T
typhoonzero 已提交
164
            reader,
Y
Yancey1989 已提交
165 166 167 168 169 170
            batch_size=args.batch_size))

    return avg_cost, optimizer, [batch_acc1,
                                 batch_acc5], batched_reader, pyreader

def append_nccl2_prepare(trainer_id, startup_prog):
T
typhoonzero 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184
    trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
    port = os.getenv("PADDLE_PSERVER_PORT")
    worker_ips = os.getenv("PADDLE_TRAINER_IPS")
    worker_endpoints = []
    for ip in worker_ips.split(","):
        worker_endpoints.append(':'.join([ip, port]))
    current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port

    config = fluid.DistributeTranspilerConfig()
    config.mode = "nccl2"
    t = fluid.DistributeTranspiler(config=config)
    t.transpile(trainer_id, trainers=','.join(worker_endpoints),
        current_endpoint=current_endpoint,
        startup_program=startup_prog)
Y
Yancey1989 已提交
185 186 187 188 189 190 191 192 193 194 195 196 197


def dist_transpile(trainer_id, args, train_prog, startup_prog):
    port = os.getenv("PADDLE_PSERVER_PORT", "6174")
    pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
    eplist = []
    for ip in pserver_ips.split(","):
        eplist.append(':'.join([ip, port]))
    pserver_endpoints = ",".join(eplist)
    trainers = int(os.getenv("PADDLE_TRAINERS"))
    current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
    training_role = os.getenv("PADDLE_TRAINING_ROLE")

Y
Yancey1989 已提交
198
    config = fluid.DistributeTranspilerConfig()
Y
Yancey1989 已提交
199
    config.slice_var_up = not args.no_split_var
Y
Yancey1989 已提交
200
    t = fluid.DistributeTranspiler(config=config)
Y
Yancey1989 已提交
201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
    t.transpile(
        trainer_id,
        program=train_prog,
        pservers=pserver_endpoints,
        trainers=trainers,
        sync_mode=not args.async_mode,
        startup_program=startup_prog)
    if training_role == "PSERVER":
        pserver_program = t.get_pserver_program(current_endpoint)
        pserver_startup_program = t.get_startup_program(
            current_endpoint, pserver_program, startup_program=startup_prog)
        return pserver_program, pserver_startup_program
    elif training_role == "TRAINER":
        train_program = t.get_trainer_program()
        return train_program, startup_prog
    else:
        raise ValueError(
            'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
        )


T
typhoonzero 已提交
222
def test_parallel(exe, test_args, args, test_prog):
Y
Yancey1989 已提交
223
    acc_evaluators = []
224
    for i in six.moves.xrange(len(test_args[2])):
Y
Yancey1989 已提交
225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
        acc_evaluators.append(fluid.metrics.Accuracy())

    to_fetch = [v.name for v in test_args[2]]
    test_args[4].start()
    while True:
        try:
            acc_rets = exe.run(fetch_list=to_fetch)
            for i, e in enumerate(acc_evaluators):
                e.update(
                    value=np.array(acc_rets[i]), weight=args.batch_size)
        except fluid.core.EOFException as eof:
            test_args[4].reset()
            break

    return [e.eval() for e in acc_evaluators]

def train_parallel(train_args, test_args, args, train_prog, test_prog,
                   startup_prog, nccl_id_var, num_trainers, trainer_id):
    over_all_start = time.time()
    place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0)

    if nccl_id_var and trainer_id == 0:
        #FIXME(wuyi): wait other trainer to start listening
        time.sleep(30)

    startup_exe = fluid.Executor(place)
    startup_exe.run(startup_prog)
    strategy = fluid.ExecutionStrategy()
    strategy.num_threads = args.cpus
    strategy.allow_op_delay = False
    build_strategy = fluid.BuildStrategy()
    if args.reduce_strategy == "reduce":
        build_strategy.reduce_strategy = fluid.BuildStrategy(
        ).ReduceStrategy.Reduce
    else:
        build_strategy.reduce_strategy = fluid.BuildStrategy(
        ).ReduceStrategy.AllReduce

    avg_loss = train_args[0]

    if args.update_method == "pserver":
        # parameter server mode distributed training, merge
        # gradients on local server, do not initialize
        # ParallelExecutor with multi server all-reduce mode.
        num_trainers = 1
        trainer_id = 0

    exe = fluid.ParallelExecutor(
        True,
        avg_loss.name,
        main_program=train_prog,
        exec_strategy=strategy,
        build_strategy=build_strategy,
        num_trainers=num_trainers,
        trainer_id=trainer_id)

    if not args.no_test:
        if args.update_method == "pserver":
            test_scope = None
        else:
            test_scope = fluid.Scope()
        test_exe = fluid.ParallelExecutor(
T
typhoonzero 已提交
287 288
            True, main_program=test_prog, share_vars_from=exe,
            scope=test_scope)
Y
Yancey1989 已提交
289 290 291 292 293 294 295 296 297 298 299 300 301

    pyreader = train_args[4]
    for pass_id in range(args.pass_num):
        num_samples = 0
        start_time = time.time()
        batch_id = 0
        pyreader.start()
        while True:
            fetch_list = [avg_loss.name]
            acc_name_list = [v.name for v in train_args[2]]
            fetch_list.extend(acc_name_list)

            try:
T
typhoonzero 已提交
302 303 304 305
                if batch_id % 30 == 0:
                    fetch_ret = exe.run(fetch_list)
                else:
                    fetch_ret = exe.run([])
Y
Yancey1989 已提交
306 307 308 309 310 311 312
            except fluid.core.EOFException as eof:
                break
            except fluid.core.EnforceNotMet as ex:
                traceback.print_exc()
                break
            num_samples += args.batch_size * args.gpus

T
typhoonzero 已提交
313
            if batch_id % 30 == 0:
Y
Yancey1989 已提交
314 315 316 317 318 319
                fetched_data = [np.mean(np.array(d)) for d in fetch_ret]
                print("Pass %d, batch %d, loss %s, accucacys: %s" %
                      (pass_id, batch_id, fetched_data[0], fetched_data[1:]))
            batch_id += 1

        print_train_time(start_time, time.time(), num_samples)
T
typhoonzero 已提交
320
        pyreader.reset()
Y
Yancey1989 已提交
321 322

        if not args.no_test and test_args[2]:
T
typhoonzero 已提交
323
            test_ret = test_parallel(test_exe, test_args, args, test_prog)
Y
Yancey1989 已提交
324 325 326 327 328 329 330 331 332
            print("Pass: %d, Test Accuracy: %s\n" %
                  (pass_id, [np.mean(np.array(v)) for v in test_ret]))

    startup_exe.close()
    print("total train time: ", time.time() - over_all_start)


def print_arguments(args):
    print('----------- Configuration Arguments -----------')
333
    for arg, value in sorted(six.iteritems(vars(args))):
Y
Yancey1989 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
        print('%s: %s' % (arg, value))
    print('------------------------------------------------')


def print_train_time(start_time, end_time, num_samples):
    train_elapsed = end_time - start_time
    examples_per_sec = num_samples / train_elapsed
    print('\nTotal examples: %d, total time: %.5f, %.5f examples/sed\n' %
          (num_samples, train_elapsed, examples_per_sec))


def print_paddle_envs():
    print('----------- Configuration envs -----------')
    for k in os.environ:
        if "PADDLE_" in k:
349
            print("ENV %s:%s" % (k, os.environ[k]))
Y
Yancey1989 已提交
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400
    print('------------------------------------------------')


def main():
    args = parse_args()
    print_arguments(args)
    print_paddle_envs()

    # the unique trainer id, starting from 0, needed by trainer
    # only
    nccl_id_var, num_trainers, trainer_id = (
        None, 1, int(os.getenv("PADDLE_TRAINER_ID", "0")))

    train_prog = fluid.Program()
    test_prog = fluid.Program()
    startup_prog = fluid.Program()

    train_args = list(get_model(args, True, train_prog, startup_prog))
    test_args = list(get_model(args, False, test_prog, startup_prog))

    all_args = [train_args, test_args, args]

    if args.update_method == "pserver":
        train_prog, startup_prog = dist_transpile(trainer_id, args, train_prog,
                                                  startup_prog)
        if not train_prog:
            raise Exception(
                "Must configure correct environments to run dist train.")
        all_args.extend([train_prog, test_prog, startup_prog])
        if args.gpus > 1 and os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
            all_args.extend([nccl_id_var, num_trainers, trainer_id])
            train_parallel(*all_args)
        elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER":
            # start pserver with Executor
            server_exe = fluid.Executor(fluid.CPUPlace())
            server_exe.run(startup_prog)
            server_exe.run(train_prog)
        exit(0)

    # for other update methods, use default programs
    all_args.extend([train_prog, test_prog, startup_prog])

    if args.update_method == "nccl2":
        nccl_id_var, num_trainers, trainer_id = append_nccl2_prepare(
            trainer_id, startup_prog)

    all_args.extend([nccl_id_var, num_trainers, trainer_id])
    train_parallel(*all_args)

if __name__ == "__main__":
    main()