未验证 提交 999d9a59 编写于 作者: T tangwei12 提交者: GitHub

fix communicator with pyreader (#18350)

* add is_runnning in communicator, test=develop
上级 cff2c2d8
......@@ -87,9 +87,14 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
// init communicator here
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator";
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
if (operators::distributed::Communicator::GetInstance() == nullptr) {
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
} else {
VLOG(3) << "communicator has been initialized, skip";
}
}
#endif
}
......
......@@ -167,6 +167,8 @@ class Communicator {
void Start();
void Stop();
bool IsRunning() { return running_; }
// send grad
void Send(const std::string& var_name, const framework::Scope& scope);
......
......@@ -40,7 +40,8 @@ void BindCommunicator(py::module* m) {
return Communicator::GetInstantcePtr();
}))
.def("stop", &Communicator::Stop)
.def("start", &Communicator::Start);
.def("start", &Communicator::Start)
.def("is_running", &Communicator::IsRunning);
}
} // namespace pybind
......
......@@ -86,3 +86,21 @@ class Communicator(object):
comm.stop()
"""
self.communicator_.stop()
def is_running(self):
"""
Get communicator is running or stop.
Returns:
bool
Examples:
.. code-block:: python
import paddle.fluid as fluid
prog = fluid.Program()
comm = fluid.communicator.Communicator(prog)
comm.is_running()
"""
self.communicator_.is_running()
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import warnings
import paddle.fluid.io as io
from paddle.fluid.communicator import Communicator
......@@ -53,7 +54,11 @@ class DistributedTranspiler(Fleet):
"""
if not self._transpile_config.sync_mode:
self._communicator = Communicator(self.main_program)
self._communicator.start()
if not self._communicator.is_running():
self._communicator.start()
else:
warnings.warn("communicator has been initialized, skip")
def init_server(self, model_dir=None):
"""
......@@ -104,7 +109,8 @@ class DistributedTranspiler(Fleet):
Returns:
None
"""
if not self._transpile_config.sync_mode:
if not self._transpile_config.sync_mode and self._communicator.is_running(
):
self._communicator.stop()
self._executor.close()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册