提交 a41a94f2 编写于 作者: T typhoonzero

support nccl2 dist train in trainer

上级 9707aa6b
...@@ -131,7 +131,40 @@ class Trainer(object): ...@@ -131,7 +131,40 @@ class Trainer(object):
# load params from param_path into scope # load params from param_path into scope
io.load_persistables(exe, dirname=param_path) io.load_persistables(exe, dirname=param_path)
def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS
if "PADDLE_TRAINER_IPS" not in os.environ:
self.nccl_id_var = None
else:
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
port = os.getenv("PADDLE_PSERVER_PORT")
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
worker_endpoints = []
for ip in worker_ips.split(","):
worker_endpoints.append(':'.join([ip, port]))
self.num_trainers = len(worker_endpoints)
current_endpoint = os.getenv("POD_IP") + ":" + port
worker_endpoints.remove(current_endpoint)
# TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id
# in ParallelExecutor to start
# distributed training using NCCL2
self.nccl_id_var = self.startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
self.startup_program.global_block().append_op(
type="gen_nccl_id",
inputs={},
outputs={"NCCLID": self.nccl_id_var},
attrs={
"endpoint": current_endpoint,
"endpoint_list": worker_endpoints,
"trainer_id": self.trainer_id
})
def _dist_transpile_if_necessary(self, optimize_ops, params_grads): def _dist_transpile_if_necessary(self, optimize_ops, params_grads):
self._transpile_nccl2_dist()
if self.nccl_id_var != None:
return
if "PADDLE_TRAINING_ROLE" not in os.environ: if "PADDLE_TRAINING_ROLE" not in os.environ:
return return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册