未验证 提交 999d9a59 编写于 作者: T tangwei12 提交者: GitHub

fix communicator with pyreader (#18350)

* add is_runnning in communicator, test=develop
上级 cff2c2d8
...@@ -87,9 +87,14 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -87,9 +87,14 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
// init communicator here // init communicator here
if (send_varname_to_ctx.size() > 0) { if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator"; VLOG(3) << "this is distribute mode, will use communicator";
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope); if (operators::distributed::Communicator::GetInstance() == nullptr) {
operators::distributed::Communicator::GetInstance()->Start(); operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
} else {
VLOG(3) << "communicator has been initialized, skip";
}
} }
#endif #endif
} }
......
...@@ -167,6 +167,8 @@ class Communicator { ...@@ -167,6 +167,8 @@ class Communicator {
void Start(); void Start();
void Stop(); void Stop();
bool IsRunning() { return running_; }
// send grad // send grad
void Send(const std::string& var_name, const framework::Scope& scope); void Send(const std::string& var_name, const framework::Scope& scope);
......
...@@ -40,7 +40,8 @@ void BindCommunicator(py::module* m) { ...@@ -40,7 +40,8 @@ void BindCommunicator(py::module* m) {
return Communicator::GetInstantcePtr(); return Communicator::GetInstantcePtr();
})) }))
.def("stop", &Communicator::Stop) .def("stop", &Communicator::Stop)
.def("start", &Communicator::Start); .def("start", &Communicator::Start)
.def("is_running", &Communicator::IsRunning);
} }
} // namespace pybind } // namespace pybind
......
...@@ -86,3 +86,21 @@ class Communicator(object): ...@@ -86,3 +86,21 @@ class Communicator(object):
comm.stop() comm.stop()
""" """
self.communicator_.stop() self.communicator_.stop()
def is_running(self):
"""
Get communicator is running or stop.
Returns:
bool
Examples:
.. code-block:: python
import paddle.fluid as fluid
prog = fluid.Program()
comm = fluid.communicator.Communicator(prog)
comm.is_running()
"""
self.communicator_.is_running()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os
import warnings
import paddle.fluid.io as io import paddle.fluid.io as io
from paddle.fluid.communicator import Communicator from paddle.fluid.communicator import Communicator
...@@ -53,7 +54,11 @@ class DistributedTranspiler(Fleet): ...@@ -53,7 +54,11 @@ class DistributedTranspiler(Fleet):
""" """
if not self._transpile_config.sync_mode: if not self._transpile_config.sync_mode:
self._communicator = Communicator(self.main_program) self._communicator = Communicator(self.main_program)
self._communicator.start()
if not self._communicator.is_running():
self._communicator.start()
else:
warnings.warn("communicator has been initialized, skip")
def init_server(self, model_dir=None): def init_server(self, model_dir=None):
""" """
...@@ -104,7 +109,8 @@ class DistributedTranspiler(Fleet): ...@@ -104,7 +109,8 @@ class DistributedTranspiler(Fleet):
Returns: Returns:
None None
""" """
if not self._transpile_config.sync_mode: if not self._transpile_config.sync_mode and self._communicator.is_running(
):
self._communicator.stop() self._communicator.stop()
self._executor.close() self._executor.close()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册