diff --git a/fs/io-wq.c b/fs/io-wq.c
index 02894df7656dd1df1b4a13141c3089bf4ce03518..b53c055bea6a36dba4d3d0d7d957798d5aceec29 100644
--- a/fs/io-wq.c
+++ b/fs/io-wq.c
@@ -482,6 +482,10 @@ static void io_impersonate_work(struct io_worker *worker,
 		current->files = work->identity->files;
 		current->nsproxy = work->identity->nsproxy;
 		task_unlock(current);
+		if (!work->identity->files) {
+			/* failed grabbing files, ensure work gets cancelled */
+			work->flags |= IO_WQ_WORK_CANCEL;
+		}
 	}
 	if ((work->flags & IO_WQ_WORK_FS) && current->fs != work->identity->fs)
 		current->fs = work->identity->fs;
diff --git a/fs/io_uring.c b/fs/io_uring.c
index a7429c977eb3c6c35543f96902556f8c58b36ffd..8018c7076b25ce66a8215500ffa3531d43f4204b 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -995,20 +995,33 @@ static void io_sq_thread_drop_mm(void)
 	if (mm) {
 		kthread_unuse_mm(mm);
 		mmput(mm);
+		current->mm = NULL;
 	}
 }
 
 static int __io_sq_thread_acquire_mm(struct io_ring_ctx *ctx)
 {
-	if (!current->mm) {
-		if (unlikely(!(ctx->flags & IORING_SETUP_SQPOLL) ||
-			     !ctx->sqo_task->mm ||
-			     !mmget_not_zero(ctx->sqo_task->mm)))
-			return -EFAULT;
-		kthread_use_mm(ctx->sqo_task->mm);
+	struct mm_struct *mm;
+
+	if (current->mm)
+		return 0;
+
+	/* Should never happen */
+	if (unlikely(!(ctx->flags & IORING_SETUP_SQPOLL)))
+		return -EFAULT;
+
+	task_lock(ctx->sqo_task);
+	mm = ctx->sqo_task->mm;
+	if (unlikely(!mm || !mmget_not_zero(mm)))
+		mm = NULL;
+	task_unlock(ctx->sqo_task);
+
+	if (mm) {
+		kthread_use_mm(mm);
+		return 0;
 	}
 
-	return 0;
+	return -EFAULT;
 }
 
 static int io_sq_thread_acquire_mm(struct io_ring_ctx *ctx,
@@ -1274,9 +1287,12 @@ static bool io_identity_cow(struct io_kiocb *req)
 	/* add one for this request */
 	refcount_inc(&id->count);
 
-	/* drop old identity, assign new one. one ref for req, one for tctx */
-	if (req->work.identity != tctx->identity &&
-	    refcount_sub_and_test(2, &req->work.identity->count))
+	/* drop tctx and req identity references, if needed */
+	if (tctx->identity != &tctx->__identity &&
+	    refcount_dec_and_test(&tctx->identity->count))
+		kfree(tctx->identity);
+	if (req->work.identity != &tctx->__identity &&
+	    refcount_dec_and_test(&req->work.identity->count))
 		kfree(req->work.identity);
 
 	req->work.identity = id;
@@ -1577,14 +1593,29 @@ static void io_cqring_mark_overflow(struct io_ring_ctx *ctx)
 	}
 }
 
-static inline bool io_match_files(struct io_kiocb *req,
-				       struct files_struct *files)
+static inline bool __io_match_files(struct io_kiocb *req,
+				    struct files_struct *files)
 {
+	return ((req->flags & REQ_F_WORK_INITIALIZED) &&
+	        (req->work.flags & IO_WQ_WORK_FILES)) &&
+		req->work.identity->files == files;
+}
+
+static bool io_match_files(struct io_kiocb *req,
+			   struct files_struct *files)
+{
+	struct io_kiocb *link;
+
 	if (!files)
 		return true;
-	if ((req->flags & REQ_F_WORK_INITIALIZED) &&
-	    (req->work.flags & IO_WQ_WORK_FILES))
-		return req->work.identity->files == files;
+	if (__io_match_files(req, files))
+		return true;
+	if (req->flags & REQ_F_LINK_HEAD) {
+		list_for_each_entry(link, &req->link_list, link_list) {
+			if (__io_match_files(link, files))
+				return true;
+		}
+	}
 	return false;
 }
 
@@ -1668,7 +1699,8 @@ static void __io_cqring_fill_event(struct io_kiocb *req, long res, long cflags)
 		WRITE_ONCE(cqe->user_data, req->user_data);
 		WRITE_ONCE(cqe->res, res);
 		WRITE_ONCE(cqe->flags, cflags);
