/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL) #include "paddle/fluid/platform/gen_comm_id_helper.h" #include #include #include #include #include #include #include #include "glog/logging.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/string/split.h" namespace paddle { namespace platform { constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; // Check system calls, such as socket, bind. #define CHECK_SYS_CALL(call, name) \ do { \ int retval; \ CHECK_SYS_CALL_VAL(call, name, retval); \ } while (false) #define CHECK_SYS_CALL_VAL(call, name, retval) \ do { \ RETRY_SYS_CALL_VAL(call, name, retval); \ if (retval == -1) { \ PADDLE_THROW(platform::errors::Unavailable("Call to %s failed: %s", \ name, strerror(errno))); \ } \ } while (false) #define RETRY_SYS_CALL_VAL(call, name, retval) \ do { \ retval = (call); \ if (retval == -1 && \ (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \ LOG(WARNING) << "Call " << name << " returned " << strerror(errno) \ << " retry"; \ } else { \ break; \ } \ } while (true) static int SocketSend(int fd, const char* buffer, int size) { int offset = 0; int bytes = 0; while (offset < size) { bytes = send(fd, buffer + offset, size - offset, 0); if (bytes == -1) { if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { // send failed return -1; } else { bytes = 0; } } offset += bytes; } return offset; } static int SocketRecv(int fd, char* buffer, int size) { int offset = 0; int bytes = 0; while (offset < size) { bytes = recv(fd, buffer + offset, size - offset, 0); if (bytes == 0) { // closed by client, maybe probing alive client return 0; } if (bytes == -1) { if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { return -1; } else { bytes = 0; } } offset += bytes; } return offset; } static void BindOrConnectFailed(int timeout, int* try_times, int* total_time, const char* op, const std::string& ep) { PADDLE_ENFORCE_LT( *total_time, timeout, platform::errors::Unavailable("%s addr=%s timeout, failed reason: %s", op, ep.c_str(), strerror(errno))); ++(*try_times); int retry_time = std::min(*try_times * 500, 3000); // max 3 seconds *total_time += retry_time; LOG(WARNING) << op << " addr=" << ep << " failed " << *try_times << " times with reason: " << strerror(errno) << " retry after " << retry_time / 1000.0 << " seconds"; std::this_thread::sleep_for(std::chrono::milliseconds(retry_time)); } int CreateListenSocket(const std::string& ep) { auto addr = paddle::string::Split(ep, ':'); PADDLE_ENFORCE_EQ( addr.size(), 2UL, platform::errors::InvalidArgument( "The endpoint should contain host and port, but got %s.", ep)); std::string host = addr[0]; int port = std::stoi(addr[1]); // creating socket fd int server_fd = -1; CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", server_fd); // NOTE. Solutions to `Address already in use`. // 1. Reuse addr&port. Otherwise, once the server closes the socket // before client, the server will enter TIME-WAIT status. If we bind port // again, the error `Address already in use` will appear. // 2. Or we can close the client first to ensure that the server does // not enter the TIME-WAIT state. But this is obviously not as convenient // as the reuse method. int opt = 1; #if defined(SO_REUSEPORT) // since Linux kernel 3.9 CHECK_SYS_CALL(setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt"); #else CHECK_SYS_CALL( setsockopt(server_fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt"); #endif struct sockaddr_in address; address.sin_family = AF_INET; address.sin_addr.s_addr = INADDR_ANY; address.sin_port = htons(port); // TODO(wangxi) Set from env, default 900s=15min int timeout = 900 * 1000; int try_times = 0; int total_time = 0; while (true) { int ret_val = -1; RETRY_SYS_CALL_VAL( bind(server_fd, (struct sockaddr*)&address, sizeof(address)), "bind", ret_val); if (ret_val == -1) { BindOrConnectFailed(timeout, &try_times, &total_time, "bind", ep); continue; } break; } CHECK_SYS_CALL(listen(server_fd, 3), "listen"); LOG(INFO) << "Server listening on: " << ep << " successful."; return server_fd; } void CloseSocket(int fd) { CHECK_SYS_CALL(close(fd), "close"); } static int SocketAccept(int server_fd, const char* head) { struct sockaddr_in client_addr; socklen_t addr_length = sizeof(client_addr); char buffer[1024] = {0}; int conn = -1; while (true) { CHECK_SYS_CALL_VAL( accept(server_fd, reinterpret_cast(&client_addr), &addr_length), "accept", conn); int ret_val = SocketRecv(conn, buffer, strlen(head)); if (ret_val > 0 && strncmp(buffer, head, strlen(head)) == 0) { break; // accept client } else { VLOG(3) << "socket read failed with ret_val=" << ret_val; CloseSocket(conn); } } return conn; } static int ConnectAddr(const std::string& ep, const char* head) { auto addr = paddle::string::Split(ep, ':'); PADDLE_ENFORCE_EQ( addr.size(), 2UL, platform::errors::InvalidArgument( "The endpoint should contain host and port, but got %s.", ep)); std::string host = addr[0]; int port = std::stoi(addr[1]); int sock = -1; CHECK_SYS_CALL_VAL(socket(AF_INET, SOCK_STREAM, 0), "socket", sock); struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_port = htons(port); char* ip = NULL; struct hostent* hp = NULL; hp = gethostbyname(host.c_str()); PADDLE_ENFORCE_NOT_NULL(hp, platform::errors::InvalidArgument( "Fail to get host by name %s.", host)); int i = 0; while (hp->h_addr_list[i] != NULL) { ip = inet_ntoa(*(struct in_addr*)hp->h_addr_list[i]); VLOG(3) << "gethostbyname host:" << host << " ->ip: " << ip; break; } PADDLE_ENFORCE_GT(inet_pton(AF_INET, ip, &server_addr.sin_addr), 0, platform::errors::Unavailable("Open address %s failed: %s", ep, strerror(errno))); // TODO(wangxi) Set from env, default 900s=15min int timeout = 900 * 1000; int try_times = 0; int total_time = 0; while (true) { int ret_val = -1; RETRY_SYS_CALL_VAL( connect(sock, (struct sockaddr*)&server_addr, sizeof(server_addr)), "connect", ret_val); if (ret_val == -1) { BindOrConnectFailed(timeout, &try_times, &total_time, "connect", ep); continue; } CHECK_SYS_CALL(SocketSend(sock, head, strlen(head)), "send"); break; } return sock; } template static void RecvCommID(int conn, CommUniqueId* nccl_id) { char buffer[1024] = {0}; static_assert(sizeof(CommUniqueId) <= 1024, "nccl id bytes must <= buffer size"); CHECK_SYS_CALL(SocketRecv(conn, buffer, sizeof(CommUniqueId)), "recv comm unique id"); memcpy(nccl_id, buffer, sizeof(CommUniqueId)); } template static void SendCommID(int conn, CommUniqueId* nccl_id) { char buffer[1024] = {0}; memcpy(buffer, nccl_id, sizeof(CommUniqueId)); CHECK_SYS_CALL(SocketSend(conn, buffer, sizeof(CommUniqueId)), "send comm unique id"); } template void SendBroadCastCommID(std::vector servers, std::vector* nccl_ids) { // connect with server std::vector connects; for (auto server : servers) { VLOG(3) << "connecting endpoint: " << server; int conn = ConnectAddr(server, COMM_HEAD); connects.push_back(conn); } VLOG(3) << "connecting completed..."; for (size_t i = 0; i < nccl_ids->size(); ++i) { int j = 0; for (auto conn : connects) { VLOG(3) << "sending comm_id to " << servers[j] << " nccl_comm_no: " << i; SendCommID(conn, &(*nccl_ids)[i]); ++j; } } // close client for (auto conn : connects) { CloseSocket(conn); } } template void RecvBroadCastCommID(std::string endpoint, std::vector* nccl_ids) { int server = CreateListenSocket(endpoint); RecvBroadCastCommID(server, endpoint, nccl_ids); CloseSocket(server); } template void RecvBroadCastCommID(int server_fd, std::string endpoint, std::vector* nccl_ids) { int client = SocketAccept(server_fd, COMM_HEAD); for (size_t i = 0; i < nccl_ids->size(); ++i) { VLOG(3) << "trainer: " << endpoint << " receiving comm_id from trainer 0, nccl_comm_no: " << i; RecvCommID(client, &(*nccl_ids)[i]); } VLOG(3) << "receiving completed..."; CloseSocket(client); } /// template instantiation #define INSTANT_TEMPLATE(Type) \ template void SendBroadCastCommID(std::vector servers, \ std::vector * nccl_ids); \ template void RecvBroadCastCommID(std::string endpoint, \ std::vector * nccl_ids); #ifdef PADDLE_WITH_NCCL INSTANT_TEMPLATE(ncclUniqueId) #endif #ifdef PADDLE_WITH_XPU_BKCL INSTANT_TEMPLATE(BKCLUniqueId) #endif } // namespace platform } // namespace paddle #endif