提交 379d33eb 编写于 作者: Q Qiao Longfei

support dist train

上级 6ee6ca69
from __future__ import print_function
import argparse import argparse
import os import os
import time import time
...@@ -47,19 +48,42 @@ def parse_args(): ...@@ -47,19 +48,42 @@ def parse_args():
default='models', default='models',
help='The path for model to store (default: models)') help='The path for model to store (default: models)')
return parser.parse_args() parser.add_argument(
'--is_local',
type=bool,
def train(): default=True,
args = parse_args() help='Local train or distributed train (default: True)')
# the following arguments is used for distributed train, if is_local == false, then you should set them
parser.add_argument(
'--role',
type=str,
default='trainer', # trainer or pserver
help='The path for model to store (default: models)')
parser.add_argument(
'--endpoints',
type=str,
default='127.0.0.1:6000',
help='The pserver endpoints, like: 127.0.0.1:6000,127.0.0.1:6001')
parser.add_argument(
'--current_endpoint',
type=str,
default='127.0.0.1:6000',
help='The path for model to store (default: 127.0.0.1:6000)')
parser.add_argument(
'--trainer_id',
type=int,
default=0,
help='The path for model to store (default: models)')
parser.add_argument(
'--trainers',
type=int,
default=1,
help='The num of trianers, (default: 1)')
if not os.path.isdir(args.model_output_dir): return parser.parse_args()
os.mkdir(args.model_output_dir)
loss, data_list, auc_var, batch_auc_var = ctr_dnn_model(args.embedding_size)
optimizer = fluid.optimizer.Adam(learning_rate=1e-4)
optimizer.minimize(loss)
def train_loop(args, train_program, data_list, loss, auc_var, batch_auc_var):
dataset = reader.Dataset() dataset = reader.Dataset()
train_reader = paddle.batch( train_reader = paddle.batch(
paddle.reader.shuffle( paddle.reader.shuffle(
...@@ -76,7 +100,7 @@ def train(): ...@@ -76,7 +100,7 @@ def train():
for pass_id in range(args.num_passes): for pass_id in range(args.num_passes):
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(train_reader()):
loss_val, auc_val, batch_auc_val = exe.run( loss_val, auc_val, batch_auc_val = exe.run(
fluid.default_main_program(), train_program,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=[loss, auc_var, batch_auc_var] fetch_list=[loss, auc_var, batch_auc_var]
) )
...@@ -89,5 +113,32 @@ def train(): ...@@ -89,5 +113,32 @@ def train():
fluid.io.save_inference_model(model_dir, data_name_list, [loss, auc_var], exe) fluid.io.save_inference_model(model_dir, data_name_list, [loss, auc_var], exe)
def train():
args = parse_args()
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
loss, data_list, auc_var, batch_auc_var = ctr_dnn_model(args.embedding_size)
optimizer = fluid.optimizer.Adam(learning_rate=1e-4)
optimizer.minimize(loss)
if args.is_local:
main_program = fluid.default_main_program()
train_loop(args, main_program, data_list, loss, auc_var, batch_auc_var)
else:
t = fluid.DistributeTranspiler()
t.transpile(args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
prog = t.get_pserver_program(args.curargs.rent_endpoint)
startup = t.get_startup_program(args.current_endpoint, pserver_program=prog)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup)
exe.run(prog)
elif args.role == "trainer":
train_prog = t.get_trainer_program()
train_loop(args, train_prog, data_list, loss, auc_var, batch_auc_var)
if __name__ == '__main__': if __name__ == '__main__':
train() train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册