-	} else if (ctx->cq_overflow_flushed || req->task->io_uring->in_idle) {
+	} else if (ctx->cq_overflow_flushed ||
+		   atomic_read(&req->task->io_uring->in_idle)) {
 		/*
 		 * If we're in ring overflow flush mode, or in task cancel mode,
 		 * then we cannot store the request for later flushing, we need
@@ -1838,7 +1870,7 @@ static void __io_free_req(struct io_kiocb *req)
 	io_dismantle_req(req);
 
 	percpu_counter_dec(&tctx->inflight);
-	if (tctx->in_idle)
+	if (atomic_read(&tctx->in_idle))
 		wake_up(&tctx->wait);
 	put_task_struct(req->task);
 
@@ -7695,7 +7727,8 @@ static int io_uring_alloc_task_context(struct task_struct *task)
 	xa_init(&tctx->xa);
 	init_waitqueue_head(&tctx->wait);
 	tctx->last = NULL;
-	tctx->in_idle = 0;
+	atomic_set(&tctx->in_idle, 0);
+	tctx->sqpoll = false;
 	io_init_identity(&tctx->__identity);
 	tctx->identity = &tctx->__identity;
 	task->io_uring = tctx;
@@ -8388,22 +8421,6 @@ static bool io_match_link(struct io_kiocb *preq, struct io_kiocb *req)
 	return false;
 }
 
-static bool io_match_link_files(struct io_kiocb *req,
-				struct files_struct *files)
-{
-	struct io_kiocb *link;
-
-	if (io_match_files(req, files))
-		return true;
-	if (req->flags & REQ_F_LINK_HEAD) {
-		list_for_each_entry(link, &req->link_list, link_list) {
-			if (io_match_files(link, files))
-				return true;
-		}
-	}
-	return false;
-}
-
 /*
  * We're looking to cancel 'req' because it's holding on to our files, but
  * 'req' could be a link to another request. See if it is, and cancel that
@@ -8453,7 +8470,21 @@ static bool io_timeout_remove_link(struct io_ring_ctx *ctx,
 
 static bool io_cancel_link_cb(struct io_wq_work *work, void *data)
 {
-	return io_match_link(container_of(work, struct io_kiocb, work), data);
+	struct io_kiocb *req = container_of(work, struct io_kiocb, work);
+	bool ret;
+
+	if (req->flags & REQ_F_LINK_TIMEOUT) {
+		unsigned long flags;
+		struct io_ring_ctx *ctx = req->ctx;
+
+		/* protect against races with linked timeouts */
+		spin_lock_irqsave(&ctx->completion_lock, flags);
+		ret = io_match_link(req, data);
+		spin_unlock_irqrestore(&ctx->completion_lock, flags);
+	} else {
+		ret = io_match_link(req, data);
+	}
+	return ret;
 }
 
 static void io_attempt_cancel(struct io_ring_ctx *ctx, struct io_kiocb *req)
@@ -8479,6 +8510,7 @@ static void io_attempt_cancel(struct io_ring_ctx *ctx, struct io_kiocb *req)
 }
 
 static void io_cancel_defer_files(struct io_ring_ctx *ctx,
+				  struct task_struct *task,
 				  struct files_struct *files)
 {
 	struct io_defer_entry *de = NULL;
@@ -8486,7 +8518,8 @@ static void io_cancel_defer_files(struct io_ring_ctx *ctx,
 
 	spin_lock_irq(&ctx->completion_lock);
 	list_for_each_entry_reverse(de, &ctx->defer_list, list) {
-		if (io_match_link_files(de->req, files)) {
+		if (io_task_match(de->req, task) &&
+		    io_match_files(de->req, files)) {
 			list_cut_position(&list, &ctx->defer_list, &de->list);
 			break;
 		}
@@ -8512,7 +8545,6 @@ static bool io_uring_cancel_files(struct io_ring_ctx *ctx,
 	if (list_empty_careful(&ctx->inflight_list))
 		return false;
 
-	io_cancel_defer_files(ctx, files);
 	/* cancel all at once, should be faster than doing it one by one*/
 	io_wq_cancel_cb(ctx->io_wq, io_wq_files_match, files, true);
 
@@ -8598,8 +8630,16 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
 {
 	struct task_struct *task = current;
 
-	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data)
+	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
 		task = ctx->sq_data->thread;
+		atomic_inc(&task->io_uring->in_idle);
+		io_sq_thread_park(ctx->sq_data);
+	}
+
+	if (files)
+		io_cancel_defer_files(ctx, NULL, files);
+	else
+		io_cancel_defer_files(ctx, task, NULL);
 
 	io_cqring_overflow_flush(ctx, true, task, files);
 
@@ -8607,12 +8647,23 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
 		io_run_task_work();
 		cond_resched();
 	}
+
+	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
+		atomic_dec(&task->io_uring->in_idle);
+		/*
+		 * If the files that are going away are the ones in the thread
+		 * identity, clear them out.
+		 */
+		if (task->io_uring->identity->files == files)
+			task->io_uring->identity->files = NULL;
+		io_sq_thread_unpark(ctx->sq_data);
+	}
 }
 
 /*
  * Note that this task has used io_uring. We use it for cancelation purposes.
  */
-static int io_uring_add_task_file(struct file *file)
+static int io_uring_add_task_file(struct io_ring_ctx *ctx, struct file *file)
 {
 	struct io_uring_task *tctx = current->io_uring;
 
@@ -8634,6 +8685,14 @@ static int io_uring_add_task_file(struct file *file)
 		tctx->last = file;
 	}
 
+	/*
+	 * This is race safe in that the task itself is doing this, hence it
+	 * cannot be going through the exit/cancel paths at the same time.
+	 * This cannot be modified while exit/cancel is running.
+	 */
+	if (!tctx->sqpoll && (ctx->flags & IORING_SETUP_SQPOLL))
+		tctx->sqpoll = true;
+
 	return 0;
 }
 
