提交 7d968ab3 编写于 作者: Q Qiao Longfei

add dist train

上级 379d33eb
...@@ -50,14 +50,14 @@ def parse_args(): ...@@ -50,14 +50,14 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--is_local', '--is_local',
type=bool, type=int,
default=True, default=1,
help='Local train or distributed train (default: True)') help='Local train or distributed train (default: 1)')
# the following arguments is used for distributed train, if is_local == false, then you should set them # the following arguments is used for distributed train, if is_local == false, then you should set them
parser.add_argument( parser.add_argument(
'--role', '--role',
type=str, type=str,
default='trainer', # trainer or pserver default='pserver', # trainer or pserver
help='The path for model to store (default: models)') help='The path for model to store (default: models)')
parser.add_argument( parser.add_argument(
'--endpoints', '--endpoints',
...@@ -124,18 +124,22 @@ def train(): ...@@ -124,18 +124,22 @@ def train():
optimizer.minimize(loss) optimizer.minimize(loss)
if args.is_local: if args.is_local:
print_log("run local training")
main_program = fluid.default_main_program() main_program = fluid.default_main_program()
train_loop(args, main_program, data_list, loss, auc_var, batch_auc_var) train_loop(args, main_program, data_list, loss, auc_var, batch_auc_var)
else: else:
print_log("run dist training")
t = fluid.DistributeTranspiler() t = fluid.DistributeTranspiler()
t.transpile(args.trainer_id, pservers=args.endpoints, trainers=args.trainers) t.transpile(args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver": if args.role == "pserver":
prog = t.get_pserver_program(args.curargs.rent_endpoint) print_log("run pserver")
prog = t.get_pserver_program(args.current_endpoint)
startup = t.get_startup_program(args.current_endpoint, pserver_program=prog) startup = t.get_startup_program(args.current_endpoint, pserver_program=prog)
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup) exe.run(startup)
exe.run(prog) exe.run(prog)
elif args.role == "trainer": elif args.role == "trainer":
print_log("run trainer")
train_prog = t.get_trainer_program() train_prog = t.get_trainer_program()
train_loop(args, train_prog, data_list, loss, auc_var, batch_auc_var) train_loop(args, train_prog, data_list, loss, auc_var, batch_auc_var)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册