diff --git a/fluid/PaddleCV/image_classification/dist_train/README.md b/fluid/PaddleCV/image_classification/dist_train/README.md index a595a540adfa770253909e432e99a27228d5f062..0b2729cce4fa2e0780b8db5f87da49a8e221c665 100644 --- a/fluid/PaddleCV/image_classification/dist_train/README.md +++ b/fluid/PaddleCV/image_classification/dist_train/README.md @@ -7,13 +7,15 @@ large-scaled distributed training with two distributed mode: parameter server mo Before getting started, please make sure you have go throught the imagenet [Data Preparation](../README.md#data-preparation). -1. The entrypoint file is `dist_train.py`, some important flags are as follows: +1. The entrypoint file is `dist_train.py`, the commandline arguments are almost the same as the original `train.py`, with the following arguments specific to distributed training. - - `model`, the model to run with, default is the fine tune model `DistResnet`. - - `batch_size`, the batch_size per device. - `update_method`, specify the update method, can choose from local, pserver or nccl2. - - `device`, use CPU or GPU device. - - `gpus`, the GPU device count that the process used. + - `multi_batch_repeat`, set this greater than 1 to merge batches before pushing gradients to pservers. + - `start_test_pass`, when to start running tests. + - `num_threads`, how many threads will be used for ParallelExecutor. + - `split_var`, in pserver mode, whether to split one parameter to several pservers, default True. + - `async_mode`, do async training, defalt False. + - `reduce_strategy`, choose from "reduce", "allreduce". you can check out more details of the flags by `python dist_train.py --help`. @@ -21,66 +23,27 @@ Before getting started, please make sure you have go throught the imagenet [Data We use the environment variable to distinguish the different training role of a distributed training job. - - `PADDLE_TRAINING_ROLE`, the current training role, should be in [PSERVER, TRAINER]. - - `PADDLE_TRAINERS`, the trainer count of a job. - - `PADDLE_CURRENT_IP`, the current instance IP. - - `PADDLE_PSERVER_IPS`, the parameter server IP list, separated by "," only be used with update_method is pserver. - - `PADDLE_TRAINER_ID`, the unique trainer ID of a job, the ranging is [0, PADDLE_TRAINERS). - - `PADDLE_PSERVER_PORT`, the port of the parameter pserver listened on. - - `PADDLE_TRAINER_IPS`, the trainer IP list, separated by ",", only be used with upadte_method is nccl2. - -### Parameter Server Mode - -In this example, we launched 4 parameter server instances and 4 trainer instances in the cluster: - -1. launch parameter server process - - ``` bash - PADDLE_TRAINING_ROLE=PSERVER \ - PADDLE_TRAINERS=4 \ - PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \ - PADDLE_CURRENT_IP=192.168.0.100 \ - PADDLE_PSERVER_PORT=7164 \ - python dist_train.py \ - --model=DistResnet \ - --batch_size=32 \ - --update_method=pserver \ - --device=CPU \ - --data_dir=../data/ILSVRC2012 - ``` - -1. launch trainer process - - ``` bash - PADDLE_TRAINING_ROLE=TRAINER \ - PADDLE_TRAINERS=4 \ - PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \ - PADDLE_TRAINER_ID=0 \ - PADDLE_PSERVER_PORT=7164 \ - python dist_train.py \ - --model=DistResnet \ - --batch_size=32 \ - --update_method=pserver \ - --device=GPU \ - --data_dir=../data/ILSVRC2012 - ``` - -### NCCL2 Collective Mode - -1. launch trainer process - - ``` bash - PADDLE_TRAINING_ROLE=TRAINER \ - PADDLE_TRAINERS=4 \ - PADDLE_TRAINER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \ - PADDLE_TRAINER_ID=0 \ - python dist_train.py \ - --model=DistResnet \ - --batch_size=32 \ - --update_method=nccl2 \ - --device=GPU \ - --data_dir=../data/ILSVRC2012 - ``` + - General envs: + - `PADDLE_TRAINER_ID`, the unique trainer ID of a job, the ranging is [0, PADDLE_TRAINERS). + - `PADDLE_TRAINERS_NUM`, the trainer count of a distributed job. + - `PADDLE_CURRENT_ENDPOINT`, current process endpoint. + - Pserver mode: + - `PADDLE_TRAINING_ROLE`, the current training role, should be in [PSERVER, TRAINER]. + - `PADDLE_PSERVER_ENDPOINTS`, the parameter server endpoint list, separated by ",". + - NCCL2 mode: + - `PADDLE_TRAINER_ENDPOINTS`, endpoint list for each worker, separated by ",". + +### Try Out Different Distributed Training Modes + +You can test if distributed training works on a single node before deploying to the "real" cluster. + +***NOTE: for best performance, we recommend using multi-process mode, see No.3. And together with fp16.*** + +1. simply run `python dist_train.py` to start local training with default configuratioins. +2. for pserver mode, run `bash run_ps_mode.sh` to start 2 pservers and 2 trainers, these 2 trainers + will use GPU 0 and 1 to simulate 2 workers. +3. for nccl2 mode, run `bash run_nccl2_mode.sh` to start 2 workers. +4. for local/distributed multi-process mode, run `run_mp_mode.sh` (this test use 4 GPUs). ### Visualize the Training Process @@ -88,16 +51,10 @@ It's easy to draw the learning curve accroding to the training logs, for example the logs of ResNet50 is as follows: ``` text -Pass 0, batch 0, loss 7.0336914, accucacys: [0.0, 0.00390625] -Pass 0, batch 1, loss 7.094781, accucacys: [0.0, 0.0] -Pass 0, batch 2, loss 7.007068, accucacys: [0.0, 0.0078125] -Pass 0, batch 3, loss 7.1056547, accucacys: [0.00390625, 0.00390625] -Pass 0, batch 4, loss 7.133543, accucacys: [0.0, 0.0078125] -Pass 0, batch 5, loss 7.3055463, accucacys: [0.0078125, 0.01171875] -Pass 0, batch 6, loss 7.341838, accucacys: [0.0078125, 0.01171875] -Pass 0, batch 7, loss 7.290557, accucacys: [0.0, 0.0] -Pass 0, batch 8, loss 7.264951, accucacys: [0.0, 0.00390625] -Pass 0, batch 9, loss 7.43522, accucacys: [0.00390625, 0.00390625] +Pass 0, batch 30, loss 7.569439, acc1: 0.0125, acc5: 0.0125, avg batch time 0.1720 +Pass 0, batch 60, loss 7.027379, acc1: 0.0, acc5: 0.0, avg batch time 0.1551 +Pass 0, batch 90, loss 6.819984, acc1: 0.0, acc5: 0.0125, avg batch time 0.1492 +Pass 0, batch 120, loss 6.9076853, acc1: 0.0, acc5: 0.0125, avg batch time 0.1464 ``` The below figure shows top 1 train accuracy for local training with 8 GPUs and distributed training diff --git a/fluid/PaddleCV/image_classification/dist_train/batch_merge.py b/fluid/PaddleCV/image_classification/dist_train/batch_merge.py new file mode 100644 index 0000000000000000000000000000000000000000..7215cd586cb8ecf95a11b19e43106ad4aaea8029 --- /dev/null +++ b/fluid/PaddleCV/image_classification/dist_train/batch_merge.py @@ -0,0 +1,42 @@ +import paddle.fluid as fluid + +def copyback_repeat_bn_params(main_prog): + repeat_vars = set() + for op in main_prog.global_block().ops: + if op.type == "batch_norm": + repeat_vars.add(op.input("Mean")[0]) + repeat_vars.add(op.input("Variance")[0]) + for vname in repeat_vars: + real_var = fluid.global_scope().find_var("%s.repeat.0" % vname).get_tensor() + orig_var = fluid.global_scope().find_var(vname).get_tensor() + orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0 + +def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats): + repeat_vars = set() + for op in main_prog.global_block().ops: + if op.type == "batch_norm": + repeat_vars.add(op.input("Mean")[0]) + repeat_vars.add(op.input("Variance")[0]) + + for i in range(num_repeats): + for op in startup_prog.global_block().ops: + if op.type == "fill_constant": + for oname in op.output_arg_names: + if oname in repeat_vars: + var = startup_prog.global_block().var(oname) + repeat_var_name = "%s.repeat.%d" % (oname, i) + repeat_var = startup_prog.global_block().create_var( + name=repeat_var_name, + type=var.type, + dtype=var.dtype, + shape=var.shape, + persistable=var.persistable + ) + main_prog.global_block()._clone_variable(repeat_var) + startup_prog.global_block().append_op( + type="fill_constant", + inputs={}, + outputs={"Out": repeat_var}, + attrs=op.all_attrs() + ) + diff --git a/fluid/PaddleCV/image_classification/dist_train/dist_train.py b/fluid/PaddleCV/image_classification/dist_train/dist_train.py index 11e08aa89ccee3960f9fdf4751f89b4fdb7a2e7b..eb314085ea890ca8c6650a2c71d3d08f195a4def 100644 --- a/fluid/PaddleCV/image_classification/dist_train/dist_train.py +++ b/fluid/PaddleCV/image_classification/dist_train/dist_train.py @@ -16,6 +16,8 @@ import argparse import time import os import traceback +import functools +import subprocess import numpy as np @@ -28,127 +30,115 @@ sys.path.append("..") import models import utils from reader import train, val +from utility import add_arguments, print_arguments +from batch_merge import copyback_repeat_bn_params, append_bn_repeat_init_op +from dist_utils import pserver_prepare, nccl2_prepare +from env import dist_env def parse_args(): - parser = argparse.ArgumentParser('Distributed Image Classification Training.') - parser.add_argument( - '--model', - type=str, - default='DistResNet', - help='The model to run.') - parser.add_argument( - '--batch_size', type=int, default=32, help='The minibatch size per device.') - parser.add_argument( - '--multi_batch_repeat', type=int, default=1, help='Batch merge repeats.') - parser.add_argument( - '--learning_rate', type=float, default=0.1, help='The learning rate.') - parser.add_argument( - '--pass_num', type=int, default=90, help='The number of passes.') - parser.add_argument( - '--data_format', - type=str, - default='NCHW', - choices=['NCHW', 'NHWC'], - help='The data data_format, now only support NCHW.') - parser.add_argument( - '--device', - type=str, - default='GPU', - choices=['CPU', 'GPU'], - help='The device type.') - parser.add_argument( - '--gpus', - type=int, - default=1, - help='If gpus > 1, will use ParallelExecutor to run, else use Executor.') - parser.add_argument( - '--cpus', - type=int, - default=1, - help='If cpus > 1, will set ParallelExecutor to use multiple threads.') - parser.add_argument( - '--no_test', - action='store_true', - help='If set, do not test the testset during training.') - parser.add_argument( - '--memory_optimize', - action='store_true', - help='If set, optimize runtime memory before start.') - parser.add_argument( - '--update_method', - type=str, - default='local', - choices=['local', 'pserver', 'nccl2'], - help='Choose parameter update method, can be local, pserver, nccl2.') - parser.add_argument( - '--no_split_var', - action='store_true', - default=False, - help='Whether split variables into blocks when update_method is pserver') - parser.add_argument( - '--async_mode', - action='store_true', - default=False, - help='Whether start pserver in async mode to support ASGD') - parser.add_argument( - '--reduce_strategy', - type=str, - choices=['reduce', 'all_reduce'], - default='all_reduce', - help='Specify the reduce strategy, can be reduce, all_reduce') - parser.add_argument( - '--data_dir', - type=str, - default="../data/ILSVRC2012", - help="The ImageNet dataset root dir." - ) + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + # yapf: disable + add_arg('batch_size', int, 256, "Minibatch size.") + add_arg('use_gpu', bool, True, "Whether to use GPU or not.") + add_arg('total_images', int, 1281167, "Training image number.") + add_arg('num_epochs', int, 120, "number of epochs.") + add_arg('class_dim', int, 1000, "Class number.") + add_arg('image_shape', str, "3,224,224", "input image size") + add_arg('model_save_dir', str, "output", "model save directory") + add_arg('with_mem_opt', bool, False, "Whether to use memory optimization or not.") + add_arg('pretrained_model', str, None, "Whether to use pretrained model.") + add_arg('checkpoint', str, None, "Whether to resume checkpoint.") + add_arg('lr', float, 0.1, "set learning rate.") + add_arg('lr_strategy', str, "piecewise_decay", "Set the learning rate decay strategy.") + add_arg('model', str, "DistResNet", "Set the network to use.") + add_arg('enable_ce', bool, False, "If set True, enable continuous evaluation job.") + add_arg('data_dir', str, "./data/ILSVRC2012", "The ImageNet dataset root dir.") + add_arg('model_category', str, "models", "Whether to use models_name or not, valid value:'models','models_name'" ) + add_arg('fp16', bool, False, "Enable half precision training with fp16." ) + add_arg('scale_loss', float, 1.0, "Scale loss for fp16." ) + # for distributed + add_arg('update_method', str, "local", "Can be local, pserver, nccl2.") + add_arg('multi_batch_repeat', int, 1, "Batch merge repeats.") + add_arg('start_test_pass', int, 0, "Start test after x passes.") + add_arg('num_threads', int, 8, "Use num_threads to run the fluid program.") + add_arg('split_var', bool, True, "Split params on pserver.") + add_arg('async_mode', bool, False, "Async distributed training, only for pserver mode.") + add_arg('reduce_strategy', str, "allreduce", "Choose from reduce or allreduce.") + # yapf: enable args = parser.parse_args() return args -def get_model(args, is_train, main_prog, startup_prog): - pyreader = None - class_dim = 1000 - if args.data_format == 'NCHW': - dshape = [3, 224, 224] +def get_device_num(): + if os.getenv("CPU_NUM"): + return int(os.getenv("CPU_NUM")) + visible_device = os.getenv('CUDA_VISIBLE_DEVICES') + if visible_device: + device_num = len(visible_device.split(',')) else: - dshape = [224, 224, 3] + device_num = subprocess.check_output(['nvidia-smi', '-L']).decode().count('\n') + return device_num + +def prepare_reader(is_train, pyreader, args): if is_train: reader = train(data_dir=args.data_dir) else: reader = val(data_dir=args.data_dir) + if is_train: + bs = args.batch_size / get_device_num() + else: + bs = 16 + pyreader.decorate_paddle_reader( + paddle.batch( + reader, + batch_size=bs)) + +def build_program(is_train, main_prog, startup_prog, args): + pyreader = None + class_dim = args.class_dim + image_shape = [int(m) for m in args.image_shape.split(",")] - trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) + trainer_count = args.dist_env["num_trainers"] with fluid.program_guard(main_prog, startup_prog): + pyreader = fluid.layers.py_reader( + capacity=16, + shapes=([-1] + image_shape, (-1, 1)), + dtypes=('float32', 'int64'), + name="train_reader" if is_train else "test_reader", + use_double_buffer=True) with fluid.unique_name.guard(): - pyreader = fluid.layers.py_reader( - capacity=args.batch_size * args.gpus, - shapes=([-1] + dshape, (-1, 1)), - dtypes=('float32', 'int64'), - name="train_reader" if is_train else "test_reader", - use_double_buffer=True) - input, label = fluid.layers.read_file(pyreader) + image, label = fluid.layers.read_file(pyreader) + if args.fp16: + image = fluid.layers.cast(image, "float16") model_def = models.__dict__[args.model](layers=50, is_train=is_train) - predict = model_def.net(input, class_dim=class_dim) + predict = model_def.net(image, class_dim=class_dim) + cost, pred = fluid.layers.softmax_with_cross_entropy(predict, label, return_softmax=True) + if args.scale_loss > 1: + avg_cost = fluid.layers.mean(x=cost) * float(args.scale_loss) + else: + avg_cost = fluid.layers.mean(x=cost) - cost = fluid.layers.cross_entropy(input=predict, label=label) - avg_cost = fluid.layers.mean(x=cost) - - batch_acc1 = fluid.layers.accuracy(input=predict, label=label, k=1) - batch_acc5 = fluid.layers.accuracy(input=predict, label=label, k=5) + batch_acc1 = fluid.layers.accuracy(input=pred, label=label, k=1) + batch_acc5 = fluid.layers.accuracy(input=pred, label=label, k=5) optimizer = None if is_train: - start_lr = args.learning_rate + start_lr = args.lr # n * worker * repeat - end_lr = args.learning_rate * trainer_count * args.multi_batch_repeat - total_images = 1281167 / trainer_count - step = int(total_images / (args.batch_size * args.gpus * args.multi_batch_repeat) + 1) + end_lr = args.lr * trainer_count * args.multi_batch_repeat + total_images = args.total_images / trainer_count + step = int(total_images / (args.batch_size * args.multi_batch_repeat) + 1) warmup_steps = step * 5 # warmup 5 passes epochs = [30, 60, 80] bd = [step * e for e in epochs] base_lr = end_lr lr = [] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] + print("start lr: %s, end lr: %s, decay boundaries: %s" % ( + start_lr, + end_lr, + bd + )) # NOTE: we put weight decay in layers config, and remove # weight decay on bn layers, so don't add weight decay in @@ -159,151 +149,77 @@ def get_model(args, is_train, main_prog, startup_prog): boundaries=bd, values=lr), warmup_steps, start_lr, end_lr), momentum=0.9) - optimizer.minimize(avg_cost) + if args.fp16: + params_grads = optimizer.backward(avg_cost) + master_params_grads = utils.create_master_params_grads( + params_grads, main_prog, startup_prog, args.scale_loss) + optimizer.apply_gradients(master_params_grads) + utils.master_param_to_train_param(master_params_grads, params_grads, main_prog) + else: + optimizer.minimize(avg_cost) - batched_reader = None - pyreader.decorate_paddle_reader( - paddle.batch( - reader, - batch_size=args.batch_size)) - - return avg_cost, optimizer, [batch_acc1, - batch_acc5], batched_reader, pyreader - -def append_nccl2_prepare(trainer_id, startup_prog): - trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) - port = os.getenv("PADDLE_PSERVER_PORT") - worker_ips = os.getenv("PADDLE_TRAINER_IPS") - worker_endpoints = [] - for ip in worker_ips.split(","): - worker_endpoints.append(':'.join([ip, port])) - current_endpoint = os.getenv("PADDLE_CURRENT_IP") + ":" + port - num_trainers = len(worker_endpoints) - - config = fluid.DistributeTranspilerConfig() - config.mode = "nccl2" - t = fluid.DistributeTranspiler(config=config) - t.transpile(trainer_id, trainers=','.join(worker_endpoints), - current_endpoint=current_endpoint, - startup_program=startup_prog) - return num_trainers, trainer_id - - -def dist_transpile(trainer_id, args, train_prog, startup_prog): - port = os.getenv("PADDLE_PSERVER_PORT", "6174") - pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "") - eplist = [] - for ip in pserver_ips.split(","): - eplist.append(':'.join([ip, port])) - pserver_endpoints = ",".join(eplist) - trainers = int(os.getenv("PADDLE_TRAINERS")) - current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port - training_role = os.getenv("PADDLE_TRAINING_ROLE") - - config = fluid.DistributeTranspilerConfig() - config.slice_var_up = not args.no_split_var - t = fluid.DistributeTranspiler(config=config) - t.transpile( - trainer_id, - program=train_prog, - pservers=pserver_endpoints, - trainers=trainers, - sync_mode=not args.async_mode, - startup_program=startup_prog) - if training_role == "PSERVER": - pserver_program = t.get_pserver_program(current_endpoint) - pserver_startup_program = t.get_startup_program( - current_endpoint, pserver_program, startup_program=startup_prog) - return pserver_program, pserver_startup_program - elif training_role == "TRAINER": - train_program = t.get_trainer_program() - return train_program, startup_prog - else: - raise ValueError( - 'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER' - ) - -def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats): - repeat_vars = set() - for op in main_prog.global_block().ops: - if op.type == "batch_norm": - repeat_vars.add(op.input("Mean")[0]) - repeat_vars.add(op.input("Variance")[0]) - - for i in range(num_repeats): - for op in startup_prog.global_block().ops: - if op.type == "fill_constant": - for oname in op.output_arg_names: - if oname in repeat_vars: - var = startup_prog.global_block().var(oname) - repeat_var_name = "%s.repeat.%d" % (oname, i) - repeat_var = startup_prog.global_block().create_var( - name=repeat_var_name, - type=var.type, - dtype=var.dtype, - shape=var.shape, - persistable=var.persistable - ) - main_prog.global_block()._clone_variable(repeat_var) - startup_prog.global_block().append_op( - type="fill_constant", - inputs={}, - outputs={"Out": repeat_var}, - attrs=op.all_attrs() - ) - - -def copyback_repeat_bn_params(main_prog): - repeat_vars = set() - for op in main_prog.global_block().ops: - if op.type == "batch_norm": - repeat_vars.add(op.input("Mean")[0]) - repeat_vars.add(op.input("Variance")[0]) - for vname in repeat_vars: - real_var = fluid.global_scope().find_var("%s.repeat.0" % vname).get_tensor() - orig_var = fluid.global_scope().find_var(vname).get_tensor() - orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0 - - -def test_single(exe, test_args, args, test_prog): - acc_evaluators = [] - for i in xrange(len(test_args[2])): - acc_evaluators.append(fluid.metrics.Accuracy()) - - to_fetch = [v.name for v in test_args[2]] - test_args[4].start() + # prepare reader for current program + prepare_reader(is_train, pyreader, args) + + return pyreader, avg_cost, batch_acc1, batch_acc5 + + +def test_single(exe, test_prog, args, pyreader, fetch_list): + acc1 = fluid.metrics.Accuracy() + acc5 = fluid.metrics.Accuracy() + test_losses = [] + pyreader.start() while True: try: - acc_rets = exe.run(program=test_prog, fetch_list=to_fetch) - for i, e in enumerate(acc_evaluators): - e.update( - value=np.array(acc_rets[i]), weight=args.batch_size) - except fluid.core.EOFException as eof: - test_args[4].reset() + acc_rets = exe.run(program=test_prog, fetch_list=fetch_list) + test_losses.append(acc_rets[0]) + acc1.update(value=np.array(acc_rets[1]), weight=args.batch_size) + acc5.update(value=np.array(acc_rets[2]), weight=args.batch_size) + except fluid.core.EOFException: + pyreader.reset() break + test_avg_loss = np.mean(np.array(test_losses)) + return test_avg_loss, np.mean(acc1.eval()), np.mean(acc5.eval()) - return [e.eval() for e in acc_evaluators] +def run_pserver(train_prog, startup_prog): + server_exe = fluid.Executor(fluid.CPUPlace()) + server_exe.run(startup_prog) + server_exe.run(train_prog) +def train_parallel(args): + train_prog = fluid.Program() + test_prog = fluid.Program() + startup_prog = fluid.Program() -def train_parallel(train_args, test_args, args, train_prog, test_prog, - startup_prog, num_trainers, trainer_id): - over_all_start = time.time() - place = core.CPUPlace() if args.device == 'CPU' else core.CUDAPlace(0) + train_pyreader, train_cost, train_acc1, train_acc5 = build_program(True, train_prog, startup_prog, args) + test_pyreader, test_cost, test_acc1, test_acc5 = build_program(False, test_prog, startup_prog, args) + + if args.update_method == "pserver": + train_prog, startup_prog = pserver_prepare(args, train_prog, startup_prog) + elif args.update_method == "nccl2": + nccl2_prepare(args, startup_prog) - if args.update_method == "nccl2" and trainer_id == 0: - #FIXME(typhoonzero): wait other trainer to start listening - time.sleep(30) + if args.dist_env["training_role"] == "PSERVER": + run_pserver(train_prog, startup_prog) + exit(0) + + if args.use_gpu: + # NOTE: for multi process mode: one process per GPU device. + gpu_id = 0 + if os.getenv("FLAGS_selected_gpus"): + gpu_id = int(os.getenv("FLAGS_selected_gpus")) + place = core.CUDAPlace(gpu_id) if args.use_gpu else core.CPUPlace() startup_exe = fluid.Executor(place) if args.multi_batch_repeat > 1: append_bn_repeat_init_op(train_prog, startup_prog, args.multi_batch_repeat) startup_exe.run(startup_prog) + strategy = fluid.ExecutionStrategy() - strategy.num_threads = args.cpus - strategy.allow_op_delay = False + strategy.num_threads = args.num_threads build_strategy = fluid.BuildStrategy() if args.multi_batch_repeat > 1: - pass_builder = build_strategy._create_passes_from_strategy() + pass_builder = build_strategy._finalize_strategy_and_create_passes() mypass = pass_builder.insert_pass( len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") mypass.set_int("num_repeats", args.multi_batch_repeat) @@ -314,73 +230,65 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, build_strategy.reduce_strategy = fluid.BuildStrategy( ).ReduceStrategy.AllReduce - avg_loss = train_args[0] - - if args.update_method == "pserver": + if args.update_method == "pserver" or args.update_method == "local": # parameter server mode distributed training, merge # gradients on local server, do not initialize # ParallelExecutor with multi server all-reduce mode. num_trainers = 1 trainer_id = 0 + else: + num_trainers = args.dist_env["num_trainers"] + trainer_id = args.dist_env["trainer_id"] exe = fluid.ParallelExecutor( True, - avg_loss.name, + train_cost.name, main_program=train_prog, exec_strategy=strategy, build_strategy=build_strategy, num_trainers=num_trainers, trainer_id=trainer_id) - pyreader = train_args[4] - for pass_id in range(args.pass_num): + over_all_start = time.time() + fetch_list = [train_cost.name, train_acc1.name, train_acc5.name] + for pass_id in range(args.num_epochs): num_samples = 0 start_time = time.time() - batch_id = 0 - pyreader.start() + batch_id = 1 + train_pyreader.start() while True: - fetch_list = [avg_loss.name] - acc_name_list = [v.name for v in train_args[2]] - fetch_list.extend(acc_name_list) try: if batch_id % 30 == 0: fetch_ret = exe.run(fetch_list) + fetched_data = [np.mean(np.array(d)) for d in fetch_ret] + print("Pass %d, batch %d, loss %s, acc1: %s, acc5: %s, avg batch time %.4f" % + (pass_id, batch_id, fetched_data[0], fetched_data[1], + fetched_data[2], (time.time()-start_time) / batch_id)) else: fetch_ret = exe.run([]) - except fluid.core.EOFException as eof: + except fluid.core.EOFException: break - except fluid.core.EnforceNotMet as ex: + except fluid.core.EnforceNotMet: traceback.print_exc() break - num_samples += args.batch_size * args.gpus - - if batch_id % 30 == 0: - fetched_data = [np.mean(np.array(d)) for d in fetch_ret] - print("Pass %d, batch %d, loss %s, accucacys: %s" % - (pass_id, batch_id, fetched_data[0], fetched_data[1:])) + num_samples += args.batch_size batch_id += 1 print_train_time(start_time, time.time(), num_samples) - pyreader.reset() + train_pyreader.reset() - if not args.no_test and test_args[2]: + if pass_id > args.start_test_pass: if args.multi_batch_repeat > 1: copyback_repeat_bn_params(train_prog) - test_ret = test_single(startup_exe, test_args, args, test_prog) - print("Pass: %d, Test Accuracy: %s\n" % - (pass_id, [np.mean(np.array(v)) for v in test_ret])) + test_fetch_list = [test_cost.name, test_acc1.name, test_acc5.name] + test_ret = test_single(startup_exe, test_prog, args, test_pyreader,test_fetch_list) + print("Pass: %d, Test Loss %s, test acc1: %s, test acc5: %s\n" % + (pass_id, test_ret[0], test_ret[1], test_ret[2])) startup_exe.close() print("total train time: ", time.time() - over_all_start) -def print_arguments(args): - print('----------- Configuration Arguments -----------') - for arg, value in sorted(six.iteritems(vars(args))): - print('%s: %s' % (arg, value)) - print('------------------------------------------------') - - def print_train_time(start_time, end_time, num_samples): train_elapsed = end_time - start_time examples_per_sec = num_samples / train_elapsed @@ -400,47 +308,8 @@ def main(): args = parse_args() print_arguments(args) print_paddle_envs() - - # the unique trainer id, starting from 0, needed by trainer - # only - num_trainers, trainer_id = ( - 1, int(os.getenv("PADDLE_TRAINER_ID", "0"))) - - train_prog = fluid.Program() - test_prog = fluid.Program() - startup_prog = fluid.Program() - - train_args = list(get_model(args, True, train_prog, startup_prog)) - test_args = list(get_model(args, False, test_prog, startup_prog)) - - all_args = [train_args, test_args, args] - - if args.update_method == "pserver": - train_prog, startup_prog = dist_transpile(trainer_id, args, train_prog, - startup_prog) - if not train_prog: - raise Exception( - "Must configure correct environments to run dist train.") - all_args.extend([train_prog, test_prog, startup_prog]) - if os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER": - all_args.extend([num_trainers, trainer_id]) - train_parallel(*all_args) - elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER": - # start pserver with Executor - server_exe = fluid.Executor(fluid.CPUPlace()) - server_exe.run(startup_prog) - server_exe.run(train_prog) - exit(0) - - # for other update methods, use default programs - all_args.extend([train_prog, test_prog, startup_prog]) - - if args.update_method == "nccl2": - num_trainers, trainer_id = append_nccl2_prepare( - trainer_id, startup_prog) - - all_args.extend([num_trainers, trainer_id]) - train_parallel(*all_args) + args.dist_env = dist_env() + train_parallel(args) if __name__ == "__main__": main() diff --git a/fluid/PaddleCV/image_classification/dist_train/dist_utils.py b/fluid/PaddleCV/image_classification/dist_train/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51007273f717fe815d684aaae7c02b3d7245c4e7 --- /dev/null +++ b/fluid/PaddleCV/image_classification/dist_train/dist_utils.py @@ -0,0 +1,43 @@ +import os +import paddle.fluid as fluid + + +def nccl2_prepare(args, startup_prog): + config = fluid.DistributeTranspilerConfig() + config.mode = "nccl2" + t = fluid.DistributeTranspiler(config=config) + + envs = args.dist_env + + t.transpile(envs["trainer_id"], + trainers=','.join(envs["trainer_endpoints"]), + current_endpoint=envs["current_endpoint"], + startup_program=startup_prog) + + +def pserver_prepare(args, train_prog, startup_prog): + config = fluid.DistributeTranspilerConfig() + config.slice_var_up = args.split_var + t = fluid.DistributeTranspiler(config=config) + envs = args.dist_env + training_role = envs["training_role"] + + t.transpile( + envs["trainer_id"], + program=train_prog, + pservers=envs["pserver_endpoints"], + trainers=envs["num_trainers"], + sync_mode=not args.async_mode, + startup_program=startup_prog) + if training_role == "PSERVER": + pserver_program = t.get_pserver_program(envs["current_endpoint"]) + pserver_startup_program = t.get_startup_program( + envs["current_endpoint"], pserver_program, startup_program=startup_prog) + return pserver_program, pserver_startup_program + elif training_role == "TRAINER": + train_program = t.get_trainer_program() + return train_program, startup_prog + else: + raise ValueError( + 'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER' + ) diff --git a/fluid/PaddleCV/image_classification/dist_train/env.py b/fluid/PaddleCV/image_classification/dist_train/env.py new file mode 100644 index 0000000000000000000000000000000000000000..f85297e4d3e24322176ad25ee34366f446e18896 --- /dev/null +++ b/fluid/PaddleCV/image_classification/dist_train/env.py @@ -0,0 +1,33 @@ +import os + + +def dist_env(): + """ + Return a dict of all variable that distributed training may use. + NOTE: you may rewrite this function to suit your cluster environments. + """ + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + num_trainers = 1 + training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER") + assert(training_role == "PSERVER" or training_role == "TRAINER") + + # - PADDLE_TRAINER_ENDPOINTS means nccl2 mode. + # - PADDLE_PSERVER_ENDPOINTS means pserver mode. + # - PADDLE_CURRENT_ENDPOINT means current process endpoint. + trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS") + pserver_endpoints = os.getenv("PADDLE_PSERVER_ENDPOINTS") + current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT") + if trainer_endpoints: + trainer_endpoints = trainer_endpoints.split(",") + num_trainers = len(trainer_endpoints) + elif pserver_endpoints: + num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM")) + + return { + "trainer_id": trainer_id, + "num_trainers": num_trainers, + "current_endpoint": current_endpoint, + "training_role": training_role, + "pserver_endpoints": pserver_endpoints, + "trainer_endpoints": trainer_endpoints + } diff --git a/fluid/PaddleCV/image_classification/dist_train/run_mp_mode.sh b/fluid/PaddleCV/image_classification/dist_train/run_mp_mode.sh new file mode 100755 index 0000000000000000000000000000000000000000..bf04e078284f02be0774209a599b839d0bbf20f5 --- /dev/null +++ b/fluid/PaddleCV/image_classification/dist_train/run_mp_mode.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +# Test using 4 GPUs +export CUDA_VISIBLE_DEVICES="0,1,2,3" +export MODEL="DistResNet" +export PADDLE_TRAINER_ENDPOINTS="127.0.0.1:7160,127.0.0.1:7161,127.0.0.1:7162,127.0.0.1:7163" +# PADDLE_TRAINERS_NUM is used only for reader when nccl2 mode +export PADDLE_TRAINERS_NUM="4" + +mkdir -p logs + +for i in {0..3} +do +PADDLE_TRAINING_ROLE="TRAINER" \ +PADDLE_CURRENT_ENDPOINT="127.0.0.1:716${i}" \ +PADDLE_TRAINER_ID="${i}" \ +FLAGS_selected_gpus="${i}" \ +python dist_train.py --model $MODEL --update_method nccl2 --batch_size 32 --fp16 1 --scale_loss 8 &> logs/tr$i.log & +done diff --git a/fluid/PaddleCV/image_classification/dist_train/run_nccl2_mode.sh b/fluid/PaddleCV/image_classification/dist_train/run_nccl2_mode.sh new file mode 100755 index 0000000000000000000000000000000000000000..120a96647e093de6af362bd51d8e6942249db56f --- /dev/null +++ b/fluid/PaddleCV/image_classification/dist_train/run_nccl2_mode.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +export MODEL="DistResNet" +export PADDLE_TRAINER_ENDPOINTS="127.0.0.1:7160,127.0.0.1:7161" +# PADDLE_TRAINERS_NUM is used only for reader when nccl2 mode +export PADDLE_TRAINERS_NUM="2" + +mkdir -p logs + +# NOTE: set NCCL_P2P_DISABLE so that can run nccl2 distribute train on one node. + +PADDLE_TRAINING_ROLE="TRAINER" \ +PADDLE_CURRENT_ENDPOINT="127.0.0.1:7160" \ +PADDLE_TRAINER_ID="0" \ +CUDA_VISIBLE_DEVICES="0" \ +NCCL_P2P_DISABLE="1" \ +python dist_train.py --model $MODEL --update_method nccl2 --batch_size 32 &> logs/tr0.log & + +PADDLE_TRAINING_ROLE="TRAINER" \ +PADDLE_CURRENT_ENDPOINT="127.0.0.1:7161" \ +PADDLE_TRAINER_ID="1" \ +CUDA_VISIBLE_DEVICES="1" \ +NCCL_P2P_DISABLE="1" \ +python dist_train.py --model $MODEL --update_method nccl2 --batch_size 32 &> logs/tr1.log & diff --git a/fluid/PaddleCV/image_classification/dist_train/run_ps_mode.sh b/fluid/PaddleCV/image_classification/dist_train/run_ps_mode.sh new file mode 100755 index 0000000000000000000000000000000000000000..99926afbb04e0bc2795a4fd7fd8b4ff58aefec31 --- /dev/null +++ b/fluid/PaddleCV/image_classification/dist_train/run_ps_mode.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +export MODEL="DistResNet" +export PADDLE_PSERVER_ENDPOINTS="127.0.0.1:7160,127.0.0.1:7161" +export PADDLE_TRAINERS_NUM="2" + +mkdir -p logs + +PADDLE_TRAINING_ROLE="PSERVER" \ +PADDLE_CURRENT_ENDPOINT="127.0.0.1:7160" \ +python dist_train.py --model $MODEL --update_method pserver --batch_size 32 &> logs/ps0.log & + +PADDLE_TRAINING_ROLE="PSERVER" \ +PADDLE_CURRENT_ENDPOINT="127.0.0.1:7161" \ +python dist_train.py --model $MODEL --update_method pserver --batch_size 32 &> logs/ps1.log & + +PADDLE_TRAINING_ROLE="TRAINER" \ +PADDLE_CURRENT_ENDPOINT="127.0.0.1:7160" \ +PADDLE_TRAINER_ID="0" \ +CUDA_VISIBLE_DEVICES="0" \ +python dist_train.py --model $MODEL --update_method pserver --batch_size 32 &> logs/tr0.log & + +PADDLE_TRAINING_ROLE="TRAINER" \ +PADDLE_CURRENT_ENDPOINT="127.0.0.1:7161" \ +PADDLE_TRAINER_ID="1" \ +CUDA_VISIBLE_DEVICES="1" \ +python dist_train.py --model $MODEL --update_method pserver --batch_size 32 &> logs/tr1.log & diff --git a/fluid/PaddleCV/image_classification/models/resnet_dist.py b/fluid/PaddleCV/image_classification/models/resnet_dist.py index 9aed8a9841d4ae9d47cbbe15df51ab8652d5e3fc..3420d790c25534b4a73ea660b2d880ff899ee62f 100644 --- a/fluid/PaddleCV/image_classification/models/resnet_dist.py +++ b/fluid/PaddleCV/image_classification/models/resnet_dist.py @@ -14,8 +14,9 @@ train_parameters = { "learning_strategy": { "name": "piecewise_decay", "batch_size": 256, - "epochs": [30, 60, 90], - "steps": [0.1, 0.01, 0.001, 0.0001] + "epochs": [30, 60, 80], + "steps": [0.1, 0.01, 0.001, 0.0001], + "warmup_passes": 5 } } @@ -118,3 +119,4 @@ class DistResNet(): short = self.shortcut(input, num_filters * 4, stride) return fluid.layers.elementwise_add(x=short, y=conv2, act='relu') + diff --git a/fluid/PaddleCV/image_classification/reader.py b/fluid/PaddleCV/image_classification/reader.py index 316b956a0788e593f63e4cf7592c16eec1b1aba8..3d52acc3813d309153a75a2188b5587ecbe13e97 100644 --- a/fluid/PaddleCV/image_classification/reader.py +++ b/fluid/PaddleCV/image_classification/reader.py @@ -139,7 +139,7 @@ def _reader_creator(file_list, if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'): # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) - trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) + trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1")) per_node_lines = len(full_lines) // trainer_count lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1) * per_node_lines]