diff --git a/fs/io_uring.c b/fs/io_uring.c index 0d973de7512776691cadf739530c75bebbf05859..8c976fde40bd69c2aac413ce34332db5600d6ae2 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -487,6 +487,7 @@ enum { REQ_F_COMP_LOCKED_BIT, REQ_F_NEED_CLEANUP_BIT, REQ_F_OVERFLOW_BIT, + REQ_F_POLLED_BIT, }; enum { @@ -529,6 +530,13 @@ enum { REQ_F_NEED_CLEANUP = BIT(REQ_F_NEED_CLEANUP_BIT), /* in overflow list */ REQ_F_OVERFLOW = BIT(REQ_F_OVERFLOW_BIT), + /* already went through poll handler */ + REQ_F_POLLED = BIT(REQ_F_POLLED_BIT), +}; + +struct async_poll { + struct io_poll_iocb poll; + struct io_wq_work work; }; /* @@ -562,27 +570,29 @@ struct io_kiocb { u8 opcode; struct io_ring_ctx *ctx; - union { - struct list_head list; - struct hlist_node hash_node; - }; - struct list_head link_list; + struct list_head list; unsigned int flags; refcount_t refs; + struct task_struct *task; u64 user_data; u32 result; u32 sequence; + struct list_head link_list; + struct list_head inflight_entry; union { /* * Only commands that never go async can use the below fields, - * obviously. Right now only IORING_OP_POLL_ADD uses them. + * obviously. Right now only IORING_OP_POLL_ADD uses them, and + * async armed poll handlers for regular commands. The latter + * restore the work, if needed. */ struct { - struct task_struct *task; struct callback_head task_work; + struct hlist_node hash_node; + struct async_poll *apoll; }; struct io_wq_work work; }; @@ -3563,9 +3573,209 @@ static int io_connect(struct io_kiocb *req, struct io_kiocb **nxt, #endif } -static bool io_poll_remove_one(struct io_kiocb *req) +struct io_poll_table { + struct poll_table_struct pt; + struct io_kiocb *req; + int error; +}; + +static void __io_queue_proc(struct io_poll_iocb *poll, struct io_poll_table *pt, + struct wait_queue_head *head) +{ + if (unlikely(poll->head)) { + pt->error = -EINVAL; + return; + } + + pt->error = 0; + poll->head = head; + add_wait_queue(head, &poll->wait); +} + +static void io_async_queue_proc(struct file *file, struct wait_queue_head *head, + struct poll_table_struct *p) +{ + struct io_poll_table *pt = container_of(p, struct io_poll_table, pt); + + __io_queue_proc(&pt->req->apoll->poll, pt, head); +} + +static int __io_async_wake(struct io_kiocb *req, struct io_poll_iocb *poll, + __poll_t mask, task_work_func_t func) +{ + struct task_struct *tsk; + + /* for instances that support it check for an event match first: */ + if (mask && !(mask & poll->events)) + return 0; + + trace_io_uring_task_add(req->ctx, req->opcode, req->user_data, mask); + + list_del_init(&poll->wait.entry); + + tsk = req->task; + req->result = mask; + init_task_work(&req->task_work, func); + /* + * If this fails, then the task is exiting. If that is the case, then + * the exit check will ultimately cancel these work items. Hence we + * don't need to check here and handle it specifically. + */ + task_work_add(tsk, &req->task_work, true); + wake_up_process(tsk); + return 1; +} + +static void io_async_task_func(struct callback_head *cb) +{ + struct io_kiocb *req = container_of(cb, struct io_kiocb, task_work); + struct async_poll *apoll = req->apoll; + struct io_ring_ctx *ctx = req->ctx; + + trace_io_uring_task_run(req->ctx, req->opcode, req->user_data); + + WARN_ON_ONCE(!list_empty(&req->apoll->poll.wait.entry)); + + if (hash_hashed(&req->hash_node)) { + spin_lock_irq(&ctx->completion_lock); + hash_del(&req->hash_node); + spin_unlock_irq(&ctx->completion_lock); + } + + /* restore ->work in case we need to retry again */ + memcpy(&req->work, &apoll->work, sizeof(req->work)); + + __set_current_state(TASK_RUNNING); + mutex_lock(&ctx->uring_lock); + __io_queue_sqe(req, NULL); + mutex_unlock(&ctx->uring_lock); + + kfree(apoll); +} + +static int io_async_wake(struct wait_queue_entry *wait, unsigned mode, int sync, + void *key) +{ + struct io_kiocb *req = wait->private; + struct io_poll_iocb *poll = &req->apoll->poll; + + trace_io_uring_poll_wake(req->ctx, req->opcode, req->user_data, + key_to_poll(key)); + + return __io_async_wake(req, poll, key_to_poll(key), io_async_task_func); +} + +static void io_poll_req_insert(struct io_kiocb *req) +{ + struct io_ring_ctx *ctx = req->ctx; + struct hlist_head *list; + + list = &ctx->cancel_hash[hash_long(req->user_data, ctx->cancel_hash_bits)]; + hlist_add_head(&req->hash_node, list); +} + +static __poll_t __io_arm_poll_handler(struct io_kiocb *req, + struct io_poll_iocb *poll, + struct io_poll_table *ipt, __poll_t mask, + wait_queue_func_t wake_func) + __acquires(&ctx->completion_lock) +{ + struct io_ring_ctx *ctx = req->ctx; + bool cancel = false; + + poll->file = req->file; + poll->head = NULL; + poll->done = poll->canceled = false; + poll->events = mask; + + ipt->pt._key = mask; + ipt->req = req; + ipt->error = -EINVAL; + + INIT_LIST_HEAD(&poll->wait.entry); + init_waitqueue_func_entry(&poll->wait, wake_func); + poll->wait.private = req; + + mask = vfs_poll(req->file, &ipt->pt) & poll->events; + + spin_lock_irq(&ctx->completion_lock); + if (likely(poll->head)) { + spin_lock(&poll->head->lock); + if (unlikely(list_empty(&poll->wait.entry))) { + if (ipt->error) + cancel = true; + ipt->error = 0; + mask = 0; + } + if (mask || ipt->error) + list_del_init(&poll->wait.entry); + else if (cancel) + WRITE_ONCE(poll->canceled, true); + else if (!poll->done) /* actually waiting for an event */ + io_poll_req_insert(req); + spin_unlock(&poll->head->lock); + } + + return mask; +} + +static bool io_arm_poll_handler(struct io_kiocb *req) +{ + const struct io_op_def *def = &io_op_defs[req->opcode]; + struct io_ring_ctx *ctx = req->ctx; + struct async_poll *apoll; + struct io_poll_table ipt; + __poll_t mask, ret; + + if (!req->file || !file_can_poll(req->file)) + return false; + if (req->flags & (REQ_F_MUST_PUNT | REQ_F_POLLED)) + return false; + if (!def->pollin && !def->pollout) + return false; + + apoll = kmalloc(sizeof(*apoll), GFP_ATOMIC); + if (unlikely(!apoll)) + return false; + + req->flags |= REQ_F_POLLED; + memcpy(&apoll->work, &req->work, sizeof(req->work)); + + /* + * Don't need a reference here, as we're adding it to the task + * task_works list. If the task exits, the list is pruned. + */ + req->task = current; + req->apoll = apoll; + INIT_HLIST_NODE(&req->hash_node); + + if (def->pollin) + mask = POLLIN | POLLRDNORM; + if (def->pollout) + mask |= POLLOUT | POLLWRNORM; + mask |= POLLERR | POLLPRI; + + ipt.pt._qproc = io_async_queue_proc; + + ret = __io_arm_poll_handler(req, &apoll->poll, &ipt, mask, + io_async_wake); + if (ret) { + ipt.error = 0; + apoll->poll.done = true; + spin_unlock_irq(&ctx->completion_lock); + memcpy(&req->work, &apoll->work, sizeof(req->work)); + kfree(apoll); + return false; + } + spin_unlock_irq(&ctx->completion_lock); + trace_io_uring_poll_arm(ctx, req->opcode, req->user_data, mask, + apoll->poll.events); + return true; +} + +static bool __io_poll_remove_one(struct io_kiocb *req, + struct io_poll_iocb *poll) { - struct io_poll_iocb *poll = &req->poll; bool do_complete = false; spin_lock(&poll->head->lock); @@ -3575,7 +3785,24 @@ static bool io_poll_remove_one(struct io_kiocb *req) do_complete = true; } spin_unlock(&poll->head->lock); + return do_complete; +} + +static bool io_poll_remove_one(struct io_kiocb *req) +{ + bool do_complete; + + if (req->opcode == IORING_OP_POLL_ADD) { + do_complete = __io_poll_remove_one(req, &req->poll); + } else { + /* non-poll requests have submit ref still */ + do_complete = __io_poll_remove_one(req, &req->apoll->poll); + if (do_complete) + io_put_req(req); + } + hash_del(&req->hash_node); + if (do_complete) { io_cqring_fill_event(req, -ECANCELED); io_commit_cqring(req->ctx); @@ -3686,8 +3913,13 @@ static void io_poll_task_func(struct callback_head *cb) struct io_kiocb *nxt = NULL; io_poll_task_handler(req, &nxt); - if (nxt) + if (nxt) { + struct io_ring_ctx *ctx = nxt->ctx; + + mutex_lock(&ctx->uring_lock); __io_queue_sqe(nxt, NULL); + mutex_unlock(&ctx->uring_lock); + } } static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync, @@ -3695,51 +3927,16 @@ static int io_poll_wake(struct wait_queue_entry *wait, unsigned mode, int sync, { struct io_kiocb *req = wait->private; struct io_poll_iocb *poll = &req->poll; - __poll_t mask = key_to_poll(key); - struct task_struct *tsk; - /* for instances that support it check for an event match first: */ - if (mask && !(mask & poll->events)) - return 0; - - list_del_init(&poll->wait.entry); - - tsk = req->task; - req->result = mask; - init_task_work(&req->task_work, io_poll_task_func); - task_work_add(tsk, &req->task_work, true); - wake_up_process(tsk); - return 1; + return __io_async_wake(req, poll, key_to_poll(key), io_poll_task_func); } -struct io_poll_table { - struct poll_table_struct pt; - struct io_kiocb *req; - int error; -}; - static void io_poll_queue_proc(struct file *file, struct wait_queue_head *head, struct poll_table_struct *p) { struct io_poll_table *pt = container_of(p, struct io_poll_table, pt); - if (unlikely(pt->req->poll.head)) { - pt->error = -EINVAL; - return; - } - - pt->error = 0; - pt->req->poll.head = head; - add_wait_queue(head, &pt->req->poll.wait); -} - -static void io_poll_req_insert(struct io_kiocb *req) -{ - struct io_ring_ctx *ctx = req->ctx; - struct hlist_head *list; - - list = &ctx->cancel_hash[hash_long(req->user_data, ctx->cancel_hash_bits)]; - hlist_add_head(&req->hash_node, list); + __io_queue_proc(&pt->req->poll, pt, head); } static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe) @@ -3757,7 +3954,10 @@ static int io_poll_add_prep(struct io_kiocb *req, const struct io_uring_sqe *sqe events = READ_ONCE(sqe->poll_events); poll->events = demangle_poll(events) | EPOLLERR | EPOLLHUP; - /* task will wait for requests on exit, don't need a ref */ + /* + * Don't need a reference here, as we're adding it to the task + * task_works list. If the task exits, the list is pruned. + */ req->task = current; return 0; } @@ -3767,46 +3967,15 @@ static int io_poll_add(struct io_kiocb *req, struct io_kiocb **nxt) struct io_poll_iocb *poll = &req->poll; struct io_ring_ctx *ctx = req->ctx; struct io_poll_table ipt; - bool cancel = false; __poll_t mask; INIT_HLIST_NODE(&req->hash_node); - - poll->head = NULL; - poll->done = false; - poll->canceled = false; - - ipt.pt._qproc = io_poll_queue_proc; - ipt.pt._key = poll->events; - ipt.req = req; - ipt.error = -EINVAL; /* same as no support for IOCB_CMD_POLL */ - - /* initialized the list so that we can do list_empty checks */ - INIT_LIST_HEAD(&poll->wait.entry); - init_waitqueue_func_entry(&poll->wait, io_poll_wake); - poll->wait.private = req; - INIT_LIST_HEAD(&req->list); + ipt.pt._qproc = io_poll_queue_proc; - mask = vfs_poll(poll->file, &ipt.pt) & poll->events; + mask = __io_arm_poll_handler(req, &req->poll, &ipt, poll->events, + io_poll_wake); - spin_lock_irq(&ctx->completion_lock); - if (likely(poll->head)) { - spin_lock(&poll->head->lock); - if (unlikely(list_empty(&poll->wait.entry))) { - if (ipt.error) - cancel = true; - ipt.error = 0; - mask = 0; - } - if (mask || ipt.error) - list_del_init(&poll->wait.entry); - else if (cancel) - WRITE_ONCE(poll->canceled, true); - else if (!poll->done) /* actually waiting for an event */ - io_poll_req_insert(req); - spin_unlock(&poll->head->lock); - } if (mask) { /* no async, we'd stolen it */ ipt.error = 0; io_poll_complete(req, mask, 0); @@ -4751,6 +4920,9 @@ static struct io_kiocb *io_prep_linked_timeout(struct io_kiocb *req) if (!(req->flags & REQ_F_LINK)) return NULL; + /* for polled retry, if flag is set, we already went through here */ + if (req->flags & REQ_F_POLLED) + return NULL; nxt = list_first_entry_or_null(&req->link_list, struct io_kiocb, link_list); @@ -4788,6 +4960,11 @@ static void __io_queue_sqe(struct io_kiocb *req, const struct io_uring_sqe *sqe) */ if (ret == -EAGAIN && (!(req->flags & REQ_F_NOWAIT) || (req->flags & REQ_F_MUST_PUNT))) { + if (io_arm_poll_handler(req)) { + if (linked_timeout) + io_queue_linked_timeout(linked_timeout); + goto done_req; + } punt: if (io_op_defs[req->opcode].file_table) { ret = io_grab_files(req); @@ -6782,6 +6959,17 @@ static void __io_uring_show_fdinfo(struct io_ring_ctx *ctx, struct seq_file *m) seq_printf(m, "Personalities:\n"); idr_for_each(&ctx->personality_idr, io_uring_show_cred, m); } + seq_printf(m, "PollList:\n"); + spin_lock_irq(&ctx->completion_lock); + for (i = 0; i < (1U << ctx->cancel_hash_bits); i++) { + struct hlist_head *list = &ctx->cancel_hash[i]; + struct io_kiocb *req; + + hlist_for_each_entry(req, list, hash_node) + seq_printf(m, " op=%d, task_works=%d\n", req->opcode, + req->task->task_works != NULL); + } + spin_unlock_irq(&ctx->completion_lock); mutex_unlock(&ctx->uring_lock); } @@ -6998,7 +7186,7 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p) p->features = IORING_FEAT_SINGLE_MMAP | IORING_FEAT_NODROP | IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS | - IORING_FEAT_CUR_PERSONALITY; + IORING_FEAT_CUR_PERSONALITY | IORING_FEAT_FAST_POLL; trace_io_uring_create(ret, ctx, p->sq_entries, p->cq_entries, p->flags); return ret; err: diff --git a/include/trace/events/io_uring.h b/include/trace/events/io_uring.h index 27bd9e4f927bb6f01415cba6566570b4c0a92652..9f0d3b7d56b0f31351e0d498be643524b3e895dc 100644 --- a/include/trace/events/io_uring.h +++ b/include/trace/events/io_uring.h @@ -357,6 +357,109 @@ TRACE_EVENT(io_uring_submit_sqe, __entry->force_nonblock, __entry->sq_thread) ); +TRACE_EVENT(io_uring_poll_arm, + + TP_PROTO(void *ctx, u8 opcode, u64 user_data, int mask, int events), + + TP_ARGS(ctx, opcode, user_data, mask, events), + + TP_STRUCT__entry ( + __field( void *, ctx ) + __field( u8, opcode ) + __field( u64, user_data ) + __field( int, mask ) + __field( int, events ) + ), + + TP_fast_assign( + __entry->ctx = ctx; + __entry->opcode = opcode; + __entry->user_data = user_data; + __entry->mask = mask; + __entry->events = events; + ), + + TP_printk("ring %p, op %d, data 0x%llx, mask 0x%x, events 0x%x", + __entry->ctx, __entry->opcode, + (unsigned long long) __entry->user_data, + __entry->mask, __entry->events) +); + +TRACE_EVENT(io_uring_poll_wake, + + TP_PROTO(void *ctx, u8 opcode, u64 user_data, int mask), + + TP_ARGS(ctx, opcode, user_data, mask), + + TP_STRUCT__entry ( + __field( void *, ctx ) + __field( u8, opcode ) + __field( u64, user_data ) + __field( int, mask ) + ), + + TP_fast_assign( + __entry->ctx = ctx; + __entry->opcode = opcode; + __entry->user_data = user_data; + __entry->mask = mask; + ), + + TP_printk("ring %p, op %d, data 0x%llx, mask 0x%x", + __entry->ctx, __entry->opcode, + (unsigned long long) __entry->user_data, + __entry->mask) +); + +TRACE_EVENT(io_uring_task_add, + + TP_PROTO(void *ctx, u8 opcode, u64 user_data, int mask), + + TP_ARGS(ctx, opcode, user_data, mask), + + TP_STRUCT__entry ( + __field( void *, ctx ) + __field( u8, opcode ) + __field( u64, user_data ) + __field( int, mask ) + ), + + TP_fast_assign( + __entry->ctx = ctx; + __entry->opcode = opcode; + __entry->user_data = user_data; + __entry->mask = mask; + ), + + TP_printk("ring %p, op %d, data 0x%llx, mask %x", + __entry->ctx, __entry->opcode, + (unsigned long long) __entry->user_data, + __entry->mask) +); + +TRACE_EVENT(io_uring_task_run, + + TP_PROTO(void *ctx, u8 opcode, u64 user_data), + + TP_ARGS(ctx, opcode, user_data), + + TP_STRUCT__entry ( + __field( void *, ctx ) + __field( u8, opcode ) + __field( u64, user_data ) + ), + + TP_fast_assign( + __entry->ctx = ctx; + __entry->opcode = opcode; + __entry->user_data = user_data; + ), + + TP_printk("ring %p, op %d, data 0x%llx", + __entry->ctx, __entry->opcode, + (unsigned long long) __entry->user_data) +); + #endif /* _TRACE_IO_URING_H */ /* This part must be outside protection */ diff --git a/include/uapi/linux/io_uring.h b/include/uapi/linux/io_uring.h index 08891cc1c1e784845abda5643209bc45c3fa075c..53b36311cdacb76d37c1252b29848798f7eeed5b 100644 --- a/include/uapi/linux/io_uring.h +++ b/include/uapi/linux/io_uring.h @@ -216,6 +216,7 @@ struct io_uring_params { #define IORING_FEAT_SUBMIT_STABLE (1U << 2) #define IORING_FEAT_RW_CUR_POS (1U << 3) #define IORING_FEAT_CUR_PERSONALITY (1U << 4) +#define IORING_FEAT_FAST_POLL (1U << 5) /* * io_uring_register(2) opcodes and arguments