提交 df12c591 编写于 作者: T typhoonzero

refine

上级 807faec9
...@@ -296,11 +296,6 @@ def train_parallel(args): ...@@ -296,11 +296,6 @@ def train_parallel(args):
# scope=fluid.global_scope().new_scope() # scope=fluid.global_scope().new_scope()
# ) # )
print("Run test before head")
test_fetch_list = [test_cost.name, test_acc1.name, test_acc5.name]
test_ret = test_single(startup_exe, test_prog, args, test_pyreader,test_fetch_list)
print("End run test before head")
over_all_start = time.time() over_all_start = time.time()
fetch_list = [train_cost.name, train_acc1.name, train_acc5.name] fetch_list = [train_cost.name, train_acc1.name, train_acc5.name]
steps_per_pass = args.total_images / args.batch_size / args.dist_env["num_trainers"] steps_per_pass = args.total_images / args.batch_size / args.dist_env["num_trainers"]
...@@ -316,8 +311,8 @@ def train_parallel(args): ...@@ -316,8 +311,8 @@ def train_parallel(args):
if batch_id % 30 == 0: if batch_id % 30 == 0:
fetch_ret = exe.run(fetch_list) fetch_ret = exe.run(fetch_list)
fetched_data = [np.mean(np.array(d)) for d in fetch_ret] fetched_data = [np.mean(np.array(d)) for d in fetch_ret]
print("Pass %d, batch %d, loss %s, acc1: %s, acc5: %s, avg batch time %.4f" % print("Pass [%d/%d], batch [%d/%d], loss %s, acc1: %s, acc5: %s, avg batch time %.4f" %
(pass_id, batch_id, fetched_data[0], fetched_data[1], (pass_id, args.num_epochs, batch_id, steps_per_pass, fetched_data[0], fetched_data[1],
fetched_data[2], (time.time()-start_time) / batch_id)) fetched_data[2], (time.time()-start_time) / batch_id))
else: else:
fetch_ret = exe.run([]) fetch_ret = exe.run([])
...@@ -333,8 +328,7 @@ def train_parallel(args): ...@@ -333,8 +328,7 @@ def train_parallel(args):
print_train_time(start_time, time.time(), num_samples) print_train_time(start_time, time.time(), num_samples)
train_pyreader.reset() train_pyreader.reset()
if pass_id >= args.start_test_pass:
if pass_id > args.start_test_pass:
if args.multi_batch_repeat > 1: if args.multi_batch_repeat > 1:
copyback_repeat_bn_params(train_prog) copyback_repeat_bn_params(train_prog)
test_fetch_list = [test_cost.name, test_acc1.name, test_acc5.name] test_fetch_list = [test_cost.name, test_acc1.name, test_acc5.name]
...@@ -349,7 +343,6 @@ def train_parallel(args): ...@@ -349,7 +343,6 @@ def train_parallel(args):
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
os.makedirs(model_path) os.makedirs(model_path)
fluid.io.save_persistables(startup_exe, model_path, main_program=train_prog) fluid.io.save_persistables(startup_exe, model_path, main_program=train_prog)
startup_exe.close() startup_exe.close()
print("total train time: ", time.time() - over_all_start) print("total train time: ", time.time() - over_all_start)
......
...@@ -54,7 +54,7 @@ def _update_role_var_grad(prog, params_grads): ...@@ -54,7 +54,7 @@ def _update_role_var_grad(prog, params_grads):
role = op.attr("op_role") role = op.attr("op_role")
if role & int(BACKWARD) and op.has_attr("op_role_var"): if role & int(BACKWARD) and op.has_attr("op_role_var"):
# have backward bits then remove all op_role_var # have backward bits then remove all op_role_var
op.desc._remove_attr("op_role_var") op.desc.remove_attr("op_role_var")
for op in prog.global_block().ops: for op in prog.global_block().ops:
if op.type == "allreduce": if op.type == "allreduce":
allreduce_role_var = [] allreduce_role_var = []
...@@ -95,6 +95,7 @@ def create_master_params_grads(params_grads, main_prog, startup_prog, scale_loss ...@@ -95,6 +95,7 @@ def create_master_params_grads(params_grads, main_prog, startup_prog, scale_loss
else: else:
reduced_master_grad = master_grad reduced_master_grad = master_grad
params_grads_to_apply.append([master_param, reduced_master_grad]) params_grads_to_apply.append([master_param, reduced_master_grad])
# update program op role var acording to master grads before allreduce. # update program op role var acording to master grads before allreduce.
_update_role_var_grad(main_prog, master_params_grads) _update_role_var_grad(main_prog, master_params_grads)
main_prog._current_role = tmp_role main_prog._current_role = tmp_role
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册