From bafd823666bde1098cf07eb23d406bc9780c7b28 Mon Sep 17 00:00:00 2001 From: heqiaozhi Date: Fri, 14 Dec 2018 13:18:28 +0800 Subject: [PATCH] test --- .../paddle/fluid/distributed/ps_instance.py | 58 ++++++++++++++++--- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/distributed/ps_instance.py b/python/paddle/fluid/distributed/ps_instance.py index b93da053a36..6b44d0cd16f 100644 --- a/python/paddle/fluid/distributed/ps_instance.py +++ b/python/paddle/fluid/distributed/ps_instance.py @@ -14,27 +14,36 @@ import helper as dist_helper import sys + class PaddlePSInstance(object): + """ + PaddlePSInstance class is used to generate A instance of server or worker + Args: + server_worker_mode: is a value 0 or 1, default is 1 + proc_per_node: process per node, default is 2 + Examples: + instance = PaddlePSInstance(1, 2) + """ + def __init__(self, server_worker_mode, proc_per_node): self.dh = dist_helper.MPIHelper() self._rankid = self.dh.get_rank() self._server_worker_mode = server_worker_mode self._proc_per_node = proc_per_node self._nodes = self.dh.get_size() - + self._ip = 0 self._worker_num = self._nodes * self._proc_per_node / 2 self._server_num = self._nodes * self._proc_per_node / 2 self._total_server_worker = self._worker_num + self._server_num - self._node_type = None #IDLE=-1, WORKER=1, SERVER=0 + self._node_type = None #IDLE=-1, WORKER=1, SERVER=0 self._set_nodetype() self._comm = None self._split_comm() - def _set_nodetype(self): if self._server_worker_mode == 0: - if self._rankid < self._server_num: + if self._rankid < self._server_num: self._node_type = 1 elif self._rankid < self._total_server_worker: self._node_type = 0 @@ -46,13 +55,13 @@ class PaddlePSInstance(object): self._node_type = 0 else: self._node_type = 1 - else: - self._node_type = -1; + else: + self._node_type = -1 else: self._node_type = -1 - + #if self._rankid == 0: - #print "node type: ", self._node_type + #print "node type: ", self._node_type def _split_comm(self): if self.is_server(): @@ -62,45 +71,78 @@ class PaddlePSInstance(object): pass def get_worker_index(self): + """ + Return worker index + """ if self._server_worker_mode == 0: return self._rankid == self.server_num else: return self._rankid / self._proc_per_node def get_server_index(self): + """ + Return server index + """ if self._server_worker_mode == 0: return self.rank_id else: return self.rank_id / self._proc_per_node def is_worker(self): + """ + Return instance is worker or not + """ return self._node_type == 1 def is_server(self): + """ + Return instance is server or not + """ return self._node_type == 0 def is_first_worker(self): + """ + Return instance is first worker or not + """ return self.is_worker() and 0 == self.get_worker_index() def set_ip(self, ip): + """ + set server ip + """ self._ip = ip def gather_ips(self): + """ + Return all servers and workers ip throught mpi allgather + """ self._ips = self.dh.comm.allgather(self._ip) return self._ips def get_node_cnt(self): + """ + Return node cnt + """ return self._nodes def barrier_all(self): + """ + barrier workers and servers + """ self.dh.comm.barrier() def barrier_worker(self): + """ + barrier workers + """ if self.is_worker(): self._comm.barrier() pass def finalize(self): + """ + MPI finalize + """ self.dh.finalize() pass -- GitLab