未验证 提交 b93e5119 编写于 作者: L LiYuRio 提交者: GitHub

Optimize the performance of tcp store initialized and connection (#49810)

* Optimize tcp store initialized

* fix: use phi::errors

---------
Co-authored-by: 元无心's avatarWen Sun <syl1887415157@126.com>
上级 6c650445
......@@ -71,6 +71,7 @@ void MasterDaemon::_do_add(SocketType socket) {
VLOG(4) << "TCPStore: new value (" << new_value << ") for key (" << key
<< ") " << GetSockName(socket);
tcputils::send_value<int64_t>(socket, new_value);
_notify_waiting_sockets(key);
}
void MasterDaemon::_do_set(SocketType socket) {
......@@ -79,6 +80,19 @@ void MasterDaemon::_do_set(SocketType socket) {
auto value = tcputils::receive_vector<uint8_t>(socket);
_store[key] = value;
_notify_waiting_sockets(key);
}
void MasterDaemon::_notify_waiting_sockets(const std::string& key) {
if (_waiting_sockets.find(key) != _waiting_sockets.end()) {
for (auto waiting_socket : _waiting_sockets.at(key)) {
auto reply = ReplyType::STOP_WAIT;
VLOG(3) << "TCPStore: nofify the socket: " << GetSockName(waiting_socket)
<< " that key: " << key << " is ready.";
tcputils::send_value<ReplyType>(waiting_socket, reply);
}
_waiting_sockets.erase(key);
}
}
void MasterDaemon::_do_get(SocketType socket) {
......@@ -136,13 +150,15 @@ void MasterDaemon::_do_wait(SocketType socket) {
<< GetSockName(socket);
auto iter = _store.find(key);
auto reply = ReplyType::STOP_WAIT;
if (iter == _store.end()) {
reply = ReplyType::WAITING;
// The key can not be found in store currently. Record and check later.
_waiting_sockets[key].emplace_back(socket);
} else {
auto reply = ReplyType::STOP_WAIT;
VLOG(3) << "TCPStore: wait reply (" << static_cast<int>(reply)
<< ") for key (" << key << ").";
tcputils::send_value<ReplyType>(socket, reply);
}
VLOG(3) << "TCPStore: wait reply (" << static_cast<int>(reply)
<< ") for key (" << key << ").";
tcputils::send_value<ReplyType>(socket, reply);
}
void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) {
......@@ -160,6 +176,7 @@ void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) {
continue;
}
VLOG(4) << "Plan to receive command from " << GetSockName(fds[i].fd);
Command command = tcputils::receive_value<Command>(fds[i].fd);
VLOG(3) << "TCPStore: recv command: " << static_cast<int>(command) << ".";
......@@ -177,10 +194,27 @@ void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) {
_do_wait(fds[i].fd);
break;
default:
LOG(WARNING) << "Unknown command: " << static_cast<int>(command)
<< " from addr info:" << GetSockName(fds[i].fd);
VLOG(4) << "Unknown command: " << static_cast<int>(command)
<< " from addr info:" << GetSockName(fds[i].fd);
}
} catch (const std::exception& ex) {
auto map_iter = _waiting_sockets.begin();
while (map_iter != _waiting_sockets.end()) {
auto vec_iter = map_iter->second.begin();
while (vec_iter != map_iter->second.end()) {
if (*vec_iter == fds[i].fd) {
vec_iter = map_iter->second.erase(vec_iter);
} else {
++vec_iter;
}
}
if (map_iter->second.empty()) {
map_iter = _waiting_sockets.erase(map_iter);
} else {
++map_iter;
}
}
tcputils::close_socket(fds[i].fd);
fds.erase(fds.begin() + i);
#ifdef _WIN32
......@@ -189,7 +223,7 @@ void MasterDaemon::ProcessCommands(std::vector<struct pollfd>* p_fds) {
_sockets.erase(_sockets.begin() + i - 2);
#endif
VLOG(3) << "Meet some exceptions during run:" << ex.what();
VLOG(5) << "Meet some exceptions during run:" << ex.what();
}
}
}
......@@ -328,34 +362,36 @@ void TCPStore::waitWorkers() {
}
add(_init_key, 1);
VLOG(3) << paddle::string::Sprintf("_timeout:%d", _timeout);
auto begin = std::chrono::steady_clock::now();
do {
auto value = get(_init_key);
int completed = std::stoi(std::string(value.begin(), value.end()));
VLOG(3) << completed << " worker ready, total " << _num_workers
<< ", _timeout:" << _timeout;
if (completed >= _num_workers) {
break;
}
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - begin);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
if (_timeout != 0 && elapsed.count() > _timeout) {
LOG(FATAL) << paddle::string::Sprintf(
"_timeout:%d elapsed:%d (elapsed > _timeout)=%d",
_timeout,
elapsed.count(),
elapsed.count() > _timeout);
PADDLE_ENFORCE_EQ(
completed,
_num_workers,
phi::errors::InvalidArgument(
"TCPStore timeouted and not all workers got ready."));
}
} while (true);
if (_is_master) {
VLOG(3) << paddle::string::Sprintf("_timeout:%d", _timeout);
auto begin = std::chrono::steady_clock::now();
do {
auto value = get(_init_key);
int completed = std::stoi(std::string(value.begin(), value.end()));
VLOG(3) << completed << " worker ready, total " << _num_workers
<< ", _timeout:" << _timeout;
if (completed >= _num_workers) {
break;
}
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::steady_clock::now() - begin);
std::this_thread::sleep_for(std::chrono::milliseconds(10));
if (_timeout != 0 && elapsed.count() > _timeout) {
LOG(FATAL) << paddle::string::Sprintf(
"_timeout:%d elapsed:%d (elapsed > _timeout)=%d",
_timeout,
elapsed.count(),
elapsed.count() > _timeout);
PADDLE_ENFORCE_EQ(
completed,
_num_workers,
phi::errors::InvalidArgument(
"TCPStore timeouted and not all workers got ready."));
}
} while (true);
}
VLOG(3) << "TCPStore initialized.";
}
......@@ -384,12 +420,9 @@ void TCPStore::wait(const std::string& key) {
VLOG(3) << "TCPStore wait.";
_client->send_command_for_key(Command::WAIT, _key_prefix + key);
reply = _client->receive_value<ReplyType>();
while (reply != ReplyType::STOP_WAIT) {
std::this_thread::sleep_for(std::chrono::milliseconds(500));
_client->send_command_for_key(Command::WAIT, _key_prefix + key);
reply = _client->receive_value<ReplyType>();
}
PADDLE_ENFORCE(
reply == ReplyType::STOP_WAIT,
phi::errors::InvalidArgument("Stop_waiting response is expected"));
}
TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; }
......
......@@ -60,12 +60,15 @@ class MasterDaemon {
void _do_wait(SocketType socket);
void _do_get(SocketType socket);
void _do_set(SocketType socket);
void _notify_waiting_sockets(const std::string&);
SocketType _listen_socket;
std::vector<SocketType> _sockets;
std::unordered_map<std::string, std::vector<uint8_t>> _store;
std::thread _background_thread{};
int _nranks = -1;
int _timeout = 0;
std::unordered_map<std::string, std::vector<SocketType>>
_waiting_sockets; // key -> list of waiting sockets
void InitControlFd();
void CloseControlFd();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册