From 999d9a59a50c69ac53bd8c0bcb8ee74ab2bcbade Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 27 Jun 2019 11:00:03 +0800 Subject: [PATCH] fix communicator with pyreader (#18350) * add is_runnning in communicator, test=develop --- .../details/async_ssa_graph_executor.cc | 11 ++++++++--- .../fluid/operators/distributed/communicator.h | 2 ++ paddle/fluid/pybind/communicator_py.cc | 3 ++- python/paddle/fluid/communicator.py | 18 ++++++++++++++++++ .../distribute_transpiler/__init__.py | 10 ++++++++-- 5 files changed, 38 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index ce7849cb419..da9721ea73d 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -87,9 +87,14 @@ void ProcessGraph(std::vector graphs, Scope *scope) { // init communicator here if (send_varname_to_ctx.size() > 0) { VLOG(3) << "this is distribute mode, will use communicator"; - operators::distributed::Communicator::Init(send_varname_to_ctx, - recv_varname_to_ctx, scope); - operators::distributed::Communicator::GetInstance()->Start(); + + if (operators::distributed::Communicator::GetInstance() == nullptr) { + operators::distributed::Communicator::Init(send_varname_to_ctx, + recv_varname_to_ctx, scope); + operators::distributed::Communicator::GetInstance()->Start(); + } else { + VLOG(3) << "communicator has been initialized, skip"; + } } #endif } diff --git a/paddle/fluid/operators/distributed/communicator.h b/paddle/fluid/operators/distributed/communicator.h index 17f68fb4f1b..6db02fc8402 100644 --- a/paddle/fluid/operators/distributed/communicator.h +++ b/paddle/fluid/operators/distributed/communicator.h @@ -167,6 +167,8 @@ class Communicator { void Start(); void Stop(); + bool IsRunning() { return running_; } + // send grad void Send(const std::string& var_name, const framework::Scope& scope); diff --git a/paddle/fluid/pybind/communicator_py.cc b/paddle/fluid/pybind/communicator_py.cc index 1d4052358b3..5b576f06dab 100644 --- a/paddle/fluid/pybind/communicator_py.cc +++ b/paddle/fluid/pybind/communicator_py.cc @@ -40,7 +40,8 @@ void BindCommunicator(py::module* m) { return Communicator::GetInstantcePtr(); })) .def("stop", &Communicator::Stop) - .def("start", &Communicator::Start); + .def("start", &Communicator::Start) + .def("is_running", &Communicator::IsRunning); } } // namespace pybind diff --git a/python/paddle/fluid/communicator.py b/python/paddle/fluid/communicator.py index 7d0db90b6ad..2fecdd34c15 100644 --- a/python/paddle/fluid/communicator.py +++ b/python/paddle/fluid/communicator.py @@ -86,3 +86,21 @@ class Communicator(object): comm.stop() """ self.communicator_.stop() + + def is_running(self): + """ + Get communicator is running or stop. + + Returns: + bool + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + + prog = fluid.Program() + comm = fluid.communicator.Communicator(prog) + comm.is_running() + """ + self.communicator_.is_running() diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py index 5b80bdb95d8..3854f258be7 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings import paddle.fluid.io as io from paddle.fluid.communicator import Communicator @@ -53,7 +54,11 @@ class DistributedTranspiler(Fleet): """ if not self._transpile_config.sync_mode: self._communicator = Communicator(self.main_program) - self._communicator.start() + + if not self._communicator.is_running(): + self._communicator.start() + else: + warnings.warn("communicator has been initialized, skip") def init_server(self, model_dir=None): """ @@ -104,7 +109,8 @@ class DistributedTranspiler(Fleet): Returns: None """ - if not self._transpile_config.sync_mode: + if not self._transpile_config.sync_mode and self._communicator.is_running( + ): self._communicator.stop() self._executor.close() -- GitLab