未验证 提交 b9c32c77 编写于 作者: W Wu Yi 提交者: GitHub

Merge pull request #1893 from typhoonzero/fix_batch_merge_for_develop

fix batch merge for develop
import paddle.fluid as fluid
import numpy as np
def copyback_repeat_bn_params(main_prog):
repeat_vars = set()
......
......@@ -187,6 +187,24 @@ def test_single(exe, test_prog, args, pyreader, fetch_list):
test_avg_loss = np.mean(np.array(test_losses))
return test_avg_loss, np.mean(acc1.eval()), np.mean(acc5.eval())
def test_parallel(exe, test_prog, args, pyreader, fetch_list):
acc1 = fluid.metrics.Accuracy()
acc5 = fluid.metrics.Accuracy()
test_losses = []
pyreader.start()
while True:
try:
acc_rets = exe.run(fetch_list=fetch_list)
test_losses.append(acc_rets[0])
acc1.update(value=np.array(acc_rets[1]), weight=args.batch_size)
acc5.update(value=np.array(acc_rets[2]), weight=args.batch_size)
except fluid.core.EOFException:
pyreader.reset()
break
test_avg_loss = np.mean(np.array(test_losses))
return test_avg_loss, np.mean(acc1.eval()), np.mean(acc5.eval())
def run_pserver(train_prog, startup_prog):
server_exe = fluid.Executor(fluid.CPUPlace())
server_exe.run(startup_prog)
......@@ -224,11 +242,10 @@ def train_parallel(args):
strategy = fluid.ExecutionStrategy()
strategy.num_threads = args.num_threads
build_strategy = fluid.BuildStrategy()
if args.multi_batch_repeat > 1:
pass_builder = build_strategy._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 2, "multi_batch_merge_pass")
mypass.set_int("num_repeats", args.multi_batch_repeat)
build_strategy.enable_inplace = False
build_strategy.memory_optimize = False
if args.reduce_strategy == "reduce":
build_strategy.reduce_strategy = fluid.BuildStrategy(
).ReduceStrategy.Reduce
......@@ -245,6 +262,15 @@ def train_parallel(args):
else:
num_trainers = args.dist_env["num_trainers"]
trainer_id = args.dist_env["trainer_id"]
# Set this to let build_strategy to add "allreduce_deps_pass" automatically
build_strategy.num_trainers = num_trainers
build_strategy.trainer_id = trainer_id
if args.multi_batch_repeat > 1:
pass_builder = build_strategy._finalize_strategy_and_create_passes()
mypass = pass_builder.insert_pass(
len(pass_builder.all_passes()) - 4, "multi_batch_merge_pass")
mypass.set("num_repeats", args.multi_batch_repeat)
exe = fluid.ParallelExecutor(
True,
......@@ -255,6 +281,14 @@ def train_parallel(args):
num_trainers=num_trainers,
trainer_id=trainer_id)
# Uncomment below lines to use ParallelExecutor to run test.
# test_exe = fluid.ParallelExecutor(
# True,
# main_program=test_prog,
# share_vars_from=exe,
# scope=fluid.global_scope().new_scope()
# )
over_all_start = time.time()
fetch_list = [train_cost.name, train_acc1.name, train_acc5.name]
steps_per_pass = args.total_images / args.batch_size / args.dist_env["num_trainers"]
......@@ -293,6 +327,8 @@ def train_parallel(args):
copyback_repeat_bn_params(train_prog)
test_fetch_list = [test_cost.name, test_acc1.name, test_acc5.name]
test_ret = test_single(startup_exe, test_prog, args, test_pyreader,test_fetch_list)
# NOTE: switch to below line if you use ParallelExecutor to run test.
# test_ret = test_parallel(test_exe, test_prog, args, test_pyreader,test_fetch_list)
print("Pass: %d, Test Loss %s, test acc1: %s, test acc5: %s\n" %
(pass_id, test_ret[0], test_ret[1], test_ret[2]))
......@@ -324,3 +360,4 @@ def main():
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册