diff --git a/fluid/PaddleCV/image_classification/dist_train/README.md b/fluid/PaddleCV/image_classification/dist_train/README.md
index 282a026acf1ee6d5b1c17aa05a2a8f734047c006..a595a540adfa770253909e432e99a27228d5f062 100644
--- a/fluid/PaddleCV/image_classification/dist_train/README.md
+++ b/fluid/PaddleCV/image_classification/dist_train/README.md
@@ -9,7 +9,7 @@ Before getting started, please make sure you have go throught the imagenet [Data
1. The entrypoint file is `dist_train.py`, some important flags are as follows:
- - `model`, the model to run with, such as `ResNet50`, `ResNet101` and etc..
+ - `model`, the model to run with, default is the fine tune model `DistResnet`.
- `batch_size`, the batch_size per device.
- `update_method`, specify the update method, can choose from local, pserver or nccl2.
- `device`, use CPU or GPU device.
@@ -35,14 +35,14 @@ In this example, we launched 4 parameter server instances and 4 trainer instance
1. launch parameter server process
- ``` python
+ ``` bash
PADDLE_TRAINING_ROLE=PSERVER \
PADDLE_TRAINERS=4 \
PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
PADDLE_CURRENT_IP=192.168.0.100 \
PADDLE_PSERVER_PORT=7164 \
python dist_train.py \
- --model=ResNet50 \
+ --model=DistResnet \
--batch_size=32 \
--update_method=pserver \
--device=CPU \
@@ -51,34 +51,33 @@ In this example, we launched 4 parameter server instances and 4 trainer instance
1. launch trainer process
- ``` python
+ ``` bash
PADDLE_TRAINING_ROLE=TRAINER \
PADDLE_TRAINERS=4 \
PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
PADDLE_TRAINER_ID=0 \
PADDLE_PSERVER_PORT=7164 \
python dist_train.py \
- --model=ResNet50 \
+ --model=DistResnet \
--batch_size=32 \
--update_method=pserver \
--device=GPU \
--data_dir=../data/ILSVRC2012
-
```
### NCCL2 Collective Mode
1. launch trainer process
- ``` python
+ ``` bash
PADDLE_TRAINING_ROLE=TRAINER \
PADDLE_TRAINERS=4 \
PADDLE_TRAINER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
PADDLE_TRAINER_ID=0 \
python dist_train.py \
- --model=ResNet50 \
+ --model=DistResnet \
--batch_size=32 \
- --update_method=pserver \
+ --update_method=nccl2 \
--device=GPU \
--data_dir=../data/ILSVRC2012
```
@@ -101,13 +100,37 @@ Pass 0, batch 8, loss 7.264951, accucacys: [0.0, 0.00390625]
Pass 0, batch 9, loss 7.43522, accucacys: [0.00390625, 0.00390625]
```
-The training accucacys top1 of local training, distributed training with NCCL2 and parameter server architecture on the ResNet50 model are shown in the below figure:
+The below figure shows top 1 train accuracy for local training with 8 GPUs and distributed training
+with 32 GPUs, and also distributed training with batch merge feature turned on. Note that the
+red curve is trained with origin model configuration, which does not have the warmup and some detailed
+modifications.
+
+For distributed training with 32GPUs using `--model DistResnet` we can achieve test accuracy 75.5% after
+90 passes of training (the test accuracy is not shown in below figure). We can also achieve this result
+using "batch merge" feature by setting `--multi_batch_repeat 4` and with higher throughput.
-Training acc1 curves
+Training top-1 accuracy curves
+### Finetuning for Distributed Training
+
+The default resnet50 distributed training config is based on this paper: https://arxiv.org/pdf/1706.02677.pdf
+
+- use `--model DistResnet`
+- we use 32 P40 GPUs with 4 Nodes, each has 8 GPUs
+- we set `batch_size=32` for each GPU, in `batch_merge=on` case, we repeat 4 times before communicating with pserver.
+- learning rate starts from 0.1 and warm up to 0.4 in 5 passes(because we already have gradient merging,
+ so we only need to linear scale up to trainer count) using 4 nodes.
+- using batch_merge (`--multi_batch_repeat 4`) can make better use of GPU computing power and increase the
+ total training throughput. Because in the fine-tune configuration, we have to use `batch_size=32` per GPU,
+ and recent GPU is so fast that the communication between nodes may delay the total speed. In batch_merge mode
+ we run several batches forward and backward computation, then merge the gradients and send to pserver for
+ optimization, we use different batch norm mean and variance variable in each repeat so that adding repeats
+ behaves the same as adding more GPUs.
+
+
### Performance
TBD
diff --git a/fluid/PaddleCV/image_classification/dist_train/dist_train.py b/fluid/PaddleCV/image_classification/dist_train/dist_train.py
index 160bfb95ac4cdb38083891b9e5f3e76d5371fc06..05c0c23212cfe49f6ef7332143f833d7d7fa7486 100644
--- a/fluid/PaddleCV/image_classification/dist_train/dist_train.py
+++ b/fluid/PaddleCV/image_classification/dist_train/dist_train.py
@@ -218,17 +218,58 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
-
-def test_parallel(exe, test_args, args, test_prog):
+def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats):
+ repeat_vars = set()
+ for op in main_prog.global_block().ops:
+ if op.type == "batch_norm":
+ repeat_vars.add(op.input("Mean")[0])
+ repeat_vars.add(op.input("Variance")[0])
+
+ for i in range(num_repeats):
+ for op in startup_prog.global_block().ops:
+ if op.type == "fill_constant":
+ for oname in op.output_arg_names:
+ if oname in repeat_vars:
+ var = startup_prog.global_block().var(oname)
+ repeat_var_name = "%s.repeat.%d" % (oname, i)
+ repeat_var = startup_prog.global_block().create_var(
+ name=repeat_var_name,
+ type=var.type,
+ dtype=var.dtype,
+ shape=var.shape,
+ persistable=var.persistable
+ )
+ main_prog.global_block()._clone_variable(repeat_var)
+ startup_prog.global_block().append_op(
+ type="fill_constant",
+ inputs={},
+ outputs={"Out": repeat_var},
+ attrs=op.all_attrs()
+ )
+
+
+def copyback_repeat_bn_params(main_prog):
+ repeat_vars = set()
+ for op in main_prog.global_block().ops:
+ if op.type == "batch_norm":
+ repeat_vars.add(op.input("Mean")[0])
+ repeat_vars.add(op.input("Variance")[0])
+ for vname in repeat_vars:
+ real_var = fluid.global_scope().find_var("%s.repeat.0" % vname).get_tensor()
+ orig_var = fluid.global_scope().find_var(vname).get_tensor()
+ orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0
+
+
+def test_single(exe, test_args, args, test_prog):
acc_evaluators = []
- for i in six.moves.xrange(len(test_args[2])):
+ for i in xrange(len(test_args[2])):
acc_evaluators.append(fluid.metrics.Accuracy())
to_fetch = [v.name for v in test_args[2]]
test_args[4].start()
while True:
try:
- acc_rets = exe.run(fetch_list=to_fetch)
+ acc_rets = exe.run(program=test_prog, fetch_list=to_fetch)
for i, e in enumerate(acc_evaluators):
e.update(
value=np.array(acc_rets[i]), weight=args.batch_size)
@@ -238,6 +279,7 @@ def test_parallel(exe, test_args, args, test_prog):
return [e.eval() for e in acc_evaluators]
+
def train_parallel(train_args, test_args, args, train_prog, test_prog,
startup_prog, nccl_id_var, num_trainers, trainer_id):
over_all_start = time.time()
@@ -248,11 +290,18 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
time.sleep(30)
startup_exe = fluid.Executor(place)
+ if args.multi_batch_repeat > 1:
+ append_bn_repeat_init_op(train_prog, startup_prog, args.multi_batch_repeat)
startup_exe.run(startup_prog)
strategy = fluid.ExecutionStrategy()
strategy.num_threads = args.cpus
strategy.allow_op_delay = False
build_strategy = fluid.BuildStrategy()
+ if args.multi_batch_repeat > 1:
+ pass_builder = build_strategy._create_passes_from_strategy()
+ mypass = pass_builder.insert_pass(
+ len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
+ mypass.set_int("num_repeats", args.multi_batch_repeat)
if args.reduce_strategy == "reduce":
build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.Reduce
@@ -278,15 +327,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
num_trainers=num_trainers,
trainer_id=trainer_id)
- if not args.no_test:
- if args.update_method == "pserver":
- test_scope = None
- else:
- test_scope = fluid.Scope()
- test_exe = fluid.ParallelExecutor(
- True, main_program=test_prog, share_vars_from=exe,
- scope=test_scope)
-
pyreader = train_args[4]
for pass_id in range(args.pass_num):
num_samples = 0
@@ -297,7 +337,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
fetch_list = [avg_loss.name]
acc_name_list = [v.name for v in train_args[2]]
fetch_list.extend(acc_name_list)
-
try:
if batch_id % 30 == 0:
fetch_ret = exe.run(fetch_list)
@@ -320,7 +359,9 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
pyreader.reset()
if not args.no_test and test_args[2]:
- test_ret = test_parallel(test_exe, test_args, args, test_prog)
+ if args.multi_batch_repeat > 1:
+ copyback_repeat_bn_params(train_prog)
+ test_ret = test_single(startup_exe, test_args, args, test_prog)
print("Pass: %d, Test Accuracy: %s\n" %
(pass_id, [np.mean(np.array(v)) for v in test_ret]))
@@ -376,7 +417,7 @@ def main():
raise Exception(
"Must configure correct environments to run dist train.")
all_args.extend([train_prog, test_prog, startup_prog])
- if args.gpus > 1 and os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
+ if os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
all_args.extend([nccl_id_var, num_trainers, trainer_id])
train_parallel(*all_args)
elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER":
diff --git a/fluid/PaddleCV/image_classification/images/resnet50_32gpus-acc1.png b/fluid/PaddleCV/image_classification/images/resnet50_32gpus-acc1.png
index 6d4c478743d0e5af0a9d727c76b433849c6a81dc..a4dda7babd8054b813596e138945b05e7df330c0 100644
Binary files a/fluid/PaddleCV/image_classification/images/resnet50_32gpus-acc1.png and b/fluid/PaddleCV/image_classification/images/resnet50_32gpus-acc1.png differ