diff --git a/paddle/fluid/distributed/store/tcp_store.cc b/paddle/fluid/distributed/store/tcp_store.cc index 28387af44df17cb88f07b5615b745b4f43a5d11d..25b66718e4a92ea33693375164d29f34af7816a7 100644 --- a/paddle/fluid/distributed/store/tcp_store.cc +++ b/paddle/fluid/distributed/store/tcp_store.cc @@ -95,19 +95,6 @@ void MasterDaemon::_do_get(SocketType socket) { tcputils::send_vector(socket, value); } -void MasterDaemon::_do_stop(SocketType socket) { - VLOG(4) << "MasterDaemon::_do_stop " << GetSockName(socket); - if (!_has_stop) { - _stop_time = std::chrono::system_clock::now(); - } - _has_stop = true; - ReplyType value = ReplyType::STOP_WAIT; - tcputils::send_value(socket, value); - if (--_nranks == 0) { - _stop = true; - } -} - #ifndef _WIN32 void MasterDaemon::InitControlFd() { PADDLE_ENFORCE_NE( @@ -135,9 +122,13 @@ void MasterDaemon::StopByControlFd() { } } #else -void MasterDaemon::InitControlFd() {} -void MasterDaemon::CloseControlFd() {} -void MasterDaemon::StopByControlFd() {} +void MasterDaemon::InitControlFd() { + ghStopEvent_ = CreateEvent(NULL, TRUE, FALSE, NULL); + PADDLE_ENFORCE(ghStopEvent_, + platform::errors::Fatal("failed to cread control pipe")); +} +void MasterDaemon::CloseControlFd() { CloseHandle(ghStopEvent_); } +void MasterDaemon::StopByControlFd() { SetEvent(ghStopEvent_); } #endif void MasterDaemon::_do_wait(SocketType socket) { @@ -186,9 +177,6 @@ void MasterDaemon::ProcessCommands(std::vector* p_fds) { case Command::WAIT: _do_wait(fds[i].fd); break; - case Command::STOP: - _do_stop(fds[i].fd); - break; default: LOG(WARNING) << "Unknown command: " << static_cast(command) << " from addr info:" << GetSockName(fds[i].fd); @@ -208,7 +196,6 @@ void MasterDaemon::ProcessCommands(std::vector* p_fds) { } void MasterDaemon::run() { - VLOG(4) << "begin to run run _stop:" << _stop << " _has_stop:" << _has_stop; std::vector fds; #ifdef _WIN32 fds.push_back({_listen_socket, POLLIN}); @@ -218,23 +205,8 @@ void MasterDaemon::run() { {.fd = _control_fd[0], .events = POLLIN | POLLHUP, .revents = 0}); #endif - while (!_stop) { - auto end_time = std::chrono::system_clock::now(); - if (_has_stop) { - std::chrono::duration diff = end_time - _stop_time; - int elapsed_seconds = static_cast(diff.count()); - PADDLE_ENFORCE_LT( - elapsed_seconds, - _timeout, - platform::errors::Fatal( - "%d seconds elapsed after the first worker " - "stopped, so we think there may be something wrong and will " - "stop the master worker. You can use " - "'export FLAGS_stop_check_timeout=3600'" - " to change the timeout value in seconds. The default one is 900", - elapsed_seconds)); - } - + bool finished = false; + while (!finished) { for (size_t i = 0; i < fds.size(); i++) { fds[i].revents = 0; } @@ -242,7 +214,15 @@ void MasterDaemon::run() { VLOG(9) << "begin to poll fds_size:" << paddle::string::Sprintf("%d", fds.size()); #ifdef _WIN32 - ::WSAPoll(fds.data(), fds.size(), INFTIME); + int res = ::WSAPoll(fds.data(), fds.size(), INFTIME); + if (res == 0) { + auto rv = WaitForSingleObject(ghStopEvent_, 0); + if (rv != WAIT_TIMEOUT) { + finished = true; + break; + } + continue; + } #else ::poll(fds.data(), fds.size(), INFTIME); @@ -256,7 +236,7 @@ void MasterDaemon::run() { } VLOG(0) << "receive shutdown event and so quit from MasterDaemon run loop"; - _stop = true; + finished = true; break; } #endif diff --git a/paddle/fluid/distributed/store/tcp_store.h b/paddle/fluid/distributed/store/tcp_store.h index 2d25987dc19e2bb542d7478efcbf3e9a5af08079..06f2ce55041b1f0a976caedcd13924c62b03e901 100644 --- a/paddle/fluid/distributed/store/tcp_store.h +++ b/paddle/fluid/distributed/store/tcp_store.h @@ -60,21 +60,18 @@ class MasterDaemon { void _do_wait(SocketType socket); void _do_get(SocketType socket); void _do_set(SocketType socket); - void _do_stop(SocketType socket); SocketType _listen_socket; std::vector _sockets; std::unordered_map> _store; std::thread _background_thread{}; int _nranks = -1; int _timeout = 0; - bool _stop = false; // all workers stopped - std::chrono::time_point _stop_time; - bool _has_stop = false; // at least one worker stopped void InitControlFd(); void CloseControlFd(); void StopByControlFd(); #ifdef _WIN32 + HANDLE ghStopEvent_{}; #else std::array _control_fd{{-1, -1}}; #endif