diff --git a/fs/io_uring.c b/fs/io_uring.c
index 5762750c666c6ffcd0440f5692128cafb5ebefe8..d30cbf0f7b1c2e0f846adc4598eb06e02b6f9048 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -274,7 +274,7 @@ struct io_sq_data {
 
 	unsigned long		state;
 	struct completion	startup;
-	struct completion	completion;
+	struct completion	parked;
 	struct completion	exited;
 };
 
@@ -6656,7 +6656,7 @@ static void io_sq_thread_parkme(struct io_sq_data *sqd)
 		 * wait_task_inactive().
 		 */
 		preempt_disable();
-		complete(&sqd->completion);
+		complete(&sqd->parked);
 		schedule_preempt_disabled();
 		preempt_enable();
 	}
@@ -6751,14 +6751,18 @@ static int io_sq_thread(void *data)
 
 	io_run_task_work();
 
-	if (io_sq_thread_should_park(sqd))
-		io_sq_thread_parkme(sqd);
-
 	/*
-	 * Clear thread under lock so that concurrent parks work correctly
+	 * Ensure that we park properly if racing with someone trying to park
+	 * while we're exiting. If we fail to grab the lock, check park and
+	 * park if necessary. The ordering with the park bit and the lock
+	 * ensures that we catch this reliably.
 	 */
-	complete(&sqd->completion);
-	mutex_lock(&sqd->lock);
+	if (!mutex_trylock(&sqd->lock)) {
+		if (io_sq_thread_should_park(sqd))
+			io_sq_thread_parkme(sqd);
+		mutex_lock(&sqd->lock);
+	}
+
 	sqd->thread = NULL;
 	list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
 		ctx->sqo_exec = 1;
@@ -7067,29 +7071,25 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 static void io_sq_thread_unpark(struct io_sq_data *sqd)
 	__releases(&sqd->lock)
 {
-	if (!sqd->thread)
-		return;
 	if (sqd->thread == current)
 		return;
 	clear_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-	wake_up_state(sqd->thread, TASK_PARKED);
+	if (sqd->thread)
+		wake_up_state(sqd->thread, TASK_PARKED);
 	mutex_unlock(&sqd->lock);
 }
 
-static bool io_sq_thread_park(struct io_sq_data *sqd)
+static void io_sq_thread_park(struct io_sq_data *sqd)
 	__acquires(&sqd->lock)
 {
 	if (sqd->thread == current)
-		return true;
+		return;
+	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
 	mutex_lock(&sqd->lock);
-	if (!sqd->thread) {
-		mutex_unlock(&sqd->lock);
-		return false;
+	if (sqd->thread) {
+		wake_up_process(sqd->thread);
+		wait_for_completion(&sqd->parked);
 	}
-	set_bit(IO_SQ_THREAD_SHOULD_PARK, &sqd->state);
-	wake_up_process(sqd->thread);
-	wait_for_completion(&sqd->completion);
-	return true;
 }
 
 static void io_sq_thread_stop(struct io_sq_data *sqd)
@@ -7185,7 +7185,7 @@ static struct io_sq_data *io_get_sq_data(struct io_uring_params *p)
 	mutex_init(&sqd->lock);
 	init_waitqueue_head(&sqd->wait);
 	init_completion(&sqd->startup);
-	init_completion(&sqd->completion);
+	init_completion(&sqd->parked);
 	init_completion(&sqd->exited);
 	return sqd;
 }
@@ -7829,7 +7829,7 @@ static int io_sq_thread_fork(struct io_sq_data *sqd, struct io_ring_ctx *ctx)
 	int ret;
 
 	clear_bit(IO_SQ_THREAD_SHOULD_STOP, &sqd->state);
-	reinit_completion(&sqd->completion);
+	reinit_completion(&sqd->parked);
 	ctx->sqo_exec = 0;
 	sqd->task_pid = current->pid;
 	tsk = create_io_thread(io_sq_thread, sqd, NUMA_NO_NODE);
@@ -8712,7 +8712,6 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
 					  struct files_struct *files)
 {
 	struct task_struct *task = current;
-	bool did_park = false;
 
 	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
 		/* never started, nothing to cancel */
@@ -8720,11 +8719,10 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
 			io_sq_offload_start(ctx);
 			return;
 		}
-		did_park = io_sq_thread_park(ctx->sq_data);
-		if (did_park) {
-			task = ctx->sq_data->thread;
+		io_sq_thread_park(ctx->sq_data);
+		task = ctx->sq_data->thread;
+		if (task)
 			atomic_inc(&task->io_uring->in_idle);
-		}
 	}
 
 	io_cancel_defer_files(ctx, task, files);
@@ -8733,10 +8731,10 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
 	if (!files)
 		io_uring_try_cancel_requests(ctx, task, NULL);
 
-	if (did_park) {
+	if (task)
 		atomic_dec(&task->io_uring->in_idle);
+	if (ctx->sq_data)
 		io_sq_thread_unpark(ctx->sq_data);
-	}
 }
 
 /*
@@ -8836,15 +8834,12 @@ static void io_uring_cancel_sqpoll(struct io_ring_ctx *ctx)
 
 	if (!sqd)
 		return;
-	if (!io_sq_thread_park(sqd))
-		return;
-	tctx = ctx->sq_data->thread->io_uring;
-	/* can happen on fork/alloc failure, just ignore that state */
-	if (!tctx) {
+	io_sq_thread_park(sqd);
+	if (!sqd->thread || !sqd->thread->io_uring) {
 		io_sq_thread_unpark(sqd);
 		return;
 	}
-
+	tctx = ctx->sq_data->thread->io_uring;
 	atomic_inc(&tctx->in_idle);
 	do {
 		/* read completions before cancelations */