diff --git a/oneflow/core/comm_network/epoll/io_event_poller.cpp b/oneflow/core/comm_network/epoll/io_event_poller.cpp index 020d6c0f8bdbb146ef85199d318a5922d00fd98f..c71add53c9bc19e5863a6ae2033a1cefbd39ede6 100644 --- a/oneflow/core/comm_network/epoll/io_event_poller.cpp +++ b/oneflow/core/comm_network/epoll/io_event_poller.cpp @@ -25,28 +25,44 @@ IOEventPoller::~IOEventPoller() { void IOEventPoller::AddFd(int fd, std::function read_handler, std::function write_handler) { + AddFd(fd, &read_handler, &write_handler); +} + +void IOEventPoller::AddFdWithOnlyReadHandler( + int fd, std::function read_handler) { + AddFd(fd, &read_handler, nullptr); +} + +void IOEventPoller::Start() { + thread_ = std::thread(&IOEventPoller::EpollLoop, this); +} + +void IOEventPoller::AddFd(int fd, std::function* read_handler, + std::function* write_handler) { unclosed_fd_cnt_ += 1; fds_.push_back(fd); // Set Fd NONBLOCK int opt = fcntl(fd, F_GETFL); PCHECK(opt != -1); PCHECK(fcntl(fd, F_SETFL, opt | O_NONBLOCK) == 0); + // Set CLOEXEC + opt = fcntl(fd, F_GETFD); + PCHECK(opt != -1); + PCHECK(fcntl(fd, F_SETFD, opt | FD_CLOEXEC) == 0); // New IOHandler on Heap IOHandler* io_handler = new IOHandler; - io_handler->read_handler = read_handler; - io_handler->write_handler = write_handler; + if (read_handler) { io_handler->read_handler = *read_handler; } + if (write_handler) { io_handler->write_handler = *write_handler; } io_handlers_.push_front(io_handler); // Add Fd to Epoll epoll_event ep_event; - ep_event.events = EPOLLIN | EPOLLOUT | EPOLLET; + ep_event.events = EPOLLET; + if (read_handler) { ep_event.events |= EPOLLIN; } + if (write_handler) { ep_event.events |= EPOLLOUT; } ep_event.data.ptr = io_handler; PCHECK(epoll_ctl(epfd_, EPOLL_CTL_ADD, fd, &ep_event) == 0); } -void IOEventPoller::Start() { - thread_ = std::thread(&IOEventPoller::EpollLoop, this); -} - void IOEventPoller::EpollLoop() { while (unclosed_fd_cnt_ > 0) { int event_num = epoll_wait(epfd_, ep_events_, max_event_num_, -1); diff --git a/oneflow/core/comm_network/epoll/io_event_poller.h b/oneflow/core/comm_network/epoll/io_event_poller.h index 6a56bd16c666872f7976efc5b9db29f6169d01f6..08b9822c4e15200086e604d7d0a2975c3394489f 100644 --- a/oneflow/core/comm_network/epoll/io_event_poller.h +++ b/oneflow/core/comm_network/epoll/io_event_poller.h @@ -15,15 +15,23 @@ class IOEventPoller final { void AddFd(int fd, std::function read_handler, std::function write_handler); + void AddFdWithOnlyReadHandler(int fd, std::function read_handler); void Start(); private: struct IOHandler { + IOHandler() { + read_handler = []() { UNEXPECTED_RUN(); }; + write_handler = []() { UNEXPECTED_RUN(); }; + } std::function read_handler; std::function write_handler; }; + void AddFd(int fd, std::function* read_handler, + std::function* write_handler); + void EpollLoop(); static const int max_event_num_; diff --git a/oneflow/core/comm_network/epoll/socket_read_helper.cpp b/oneflow/core/comm_network/epoll/socket_read_helper.cpp index 8298716e835e0cbd17a54508a9f2330d0c9d3ea0..f012afb9ff1485bd0a9c74fd15bf15440eb4901e 100644 --- a/oneflow/core/comm_network/epoll/socket_read_helper.cpp +++ b/oneflow/core/comm_network/epoll/socket_read_helper.cpp @@ -57,7 +57,7 @@ bool SocketReadHelper::DoCurRead( void SocketReadHelper::SetStatusWhenMsgHeadDone() { switch (cur_msg_.msg_type) { #define MAKE_ENTRY(x, y) \ - case SocketMsgType::k##x: SetStatusWhen##x##MsgHeadDone(); + case SocketMsgType::k##x: SetStatusWhen##x##MsgHeadDone(); break; OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, SOCKET_MSG_TYPE_SEQ); #undef MAKE_ENTRY default: UNEXPECTED_RUN(); diff --git a/oneflow/core/comm_network/epoll/socket_write_helper.cpp b/oneflow/core/comm_network/epoll/socket_write_helper.cpp index 065cf87f88fecc24aabea224384e084543b2008f..7749640e3f1c310c93ef048c55ea7533170e0f98 100644 --- a/oneflow/core/comm_network/epoll/socket_write_helper.cpp +++ b/oneflow/core/comm_network/epoll/socket_write_helper.cpp @@ -21,12 +21,9 @@ SocketWriteHelper::SocketWriteHelper(int sockfd, IOEventPoller* poller) { sockfd_ = sockfd; queue_not_empty_fd_ = eventfd(0, 0); PCHECK(queue_not_empty_fd_ != -1); - poller->AddFd(queue_not_empty_fd_, - std::bind(&SocketWriteHelper::ProcessQueueNotEmptyEvent, this), - [this]() { - // TODO: delete this log - LOG(INFO) << "fd " << queue_not_empty_fd_ << " writeable"; - }); + poller->AddFdWithOnlyReadHandler( + queue_not_empty_fd_, + std::bind(&SocketWriteHelper::ProcessQueueNotEmptyEvent, this)); cur_msg_queue_ = new std::queue; pending_msg_queue_ = new std::queue; cur_write_handle_ = &SocketWriteHelper::InitMsgWriteHandle;