diff --git a/paddle/phi/core/distributed/store/tcp_store.cc b/paddle/phi/core/distributed/store/tcp_store.cc index 34aa24216826b368392670b504adff2c00a9336c..f938a3cc06f663d0773392f948d326861bab3e59 100644 --- a/paddle/phi/core/distributed/store/tcp_store.cc +++ b/paddle/phi/core/distributed/store/tcp_store.cc @@ -71,6 +71,7 @@ void MasterDaemon::_do_add(SocketType socket) { VLOG(4) << "TCPStore: new value (" << new_value << ") for key (" << key << ") " << GetSockName(socket); tcputils::send_value(socket, new_value); + _notify_waiting_sockets(key); } void MasterDaemon::_do_set(SocketType socket) { @@ -79,6 +80,19 @@ void MasterDaemon::_do_set(SocketType socket) { auto value = tcputils::receive_vector(socket); _store[key] = value; + _notify_waiting_sockets(key); +} + +void MasterDaemon::_notify_waiting_sockets(const std::string& key) { + if (_waiting_sockets.find(key) != _waiting_sockets.end()) { + for (auto waiting_socket : _waiting_sockets.at(key)) { + auto reply = ReplyType::STOP_WAIT; + VLOG(3) << "TCPStore: nofify the socket: " << GetSockName(waiting_socket) + << " that key: " << key << " is ready."; + tcputils::send_value(waiting_socket, reply); + } + _waiting_sockets.erase(key); + } } void MasterDaemon::_do_get(SocketType socket) { @@ -136,13 +150,15 @@ void MasterDaemon::_do_wait(SocketType socket) { << GetSockName(socket); auto iter = _store.find(key); - auto reply = ReplyType::STOP_WAIT; if (iter == _store.end()) { - reply = ReplyType::WAITING; + // The key can not be found in store currently. Record and check later. + _waiting_sockets[key].emplace_back(socket); + } else { + auto reply = ReplyType::STOP_WAIT; + VLOG(3) << "TCPStore: wait reply (" << static_cast(reply) + << ") for key (" << key << ")."; + tcputils::send_value(socket, reply); } - VLOG(3) << "TCPStore: wait reply (" << static_cast(reply) - << ") for key (" << key << ")."; - tcputils::send_value(socket, reply); } void MasterDaemon::ProcessCommands(std::vector* p_fds) { @@ -160,6 +176,7 @@ void MasterDaemon::ProcessCommands(std::vector* p_fds) { continue; } + VLOG(4) << "Plan to receive command from " << GetSockName(fds[i].fd); Command command = tcputils::receive_value(fds[i].fd); VLOG(3) << "TCPStore: recv command: " << static_cast(command) << "."; @@ -177,10 +194,27 @@ void MasterDaemon::ProcessCommands(std::vector* p_fds) { _do_wait(fds[i].fd); break; default: - LOG(WARNING) << "Unknown command: " << static_cast(command) - << " from addr info:" << GetSockName(fds[i].fd); + VLOG(4) << "Unknown command: " << static_cast(command) + << " from addr info:" << GetSockName(fds[i].fd); } } catch (const std::exception& ex) { + auto map_iter = _waiting_sockets.begin(); + while (map_iter != _waiting_sockets.end()) { + auto vec_iter = map_iter->second.begin(); + while (vec_iter != map_iter->second.end()) { + if (*vec_iter == fds[i].fd) { + vec_iter = map_iter->second.erase(vec_iter); + } else { + ++vec_iter; + } + } + if (map_iter->second.empty()) { + map_iter = _waiting_sockets.erase(map_iter); + } else { + ++map_iter; + } + } + tcputils::close_socket(fds[i].fd); fds.erase(fds.begin() + i); #ifdef _WIN32 @@ -189,7 +223,7 @@ void MasterDaemon::ProcessCommands(std::vector* p_fds) { _sockets.erase(_sockets.begin() + i - 2); #endif - VLOG(3) << "Meet some exceptions during run:" << ex.what(); + VLOG(5) << "Meet some exceptions during run:" << ex.what(); } } } @@ -328,34 +362,36 @@ void TCPStore::waitWorkers() { } add(_init_key, 1); - VLOG(3) << paddle::string::Sprintf("_timeout:%d", _timeout); - auto begin = std::chrono::steady_clock::now(); - do { - auto value = get(_init_key); - int completed = std::stoi(std::string(value.begin(), value.end())); - VLOG(3) << completed << " worker ready, total " << _num_workers - << ", _timeout:" << _timeout; - if (completed >= _num_workers) { - break; - } - const auto elapsed = std::chrono::duration_cast( - std::chrono::steady_clock::now() - begin); - - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - if (_timeout != 0 && elapsed.count() > _timeout) { - LOG(FATAL) << paddle::string::Sprintf( - "_timeout:%d elapsed:%d (elapsed > _timeout)=%d", - _timeout, - elapsed.count(), - elapsed.count() > _timeout); - - PADDLE_ENFORCE_EQ( - completed, - _num_workers, - phi::errors::InvalidArgument( - "TCPStore timeouted and not all workers got ready.")); - } - } while (true); + if (_is_master) { + VLOG(3) << paddle::string::Sprintf("_timeout:%d", _timeout); + auto begin = std::chrono::steady_clock::now(); + do { + auto value = get(_init_key); + int completed = std::stoi(std::string(value.begin(), value.end())); + VLOG(3) << completed << " worker ready, total " << _num_workers + << ", _timeout:" << _timeout; + if (completed >= _num_workers) { + break; + } + const auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - begin); + + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + if (_timeout != 0 && elapsed.count() > _timeout) { + LOG(FATAL) << paddle::string::Sprintf( + "_timeout:%d elapsed:%d (elapsed > _timeout)=%d", + _timeout, + elapsed.count(), + elapsed.count() > _timeout); + + PADDLE_ENFORCE_EQ( + completed, + _num_workers, + phi::errors::InvalidArgument( + "TCPStore timeouted and not all workers got ready.")); + } + } while (true); + } VLOG(3) << "TCPStore initialized."; } @@ -384,12 +420,9 @@ void TCPStore::wait(const std::string& key) { VLOG(3) << "TCPStore wait."; _client->send_command_for_key(Command::WAIT, _key_prefix + key); reply = _client->receive_value(); - while (reply != ReplyType::STOP_WAIT) { - std::this_thread::sleep_for(std::chrono::milliseconds(500)); - - _client->send_command_for_key(Command::WAIT, _key_prefix + key); - reply = _client->receive_value(); - } + PADDLE_ENFORCE( + reply == ReplyType::STOP_WAIT, + phi::errors::InvalidArgument("Stop_waiting response is expected")); } TCPStore::~TCPStore() { VLOG(3) << "TCPStore destructure"; } diff --git a/paddle/phi/core/distributed/store/tcp_store.h b/paddle/phi/core/distributed/store/tcp_store.h index 663275242d8ab4f28fc3ab296614accc09bc4e3b..0f17bc9b58bd45d27fa7dcc45b96aadc32538b35 100644 --- a/paddle/phi/core/distributed/store/tcp_store.h +++ b/paddle/phi/core/distributed/store/tcp_store.h @@ -60,12 +60,15 @@ class MasterDaemon { void _do_wait(SocketType socket); void _do_get(SocketType socket); void _do_set(SocketType socket); + void _notify_waiting_sockets(const std::string&); SocketType _listen_socket; std::vector _sockets; std::unordered_map> _store; std::thread _background_thread{}; int _nranks = -1; int _timeout = 0; + std::unordered_map> + _waiting_sockets; // key -> list of waiting sockets void InitControlFd(); void CloseControlFd();