未验证 提交 357311fd 编写于 作者: G guru4elephant 提交者: GitHub

make fleet support mpi job submit directly (#18441)

make fleet support mpi job submit directly.
上级 e0d8c6ac
......@@ -198,7 +198,7 @@ class MPIRoleMaker(RoleMakerBase):
"""
finalize the current MPI instance.
"""
pass
self.MPI.Finalize()
def _get_ips(self):
"""
......@@ -356,6 +356,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
print("PaddleCloudRoleMaker() endpoints: %s" % self.endpoints)
self.endpoints = self.endpoints.split(",")
self._server_endpoints = self.endpoints
self._worker_endpoints = self.endpoints
if self.role.upper() == "PSERVER":
self._current_id = self.endpoints.index(self.current_endpoint)
self._role = Role.SERVER
......
......@@ -26,6 +26,7 @@ from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerCo
from paddle.fluid.incubate.fleet.base.fleet_base import DistributedOptimizer
from paddle.fluid.incubate.fleet.base.fleet_base import Fleet
from paddle.fluid.incubate.fleet.base.fleet_base import Mode
from paddle.fluid.incubate.fleet.base.role_maker import MPISymetricRoleMaker
class DistributedTranspiler(Fleet):
......@@ -52,6 +53,13 @@ class DistributedTranspiler(Fleet):
Returns:
None
"""
# if MPISymetricRoleMaker is defined
# we suppose a user wants to submit job on mpi cluster
if isinstance(self._role_maker, MPISymetricRoleMaker):
# check whether server has been initialized
from paddle.fluid.transpiler.details.checkport import wait_server_ready
wait_server_ready(fleet.server_endpoints(to_string=False))
if not self._transpile_config.sync_mode:
self._communicator = Communicator(self.main_program)
......@@ -114,6 +122,9 @@ class DistributedTranspiler(Fleet):
self._communicator.stop()
self._executor.close()
if isinstance(self._role_maker, MPISymetricRoleMaker):
self._role_maker._finalize()
def distributed_optimizer(self, optimizer, strategy=None):
"""
Optimizer for distributed training.
......@@ -199,13 +210,24 @@ class DistributedTranspiler(Fleet):
self._transpile_config = config
self._transpiler = OriginTranspiler(config)
print("server endpoints")
print(fleet.server_endpoints(to_string=True))
print("worker index: %d" % fleet.worker_index())
print("worker num: %d" % fleet.worker_num())
if self.is_worker():
self._transpiler.transpile(
trainer_id=fleet.worker_index(),
pservers=fleet.server_endpoints(to_string=True),
trainers=fleet.worker_num(),
sync_mode=config.sync_mode)
self.main_program = self._transpiler.get_trainer_program()
wait_port = True
if isinstance(self._role_maker, MPISymetricRoleMaker):
wait_port = False
self.main_program = self._transpiler.get_trainer_program(
wait_port=wait_port)
self.startup_program = default_startup_program()
else:
self._transpiler.transpile(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册