diff --git a/fs/io_uring.c b/fs/io_uring.c index b476bd304045ca82e47329f05d50ee1dbab9177e..b0411406c50a76dda2b3cc57afc729bca337023d 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -289,7 +289,10 @@ struct io_ring_ctx { */ struct io_poll_iocb { struct file *file; - struct wait_queue_head *head; + union { + struct wait_queue_head *head; + u64 addr; + }; __poll_t events; bool done; bool canceled; @@ -2490,24 +2493,40 @@ static int io_poll_cancel(struct io_ring_ctx *ctx, __u64 sqe_addr) return -ENOENT; } +static int io_poll_remove_prep(struct io_kiocb *req) +{ + const struct io_uring_sqe *sqe = req->sqe; + + if (req->flags & REQ_F_PREPPED) + return 0; + if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL)) + return -EINVAL; + if (sqe->ioprio || sqe->off || sqe->len || sqe->buf_index || + sqe->poll_events) + return -EINVAL; + + req->poll.addr = READ_ONCE(sqe->addr); + req->flags |= REQ_F_PREPPED; + return 0; +} + /* * Find a running poll command that matches one specified in sqe->addr, * and remove it if found. */ static int io_poll_remove(struct io_kiocb *req) { - const struct io_uring_sqe *sqe = req->sqe; struct io_ring_ctx *ctx = req->ctx; + u64 addr; int ret; - if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL)) - return -EINVAL; - if (sqe->ioprio || sqe->off || sqe->len || sqe->buf_index || - sqe->poll_events) - return -EINVAL; + ret = io_poll_remove_prep(req); + if (ret) + return ret; + addr = req->poll.addr; spin_lock_irq(&ctx->completion_lock); - ret = io_poll_cancel(ctx, READ_ONCE(sqe->addr)); + ret = io_poll_cancel(ctx, addr); spin_unlock_irq(&ctx->completion_lock); io_cqring_add_event(req, ret); @@ -2642,16 +2661,14 @@ static void io_poll_req_insert(struct io_kiocb *req) hlist_add_head(&req->hash_node, list); } -static int io_poll_add(struct io_kiocb *req, struct io_kiocb **nxt) +static int io_poll_add_prep(struct io_kiocb *req) { const struct io_uring_sqe *sqe = req->sqe; struct io_poll_iocb *poll = &req->poll; - struct io_ring_ctx *ctx = req->ctx; - struct io_poll_table ipt; - bool cancel = false; - __poll_t mask; u16 events; + if (req->flags & REQ_F_PREPPED) + return 0; if (unlikely(req->ctx->flags & IORING_SETUP_IOPOLL)) return -EINVAL; if (sqe->addr || sqe->ioprio || sqe->off || sqe->len || sqe->buf_index) @@ -2659,9 +2676,26 @@ static int io_poll_add(struct io_kiocb *req, struct io_kiocb **nxt) if (!poll->file) return -EBADF; - INIT_IO_WORK(&req->work, io_poll_complete_work); + req->flags |= REQ_F_PREPPED; events = READ_ONCE(sqe->poll_events); poll->events = demangle_poll(events) | EPOLLERR | EPOLLHUP; + return 0; +} + +static int io_poll_add(struct io_kiocb *req, struct io_kiocb **nxt) +{ + struct io_poll_iocb *poll = &req->poll; + struct io_ring_ctx *ctx = req->ctx; + struct io_poll_table ipt; + bool cancel = false; + __poll_t mask; + int ret; + + ret = io_poll_add_prep(req); + if (ret) + return ret; + + INIT_IO_WORK(&req->work, io_poll_complete_work); INIT_HLIST_NODE(&req->hash_node); poll->head = NULL; @@ -3029,6 +3063,12 @@ static int io_req_defer_prep(struct io_kiocb *req) io_req_map_rw(req, ret, iovec, inline_vecs, &iter); ret = 0; break; + case IORING_OP_POLL_ADD: + ret = io_poll_add_prep(req); + break; + case IORING_OP_POLL_REMOVE: + ret = io_poll_remove_prep(req); + break; case IORING_OP_FSYNC: ret = io_prep_fsync(req); break;