diff --git a/fs/io_uring.c b/fs/io_uring.c
index cdec595104332f7fceec0965a3663a85c10fdd23..70286b393c0e861e7ccc4d9cbe022aab09099e9f 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -6665,6 +6665,7 @@ static int io_sq_thread(void *data)
 			up_read(&sqd->rw_lock);
 			cond_resched();
 			down_read(&sqd->rw_lock);
+			io_run_task_work();
 			timeout = jiffies + sqd->sq_thread_idle;
 			continue;
 		}
@@ -6720,18 +6721,22 @@ static int io_sq_thread(void *data)
 		finish_wait(&sqd->wait, &wait);
 		timeout = jiffies + sqd->sq_thread_idle;
 	}
-
-	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
-		io_uring_cancel_sqpoll(ctx);
 	up_read(&sqd->rw_lock);
-
+	down_write(&sqd->rw_lock);
+	/*
+	 * someone may have parked and added a cancellation task_work, run
+	 * it first because we don't want it in io_uring_cancel_sqpoll()
+	 */
 	io_run_task_work();
 
-	down_write(&sqd->rw_lock);
+	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
+		io_uring_cancel_sqpoll(ctx);
 	sqd->thread = NULL;
 	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
 		io_ring_set_wakeup_flag(ctx);
 	up_write(&sqd->rw_lock);
+
+	io_run_task_work();
 	complete(&sqd->exited);
 	do_exit(0);
 }
@@ -7033,8 +7038,8 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 static void io_sq_thread_unpark(struct io_sq_data *sqd)
 	__releases(&sqd->rw_lock)
 {
-	if (sqd->thread == current)
-		return;
+	WARN_ON_ONCE(sqd->thread == current);
+
 	clear_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
 	up_write(&sqd->rw_lock);
 }
@@ -7042,8 +7047,8 @@ static void io_sq_thread_unpark(struct io_sq_data *sqd)
 static void io_sq_thread_park(struct io_sq_data *sqd)
 	__acquires(&sqd->rw_lock)
 {
-	if (sqd->thread == current)
-		return;
+	WARN_ON_ONCE(sqd->thread == current);
+
 	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
 	down_write(&sqd->rw_lock);
 	/* set again for consistency, in case concurrent parks are happening */
@@ -7054,8 +7059,8 @@ static void io_sq_thread_park(struct io_sq_data *sqd)
 
 static void io_sq_thread_stop(struct io_sq_data *sqd)
 {
-	if (test_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state))
-		return;
+	WARN_ON_ONCE(sqd->thread == current);
+
 	down_write(&sqd->rw_lock);
 	set_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
 	if (sqd->thread)
@@ -7078,7 +7083,7 @@ static void io_sq_thread_finish(struct io_ring_ctx *ctx)
 
 	if (sqd) {
 		io_sq_thread_park(sqd);
-		list_del(&ctx->sqd_list);
+		list_del_init(&ctx->sqd_list);
 		io_sqd_update_thread_idle(sqd);
 		io_sq_thread_unpark(sqd);
 
@@ -7760,7 +7765,6 @@ static int io_uring_alloc_task_context(struct task_struct *task,
 	init_waitqueue_head(&tctx->wait);
 	tctx->last = NULL;
 	atomic_set(&tctx->in_idle, 0);
-	tctx->sqpoll = false;
 	task->io_uring = tctx;
 	spin_lock_init(&tctx->task_lock);
 	INIT_WQ_LIST(&tctx->task_list);
@@ -8719,43 +8723,12 @@ static void io_uring_cancel_files(struct io_ring_ctx *ctx,
 
 		io_uring_try_cancel_requests(ctx, task, files);
 
-		if (ctx->sq_data)
-			io_sq_thread_unpark(ctx->sq_data);
 		prepare_to_wait(&task->io_uring->wait, &wait,
 				TASK_UNINTERRUPTIBLE);
 		if (inflight == io_uring_count_inflight(ctx, task, files))
 			schedule();
 		finish_wait(&task->io_uring->wait, &wait);
-		if (ctx->sq_data)
-			io_sq_thread_park(ctx->sq_data);
-	}
-}
-
-/*
- * We need to iteratively cancel requests, in case a request has dependent
- * hard links. These persist even for failure of cancelations, hence keep
- * looping until none are found.
- */
-static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
-					  struct files_struct *files)
-{
-	struct task_struct *task = current;
-
-	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
-		io_sq_thread_park(ctx->sq_data);
-		task = ctx->sq_data->thread;
-		if (task)
-			atomic_inc(&task->io_uring->in_idle);
 	}
-
-	io_uring_cancel_files(ctx, task, files);
-	if (!files)
-		io_uring_try_cancel_requests(ctx, task, NULL);
-
-	if (task)
-		atomic_dec(&task->io_uring->in_idle);
-	if (ctx->sq_data)
-		io_sq_thread_unpark(ctx->sq_data);
 }
 
 /*
@@ -8796,15 +8769,6 @@ static int io_uring_add_task_file(struct io_ring_ctx *ctx)
 		}
 		tctx->last = ctx;
 	}
-
-	/*
-	 * This is race safe in that the task itself is doing this, hence it
-	 * cannot be going through the exit/cancel paths at the same time.
-	 * This cannot be modified while exit/cancel is running.
-	 */
-	if (!tctx->sqpoll && (ctx->flags & IORING_SETUP_SQPOLL))
-		tctx->sqpoll = true;
-
 	return 0;
 }
 
