#include "oneflow/core/comm_network/epoll/epoll_comm_network.h" #include "oneflow/core/control/ctrl_client.h" #include "oneflow/core/job/machine_context.h" #ifdef PLATFORM_POSIX namespace oneflow { namespace { sockaddr_in GetSockAddr(const std::string& addr, uint16_t port) { sockaddr_in sa; sa.sin_family = AF_INET; sa.sin_port = htons(port); PCHECK(inet_pton(AF_INET, addr.c_str(), &(sa.sin_addr)) == 1); return sa; } int32_t SockListen(int32_t listen_sockfd, uint16_t listen_port, int32_t total_machine_num) { sockaddr_in sa = GetSockAddr("0.0.0.0", listen_port); int32_t bind_result = bind(listen_sockfd, reinterpret_cast(&sa), sizeof(sa)); if (bind_result == 0) { PCHECK(listen(listen_sockfd, total_machine_num) == 0); LOG(INFO) << "CommNet:Epoll listening on " << "0.0.0.0:" + std::to_string(listen_port); } else { PCHECK(errno == EACCES || errno == EADDRINUSE); } return bind_result; } int64_t GetMachineId(const sockaddr_in& sa) { char addr[INET_ADDRSTRLEN]; memset(addr, '\0', sizeof(addr)); PCHECK(inet_ntop(AF_INET, &(sa.sin_addr), addr, INET_ADDRSTRLEN)); for (int64_t i = 0; i < Global::Get()->TotalMachineNum(); ++i) { if (Global::Get()->resource().machine(i).addr() == addr) { return i; } } UNIMPLEMENTED(); } std::string GenPortKey(int64_t machine_id) { return "EpollPort/" + std::to_string(machine_id); } void PushPort(int64_t machine_id, uint16_t port) { Global::Get()->PushKV(GenPortKey(machine_id), std::to_string(port)); } void ClearPort(int64_t machine_id) { Global::Get()->ClearKV(GenPortKey(machine_id)); } uint16_t PullPort(int64_t machine_id) { uint16_t port = 0; Global::Get()->PullKV( GenPortKey(machine_id), [&](const std::string& v) { port = oneflow_cast(v); }); return port; } } // namespace EpollCommNet::~EpollCommNet() { for (size_t i = 0; i < pollers_.size(); ++i) { LOG(INFO) << "CommNet Thread " << i << " finish"; pollers_.at(i)->Stop(); } OF_BARRIER(); for (IOEventPoller* poller : pollers_) { delete poller; } for (auto& pair : sockfd2helper_) { delete pair.second; } } void EpollCommNet::RegisterMemoryDone() { for (void* dst_token : mem_descs()) { dst_token2part_done_cnt_[dst_token] = 0; } } void EpollCommNet::SendActorMsg(int64_t dst_machine_id, const ActorMsg& actor_msg) const { SocketMsg msg; msg.msg_type = SocketMsgType::kActor; msg.actor_msg = actor_msg; GetSocketHelper(dst_machine_id, epoll_conf_.link_num() - 1)->AsyncWrite(msg); } void EpollCommNet::RequestRead(int64_t dst_machine_id, void* src_token, void* dst_token, void* read_id) const { int32_t total_byte_size = static_cast(src_token)->byte_size; CHECK_GT(total_byte_size, 0); int32_t part_length = (total_byte_size + epoll_conf_.link_num() - 1) / epoll_conf_.link_num(); part_length = RoundUp(part_length, epoll_conf_.msg_segment_kbyte() * 1024); int32_t part_num = (total_byte_size + part_length - 1) / part_length; CHECK_LE(part_num, epoll_conf_.link_num()); for (int32_t link_i = 0; link_i < part_num; ++link_i) { int32_t byte_size = (total_byte_size > part_length) ? (part_length) : (total_byte_size); CHECK_GT(byte_size, 0); total_byte_size -= byte_size; SocketMsg msg; msg.msg_type = SocketMsgType::kRequestRead; msg.request_read_msg.src_machine_id = Global::Get()->this_machine_id(); msg.request_read_msg.src_token = src_token; msg.request_read_msg.dst_token = dst_token; msg.request_read_msg.offset = link_i * part_length; msg.request_read_msg.byte_size = byte_size; msg.request_read_msg.read_id = read_id; msg.request_read_msg.part_num = part_num; GetSocketHelper(dst_machine_id, link_i)->AsyncWrite(msg); } CHECK_EQ(total_byte_size, 0); } SocketMemDesc* EpollCommNet::NewMemDesc(void* ptr, size_t byte_size) const { SocketMemDesc* mem_desc = new SocketMemDesc; mem_desc->mem_ptr = ptr; mem_desc->byte_size = byte_size; return mem_desc; } EpollCommNet::EpollCommNet(const Plan& plan) : CommNetIf(plan), epoll_conf_(Global::Get()->epoll_conf()) { pollers_.resize(Global::Get()->CommNetWorkerNum(), nullptr); for (size_t i = 0; i < pollers_.size(); ++i) { pollers_.at(i) = new IOEventPoller; } InitSockets(); for (IOEventPoller* poller : pollers_) { poller->Start(); } } void EpollCommNet::InitSockets() { int64_t this_machine_id = Global::Get()->this_machine_id(); auto this_machine = Global::Get()->resource().machine(this_machine_id); int64_t total_machine_num = Global::Get()->TotalMachineNum(); machine_link_id2sockfds_.assign(total_machine_num * epoll_conf_.link_num(), -1); sockfd2helper_.clear(); size_t poller_idx = 0; auto NewSocketHelper = [&](int32_t sockfd) { IOEventPoller* poller = pollers_.at(poller_idx); poller_idx = (poller_idx + 1) % pollers_.size(); return new SocketHelper(sockfd, poller); }; // listen int32_t listen_sockfd = socket(AF_INET, SOCK_STREAM, 0); int32_t this_listen_port = Global::Get()->resource().data_port(); if (this_listen_port != -1) { CHECK_EQ(SockListen(listen_sockfd, this_listen_port, total_machine_num), 0); PushPort(this_machine_id, ((this_machine.data_port_agent() != -1) ? (this_machine.data_port_agent()) : (this_listen_port))); } else { for (this_listen_port = 1024; this_listen_port < MaxVal(); ++this_listen_port) { if (SockListen(listen_sockfd, this_listen_port, total_machine_num) == 0) { PushPort(this_machine_id, this_listen_port); break; } } CHECK_LT(this_listen_port, MaxVal()); } int32_t src_machine_count = 0; // connect for (int64_t peer_mchn_id : peer_machine_id()) { if (peer_mchn_id < this_machine_id) { ++src_machine_count; continue; } uint16_t peer_port = PullPort(peer_mchn_id); auto peer_machine = Global::Get()->resource().machine(peer_mchn_id); sockaddr_in peer_sockaddr = GetSockAddr(peer_machine.addr(), peer_port); for (int32_t link_i = 0; link_i < epoll_conf_.link_num(); ++link_i) { int32_t sockfd = socket(AF_INET, SOCK_STREAM, 0); PCHECK(connect(sockfd, reinterpret_cast(&peer_sockaddr), sizeof(peer_sockaddr)) == 0); CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second); machine_link_id2sockfds_.at(peer_mchn_id * epoll_conf_.link_num() + link_i) = sockfd; } } // accept FOR_RANGE(int32_t, idx, 0, src_machine_count) { sockaddr_in peer_sockaddr; socklen_t len = sizeof(peer_sockaddr); for (int32_t link_i = 0; link_i < epoll_conf_.link_num(); ++link_i) { int32_t sockfd = accept(listen_sockfd, reinterpret_cast(&peer_sockaddr), &len); PCHECK(sockfd != -1); CHECK(sockfd2helper_.emplace(sockfd, NewSocketHelper(sockfd)).second); int64_t peer_mchn_id = GetMachineId(peer_sockaddr); machine_link_id2sockfds_.at(peer_mchn_id * epoll_conf_.link_num() + link_i) = sockfd; } } PCHECK(close(listen_sockfd) == 0); ClearPort(this_machine_id); // useful log for (int64_t peer_mchn_id : peer_machine_id()) { FOR_RANGE(int32_t, link_i, 0, epoll_conf_.link_num()) { int32_t sockfd = machine_link_id2sockfds_.at(peer_mchn_id * epoll_conf_.link_num() + link_i); CHECK_GT(sockfd, 0); LOG(INFO) << "machine: " << peer_mchn_id << ", link index: " << link_i << ", sockfd: " << sockfd; } } } SocketHelper* EpollCommNet::GetSocketHelper(int64_t machine_id, int32_t link_index) const { int32_t sockfd = machine_link_id2sockfds_.at(machine_id * epoll_conf_.link_num() + link_index); return sockfd2helper_.at(sockfd); } void EpollCommNet::DoRead(void* read_id, int64_t src_machine_id, void* src_token, void* dst_token) { SocketMsg msg; msg.msg_type = SocketMsgType::kRequestWrite; msg.request_write_msg.src_token = src_token; msg.request_write_msg.dst_machine_id = Global::Get()->this_machine_id(); msg.request_write_msg.dst_token = dst_token; msg.request_write_msg.read_id = read_id; dst_token2part_done_cnt_.at(dst_token) = 0; GetSocketHelper(src_machine_id, epoll_conf_.link_num() - 1)->AsyncWrite(msg); { std::unique_lock lck(read_done_mtx_); machine_id2read_done_order_[src_machine_id].push(read_id); read_id2done_status_.emplace(read_id, false); } } void EpollCommNet::PartReadDone(void* read_id, int64_t src_machine_id, void* dst_token, int32_t part_num) { if (dst_token2part_done_cnt_.at(dst_token).fetch_add(1, std::memory_order_relaxed) == (part_num - 1)) { { std::unique_lock lck(read_done_mtx_); read_id2done_status_.at(read_id) = true; auto& read_done_order = machine_id2read_done_order_.at(src_machine_id); while (read_id2done_status_.at(read_done_order.front())) { void* item = read_done_order.front(); ReadDone(item); read_id2done_status_.erase(item); read_done_order.pop(); } } } } } // namespace oneflow #endif // PLATFORM_POSIX