提交 98411d04 编写于 作者: T typhoonzero

refine dist train

上级 58e9bc20
...@@ -218,17 +218,58 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog): ...@@ -218,17 +218,58 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER' 'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
) )
def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats):
def test_parallel(exe, test_args, args, test_prog): repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for i in range(num_repeats):
for op in startup_prog.global_block().ops:
if op.type == "fill_constant":
for oname in op.output_arg_names:
if oname in repeat_vars:
var = startup_prog.global_block().var(oname)
repeat_var_name = "%s.repeat.%d" % (oname, i)
repeat_var = startup_prog.global_block().create_var(
name=repeat_var_name,
type=var.type,
dtype=var.dtype,
shape=var.shape,
persistable=var.persistable
)
main_prog.global_block()._clone_variable(repeat_var)
startup_prog.global_block().append_op(
type="fill_constant",
inputs={},
outputs={"Out": repeat_var},
attrs=op.all_attrs()
)
def copyback_repeat_bn_params(main_prog):
repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for vname in repeat_vars:
real_var = fluid.global_scope().find_var("%s.repeat.0" % vname).get_tensor()
orig_var = fluid.global_scope().find_var(vname).get_tensor()
orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0
def test_single(exe, test_args, args, test_prog):
acc_evaluators = [] acc_evaluators = []
for i in six.moves.xrange(len(test_args[2])): for i in xrange(len(test_args[2])):
acc_evaluators.append(fluid.metrics.Accuracy()) acc_evaluators.append(fluid.metrics.Accuracy())
to_fetch = [v.name for v in test_args[2]] to_fetch = [v.name for v in test_args[2]]
test_args[4].start() test_args[4].start()
while True: while True:
try: try:
acc_rets = exe.run(fetch_list=to_fetch) acc_rets = exe.run(program=test_prog, fetch_list=to_fetch)
for i, e in enumerate(acc_evaluators): for i, e in enumerate(acc_evaluators):
e.update( e.update(
value=np.array(acc_rets[i]), weight=args.batch_size) value=np.array(acc_rets[i]), weight=args.batch_size)
...@@ -238,6 +279,7 @@ def test_parallel(exe, test_args, args, test_prog): ...@@ -238,6 +279,7 @@ def test_parallel(exe, test_args, args, test_prog):
return [e.eval() for e in acc_evaluators] return [e.eval() for e in acc_evaluators]
def train_parallel(train_args, test_args, args, train_prog, test_prog, def train_parallel(train_args, test_args, args, train_prog, test_prog,
startup_prog, nccl_id_var, num_trainers, trainer_id): startup_prog, nccl_id_var, num_trainers, trainer_id):
over_all_start = time.time() over_all_start = time.time()
...@@ -248,11 +290,18 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, ...@@ -248,11 +290,18 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
time.sleep(30) time.sleep(30)
startup_exe = fluid.Executor(place) startup_exe = fluid.Executor(place)
if args.multi_batch_repeat > 1:
append_bn_repeat_init_op(train_prog, startup_prog, args.multi_batch_repeat)
startup_exe.run(startup_prog) startup_exe.run(startup_prog)
strategy = fluid.ExecutionStrategy() strategy = fluid.ExecutionStrategy()
strategy.num_threads = args.cpus strategy.num_threads = args.cpus
strategy.allow_op_delay = False strategy.allow_op_delay = False
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
if args.multi_batch_repeat > 1:
pass_builder = build_strategy._create_passes_from_strategy()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.multi_batch_repeat)
if args.reduce_strategy == "reduce": if args.reduce_strategy == "reduce":
build_strategy.reduce_strategy = fluid.BuildStrategy( build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.Reduce ).ReduceStrategy.Reduce
...@@ -278,15 +327,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, ...@@ -278,15 +327,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
num_trainers=num_trainers, num_trainers=num_trainers,
trainer_id=trainer_id) trainer_id=trainer_id)
if not args.no_test:
if args.update_method == "pserver":
test_scope = None
else:
test_scope = fluid.Scope()
test_exe = fluid.ParallelExecutor(
True, main_program=test_prog, share_vars_from=exe,
scope=test_scope)
pyreader = train_args[4] pyreader = train_args[4]
for pass_id in range(args.pass_num): for pass_id in range(args.pass_num):
num_samples = 0 num_samples = 0
...@@ -297,7 +337,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, ...@@ -297,7 +337,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
fetch_list = [avg_loss.name] fetch_list = [avg_loss.name]
acc_name_list = [v.name for v in train_args[2]] acc_name_list = [v.name for v in train_args[2]]
fetch_list.extend(acc_name_list) fetch_list.extend(acc_name_list)
try: try:
if batch_id % 30 == 0: if batch_id % 30 == 0:
fetch_ret = exe.run(fetch_list) fetch_ret = exe.run(fetch_list)
...@@ -320,7 +359,9 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, ...@@ -320,7 +359,9 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
pyreader.reset() pyreader.reset()
if not args.no_test and test_args[2]: if not args.no_test and test_args[2]:
test_ret = test_parallel(test_exe, test_args, args, test_prog) if args.multi_batch_repeat > 1:
copyback_repeat_bn_params(train_prog)
test_ret = test_single(startup_exe, test_args, args, test_prog)
print("Pass: %d, Test Accuracy: %s\n" % print("Pass: %d, Test Accuracy: %s\n" %
(pass_id, [np.mean(np.array(v)) for v in test_ret])) (pass_id, [np.mean(np.array(v)) for v in test_ret]))
...@@ -376,7 +417,7 @@ def main(): ...@@ -376,7 +417,7 @@ def main():
raise Exception( raise Exception(
"Must configure correct environments to run dist train.") "Must configure correct environments to run dist train.")
all_args.extend([train_prog, test_prog, startup_prog]) all_args.extend([train_prog, test_prog, startup_prog])
if args.gpus > 1 and os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER": if os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
all_args.extend([nccl_id_var, num_trainers, trainer_id]) all_args.extend([nccl_id_var, num_trainers, trainer_id])
train_parallel(*all_args) train_parallel(*all_args)
elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER": elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册