提交 e4e73df6 编写于 作者: L Liang

add paddle_cloud_train of ctr

上级 19b08cfd
......@@ -56,12 +56,26 @@ def parse_args():
type=int,
default=1000001,
help='sparse feature hashing space for index processing')
parser.add_argument(
'--is_local',
type=int,
default=1,
help='Local train or distributed train (default: 1)')
parser.add_argument(
'--cloud_train',
type=int,
default=0,
help='Local train or distributed train on paddlecloud (default: 0)')
parser.add_argument(
'--async_mode',
action='store_true',
default=False,
help='Whether start pserver in async mode to support ASGD')
parser.add_argument(
'--no_split_var',
action='store_true',
default=False,
help='Whether split variables into blocks when update_method is pserver')
# the following arguments is used for distributed train, if is_local == false, then you should set them
parser.add_argument(
'--role',
......@@ -134,6 +148,21 @@ def train():
loss, data_list, auc_var, batch_auc_var = ctr_dnn_model(args.embedding_size, args.sparse_feature_dim)
optimizer = fluid.optimizer.Adam(learning_rate=1e-4)
optimizer.minimize(loss)
if args.cloud_train:
# the port of all pservers, needed by both trainer and pserver
port = os.getenv("PADDLE_PORT", "6174")
# comma separated ips of all pservers, needed by trainer and
pserver_ips = os.getenv("PADDLE_PSERVERS", "")
eplist = []
for ip in pserver_ips.split(","):
eplist.append(':'.join([ip, port]))
args.endpoints = ",".join(eplist)
args.trainers = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
args.current_endpoint = os.getenv("POD_IP", "localhost") + ":" + port
args.role = os.getenv("TRAINING_ROLE", "TRAINER")
args.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
args.is_local = bool(int(os.getenv("PADDLE_IS_LOCAL", 0)))
if args.is_local:
logger.info("run local training")
......@@ -143,18 +172,22 @@ def train():
logger.info("run dist training")
t = fluid.DistributeTranspiler()
t.transpile(args.trainer_id, pservers=args.endpoints, trainers=args.trainers)
if args.role == "pserver":
if args.role == "pserver" or args.role == "PSERVER":
logger.info("run pserver")
prog = t.get_pserver_program(args.current_endpoint)
startup = t.get_startup_program(args.current_endpoint, pserver_program=prog)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup)
exe.run(prog)
elif args.role == "trainer":
elif args.role == "trainer" or args.role == "TRAINING_ROLE":
logger.info("run trainer")
train_prog = t.get_trainer_program()
train_loop(args, train_prog, data_list, loss, auc_var, batch_auc_var,
args.trainers, args.trainer_id)
else:
raise ValueError(
'PADDLE_TRAINING_ROLE environment variable must be either TRAINER or PSERVER'
)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册