diff --git a/fluid/PaddleCV/image_classification/dist_train/README.md b/fluid/PaddleCV/image_classification/dist_train/README.md index 282a026acf1ee6d5b1c17aa05a2a8f734047c006..a595a540adfa770253909e432e99a27228d5f062 100644 --- a/fluid/PaddleCV/image_classification/dist_train/README.md +++ b/fluid/PaddleCV/image_classification/dist_train/README.md @@ -9,7 +9,7 @@ Before getting started, please make sure you have go throught the imagenet [Data 1. The entrypoint file is `dist_train.py`, some important flags are as follows: - - `model`, the model to run with, such as `ResNet50`, `ResNet101` and etc.. + - `model`, the model to run with, default is the fine tune model `DistResnet`. - `batch_size`, the batch_size per device. - `update_method`, specify the update method, can choose from local, pserver or nccl2. - `device`, use CPU or GPU device. @@ -35,14 +35,14 @@ In this example, we launched 4 parameter server instances and 4 trainer instance 1. launch parameter server process - ``` python + ``` bash PADDLE_TRAINING_ROLE=PSERVER \ PADDLE_TRAINERS=4 \ PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \ PADDLE_CURRENT_IP=192.168.0.100 \ PADDLE_PSERVER_PORT=7164 \ python dist_train.py \ - --model=ResNet50 \ + --model=DistResnet \ --batch_size=32 \ --update_method=pserver \ --device=CPU \ @@ -51,34 +51,33 @@ In this example, we launched 4 parameter server instances and 4 trainer instance 1. launch trainer process - ``` python + ``` bash PADDLE_TRAINING_ROLE=TRAINER \ PADDLE_TRAINERS=4 \ PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \ PADDLE_TRAINER_ID=0 \ PADDLE_PSERVER_PORT=7164 \ python dist_train.py \ - --model=ResNet50 \ + --model=DistResnet \ --batch_size=32 \ --update_method=pserver \ --device=GPU \ --data_dir=../data/ILSVRC2012 - ``` ### NCCL2 Collective Mode 1. launch trainer process - ``` python + ``` bash PADDLE_TRAINING_ROLE=TRAINER \ PADDLE_TRAINERS=4 \ PADDLE_TRAINER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \ PADDLE_TRAINER_ID=0 \ python dist_train.py \ - --model=ResNet50 \ + --model=DistResnet \ --batch_size=32 \ - --update_method=pserver \ + --update_method=nccl2 \ --device=GPU \ --data_dir=../data/ILSVRC2012 ``` @@ -101,13 +100,37 @@ Pass 0, batch 8, loss 7.264951, accucacys: [0.0, 0.00390625] Pass 0, batch 9, loss 7.43522, accucacys: [0.00390625, 0.00390625] ``` -The training accucacys top1 of local training, distributed training with NCCL2 and parameter server architecture on the ResNet50 model are shown in the below figure: +The below figure shows top 1 train accuracy for local training with 8 GPUs and distributed training +with 32 GPUs, and also distributed training with batch merge feature turned on. Note that the +red curve is trained with origin model configuration, which does not have the warmup and some detailed +modifications. + +For distributed training with 32GPUs using `--model DistResnet` we can achieve test accuracy 75.5% after +90 passes of training (the test accuracy is not shown in below figure). We can also achieve this result +using "batch merge" feature by setting `--multi_batch_repeat 4` and with higher throughput.


-Training acc1 curves +Training top-1 accuracy curves

+### Finetuning for Distributed Training + +The default resnet50 distributed training config is based on this paper: https://arxiv.org/pdf/1706.02677.pdf + +- use `--model DistResnet` +- we use 32 P40 GPUs with 4 Nodes, each has 8 GPUs +- we set `batch_size=32` for each GPU, in `batch_merge=on` case, we repeat 4 times before communicating with pserver. +- learning rate starts from 0.1 and warm up to 0.4 in 5 passes(because we already have gradient merging, + so we only need to linear scale up to trainer count) using 4 nodes. +- using batch_merge (`--multi_batch_repeat 4`) can make better use of GPU computing power and increase the + total training throughput. Because in the fine-tune configuration, we have to use `batch_size=32` per GPU, + and recent GPU is so fast that the communication between nodes may delay the total speed. In batch_merge mode + we run several batches forward and backward computation, then merge the gradients and send to pserver for + optimization, we use different batch norm mean and variance variable in each repeat so that adding repeats + behaves the same as adding more GPUs. + + ### Performance TBD diff --git a/fluid/PaddleCV/image_classification/dist_train/dist_train.py b/fluid/PaddleCV/image_classification/dist_train/dist_train.py index 160bfb95ac4cdb38083891b9e5f3e76d5371fc06..05c0c23212cfe49f6ef7332143f833d7d7fa7486 100644 --- a/fluid/PaddleCV/image_classification/dist_train/dist_train.py +++ b/fluid/PaddleCV/image_classification/dist_train/dist_train.py @@ -218,17 +218,58 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog): 'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER' ) - -def test_parallel(exe, test_args, args, test_prog): +def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats): + repeat_vars = set() + for op in main_prog.global_block().ops: + if op.type == "batch_norm": + repeat_vars.add(op.input("Mean")[0]) + repeat_vars.add(op.input("Variance")[0]) + + for i in range(num_repeats): + for op in startup_prog.global_block().ops: + if op.type == "fill_constant": + for oname in op.output_arg_names: + if oname in repeat_vars: + var = startup_prog.global_block().var(oname) + repeat_var_name = "%s.repeat.%d" % (oname, i) + repeat_var = startup_prog.global_block().create_var( + name=repeat_var_name, + type=var.type, + dtype=var.dtype, + shape=var.shape, + persistable=var.persistable + ) + main_prog.global_block()._clone_variable(repeat_var) + startup_prog.global_block().append_op( + type="fill_constant", + inputs={}, + outputs={"Out": repeat_var}, + attrs=op.all_attrs() + ) + + +def copyback_repeat_bn_params(main_prog): + repeat_vars = set() + for op in main_prog.global_block().ops: + if op.type == "batch_norm": + repeat_vars.add(op.input("Mean")[0]) + repeat_vars.add(op.input("Variance")[0]) + for vname in repeat_vars: + real_var = fluid.global_scope().find_var("%s.repeat.0" % vname).get_tensor() + orig_var = fluid.global_scope().find_var(vname).get_tensor() + orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0 + + +def test_single(exe, test_args, args, test_prog): acc_evaluators = [] - for i in six.moves.xrange(len(test_args[2])): + for i in xrange(len(test_args[2])): acc_evaluators.append(fluid.metrics.Accuracy()) to_fetch = [v.name for v in test_args[2]] test_args[4].start() while True: try: - acc_rets = exe.run(fetch_list=to_fetch) + acc_rets = exe.run(program=test_prog, fetch_list=to_fetch) for i, e in enumerate(acc_evaluators): e.update( value=np.array(acc_rets[i]), weight=args.batch_size) @@ -238,6 +279,7 @@ def test_parallel(exe, test_args, args, test_prog): return [e.eval() for e in acc_evaluators] + def train_parallel(train_args, test_args, args, train_prog, test_prog, startup_prog, nccl_id_var, num_trainers, trainer_id): over_all_start = time.time() @@ -248,11 +290,18 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, time.sleep(30) startup_exe = fluid.Executor(place) + if args.multi_batch_repeat > 1: + append_bn_repeat_init_op(train_prog, startup_prog, args.multi_batch_repeat) startup_exe.run(startup_prog) strategy = fluid.ExecutionStrategy() strategy.num_threads = args.cpus strategy.allow_op_delay = False build_strategy = fluid.BuildStrategy() + if args.multi_batch_repeat > 1: + pass_builder = build_strategy._create_passes_from_strategy() + mypass = pass_builder.insert_pass( + len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass") + mypass.set_int("num_repeats", args.multi_batch_repeat) if args.reduce_strategy == "reduce": build_strategy.reduce_strategy = fluid.BuildStrategy( ).ReduceStrategy.Reduce @@ -278,15 +327,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, num_trainers=num_trainers, trainer_id=trainer_id) - if not args.no_test: - if args.update_method == "pserver": - test_scope = None - else: - test_scope = fluid.Scope() - test_exe = fluid.ParallelExecutor( - True, main_program=test_prog, share_vars_from=exe, - scope=test_scope) - pyreader = train_args[4] for pass_id in range(args.pass_num): num_samples = 0 @@ -297,7 +337,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, fetch_list = [avg_loss.name] acc_name_list = [v.name for v in train_args[2]] fetch_list.extend(acc_name_list) - try: if batch_id % 30 == 0: fetch_ret = exe.run(fetch_list) @@ -320,7 +359,9 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, pyreader.reset() if not args.no_test and test_args[2]: - test_ret = test_parallel(test_exe, test_args, args, test_prog) + if args.multi_batch_repeat > 1: + copyback_repeat_bn_params(train_prog) + test_ret = test_single(startup_exe, test_args, args, test_prog) print("Pass: %d, Test Accuracy: %s\n" % (pass_id, [np.mean(np.array(v)) for v in test_ret])) @@ -376,7 +417,7 @@ def main(): raise Exception( "Must configure correct environments to run dist train.") all_args.extend([train_prog, test_prog, startup_prog]) - if args.gpus > 1 and os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER": + if os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER": all_args.extend([nccl_id_var, num_trainers, trainer_id]) train_parallel(*all_args) elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER": diff --git a/fluid/PaddleCV/image_classification/images/resnet50_32gpus-acc1.png b/fluid/PaddleCV/image_classification/images/resnet50_32gpus-acc1.png index 6d4c478743d0e5af0a9d727c76b433849c6a81dc..a4dda7babd8054b813596e138945b05e7df330c0 100644 Binary files a/fluid/PaddleCV/image_classification/images/resnet50_32gpus-acc1.png and b/fluid/PaddleCV/image_classification/images/resnet50_32gpus-acc1.png differ