未验证 提交 357311fd 编写于 作者: G guru4elephant 提交者: GitHub

make fleet support mpi job submit directly (#18441)

make fleet support mpi job submit directly.
上级 e0d8c6ac
...@@ -198,7 +198,7 @@ class MPIRoleMaker(RoleMakerBase): ...@@ -198,7 +198,7 @@ class MPIRoleMaker(RoleMakerBase):
""" """
finalize the current MPI instance. finalize the current MPI instance.
""" """
pass self.MPI.Finalize()
def _get_ips(self): def _get_ips(self):
""" """
...@@ -356,6 +356,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -356,6 +356,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints) print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints)
self.endpoints = self.endpoints.split(",") self.endpoints = self.endpoints.split(",")
self._server_endpoints = self.endpoints self._server_endpoints = self.endpoints
self._worker_endpoints = self.endpoints
if self.role.upper() == "PSERVER": if self.role.upper() == "PSERVER":
self._current_id = self.endpoints.index(self.current_endpoint) self._current_id = self.endpoints.index(self.current_endpoint)
self._role = Role.SERVER self._role = Role.SERVER
......
...@@ -26,6 +26,7 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo ...@@ -26,6 +26,7 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
from paddle.fluid.incubate.fleet.base.fleet_base import Mode from paddle.fluid.incubate.fleet.base.fleet_base import Mode
from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
class DistributedTranspiler(Fleet): class DistributedTranspiler(Fleet):
...@@ -52,6 +53,13 @@ class DistributedTranspiler(Fleet): ...@@ -52,6 +53,13 @@ class DistributedTranspiler(Fleet):
Returns: Returns:
None None
""" """
# if MPISymetricRoleMaker is defined
# we suppose a user wants to submit job on mpi cluster
if isinstance(self._role_maker, MPISymetricRoleMaker):
# check whether server has been initialized
from paddle.fluid.transpiler.details.checkport import wait_server_ready
wait_server_ready(fleet.server_endpoints(to_string=False))
if not self._transpile_config.sync_mode: if not self._transpile_config.sync_mode:
self._communicator = Communicator(self.main_program) self._communicator = Communicator(self.main_program)
...@@ -114,6 +122,9 @@ class DistributedTranspiler(Fleet): ...@@ -114,6 +122,9 @@ class DistributedTranspiler(Fleet):
self._communicator.stop() self._communicator.stop()
self._executor.close() self._executor.close()
if isinstance(self._role_maker, MPISymetricRoleMaker):
self._role_maker._finalize()
def distributed_optimizer(self, optimizer, strategy=None): def distributed_optimizer(self, optimizer, strategy=None):
""" """
Optimizer for distributed training. Optimizer for distributed training.
...@@ -199,13 +210,24 @@ class DistributedTranspiler(Fleet): ...@@ -199,13 +210,24 @@ class DistributedTranspiler(Fleet):
self._transpile_config = config self._transpile_config = config
self._transpiler = OriginTranspiler(config) self._transpiler = OriginTranspiler(config)
print("server endpoints")
print(fleet.server_endpoints(to_string=True))
print("worker index: %d" % fleet.worker_index())
print("worker num: %d" % fleet.worker_num())
if self.is_worker(): if self.is_worker():
self._transpiler.transpile( self._transpiler.transpile(
trainer_id=fleet.worker_index(), trainer_id=fleet.worker_index(),
pservers=fleet.server_endpoints(to_string=True), pservers=fleet.server_endpoints(to_string=True),
trainers=fleet.worker_num(), trainers=fleet.worker_num(),
sync_mode=config.sync_mode) sync_mode=config.sync_mode)
self.main_program = self._transpiler.get_trainer_program()
wait_port = True
if isinstance(self._role_maker, MPISymetricRoleMaker):
wait_port = False
self.main_program = self._transpiler.get_trainer_program(
wait_port=wait_port)
self.startup_program = default_startup_program() self.startup_program = default_startup_program()
else: else:
self._transpiler.transpile( self._transpiler.transpile(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册