@@ -8847,6 +8811,44 @@ static void io_uring_clean_tctx(struct io_uring_task *tctx)
 	}
 }
 
+static s64 tctx_inflight(struct io_uring_task *tctx)
+{
+	return percpu_counter_sum(&tctx->inflight);
+}
+
+static void io_sqpoll_cancel_cb(struct callback_head *cb)
+{
+	struct io_tctx_exit *work = container_of(cb, struct io_tctx_exit, task_work);
+	struct io_ring_ctx *ctx = work->ctx;
+	struct io_sq_data *sqd = ctx->sq_data;
+
+	if (sqd->thread)
+		io_uring_cancel_sqpoll(ctx);
+	complete(&work->completion);
+}
+
+static void io_sqpoll_cancel_sync(struct io_ring_ctx *ctx)
+{
+	struct io_sq_data *sqd = ctx->sq_data;
+	struct io_tctx_exit work = { .ctx = ctx, };
+	struct task_struct *task;
+
+	io_sq_thread_park(sqd);
+	list_del_init(&ctx->sqd_list);
+	io_sqd_update_thread_idle(sqd);
+	task = sqd->thread;
+	if (task) {
+		init_completion(&work.completion);
+		init_task_work(&work.task_work, io_sqpoll_cancel_cb);
+		WARN_ON_ONCE(task_work_add(task, &work.task_work, TWA_SIGNAL));
+		wake_up_process(task);
+	}
+	io_sq_thread_unpark(sqd);
+
+	if (task)
+		wait_for_completion(&work.completion);
+}
+
 void __io_uring_files_cancel(struct files_struct *files)
 {
 	struct io_uring_task *tctx = current->io_uring;
@@ -8855,41 +8857,40 @@ void __io_uring_files_cancel(struct files_struct *files)
 
 	/* make sure overflow events are dropped */
 	atomic_inc(&tctx->in_idle);
-	xa_for_each(&tctx->xa, index, node)
-		io_uring_cancel_task_requests(node->ctx, files);
+	xa_for_each(&tctx->xa, index, node) {
+		struct io_ring_ctx *ctx = node->ctx;
+
+		if (ctx->sq_data) {
+			io_sqpoll_cancel_sync(ctx);
+			continue;
+		}
+		io_uring_cancel_files(ctx, current, files);
+		if (!files)
+			io_uring_try_cancel_requests(ctx, current, NULL);
+	}
 	atomic_dec(&tctx->in_idle);
 
 	if (files)
 		io_uring_clean_tctx(tctx);
 }
 
-static s64 tctx_inflight(struct io_uring_task *tctx)
-{
-	return percpu_counter_sum(&tctx->inflight);
-}
-
+/* should only be called by SQPOLL task */
 static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx)
 {
 	struct io_sq_data *sqd = ctx->sq_data;
-	struct io_uring_task *tctx;
+	struct io_uring_task *tctx = current->io_uring;
 	s64 inflight;
 	DEFINE_WAIT(wait);
 
-	if (!sqd)
-		return;
-	io_sq_thread_park(sqd);
-	if (!sqd->thread || !sqd->thread->io_uring) {
-		io_sq_thread_unpark(sqd);
-		return;
-	}
-	tctx = ctx->sq_data->thread->io_uring;
+	WARN_ON_ONCE(!sqd || ctx->sq_data->thread != current);
+
 	atomic_inc(&tctx->in_idle);
 	do {
 		/* read completions before cancelations */
 		inflight = tctx_inflight(tctx);
 		if (!inflight)
 			break;
-		io_uring_cancel_task_requests(ctx, NULL);
+		io_uring_try_cancel_requests(ctx, current, NULL);
 
 		prepare_to_wait(&tctx->wait, &wait, TASK_UNINTERRUPTIBLE);
 		/*
@@ -8902,7 +8903,6 @@ static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx)
 		finish_wait(&tctx->wait, &wait);
 	} while (1);
 	atomic_dec(&tctx->in_idle);
-	io_sq_thread_unpark(sqd);
 }
 
 /*
@@ -8917,15 +8917,6 @@ void __io_uring_task_cancel(void)
 
 	/* make sure overflow events are dropped */
 	atomic_inc(&tctx->in_idle);
-
-	if (tctx->sqpoll) {
-		struct io_tctx_node *node;
-		unsigned long index;
-
-		xa_for_each(&tctx->xa, index, node)
-			io_uring_cancel_sqpoll(node->ctx);
-	}
-
 	do {
 		/* read completions before cancelations */
 		inflight = tctx_inflight(tctx);