@@ -8675,7 +8734,7 @@ void __io_uring_files_cancel(struct files_struct *files)
 	unsigned long index;
 
 	/* make sure overflow events are dropped */
-	tctx->in_idle = true;
+	atomic_inc(&tctx->in_idle);
 
 	xa_for_each(&tctx->xa, index, file) {
 		struct io_ring_ctx *ctx = file->private_data;
@@ -8684,6 +8743,35 @@ void __io_uring_files_cancel(struct files_struct *files)
 		if (files)
 			io_uring_del_task_file(file);
 	}
+
+	atomic_dec(&tctx->in_idle);
+}
+
+static s64 tctx_inflight(struct io_uring_task *tctx)
+{
+	unsigned long index;
+	struct file *file;
+	s64 inflight;
+
+	inflight = percpu_counter_sum(&tctx->inflight);
+	if (!tctx->sqpoll)
+		return inflight;
+
+	/*
+	 * If we have SQPOLL rings, then we need to iterate and find them, and
+	 * add the pending count for those.
+	 */
+	xa_for_each(&tctx->xa, index, file) {
+		struct io_ring_ctx *ctx = file->private_data;
+
+		if (ctx->flags & IORING_SETUP_SQPOLL) {
+			struct io_uring_task *__tctx = ctx->sqo_task->io_uring;
+
+			inflight += percpu_counter_sum(&__tctx->inflight);
+		}
+	}
+
+	return inflight;
 }
 
 /*
@@ -8697,11 +8785,11 @@ void __io_uring_task_cancel(void)
 	s64 inflight;
 
 	/* make sure overflow events are dropped */
-	tctx->in_idle = true;
+	atomic_inc(&tctx->in_idle);
 
 	do {
 		/* read completions before cancelations */
-		inflight = percpu_counter_sum(&tctx->inflight);
+		inflight = tctx_inflight(tctx);
 		if (!inflight)
 			break;
 		__io_uring_files_cancel(NULL);
@@ -8712,13 +8800,13 @@ void __io_uring_task_cancel(void)
 		 * If we've seen completions, retry. This avoids a race where
 		 * a completion comes in before we did prepare_to_wait().
 		 */
-		if (inflight != percpu_counter_sum(&tctx->inflight))
+		if (inflight != tctx_inflight(tctx))
 			continue;
 		schedule();
 	} while (1);
 
 	finish_wait(&tctx->wait, &wait);
-	tctx->in_idle = false;
+	atomic_dec(&tctx->in_idle);
 }
 
 static int io_uring_flush(struct file *file, void *data)
@@ -8863,7 +8951,7 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
 			io_sqpoll_wait_sq(ctx);
 		submitted = to_submit;
 	} else if (to_submit) {
-		ret = io_uring_add_task_file(f.file);
+		ret = io_uring_add_task_file(ctx, f.file);
 		if (unlikely(ret))
 			goto out;
 		mutex_lock(&ctx->uring_lock);
@@ -8900,7 +8988,8 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
 #ifdef CONFIG_PROC_FS
 static int io_uring_show_cred(int id, void *p, void *data)
 {
-	const struct cred *cred = p;
+	struct io_identity *iod = p;
+	const struct cred *cred = iod->creds;
 	struct seq_file *m = data;
 	struct user_namespace *uns = seq_user_ns(m);
 	struct group_info *gi;
@@ -9092,7 +9181,7 @@ static int io_uring_get_fd(struct io_ring_ctx *ctx)
 #if defined(CONFIG_UNIX)
 	ctx->ring_sock->file = file;
 #endif
-	if (unlikely(io_uring_add_task_file(file))) {
+	if (unlikely(io_uring_add_task_file(ctx, file))) {
 		file = ERR_PTR(-ENOMEM);
 		goto err_fd;
 	}
diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h
index 868364cea3b740429a7ca4e84337b7ad79846956..35b2d845704d9175466d979138d413d85423664b 100644
--- a/include/linux/io_uring.h
+++ b/include/linux/io_uring.h
@@ -30,7 +30,8 @@ struct io_uring_task {
 	struct percpu_counter	inflight;
 	struct io_identity	__identity;
 	struct io_identity	*identity;
-	bool			in_idle;
+	atomic_t		in_idle;
+	bool			sqpoll;
 };
 
 #if defined(CONFIG_IO_URING)