diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index ce7849cb419950dc2ede4182d108e51bcf6e9945..da9721ea73d28f48ddfc12672fc6249a2a23c9df 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -87,9 +87,14 @@ void ProcessGraph(std::vector graphs, Scope *scope) { // init communicator here if (send_varname_to_ctx.size() > 0) { VLOG(3) << "this is distribute mode, will use communicator"; - operators::distributed::Communicator::Init(send_varname_to_ctx, - recv_varname_to_ctx, scope); - operators::distributed::Communicator::GetInstance()->Start(); + + if (operators::distributed::Communicator::GetInstance() == nullptr) { + operators::distributed::Communicator::Init(send_varname_to_ctx, + recv_varname_to_ctx, scope); + operators::distributed::Communicator::GetInstance()->Start(); + } else { + VLOG(3) << "communicator has been initialized, skip"; + } } #endif } diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 17f68fb4f1b86b22e9d422e4c0421a2bd2515586..6db02fc84025fffc75e2512ea91100b481fa884c 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -167,6 +167,8 @@ class Communicator { void Start(); void Stop(); + bool IsRunning() { return running_; } + // send grad void Send(const std::string& var_name, const framework::Scope& scope); diff --git a/paddle/fluid/pybind/communicator_py.cc b/paddle/fluid/pybind/communicator_py.cc index 1d4052358b3d0dff2324a7e9c1f4a6b8a689119c..5b576f06dab9fba4cccdff35647c8bc9cebcbdc9 100644 --- a/paddle/fluid/pybind/communicator_py.cc +++ b/paddle/fluid/pybind/communicator_py.cc @@ -40,7 +40,8 @@ void BindCommunicator(py::module* m) { return Communicator::GetInstantcePtr(); })) .def("stop", &Communicator::Stop) - .def("start", &Communicator::Start); + .def("start", &Communicator::Start) + .def("is_running", &Communicator::IsRunning); } } // namespace pybind diff --git a/python/paddle/fluid/communicator.py b/python/paddle/fluid/communicator.py index 7d0db90b6ade2465a676ecc41fc410d4a3a97de6..2fecdd34c1569145981f56746e218a3ecf6bb9b4 100644 --- a/python/paddle/fluid/communicator.py +++ b/python/paddle/fluid/communicator.py @@ -86,3 +86,21 @@ class Communicator(object): comm.stop() """ self.communicator_.stop() + + def is_running(self): + """ + Get communicator is running or stop. + + Returns: + bool + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + prog = fluid.Program() + comm = fluid.communicator.Communicator(prog) + comm.is_running() + """ + self.communicator_.is_running() diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index 5b80bdb95d8639bcf21fc62b988bec26c7db4b0a..3854f258be73e368f545db0307c901e10ab89e02 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings import paddle.fluid.io as io from paddle.fluid.communicator import Communicator @@ -53,7 +54,11 @@ class DistributedTranspiler(Fleet): """ if not self._transpile_config.sync_mode: self._communicator = Communicator(self.main_program) - self._communicator.start() + + if not self._communicator.is_running(): + self._communicator.start() + else: + warnings.warn("communicator has been initialized, skip") def init_server(self, model_dir=None): """ @@ -104,7 +109,8 @@ class DistributedTranspiler(Fleet): Returns: None """ - if not self._transpile_config.sync_mode: + if not self._transpile_config.sync_mode and self._communicator.is_running( + ): self._communicator.stop() self._executor.close()