diff --git a/paddle/fluid/distributed/store/tcp_store.cc b/paddle/fluid/distributed/store/tcp_store.cc index b0d5add49565ffb19762778ddd44a388b140c0ee..ec6f0e26a08fa303d6f7bd66f199f6a9362e5b5a 100644 --- a/paddle/fluid/distributed/store/tcp_store.cc +++ b/paddle/fluid/distributed/store/tcp_store.cc @@ -19,21 +19,25 @@ #include "paddle/fluid/distributed/store/tcp_store.h" #include "paddle/fluid/distributed/store/tcp_utils.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/flags.h" namespace paddle { namespace distributed { namespace detail { -constexpr int INFTIME = -1; +constexpr int INFTIME = 10000; // 10 seconds -std::unique_ptr MasterDaemon::start(SocketType socket, - int nranks) { - return std::make_unique(socket, nranks); +std::unique_ptr MasterDaemon::start(SocketType socket, int nranks, + int stop_check_timeout) { + return std::make_unique(socket, nranks, stop_check_timeout); } -MasterDaemon::MasterDaemon(SocketType socket, int nranks) - : _listen_socket(socket), _nranks(nranks) { +MasterDaemon::MasterDaemon(SocketType socket, int nranks, + int stop_check_timeout) + : _listen_socket(socket), + _nranks(nranks), + _stop_check_timeout(stop_check_timeout) { _background_thread = std::thread{&MasterDaemon::run, this}; } @@ -86,6 +90,10 @@ void MasterDaemon::_do_get(SocketType socket) { void MasterDaemon::_do_stop(SocketType socket) { VLOG(3) << "MasterDaemon::_do_stop"; + if (!_has_stop) { + _stop_time = std::chrono::system_clock::now(); + } + _has_stop = true; ReplyType value = ReplyType::STOP_WAIT; tcputils::send_value(socket, value); if (--_nranks == 0) { @@ -115,6 +123,20 @@ void MasterDaemon::run() { #endif while (!_stop) { + auto end_time = std::chrono::system_clock::now(); + if (_has_stop) { + std::chrono::duration diff = end_time - _stop_time; + int elapsed_seconds = static_cast(diff.count()); + PADDLE_ENFORCE_LT( + elapsed_seconds, _stop_check_timeout, + platform::errors::Fatal( + "%d seconds elapsed after the first worker " + "stopped, so we think there may be something wrong and will " + "stop the master worker. You can use " + "'export FLAGS_stop_check_timeout=3600'" + " to change the timeout value in seconds. The default one is 900", + elapsed_seconds)); + } for (size_t i = 0; i < fds.size(); i++) { fds[i].revents = 0; } @@ -173,10 +195,12 @@ void MasterDaemon::run() { } } -std::unique_ptr TCPServer::create(uint16_t port, int nranks) { +std::unique_ptr TCPServer::create(uint16_t port, int nranks, + int stop_check_timeout) { int socket = tcputils::tcp_listen("", std::to_string(port), AF_INET); auto server = std::make_unique(); - server->_master_daemon = MasterDaemon::start(socket, nranks); + server->_master_daemon = + MasterDaemon::start(socket, nranks, stop_check_timeout); return server; } @@ -219,10 +243,11 @@ std::vector TCPClient::receive_vector() { } // namespace detail TCPStore::TCPStore(std::string host, uint16_t port, bool is_master, - size_t num_workers, std::chrono::seconds timeout) + size_t num_workers, std::chrono::seconds timeout, + int stop_check_timeout) : Store(timeout), _is_master(is_master), _num_workers(num_workers) { if (_is_master) { - _server = detail::TCPServer::create(port, num_workers); + _server = detail::TCPServer::create(port, num_workers, stop_check_timeout); } _client = detail::TCPClient::connect(host, port); diff --git a/paddle/fluid/distributed/store/tcp_store.h b/paddle/fluid/distributed/store/tcp_store.h index 17c1d8ea30a421f04d054d59ac93c8c60406ef68..4ca9a673bf57562609e45d090741eab96f92a6c8 100644 --- a/paddle/fluid/distributed/store/tcp_store.h +++ b/paddle/fluid/distributed/store/tcp_store.h @@ -34,9 +34,11 @@ namespace detail { class MasterDaemon { public: static std::unique_ptr start(SocketType listen_socket, - int nranks); + int nranks, + int stop_check_timeout); MasterDaemon() = delete; - explicit MasterDaemon(SocketType listen_socket, int nranks); + explicit MasterDaemon(SocketType listen_socket, int nranks, + int stop_check_timeout); ~MasterDaemon(); private: @@ -51,13 +53,17 @@ class MasterDaemon { std::unordered_map> _store; std::thread _background_thread{}; int _nranks; - bool _stop = false; + int _stop_check_timeout; + bool _stop = false; // all workers stopped + std::chrono::time_point _stop_time; + bool _has_stop = false; // at least one worker stopped }; class TCPServer { public: TCPServer() = default; - static std::unique_ptr create(std::uint16_t port, int nranks); + static std::unique_ptr create(std::uint16_t port, int nranks, + int stop_check_timeout); private: std::unique_ptr _master_daemon; @@ -93,7 +99,8 @@ class TCPStore : public Store { static constexpr std::uint16_t kDefaultPort = 6170; explicit TCPStore(std::string host, uint16_t port = kDefaultPort, bool is_master = false, size_t num_workers = 1, - std::chrono::seconds timeout = tcputils::kDefaultTimeout); + std::chrono::seconds timeout = tcputils::kDefaultTimeout, + int stop_check_timeout = 900); ~TCPStore(); diff --git a/paddle/fluid/pybind/communication.cc b/paddle/fluid/pybind/communication.cc index 1a6a395545a96b1980cae73ff65de3daef0acafc..aef02d65b4dbd22df05cb3fa0156588c9f6f412b 100644 --- a/paddle/fluid/pybind/communication.cc +++ b/paddle/fluid/pybind/communication.cc @@ -58,13 +58,16 @@ void BindTCPStore(py::module *m) { py::class_>(*m, "TCPStore", Store) .def(py::init([](std::string hostname, uint16_t port, bool is_master, - size_t world_size, std::chrono::seconds timeout) { + size_t world_size, std::chrono::seconds timeout, + int stop_check_timeout) { return std::make_shared(hostname, port, is_master, - world_size, timeout); + world_size, timeout, + stop_check_timeout); }), py::arg("hostname"), py::arg("port"), py::arg("is_master"), py::arg("world_size"), py::arg("timeout") = distributed::tcputils::kNoTimeout, + py::arg("stop_check_timeout") = 900, py::call_guard()); } diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 53d35a251c8c81e1d5f903043baf7c4e012a2ae5..8cd6c4647dce4b3304181379177639f88adf8008 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -233,8 +233,13 @@ def init_parallel_env(): master_addr, master_port = endpoints.split(":") master_port = int(master_port) is_master = rank == 0 - default_store = core.TCPStore(master_addr, master_port, is_master, - world_size) + stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900")) + default_store = core.TCPStore( + master_addr, + master_port, + is_master, + world_size, + stop_check_timeout=stop_check_timeout) _set_default_store(default_store) pg = _new_process_group_impl( backend,