提交 7d968ab3 编写于 作者: Q Qiao Longfei

add dist train

上级 379d33eb
......@@ -50,14 +50,14 @@ def parse_args():
parser.add_argument(
'--is_local',
type=bool,
default=True,
help='Local train or distributed train (default: True)')
type=int,
default=1,
help='Local train or distributed train (default: 1)')
# the following arguments is used for distributed train, if is_local == false, then you should set them
parser.add_argument(
'--role',
type=str,
default='trainer', # trainer or pserver
default='pserver', # trainer or pserver
help='The path for model to store (default: models)')
parser.add_argument(
'--endpoints',
......@@ -124,18 +124,22 @@ def train():
optimizer.minimize(loss)
if args.is_local:
print_log("run local training")
main_program = fluid.default_main_program()
train_loop(args, main_program, data_list, loss, auc_var, batch_auc_var)
else:
print_log("run dist training")
t = fluid.DistributeTranspiler()
t.transpile(args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
prog = t.get_pserver_program(args.curargs.rent_endpoint)
print_log("run pserver")
prog = t.get_pserver_program(args.current_endpoint)
startup = t.get_startup_program(args.current_endpoint, pserver_program=prog)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup)
exe.run(prog)
elif args.role == "trainer":
print_log("run trainer")
train_prog = t.get_trainer_program()
train_loop(args, train_prog, data_list, loss, auc_var, batch_auc_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册