未验证 提交 b5e60db4 编写于 作者: L lilong12 提交者: GitHub

Fix hang problem when some workers exited unexcepted (#42708)

* update
上级 1280f294
......@@ -19,21 +19,25 @@
#include "paddle/fluid/distributed/store/tcp_store.h"
#include "paddle/fluid/distributed/store/tcp_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/flags.h"
namespace paddle {
namespace distributed {
namespace detail {
constexpr int INFTIME = -1;
constexpr int INFTIME = 10000; // 10 seconds
std::unique_ptr<MasterDaemon> MasterDaemon::start(SocketType socket,
int nranks) {
return std::make_unique<MasterDaemon>(socket, nranks);
std::unique_ptr<MasterDaemon> MasterDaemon::start(SocketType socket, int nranks,
int stop_check_timeout) {
return std::make_unique<MasterDaemon>(socket, nranks, stop_check_timeout);
}
MasterDaemon::MasterDaemon(SocketType socket, int nranks)
: _listen_socket(socket), _nranks(nranks) {
MasterDaemon::MasterDaemon(SocketType socket, int nranks,
int stop_check_timeout)
: _listen_socket(socket),
_nranks(nranks),
_stop_check_timeout(stop_check_timeout) {
_background_thread = std::thread{&MasterDaemon::run, this};
}
......@@ -86,6 +90,10 @@ void MasterDaemon::_do_get(SocketType socket) {
void MasterDaemon::_do_stop(SocketType socket) {
VLOG(3) << "MasterDaemon::_do_stop";
if (!_has_stop) {
_stop_time = std::chrono::system_clock::now();
}
_has_stop = true;
ReplyType value = ReplyType::STOP_WAIT;
tcputils::send_value<ReplyType>(socket, value);
if (--_nranks == 0) {
......@@ -115,6 +123,20 @@ void MasterDaemon::run() {
#endif
while (!_stop) {
auto end_time = std::chrono::system_clock::now();
if (_has_stop) {
std::chrono::duration<double> diff = end_time - _stop_time;
int elapsed_seconds = static_cast<int>(diff.count());
PADDLE_ENFORCE_LT(
elapsed_seconds, _stop_check_timeout,
platform::errors::Fatal(
"%d seconds elapsed after the first worker "
"stopped, so we think there may be something wrong and will "
"stop the master worker. You can use "
"'export FLAGS_stop_check_timeout=3600'"
" to change the timeout value in seconds. The default one is 900",
elapsed_seconds));
}
for (size_t i = 0; i < fds.size(); i++) {
fds[i].revents = 0;
}
......@@ -173,10 +195,12 @@ void MasterDaemon::run() {
}
}
std::unique_ptr<TCPServer> TCPServer::create(uint16_t port, int nranks) {
std::unique_ptr<TCPServer> TCPServer::create(uint16_t port, int nranks,
int stop_check_timeout) {
int socket = tcputils::tcp_listen("", std::to_string(port), AF_INET);
auto server = std::make_unique<TCPServer>();
server->_master_daemon = MasterDaemon::start(socket, nranks);
server->_master_daemon =
MasterDaemon::start(socket, nranks, stop_check_timeout);
return server;
}
......@@ -219,10 +243,11 @@ std::vector<T> TCPClient::receive_vector() {
} // namespace detail
TCPStore::TCPStore(std::string host, uint16_t port, bool is_master,
size_t num_workers, std::chrono::seconds timeout)
size_t num_workers, std::chrono::seconds timeout,
int stop_check_timeout)
: Store(timeout), _is_master(is_master), _num_workers(num_workers) {
if (_is_master) {
_server = detail::TCPServer::create(port, num_workers);
_server = detail::TCPServer::create(port, num_workers, stop_check_timeout);
}
_client = detail::TCPClient::connect(host, port);
......
......@@ -34,9 +34,11 @@ namespace detail {
class MasterDaemon {
public:
static std::unique_ptr<MasterDaemon> start(SocketType listen_socket,
int nranks);
int nranks,
int stop_check_timeout);
MasterDaemon() = delete;
explicit MasterDaemon(SocketType listen_socket, int nranks);
explicit MasterDaemon(SocketType listen_socket, int nranks,
int stop_check_timeout);
~MasterDaemon();
private:
......@@ -51,13 +53,17 @@ class MasterDaemon {
std::unordered_map<std::string, std::vector<uint8_t>> _store;
std::thread _background_thread{};
int _nranks;
bool _stop = false;
int _stop_check_timeout;
bool _stop = false; // all workers stopped
std::chrono::time_point<std::chrono::system_clock> _stop_time;
bool _has_stop = false; // at least one worker stopped
};
class TCPServer {
public:
TCPServer() = default;
static std::unique_ptr<TCPServer> create(std::uint16_t port, int nranks);
static std::unique_ptr<TCPServer> create(std::uint16_t port, int nranks,
int stop_check_timeout);
private:
std::unique_ptr<MasterDaemon> _master_daemon;
......@@ -93,7 +99,8 @@ class TCPStore : public Store {
static constexpr std::uint16_t kDefaultPort = 6170;
explicit TCPStore(std::string host, uint16_t port = kDefaultPort,
bool is_master = false, size_t num_workers = 1,
std::chrono::seconds timeout = tcputils::kDefaultTimeout);
std::chrono::seconds timeout = tcputils::kDefaultTimeout,
int stop_check_timeout = 900);
~TCPStore();
......
......@@ -58,13 +58,16 @@ void BindTCPStore(py::module *m) {
py::class_<TCPStore, std::shared_ptr<TCPStore>>(*m, "TCPStore", Store)
.def(py::init([](std::string hostname, uint16_t port, bool is_master,
size_t world_size, std::chrono::seconds timeout) {
size_t world_size, std::chrono::seconds timeout,
int stop_check_timeout) {
return std::make_shared<TCPStore>(hostname, port, is_master,
world_size, timeout);
world_size, timeout,
stop_check_timeout);
}),
py::arg("hostname"), py::arg("port"), py::arg("is_master"),
py::arg("world_size"),
py::arg("timeout") = distributed::tcputils::kNoTimeout,
py::arg("stop_check_timeout") = 900,
py::call_guard<py::gil_scoped_release>());
}
......
......@@ -233,8 +233,13 @@ def init_parallel_env():
master_addr, master_port = endpoints.split(":")
master_port = int(master_port)
is_master = rank == 0
default_store = core.TCPStore(master_addr, master_port, is_master,
world_size)
stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900"))
default_store = core.TCPStore(
master_addr,
master_port,
is_master,
world_size,
stop_check_timeout=stop_check_timeout)
_set_default_store(default_store)
pg = _new_process_group_impl(
backend,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册