diff --git a/fs/io-wq.c b/fs/io-wq.c
index 36149e3cf3ed8dca1a41f78c873ce89ee57da7c4..5d089acdcfbd8e35885000828e28f18e7be7fb02 100644
--- a/fs/io-wq.c
+++ b/fs/io-wq.c
@@ -57,7 +57,8 @@ struct io_worker {
 
 	struct rcu_head rcu;
 	struct mm_struct *mm;
-	const struct cred *creds;
+	const struct cred *cur_creds;
+	const struct cred *saved_creds;
 	struct files_struct *restore_files;
 };
 
@@ -110,8 +111,6 @@ struct io_wq {
 
 	struct task_struct *manager;
 	struct user_struct *user;
-	const struct cred *creds;
-	struct mm_struct *mm;
 	refcount_t refs;
 	struct completion done;
 
@@ -138,9 +137,9 @@ static bool __io_worker_unuse(struct io_wqe *wqe, struct io_worker *worker)
 {
 	bool dropped_lock = false;
 
-	if (worker->creds) {
-		revert_creds(worker->creds);
-		worker->creds = NULL;
+	if (worker->saved_creds) {
+		revert_creds(worker->saved_creds);
+		worker->cur_creds = worker->saved_creds = NULL;
 	}
 
 	if (current->files != worker->restore_files) {
@@ -399,6 +398,43 @@ static struct io_wq_work *io_get_next_work(struct io_wqe *wqe, unsigned *hash)
 	return NULL;
 }
 
+static void io_wq_switch_mm(struct io_worker *worker, struct io_wq_work *work)
+{
+	if (worker->mm) {
+		unuse_mm(worker->mm);
+		mmput(worker->mm);
+		worker->mm = NULL;
+	}
+	if (!work->mm) {
+		set_fs(KERNEL_DS);
+		return;
+	}
+	if (mmget_not_zero(work->mm)) {
+		use_mm(work->mm);
+		if (!worker->mm)
+			set_fs(USER_DS);
+		worker->mm = work->mm;
+		/* hang on to this mm */
+		work->mm = NULL;
+		return;
+	}
+
+	/* failed grabbing mm, ensure work gets cancelled */
+	work->flags |= IO_WQ_WORK_CANCEL;
+}
+
+static void io_wq_switch_creds(struct io_worker *worker,
+			       struct io_wq_work *work)
+{
+	const struct cred *old_creds = override_creds(work->creds);
+
+	worker->cur_creds = work->creds;
+	if (worker->saved_creds)
+		put_cred(old_creds); /* creds set by previous switch */
+	else
+		worker->saved_creds = old_creds;
+}
+
 static void io_worker_handle_work(struct io_worker *worker)
 	__releases(wqe->lock)
 {
@@ -447,18 +483,10 @@ static void io_worker_handle_work(struct io_worker *worker)
 			current->files = work->files;
 			task_unlock(current);
 		}
-		if ((work->flags & IO_WQ_WORK_NEEDS_USER) && !worker->mm &&
-		    wq->mm) {
-			if (mmget_not_zero(wq->mm)) {
-				use_mm(wq->mm);
-				set_fs(USER_DS);
-				worker->mm = wq->mm;
-			} else {
-				work->flags |= IO_WQ_WORK_CANCEL;
-			}
-		}
-		if (!worker->creds)
-			worker->creds = override_creds(wq->creds);
+		if (work->mm != worker->mm)
+			io_wq_switch_mm(worker, work);
+		if (worker->cur_creds != work->creds)
+			io_wq_switch_creds(worker, work);
 		/*
 		 * OK to set IO_WQ_WORK_CANCEL even for uncancellable work,
 		 * the worker function will do the right thing.
@@ -1038,7 +1066,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
 
 	/* caller must already hold a reference to this */
 	wq->user = data->user;
-	wq->creds = data->creds;
 
 	for_each_node(node) {
 		struct io_wqe *wqe;
@@ -1065,9 +1092,6 @@ struct io_wq *io_wq_create(unsigned bounded, struct io_wq_data *data)
 
 	init_completion(&wq->done);
 
-	/* caller must have already done mmgrab() on this mm */
-	wq->mm = data->mm;
-
 	wq->manager = kthread_create(io_wq_manager, wq, "io_wq_manager");
 	if (!IS_ERR(wq->manager)) {
 		wake_up_process(wq->manager);
diff --git a/fs/io-wq.h b/fs/io-wq.h
index 47cc2323e78ffb808773271b9fd093b5577a7e04..bbc5290421901e1220483f0d6cc51b969cf94bb3 100644
--- a/fs/io-wq.h
+++ b/fs/io-wq.h
@@ -7,7 +7,6 @@ enum {
 	IO_WQ_WORK_CANCEL	= 1,
 	IO_WQ_WORK_HAS_MM	= 2,
 	IO_WQ_WORK_HASHED	= 4,
-	IO_WQ_WORK_NEEDS_USER	= 8,
 	IO_WQ_WORK_NEEDS_FILES	= 16,
 	IO_WQ_WORK_UNBOUND	= 32,
 	IO_WQ_WORK_INTERNAL	= 64,
@@ -73,6 +72,8 @@ struct io_wq_work {
 	};
 	void (*func)(struct io_wq_work **);
 	struct files_struct *files;
+	struct mm_struct *mm;
+	const struct cred *creds;
 	unsigned flags;
 };
 
@@ -82,15 +83,15 @@ struct io_wq_work {
 		(work)->func = _func;			\
 		(work)->flags = 0;			\
 		(work)->files = NULL;			\
+		(work)->mm = NULL;			\
+		(work)->creds = NULL;			\
 	} while (0)					\
 
 typedef void (get_work_fn)(struct io_wq_work *);
 typedef void (put_work_fn)(struct io_wq_work *);
 
 struct io_wq_data {
-	struct mm_struct *mm;
 	struct user_struct *user;
-	const struct cred *creds;
 
 	get_work_fn *get_work;
 	put_work_fn *put_work;
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 0ffa8537c36a29b9324fc843f9362992aa313c06..15c1efbeec550b46b3d64d9f87b6c63401f168b0 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -875,6 +875,29 @@ static void __io_commit_cqring(struct io_ring_ctx *ctx)
 	}
 }
 
+static inline void io_req_work_grab_env(struct io_kiocb *req,
+					const struct io_op_def *def)
+{
+	if (!req->work.mm && def->needs_mm) {
+		mmgrab(current->mm);
+		req->work.mm = current->mm;
+	}
+	if (!req->work.creds)
+		req->work.creds = get_current_cred();
+}
+
+static inline void io_req_work_drop_env(struct io_kiocb *req)
+{
+	if (req->work.mm) {
+		mmdrop(req->work.mm);
+		req->work.mm = NULL;
+	}
+	if (req->work.creds) {
+		put_cred(req->work.creds);
+		req->work.creds = NULL;
+	}
+}
+
 static inline bool io_prep_async_work(struct io_kiocb *req,
 				      struct io_kiocb **link)
 {
@@ -888,8 +911,8 @@ static inline bool io_prep_async_work(struct io_kiocb *req,
 		if (def->unbound_nonreg_file)
 			req->work.flags |= IO_WQ_WORK_UNBOUND;
 	}
-	if (def->needs_mm)
-		req->work.flags |= IO_WQ_WORK_NEEDS_USER;
+
+	io_req_work_grab_env(req, def);
 
 	*link = io_prep_linked_timeout(req);
 	return do_hashed;
@@ -1180,6 +1203,8 @@ static void __io_req_aux_free(struct io_kiocb *req)
 		else
 			fput(req->file);
 	}
+
+	io_req_work_drop_env(req);
 }
 
 static void __io_free_req(struct io_kiocb *req)
@@ -3960,6 +3985,8 @@ static int io_req_defer_prep(struct io_kiocb *req,
 {
 	ssize_t ret = 0;
 
+	io_req_work_grab_env(req, &io_op_defs[req->opcode]);
+
 	switch (req->opcode) {
 	case IORING_OP_NOP:
 		break;
@@ -5712,9 +5739,7 @@ static int io_sq_offload_start(struct io_ring_ctx *ctx,
 		goto err;
 	}
 
-	data.mm = ctx->sqo_mm;
 	data.user = ctx->user;
-	data.creds = ctx->creds;
 	data.get_work = io_get_work;
 	data.put_work = io_put_work;
 
@@ -6520,7 +6545,8 @@ static int io_uring_create(unsigned entries, struct io_uring_params *p)
 		goto err;
 
 	p->features = IORING_FEAT_SINGLE_MMAP | IORING_FEAT_NODROP |
-			IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS;
+			IORING_FEAT_SUBMIT_STABLE | IORING_FEAT_RW_CUR_POS |
+			IORING_FEAT_CUR_PERSONALITY;
 	trace_io_uring_create(ret, ctx, p->sq_entries, p->cq_entries, p->flags);
 	return ret;
 err:
diff --git a/include/uapi/linux/io_uring.h b/include/uapi/linux/io_uring.h
index 57d05cc5e27159cb9ff93ee204907deca9b262d5..9988e82f858b80589b1bf0eb903b70ab7fe7d8ee 100644
--- a/include/uapi/linux/io_uring.h
+++ b/include/uapi/linux/io_uring.h
@@ -195,6 +195,7 @@ struct io_uring_params {
 #define IORING_FEAT_NODROP		(1U << 1)
 #define IORING_FEAT_SUBMIT_STABLE	(1U << 2)
 #define IORING_FEAT_RW_CUR_POS		(1U << 3)
+#define IORING_FEAT_CUR_PERSONALITY	(1U << 4)
 
 /*
  * io_uring_register(2) opcodes and arguments