提交 17790188 编写于 作者: D dongdaxiang

make role maker and distributed optimizer private

上级 d87ba58c
...@@ -28,19 +28,19 @@ class RoleMakerBase(object): ...@@ -28,19 +28,19 @@ class RoleMakerBase(object):
self.pserver_endpoints_ = [] self.pserver_endpoints_ = []
self.role_is_generated_ = False self.role_is_generated_ = False
def is_worker(self): def _is_worker(self):
""" """
return is_worker() of current process return is_worker() of current process
""" """
raise NotImplementedError("Please implement this method in child class") raise NotImplementedError("Please implement this method in child class")
def is_server(self): def _is_server(self):
""" """
return is_server() of current process return is_server() of current process
""" """
raise NotImplementedError("Please implement this method in child class") raise NotImplementedError("Please implement this method in child class")
def get_local_ip(self): def _get_local_ip(self):
""" """
return get local ip return get local ip
""" """
...@@ -48,19 +48,19 @@ class RoleMakerBase(object): ...@@ -48,19 +48,19 @@ class RoleMakerBase(object):
self.ip_ = socket.gethostbyname(socket.gethostname()) self.ip_ = socket.gethostbyname(socket.gethostname())
return self.ip_ return self.ip_
def get_trainer_endpoints(self): def _get_trainer_endpoints(self):
""" """
return trainer endpoints return trainer endpoints
""" """
return self.trainer_endpoints_ return self.trainer_endpoints_
def get_pserver_endpoints(self): def _get_pserver_endpoints(self):
""" """
return pserver endpoints return pserver endpoints
""" """
return self.pserver_endpoints_ return self.pserver_endpoints_
def generate_role(self): def _generate_role(self):
""" """
generate_role() should be called to identify current process's role generate_role() should be called to identify current process's role
""" """
...@@ -80,34 +80,34 @@ class MPIRoleMaker(RoleMakerBase): ...@@ -80,34 +80,34 @@ class MPIRoleMaker(RoleMakerBase):
self.MPI = MPI self.MPI = MPI
self.ips_ = None self.ips_ = None
def get_rank(self): def _get_rank(self):
""" """
return rank return rank
""" """
self.rank_ = self.comm_.Get_rank() self.rank_ = self.comm_.Get_rank()
return self.rank_ return self.rank_
def get_size(self): def _get_size(self):
""" """
return size return size
""" """
self.size_ = self.comm_.Get_size() self.size_ = self.comm_.Get_size()
return self.size_ return self.size_
def all_gather(self, obj): def _all_gather(self, obj):
""" """
all_gather(obj) will call MPI's allgather function all_gather(obj) will call MPI's allgather function
""" """
self.barrier_all() self.barrier_all()
return self.comm_.allgather(obj) return self.comm_.allgather(obj)
def barrier_all(self): def _barrier_all(self):
""" """
barrier_all() will call MPI's barrier_all function barrier_all() will call MPI's barrier_all function
""" """
self.comm_.barrier() self.comm_.barrier()
def get_ips(self): def _get_ips(self):
""" """
collect current distributed job's ip list collect current distributed job's ip list
""" """
...@@ -115,7 +115,7 @@ class MPIRoleMaker(RoleMakerBase): ...@@ -115,7 +115,7 @@ class MPIRoleMaker(RoleMakerBase):
self.ips_ = self.comm_.allgather(self.get_local_ip()) self.ips_ = self.comm_.allgather(self.get_local_ip())
return self.ips_ return self.ips_
def finalize(self): def _finalize(self):
""" """
finalize the current MPI instance. finalize the current MPI instance.
""" """
...@@ -141,7 +141,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -141,7 +141,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return False return False
return True return True
def is_first_worker(self): def _is_first_worker(self):
""" """
return whether current process is the first worker assigned by role maker return whether current process is the first worker assigned by role maker
""" """
...@@ -149,7 +149,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -149,7 +149,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return self.is_worker() and 0 == self.worker_index() return self.is_worker() and 0 == self.worker_index()
return False return False
def is_worker(self): def _is_worker(self):
""" """
return whether current process is worker assigned by role maker return whether current process is worker assigned by role maker
""" """
...@@ -157,7 +157,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -157,7 +157,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return self.node_type_ == 1 return self.node_type_ == 1
return False return False
def is_server(self): def _is_server(self):
""" """
return whether current process is server assigned by role maker return whether current process is server assigned by role maker
""" """
...@@ -165,25 +165,25 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -165,25 +165,25 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return self.node_type_ == 0 return self.node_type_ == 0
return False return False
def worker_num(self): def _worker_num(self):
""" """
return the current number of worker return the current number of worker
""" """
if self._check_role_generation(): if self._check_role_generation():
if self.is_worker(): if self.is_worker():
return self.get_size() / 2; return self.get_size() / 2
return 0 return 0
def server_num(self): def _server_num(self):
""" """
return the current number of server return the current number of server
""" """
if self._check_role_generation(): if self._check_role_generation():
if self.is_server(): if self.is_server():
return self.get_size() / 2; return self.get_size() / 2
return 0 return 0
def worker_index(self): def _worker_index(self):
""" """
return the index of worker return the index of worker
""" """
...@@ -191,7 +191,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -191,7 +191,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return self.rank_ / self.proc_per_node_ return self.rank_ / self.proc_per_node_
return 0 return 0
def server_index(self): def _server_index(self):
""" """
return the index of server return the index of server
""" """
...@@ -199,7 +199,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -199,7 +199,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
return self.rank_ / self.proc_per_node_ return self.rank_ / self.proc_per_node_
return 0 return 0
def barrier_worker(self): def _barrier_worker(self):
""" """
barrier all workers in current distributed job barrier all workers in current distributed job
""" """
...@@ -207,7 +207,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -207,7 +207,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
if self.is_worker(): if self.is_worker():
self.node_type_comm_.barrier() self.node_type_comm_.barrier()
def barrier_server(self): def _barrier_server(self):
""" """
barrier all servers in current distributed job barrier all servers in current distributed job
""" """
...@@ -215,7 +215,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): ...@@ -215,7 +215,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
if self.is_server(): if self.is_server():
self.node_type_comm_.barrier() self.node_type_comm_.barrier()
def generate_role(self): def _generate_role(self):
""" """
generate currently process's role generate currently process's role
""" """
......
...@@ -79,7 +79,7 @@ class Fleet(object): ...@@ -79,7 +79,7 @@ class Fleet(object):
""" """
if not self.is_initialized_: if not self.is_initialized_:
self.role_maker_ = MPISymetricRoleMaker() self.role_maker_ = MPISymetricRoleMaker()
self.role_maker_.generate_role() self.role_maker_._generate_role()
self._fleet_ptr = fluid.core.Fleet() self._fleet_ptr = fluid.core.Fleet()
self.is_initialized_ = True self.is_initialized_ = True
...@@ -89,11 +89,11 @@ class Fleet(object): ...@@ -89,11 +89,11 @@ class Fleet(object):
destroyed when stop() is called. destroyed when stop() is called.
""" """
self.role_maker_.barrier_worker() self.role_maker_.barrier_worker()
if self.role_maker_.is_first_worker(): if self.role_maker_._is_first_worker():
self._fleet_ptr.stop_server() self._fleet_ptr.stop_server()
self.role_maker_.barrier_worker() self.role_maker_._barrier_worker()
self.role_maker_.barrier_all() self.role_maker_._barrier_all()
self.role_maker_.finalize() self.role_maker_._finalize()
def init_pserver(self): def init_pserver(self):
""" """
...@@ -109,15 +109,15 @@ class Fleet(object): ...@@ -109,15 +109,15 @@ class Fleet(object):
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
self._fleet_ptr.init_server(self._dist_desc_str, self._fleet_ptr.init_server(self._dist_desc_str,
self.role_maker_.get_rank()) self.role_maker_._get_rank())
self.local_ip_ = self._fleet_ptr.run_server() self.local_ip_ = self._fleet_ptr.run_server()
self.role_maker_.barrier_all() self.role_maker_._barrier_all()
self.all_ips_ = self.role_maker_.all_gather(self.local_ip_) self.all_ips_ = self.role_maker_._all_gather(self.local_ip_)
self._fleet_ptr.gather_servers(self.all_ips_, self._fleet_ptr.gather_servers(self.all_ips_,
self.role_maker_.get_size()) self.role_maker_._get_size())
# wait all workers start # wait all workers start
self.role_maker_.barrier_all() self.role_maker_._barrier_all()
else: else:
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
...@@ -142,14 +142,14 @@ class Fleet(object): ...@@ -142,14 +142,14 @@ class Fleet(object):
else: else:
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
self.role_maker_.barrier_all() # wait for server starts self.role_maker_._barrier_all() # wait for server starts
self.all_ips_ = self.role_maker_.all_gather(self.local_ip_) self.all_ips_ = self.role_maker_._all_gather(self.local_ip_)
self._fleet_ptr.init_worker(self._dist_desc_str, self.all_ips_, self._fleet_ptr.init_worker(self._dist_desc_str, self.all_ips_,
self.role_maker_.get_size(), self.role_maker_._get_size(),
self.role_maker_.get_rank()) self.role_maker_._get_rank())
self.role_maker_.barrier_all() self.role_maker_._barrier_all()
self.role_maker_.barrier_worker() self.role_maker_._barrier_worker()
if self.role_maker_.is_first_worker(): if self.role_maker_._is_first_worker():
tables = self._dist_desc.trainer_param.dense_table tables = self._dist_desc.trainer_param.dense_table
for prog in programs: for prog in programs:
prog_id = str(id(prog)) prog_id = str(id(prog))
...@@ -169,9 +169,9 @@ class Fleet(object): ...@@ -169,9 +169,9 @@ class Fleet(object):
#print "table id ", table.table_id #print "table id ", table.table_id
#print "var_name_list ", var_name_list #print "var_name_list ", var_name_list
self._fleet_ptr.init_model(prog.desc, self._fleet_ptr.init_model(prog.desc,
int(table.table_id), int(table.table_id),
var_name_list) var_name_list)
self.role_maker_.barrier_worker() self.role_maker_._barrier_worker()
else: else:
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
...@@ -180,39 +180,39 @@ class Fleet(object): ...@@ -180,39 +180,39 @@ class Fleet(object):
""" """
return the number of current job's worker num return the number of current job's worker num
""" """
return self.role_maker_.worker_num() return self.role_maker_._worker_num()
def get_server_num(self): def get_server_num(self):
""" """
return the number of current job's server num return the number of current job's server num
""" """
return self.role_maker_.server_num() return self.role_maker_._server_num()
def get_worker_index(self): def get_worker_index(self):
""" """
return the mpi rank of current worker return the mpi rank of current worker
""" """
return self.role_maker_.worker_index(); return self.role_maker_._worker_index()
def is_worker(self): def is_worker(self):
""" """
return whether current node is a worker return whether current node is a worker
""" """
return self.role_maker_.is_worker() return self.role_maker_._is_worker()
def is_server(self): def is_server(self):
""" """
return whether current node is pserver return whether current node is pserver
""" """
return self.role_maker_.is_server() return self.role_maker_._is_server()
def init_pserver_model(self): def init_pserver_model(self):
""" """
init pserver model called from pserver init pserver model called from pserver
""" """
if self.role_maker_.is_first_worker(): if self.role_maker_._is_first_worker():
self._fleet_ptr.init_model() self._fleet_ptr.init_model()
self.role_maker_.barrier_worker() self.role_maker_._barrier_worker()
def save_pserver_model(self, save_path): def save_pserver_model(self, save_path):
""" """
...@@ -290,7 +290,7 @@ class DistributedOptimizer(object): ...@@ -290,7 +290,7 @@ class DistributedOptimizer(object):
need to care about how to startup a pserver node. need to care about how to startup a pserver node.
""" """
optimize_ops, param_grads, opt_info = \ optimize_ops, param_grads, opt_info = \
self._distributed_optimizer.minimize( self._distributed_optimizer._minimize(
loss, loss,
startup_program, startup_program,
parameter_list, parameter_list,
......
...@@ -48,11 +48,11 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -48,11 +48,11 @@ class DistributedAdam(DistributedOptimizerImplBase):
".batch_size@GRAD", ".batch_square_sum@GRAD", ".batch_sum@GRAD" ".batch_size@GRAD", ".batch_square_sum@GRAD", ".batch_sum@GRAD"
] ]
def minimize(self, def _minimize(self,
losses, losses,
startup_program=None, startup_program=None,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None):
""" """
DownpounSGD is a distributed optimizer so DownpounSGD is a distributed optimizer so
that user can call minimize to generate backward that user can call minimize to generate backward
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册