From a41a94f2ee9af8abcbfcc17b87641e4af9efcac8 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 16 May 2018 13:39:02 +0800 Subject: [PATCH] support nccl2 dist train in trainer --- python/paddle/fluid/trainer.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/python/paddle/fluid/trainer.py b/python/paddle/fluid/trainer.py index c24662ac211..a47af7ccb21 100644 --- a/python/paddle/fluid/trainer.py +++ b/python/paddle/fluid/trainer.py @@ -131,7 +131,40 @@ class Trainer(object): # load params from param_path into scope io.load_persistables(exe, dirname=param_path) + def _transpile_nccl2_dist(self): + # PADDLE_TRAINER_IPS + if "PADDLE_TRAINER_IPS" not in os.environ: + self.nccl_id_var = None + else: + self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID")) + port = os.getenv("PADDLE_PSERVER_PORT") + worker_ips = os.getenv("PADDLE_TRAINER_IPS") + worker_endpoints = [] + for ip in worker_ips.split(","): + worker_endpoints.append(':'.join([ip, port])) + self.num_trainers = len(worker_endpoints) + current_endpoint = os.getenv("POD_IP") + ":" + port + worker_endpoints.remove(current_endpoint) + # TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id + # in ParallelExecutor to start + # distributed training using NCCL2 + self.nccl_id_var = self.startup_program.global_block().create_var( + name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) + self.startup_program.global_block().append_op( + type="gen_nccl_id", + inputs={}, + outputs={"NCCLID": self.nccl_id_var}, + attrs={ + "endpoint": current_endpoint, + "endpoint_list": worker_endpoints, + "trainer_id": self.trainer_id + }) + def _dist_transpile_if_necessary(self, optimize_ops, params_grads): + self._transpile_nccl2_dist() + if self.nccl_id_var != None: + return + if "PADDLE_TRAINING_ROLE" not in os.environ: return -- GitLab