未验证 提交 b5e60db4 编写于 作者: L lilong12 提交者: GitHub

Fix hang problem when some workers exited unexcepted (#42708)

* update
上级 1280f294
...@@ -19,21 +19,25 @@ ...@@ -19,21 +19,25 @@
#include "paddle/fluid/distributed/store/tcp_store.h" #include "paddle/fluid/distributed/store/tcp_store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h" #include "paddle/fluid/distributed/store/tcp_utils.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/flags.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
namespace detail { namespace detail {
constexpr int INFTIME = -1; constexpr int INFTIME = 10000; // 10 seconds
std::unique_ptr<MasterDaemon> MasterDaemon::start(SocketType socket, std::unique_ptr<MasterDaemon> MasterDaemon::start(SocketType socket, int nranks,
int nranks) { int stop_check_timeout) {
return std::make_unique<MasterDaemon>(socket, nranks); return std::make_unique<MasterDaemon>(socket, nranks, stop_check_timeout);
} }
MasterDaemon::MasterDaemon(SocketType socket, int nranks) MasterDaemon::MasterDaemon(SocketType socket, int nranks,
: _listen_socket(socket), _nranks(nranks) { int stop_check_timeout)
: _listen_socket(socket),
_nranks(nranks),
_stop_check_timeout(stop_check_timeout) {
_background_thread = std::thread{&MasterDaemon::run, this}; _background_thread = std::thread{&MasterDaemon::run, this};
} }
...@@ -86,6 +90,10 @@ void MasterDaemon::_do_get(SocketType socket) { ...@@ -86,6 +90,10 @@ void MasterDaemon::_do_get(SocketType socket) {
void MasterDaemon::_do_stop(SocketType socket) { void MasterDaemon::_do_stop(SocketType socket) {
VLOG(3) << "MasterDaemon::_do_stop"; VLOG(3) << "MasterDaemon::_do_stop";
if (!_has_stop) {
_stop_time = std::chrono::system_clock::now();
}
_has_stop = true;
ReplyType value = ReplyType::STOP_WAIT; ReplyType value = ReplyType::STOP_WAIT;
tcputils::send_value<ReplyType>(socket, value); tcputils::send_value<ReplyType>(socket, value);
if (--_nranks == 0) { if (--_nranks == 0) {
...@@ -115,6 +123,20 @@ void MasterDaemon::run() { ...@@ -115,6 +123,20 @@ void MasterDaemon::run() {
#endif #endif
while (!_stop) { while (!_stop) {
auto end_time = std::chrono::system_clock::now();
if (_has_stop) {
std::chrono::duration<double> diff = end_time - _stop_time;
int elapsed_seconds = static_cast<int>(diff.count());
PADDLE_ENFORCE_LT(
elapsed_seconds, _stop_check_timeout,
platform::errors::Fatal(
"%d seconds elapsed after the first worker "
"stopped, so we think there may be something wrong and will "
"stop the master worker. You can use "
"'export FLAGS_stop_check_timeout=3600'"
" to change the timeout value in seconds. The default one is 900",
elapsed_seconds));
}
for (size_t i = 0; i < fds.size(); i++) { for (size_t i = 0; i < fds.size(); i++) {
fds[i].revents = 0; fds[i].revents = 0;
} }
...@@ -173,10 +195,12 @@ void MasterDaemon::run() { ...@@ -173,10 +195,12 @@ void MasterDaemon::run() {
} }
} }
std::unique_ptr<TCPServer> TCPServer::create(uint16_t port, int nranks) { std::unique_ptr<TCPServer> TCPServer::create(uint16_t port, int nranks,
int stop_check_timeout) {
int socket = tcputils::tcp_listen("", std::to_string(port), AF_INET); int socket = tcputils::tcp_listen("", std::to_string(port), AF_INET);
auto server = std::make_unique<TCPServer>(); auto server = std::make_unique<TCPServer>();
server->_master_daemon = MasterDaemon::start(socket, nranks); server->_master_daemon =
MasterDaemon::start(socket, nranks, stop_check_timeout);
return server; return server;
} }
...@@ -219,10 +243,11 @@ std::vector<T> TCPClient::receive_vector() { ...@@ -219,10 +243,11 @@ std::vector<T> TCPClient::receive_vector() {
} // namespace detail } // namespace detail
TCPStore::TCPStore(std::string host, uint16_t port, bool is_master, TCPStore::TCPStore(std::string host, uint16_t port, bool is_master,
size_t num_workers, std::chrono::seconds timeout) size_t num_workers, std::chrono::seconds timeout,
int stop_check_timeout)
: Store(timeout), _is_master(is_master), _num_workers(num_workers) { : Store(timeout), _is_master(is_master), _num_workers(num_workers) {
if (_is_master) { if (_is_master) {
_server = detail::TCPServer::create(port, num_workers); _server = detail::TCPServer::create(port, num_workers, stop_check_timeout);
} }
_client = detail::TCPClient::connect(host, port); _client = detail::TCPClient::connect(host, port);
......
...@@ -34,9 +34,11 @@ namespace detail { ...@@ -34,9 +34,11 @@ namespace detail {
class MasterDaemon { class MasterDaemon {
public: public:
static std::unique_ptr<MasterDaemon> start(SocketType listen_socket, static std::unique_ptr<MasterDaemon> start(SocketType listen_socket,
int nranks); int nranks,
int stop_check_timeout);
MasterDaemon() = delete; MasterDaemon() = delete;
explicit MasterDaemon(SocketType listen_socket, int nranks); explicit MasterDaemon(SocketType listen_socket, int nranks,
int stop_check_timeout);
~MasterDaemon(); ~MasterDaemon();
private: private:
...@@ -51,13 +53,17 @@ class MasterDaemon { ...@@ -51,13 +53,17 @@ class MasterDaemon {
std::unordered_map<std::string, std::vector<uint8_t>> _store; std::unordered_map<std::string, std::vector<uint8_t>> _store;
std::thread _background_thread{}; std::thread _background_thread{};
int _nranks; int _nranks;
bool _stop = false; int _stop_check_timeout;
bool _stop = false; // all workers stopped
std::chrono::time_point<std::chrono::system_clock> _stop_time;
bool _has_stop = false; // at least one worker stopped
}; };
class TCPServer { class TCPServer {
public: public:
TCPServer() = default; TCPServer() = default;
static std::unique_ptr<TCPServer> create(std::uint16_t port, int nranks); static std::unique_ptr<TCPServer> create(std::uint16_t port, int nranks,
int stop_check_timeout);
private: private:
std::unique_ptr<MasterDaemon> _master_daemon; std::unique_ptr<MasterDaemon> _master_daemon;
...@@ -93,7 +99,8 @@ class TCPStore : public Store { ...@@ -93,7 +99,8 @@ class TCPStore : public Store {
static constexpr std::uint16_t kDefaultPort = 6170; static constexpr std::uint16_t kDefaultPort = 6170;
explicit TCPStore(std::string host, uint16_t port = kDefaultPort, explicit TCPStore(std::string host, uint16_t port = kDefaultPort,
bool is_master = false, size_t num_workers = 1, bool is_master = false, size_t num_workers = 1,
std::chrono::seconds timeout = tcputils::kDefaultTimeout); std::chrono::seconds timeout = tcputils::kDefaultTimeout,
int stop_check_timeout = 900);
~TCPStore(); ~TCPStore();
......
...@@ -58,13 +58,16 @@ void BindTCPStore(py::module *m) { ...@@ -58,13 +58,16 @@ void BindTCPStore(py::module *m) {
py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore", Store) py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore", Store)
.def(py::init([](std::string hostname, uint16_t port, bool is_master, .def(py::init([](std::string hostname, uint16_t port, bool is_master,
size_t world_size, std::chrono::seconds timeout) { size_t world_size, std::chrono::seconds timeout,
int stop_check_timeout) {
return std::make_shared<TCPStore>(hostname, port, is_master, return std::make_shared<TCPStore>(hostname, port, is_master,
world_size, timeout); world_size, timeout,
stop_check_timeout);
}), }),
py::arg("hostname"), py::arg("port"), py::arg("is_master"), py::arg("hostname"), py::arg("port"), py::arg("is_master"),
py::arg("world_size"), py::arg("world_size"),
py::arg("timeout") = distributed::tcputils::kNoTimeout, py::arg("timeout") = distributed::tcputils::kNoTimeout,
py::arg("stop_check_timeout") = 900,
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
} }
......
...@@ -233,8 +233,13 @@ def init_parallel_env(): ...@@ -233,8 +233,13 @@ def init_parallel_env():
master_addr, master_port = endpoints.split(":") master_addr, master_port = endpoints.split(":")
master_port = int(master_port) master_port = int(master_port)
is_master = rank == 0 is_master = rank == 0
default_store = core.TCPStore(master_addr, master_port, is_master, stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900"))
world_size) default_store = core.TCPStore(
master_addr,
master_port,
is_master,
world_size,
stop_check_timeout=stop_check_timeout)
_set_default_store(default_store) _set_default_store(default_store)
pg = _new_process_group_impl( pg = _new_process_group_impl(
backend, backend,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册