提交 98411d04 编写于 作者: T typhoonzero

refine dist train

上级 58e9bc20
......@@ -218,17 +218,58 @@ def dist_transpile(trainer_id, args, train_prog, startup_prog):
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
def test_parallel(exe, test_args, args, test_prog):
def append_bn_repeat_init_op(main_prog, startup_prog, num_repeats):
repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for i in range(num_repeats):
for op in startup_prog.global_block().ops:
if op.type == "fill_constant":
for oname in op.output_arg_names:
if oname in repeat_vars:
var = startup_prog.global_block().var(oname)
repeat_var_name = "%s.repeat.%d" % (oname, i)
repeat_var = startup_prog.global_block().create_var(
name=repeat_var_name,
type=var.type,
dtype=var.dtype,
shape=var.shape,
persistable=var.persistable
)
main_prog.global_block()._clone_variable(repeat_var)
startup_prog.global_block().append_op(
type="fill_constant",
inputs={},
outputs={"Out": repeat_var},
attrs=op.all_attrs()
)
def copyback_repeat_bn_params(main_prog):
repeat_vars = set()
for op in main_prog.global_block().ops:
if op.type == "batch_norm":
repeat_vars.add(op.input("Mean")[0])
repeat_vars.add(op.input("Variance")[0])
for vname in repeat_vars:
real_var = fluid.global_scope().find_var("%s.repeat.0" % vname).get_tensor()
orig_var = fluid.global_scope().find_var(vname).get_tensor()
orig_var.set(np.array(real_var), fluid.CUDAPlace(0)) # test on GPU0
def test_single(exe, test_args, args, test_prog):
acc_evaluators = []
for i in six.moves.xrange(len(test_args[2])):
for i in xrange(len(test_args[2])):
acc_evaluators.append(fluid.metrics.Accuracy())
to_fetch = [v.name for v in test_args[2]]
test_args[4].start()
while True:
try:
acc_rets = exe.run(fetch_list=to_fetch)
acc_rets = exe.run(program=test_prog, fetch_list=to_fetch)
for i, e in enumerate(acc_evaluators):
e.update(
value=np.array(acc_rets[i]), weight=args.batch_size)
......@@ -238,6 +279,7 @@ def test_parallel(exe, test_args, args, test_prog):
return [e.eval() for e in acc_evaluators]
def train_parallel(train_args, test_args, args, train_prog, test_prog,
startup_prog, nccl_id_var, num_trainers, trainer_id):
over_all_start = time.time()
......@@ -248,11 +290,18 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
time.sleep(30)
startup_exe = fluid.Executor(place)
if args.multi_batch_repeat > 1:
append_bn_repeat_init_op(train_prog, startup_prog, args.multi_batch_repeat)
startup_exe.run(startup_prog)
strategy = fluid.ExecutionStrategy()
strategy.num_threads = args.cpus
strategy.allow_op_delay = False
build_strategy = fluid.BuildStrategy()
if args.multi_batch_repeat > 1:
pass_builder = build_strategy._create_passes_from_strategy()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.multi_batch_repeat)
if args.reduce_strategy == "reduce":
build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.Reduce
......@@ -278,15 +327,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
num_trainers=num_trainers,
trainer_id=trainer_id)
if not args.no_test:
if args.update_method == "pserver":
test_scope = None
else:
test_scope = fluid.Scope()
test_exe = fluid.ParallelExecutor(
True, main_program=test_prog, share_vars_from=exe,
scope=test_scope)
pyreader = train_args[4]
for pass_id in range(args.pass_num):
num_samples = 0
......@@ -297,7 +337,6 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
fetch_list = [avg_loss.name]
acc_name_list = [v.name for v in train_args[2]]
fetch_list.extend(acc_name_list)
try:
if batch_id % 30 == 0:
fetch_ret = exe.run(fetch_list)
......@@ -320,7 +359,9 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog,
pyreader.reset()
if not args.no_test and test_args[2]:
test_ret = test_parallel(test_exe, test_args, args, test_prog)
if args.multi_batch_repeat > 1:
copyback_repeat_bn_params(train_prog)
test_ret = test_single(startup_exe, test_args, args, test_prog)
print("Pass: %d, Test Accuracy: %s\n" %
(pass_id, [np.mean(np.array(v)) for v in test_ret]))
......@@ -376,7 +417,7 @@ def main():
raise Exception(
"Must configure correct environments to run dist train.")
all_args.extend([train_prog, test_prog, startup_prog])
if args.gpus > 1 and os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
if os.getenv("PADDLE_TRAINING_ROLE") == "TRAINER":
all_args.extend([nccl_id_var, num_trainers, trainer_id])
train_parallel(*all_args)
elif os.getenv("PADDLE_TRAINING_ROLE") == "PSERVER":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册