未验证 提交 21ce9630 编写于 作者: G guru4elephant 提交者: GitHub

Merge pull request #1403 from typhoonzero/refine_dist_resnet50

refine dist train
......@@ -9,7 +9,7 @@ Before getting started, please make sure you have go throught the imagenet [Data
1. The entrypoint file is `dist_train.py`, some important flags are as follows:
- `model`, the model to run with, such as `ResNet50`, `ResNet101` and etc..
- `model`, the model to run with, default is the fine tune model `DistResnet`.
- `batch_size`, the batch_size per device.
- `update_method`, specify the update method, can choose from local, pserver or nccl2.
- `device`, use CPU or GPU device.
......@@ -35,14 +35,14 @@ In this example, we launched 4 parameter server instances and 4 trainer instance
1. launch parameter server process
``` python
``` bash
PADDLE_TRAINING_ROLE=PSERVER \
PADDLE_TRAINERS=4 \
PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
PADDLE_CURRENT_IP=192.168.0.100 \
PADDLE_PSERVER_PORT=7164 \
python dist_train.py \
--model=ResNet50 \
--model=DistResnet \
--batch_size=32 \
--update_method=pserver \
--device=CPU \
......@@ -51,34 +51,33 @@ In this example, we launched 4 parameter server instances and 4 trainer instance
1. launch trainer process
``` python
``` bash
PADDLE_TRAINING_ROLE=TRAINER \
PADDLE_TRAINERS=4 \
PADDLE_PSERVER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
PADDLE_TRAINER_ID=0 \
PADDLE_PSERVER_PORT=7164 \
python dist_train.py \
--model=ResNet50 \
--model=DistResnet \
--batch_size=32 \
--update_method=pserver \
--device=GPU \
--data_dir=../data/ILSVRC2012
```
### NCCL2 Collective Mode
1. launch trainer process
``` python
``` bash
PADDLE_TRAINING_ROLE=TRAINER \
PADDLE_TRAINERS=4 \
PADDLE_TRAINER_IPS=192.168.0.100,192.168.0.101,192.168.0.102,192.168.0.103 \
PADDLE_TRAINER_ID=0 \
python dist_train.py \
--model=ResNet50 \
--model=DistResnet \
--batch_size=32 \
--update_method=pserver \
--update_method=nccl2 \
--device=GPU \
--data_dir=../data/ILSVRC2012
```
......@@ -101,13 +100,37 @@ Pass 0, batch 8, loss 7.264951, accucacys: [0.0, 0.00390625]
Pass 0, batch 9, loss 7.43522, accucacys: [0.00390625, 0.00390625]
```
The training accucacys top1 of local training, distributed training with NCCL2 and parameter server architecture on the ResNet50 model are shown in the below figure:
The below figure shows top 1 train accuracy for local training with 8 GPUs and distributed training
with 32 GPUs, and also distributed training with batch merge feature turned on. Note that the
red curve is trained with origin model configuration, which does not have the warmup and some detailed
modifications.
For distributed training with 32GPUs using `--model DistResnet` we can achieve test accuracy 75.5% after
90 passes of training (the test accuracy is not shown in below figure). We can also achieve this result
using "batch merge" feature by setting `--multi_batch_repeat 4` and with higher throughput.
<p align="center">
<img src="../images/resnet50_32gpus-acc1.png" height=300 width=528 > <br/>
Training acc1 curves
Training top-1 accuracy curves
</p>
### Finetuning for Distributed Training
The default resnet50 distributed training config is based on this paper: https://arxiv.org/pdf/1706.02677.pdf
- use `--model DistResnet`
- we use 32 P40 GPUs with 4 Nodes, each has 8 GPUs
- we set `batch_size=32` for each GPU, in `batch_merge=on` case, we repeat 4 times before communicating with pserver.
- learning rate starts from 0.1 and warm up to 0.4 in 5 passes(because we already have gradient merging,
so we only need to linear scale up to trainer count) using 4 nodes.
- using batch_merge (`--multi_batch_repeat 4`) can make better use of GPU computing power and increase the
total training throughput. Because in the fine-tune configuration, we have to use `batch_size=32` per GPU,
and recent GPU is so fast that the communication between nodes may delay the total speed. In batch_merge mode
we run several batches forward and backward computation, then merge the gradients and send to pserver for
optimization, we use different batch norm mean and variance variable in each repeat so that adding repeats
behaves the same as adding more GPUs.
### Performance
TBD
......@@ -218,17 +218,58 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def test_parallel(exe, test_args, args, test_prog):
def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats):
repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for i in range(num_repeats):
for op in startup_prog.global_block().ops:
if op.type == "fill_constant":
for oname in op.output_arg_names:
if oname in repeat_vars:
var = startup_prog.global_block().var(oname)
repeat_var_name = "%s.repeat.%d" % (oname, i)
repeat_var = startup_prog.global_block().create_var(
name=repeat_var_name,
type=var.type,
dtype=var.dtype,
shape=var.shape,
persistable=var.persistable
)
main_prog.global_block()._clone_variable(repeat_var)
startup_prog.global_block().append_op(
type="fill_constant",
inputs={},
outputs={"Out": repeat_var},
attrs=op.all_attrs()
)
def copyback_repeat_bn_params(main_prog):
repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for vname in repeat_vars:
real_var = fluid.global_scope().find_var("%s.repeat.0" % vname).get_tensor()
orig_var = fluid.global_scope().find_var(vname).get_tensor()
orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0
def test_single(exe, test_args, args, test_prog):
acc_evaluators = []
for i in six.moves.xrange(len(test_args[2])):
for i in xrange(len(test_args[2])):
acc_evaluators.append(fluid.metrics.Accuracy())
to_fetch = [v.name for v in test_args[2]]
test_args[4].start()
while True:
try:
acc_rets = exe.run(fetch_list=to_fetch)
acc_rets = exe.run(program=test_prog, fetch_list=to_fetch)
for i, e in enumerate(acc_evaluators):
e.update(
value=np.array(acc_rets[i]), weight=args.batch_size)
......@@ -238,6 +279,7 @@ def test_parallel(exe, test_args, args, test_prog):
return [e.eval() for e in acc_evaluators]
def train_parallel(train_args, test_args, args, train_prog, test_prog,
startup_prog, nccl_id_var, num_trainers, trainer_id):
over_all_start = time.time()
......@@ -248,11 +290,18 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
time.sleep(30)
startup_exe = fluid.Executor(place)
if args.multi_batch_repeat > 1:
append_bn_repeat_init_op(train_prog, startup_prog, args.multi_batch_repeat)
startup_exe.run(startup_prog)
strategy = fluid.ExecutionStrategy()
strategy.num_threads = args.cpus
strategy.allow_op_delay = False
build_strategy = fluid.BuildStrategy()
if args.multi_batch_repeat > 1:
pass_builder = build_strategy._create_passes_from_strategy()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.multi_batch_repeat)
if args.reduce_strategy == "reduce":
build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.Reduce
......@@ -278,15 +327,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
num_trainers=num_trainers,
trainer_id=trainer_id)
if not args.no_test:
if args.update_method == "pserver":
test_scope = None
else:
test_scope = fluid.Scope()
test_exe = fluid.ParallelExecutor(
True, main_program=test_prog, share_vars_from=exe,
scope=test_scope)
pyreader = train_args[4]
for pass_id in range(args.pass_num):
num_samples = 0
......@@ -297,7 +337,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
fetch_list = [avg_loss.name]
acc_name_list = [v.name for v in train_args[2]]
fetch_list.extend(acc_name_list)
try:
if batch_id % 30 == 0:
fetch_ret = exe.run(fetch_list)
......@@ -320,7 +359,9 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
pyreader.reset()
if not args.no_test and test_args[2]:
test_ret = test_parallel(test_exe, test_args, args, test_prog)
if args.multi_batch_repeat > 1:
copyback_repeat_bn_params(train_prog)
test_ret = test_single(startup_exe, test_args, args, test_prog)
print("Pass: %d, Test Accuracy: %s\n" %
(pass_id, [np.mean(np.array(v)) for v in test_ret]))
......@@ -376,7 +417,7 @@ def main():
raise Exception(
"Must configure correct environments to run dist train.")
all_args.extend([train_prog, test_prog, startup_prog])
if args.gpus > 1 and os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
if os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
all_args.extend([nccl_id_var, num_trainers, trainer_id])
train_parallel(*all_args)
elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册