diff --git a/python/paddle/fluid/incubate/fleet/base/role_maker.py b/python/paddle/fluid/incubate/fleet/base/role_maker.py index fc4f53ff1dd46221ae3f065ad5d3414df9318d86..0867b7f65d7cbf3b68410379ffe0ed15afba3ea5 100644 --- a/python/paddle/fluid/incubate/fleet/base/role_maker.py +++ b/python/paddle/fluid/incubate/fleet/base/role_maker.py @@ -115,6 +115,7 @@ class MPISymetricRoleMaker(MPIRoleMaker): self.node_type_comm_.barrier() def generate_role(self): + # TODO(guru4elephant): only allow to be called once self.trainer_endpoints_ = self.get_ips() self.pserver_endpoints_ = self.get_ips() diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py index e2e0f5ff10d3b0000956bb9fe95a90fb4628936e..4c1d97b57bd5e8fd9fa0824d27bbdd31b3050983 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/__init__.py @@ -89,6 +89,12 @@ class Fleet(object): print("You should run DistributedOptimizer.minimize() first") sys.exit(-1) + def get_worker_num(self): + return self.role_maker_.worker_num() + + def get_server_num(self): + return self.role_maker_.server_num() + def is_worker(self): return self.role_maker_.is_worker() @@ -161,3 +167,5 @@ is_worker = fleet_instance.is_worker is_server = fleet_instance.is_server init_pserver_model = fleet_instance.init_pserver_model save_pserver_model = fleet_instance.save_pserver_model +worker_num = fleet_instance.get_worker_num +server_num = fleet_instance.get_server_num