未验证 提交 b93e5119 编写于 作者: L LiYuRio 提交者: GitHub

Optimize the performance of tcp store initialized and connection (#49810)

* Optimize tcp store initialized

* fix: use phi::errors

---------
Co-authored-by: 元无心's avatarWen Sun <syl1887415157@126.com>
上级 6c650445
...@@ -71,6 +71,7 @@ void MasterDaemon::_do_add(SocketType socket) { ...@@ -71,6 +71,7 @@ void MasterDaemon::_do_add(SocketType socket) {
VLOG(4) << "TCPStore: new value (" << new_value << ") for key (" << key VLOG(4) << "TCPStore: new value (" << new_value << ") for key (" << key
<< ") " << GetSockName(socket); << ") " << GetSockName(socket);
tcputils::send_value<int64_t>(socket, new_value); tcputils::send_value<int64_t>(socket, new_value);
_notify_waiting_sockets(key);
} }
void MasterDaemon::_do_set(SocketType socket) { void MasterDaemon::_do_set(SocketType socket) {
...@@ -79,6 +80,19 @@ void MasterDaemon::_do_set(SocketType socket) { ...@@ -79,6 +80,19 @@ void MasterDaemon::_do_set(SocketType socket) {
auto value = tcputils::receive_vector<uint8_t>(socket); auto value = tcputils::receive_vector<uint8_t>(socket);
_store[key] = value; _store[key] = value;
_notify_waiting_sockets(key);
}
void MasterDaemon::_notify_waiting_sockets(const std::string& key) {
if (_waiting_sockets.find(key) != _waiting_sockets.end()) {
for (auto waiting_socket : _waiting_sockets.at(key)) {
auto reply = ReplyType::STOP_WAIT;
VLOG(3) << "TCPStore: nofify the socket: " << GetSockName(waiting_socket)
<< " that key: " << key << " is ready.";
tcputils::send_value<ReplyType>(waiting_socket, reply);
}
_waiting_sockets.erase(key);
}
} }
void MasterDaemon::_do_get(SocketType socket) { void MasterDaemon::_do_get(SocketType socket) {
...@@ -136,13 +150,15 @@ void MasterDaemon::_do_wait(SocketType socket) { ...@@ -136,13 +150,15 @@ void MasterDaemon::_do_wait(SocketType socket) {
<< GetSockName(socket); << GetSockName(socket);
auto iter = _store.find(key); auto iter = _store.find(key);
auto reply = ReplyType::STOP_WAIT;
if (iter == _store.end()) { if (iter == _store.end()) {
reply = ReplyType::WAITING; // The key can not be found in store currently. Record and check later.
_waiting_sockets[key].emplace_back(socket);
} else {
auto reply = ReplyType::STOP_WAIT;
VLOG(3) << "TCPStore: wait reply (" << static_cast<int>(reply)
<< ") for key (" << key << ").";
tcputils::send_value<ReplyType>(socket, reply);
} }
VLOG(3) << "TCPStore: wait reply (" << static_cast<int>(reply)
<< ") for key (" << key << ").";
tcputils::send_value<ReplyType>(socket, reply);
} }
void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) { void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) {
...@@ -160,6 +176,7 @@ void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) { ...@@ -160,6 +176,7 @@ void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) {
continue; continue;
} }
VLOG(4) << "Plan to receive command from " << GetSockName(fds[i].fd);
Command command = tcputils::receive_value<Command>(fds[i].fd); Command command = tcputils::receive_value<Command>(fds[i].fd);
VLOG(3) << "TCPStore: recv command: " << static_cast<int>(command) << "."; VLOG(3) << "TCPStore: recv command: " << static_cast<int>(command) << ".";
...@@ -177,10 +194,27 @@ void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) { ...@@ -177,10 +194,27 @@ void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) {
_do_wait(fds[i].fd); _do_wait(fds[i].fd);
break; break;
default: default:
LOG(WARNING) << "Unknown command: " << static_cast<int>(command) VLOG(4) << "Unknown command: " << static_cast<int>(command)
<< " from addr info:" << GetSockName(fds[i].fd); << " from addr info:" << GetSockName(fds[i].fd);
} }
} catch (const std::exception& ex) { } catch (const std::exception& ex) {
auto map_iter = _waiting_sockets.begin();
while (map_iter != _waiting_sockets.end()) {
auto vec_iter = map_iter->second.begin();
while (vec_iter != map_iter->second.end()) {
if (*vec_iter == fds[i].fd) {
vec_iter = map_iter->second.erase(vec_iter);
} else {
++vec_iter;
}
}
if (map_iter->second.empty()) {
map_iter = _waiting_sockets.erase(map_iter);
} else {
++map_iter;
}
}
tcputils::close_socket(fds[i].fd); tcputils::close_socket(fds[i].fd);
fds.erase(fds.begin() + i); fds.erase(fds.begin() + i);
#ifdef _WIN32 #ifdef _WIN32
...@@ -189,7 +223,7 @@ void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) { ...@@ -189,7 +223,7 @@ void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) {
_sockets.erase(_sockets.begin() + i - 2); _sockets.erase(_sockets.begin() + i - 2);
#endif #endif
VLOG(3) << "Meet some exceptions during run:" << ex.what(); VLOG(5) << "Meet some exceptions during run:" << ex.what();
} }
} }
} }
...@@ -328,34 +362,36 @@ void TCPStore::waitWorkers() { ...@@ -328,34 +362,36 @@ void TCPStore::waitWorkers() {
} }
add(_init_key, 1); add(_init_key, 1);
VLOG(3) << paddle::string::Sprintf("_timeout:%d", _timeout); if (_is_master) {
auto begin = std::chrono::steady_clock::now(); VLOG(3) << paddle::string::Sprintf("_timeout:%d", _timeout);
do { auto begin = std::chrono::steady_clock::now();
auto value = get(_init_key); do {
int completed = std::stoi(std::string(value.begin(), value.end())); auto value = get(_init_key);
VLOG(3) << completed << " worker ready, total " << _num_workers int completed = std::stoi(std::string(value.begin(), value.end()));
<< ", _timeout:" << _timeout; VLOG(3) << completed << " worker ready, total " << _num_workers
if (completed >= _num_workers) { << ", _timeout:" << _timeout;
break; if (completed >= _num_workers) {
} break;
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>( }
std::chrono::steady_clock::now() - begin); const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - begin);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (_timeout != 0 && elapsed.count() > _timeout) { std::this_thread::sleep_for(std::chrono::milliseconds(10));
LOG(FATAL) << paddle::string::Sprintf( if (_timeout != 0 && elapsed.count() > _timeout) {
"_timeout:%d elapsed:%d (elapsed > _timeout)=%d", LOG(FATAL) << paddle::string::Sprintf(
_timeout, "_timeout:%d elapsed:%d (elapsed > _timeout)=%d",
elapsed.count(), _timeout,
elapsed.count() > _timeout); elapsed.count(),
elapsed.count() > _timeout);
PADDLE_ENFORCE_EQ(
completed, PADDLE_ENFORCE_EQ(
_num_workers, completed,
phi::errors::InvalidArgument( _num_workers,
"TCPStore timeouted and not all workers got ready.")); phi::errors::InvalidArgument(
} "TCPStore timeouted and not all workers got ready."));
} while (true); }
} while (true);
}
VLOG(3) << "TCPStore initialized."; VLOG(3) << "TCPStore initialized.";
} }
...@@ -384,12 +420,9 @@ void TCPStore::wait(const std::string& key) { ...@@ -384,12 +420,9 @@ void TCPStore::wait(const std::string& key) {
VLOG(3) << "TCPStore wait."; VLOG(3) << "TCPStore wait.";
_client->send_command_for_key(Command::WAIT, _key_prefix + key); _client->send_command_for_key(Command::WAIT, _key_prefix + key);
reply = _client->receive_value<ReplyType>(); reply = _client->receive_value<ReplyType>();
while (reply != ReplyType::STOP_WAIT) { PADDLE_ENFORCE(
std::this_thread::sleep_for(std::chrono::milliseconds(500)); reply == ReplyType::STOP_WAIT,
phi::errors::InvalidArgument("Stop_waiting response is expected"));
_client->send_command_for_key(Command::WAIT, _key_prefix + key);
reply = _client->receive_value<ReplyType>();
}
} }
TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; } TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; }
......
...@@ -60,12 +60,15 @@ class MasterDaemon { ...@@ -60,12 +60,15 @@ class MasterDaemon {
void _do_wait(SocketType socket); void _do_wait(SocketType socket);
void _do_get(SocketType socket); void _do_get(SocketType socket);
void _do_set(SocketType socket); void _do_set(SocketType socket);
void _notify_waiting_sockets(const std::string&);
SocketType _listen_socket; SocketType _listen_socket;
std::vector<SocketType> _sockets; std::vector<SocketType> _sockets;
std::unordered_map<std::string, std::vector<uint8_t>> _store; std::unordered_map<std::string, std::vector<uint8_t>> _store;
std::thread _background_thread{}; std::thread _background_thread{};
int _nranks = -1; int _nranks = -1;
int _timeout = 0; int _timeout = 0;
std::unordered_map<std::string, std::vector<SocketType>>
_waiting_sockets; // key -> list of waiting sockets
void InitControlFd(); void InitControlFd();
void CloseControlFd(); void CloseControlFd();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册