diff --git a/fs/io_uring.c b/fs/io_uring.c index 95501f8793640163c073f2b02dfd8e0fa622e4c6..d1142c12186598d43e8a7c64e4fbca1bec1f02b4 100644 --- a/fs/io_uring.c +++ b/fs/io_uring.c @@ -5812,8 +5812,7 @@ static void io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req, } static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr, - struct file *ring_file, int ring_fd, - struct mm_struct **mm, bool async) + struct file *ring_file, int ring_fd, bool async) { struct io_submit_state state, *statep = NULL; struct io_kiocb *link = NULL; @@ -5870,13 +5869,12 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr, break; } - if (io_op_defs[req->opcode].needs_mm && !*mm) { + if (io_op_defs[req->opcode].needs_mm && !current->mm) { if (unlikely(!mmget_not_zero(ctx->sqo_mm))) { err = -EFAULT; goto fail_req; } use_mm(ctx->sqo_mm); - *mm = ctx->sqo_mm; } req->needs_fixed_file = async; @@ -5902,10 +5900,19 @@ static int io_submit_sqes(struct io_ring_ctx *ctx, unsigned int nr, return submitted; } +static inline void io_sq_thread_drop_mm(struct io_ring_ctx *ctx) +{ + struct mm_struct *mm = current->mm; + + if (mm) { + unuse_mm(mm); + mmput(mm); + } +} + static int io_sq_thread(void *data) { struct io_ring_ctx *ctx = data; - struct mm_struct *cur_mm = NULL; const struct cred *old_cred; mm_segment_t old_fs; DEFINE_WAIT(wait); @@ -5946,11 +5953,7 @@ static int io_sq_thread(void *data) * adding ourselves to the waitqueue, as the unuse/drop * may sleep. */ - if (cur_mm) { - unuse_mm(cur_mm); - mmput(cur_mm); - cur_mm = NULL; - } + io_sq_thread_drop_mm(ctx); /* * We're polling. If we're within the defined idle @@ -6013,7 +6016,7 @@ static int io_sq_thread(void *data) } mutex_lock(&ctx->uring_lock); - ret = io_submit_sqes(ctx, to_submit, NULL, -1, &cur_mm, true); + ret = io_submit_sqes(ctx, to_submit, NULL, -1, true); mutex_unlock(&ctx->uring_lock); timeout = jiffies + ctx->sq_thread_idle; } @@ -6022,10 +6025,7 @@ static int io_sq_thread(void *data) task_work_run(); set_fs(old_fs); - if (cur_mm) { - unuse_mm(cur_mm); - mmput(cur_mm); - } + io_sq_thread_drop_mm(ctx); revert_creds(old_cred); kthread_parkme(); @@ -7507,13 +7507,8 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit, wake_up(&ctx->sqo_wait); submitted = to_submit; } else if (to_submit) { - struct mm_struct *cur_mm; - mutex_lock(&ctx->uring_lock); - /* already have mm, so io_submit_sqes() won't try to grab it */ - cur_mm = ctx->sqo_mm; - submitted = io_submit_sqes(ctx, to_submit, f.file, fd, - &cur_mm, false); + submitted = io_submit_sqes(ctx, to_submit, f.file, fd, false); mutex_unlock(&ctx->uring_lock); if (submitted != to_submit)