diff --git a/include/linux/kthread.h b/include/linux/kthread.h
index 82e197eeac91f2bff8b62df9a58c75b5645b1ca2..bd4369c83dfba7221d23dc25a37332d9efb961f0 100644
--- a/include/linux/kthread.h
+++ b/include/linux/kthread.h
@@ -3,6 +3,7 @@
 /* Simple interface for creating and stopping kernel threads without mess. */
 #include <linux/err.h>
 #include <linux/sched.h>
+#include <linux/cgroup.h>
 
 __printf(4, 5)
 struct task_struct *kthread_create_on_node(int (*threadfn)(void *data),
@@ -198,4 +199,14 @@ bool kthread_cancel_delayed_work_sync(struct kthread_delayed_work *work);
 
 void kthread_destroy_worker(struct kthread_worker *worker);
 
+#ifdef CONFIG_CGROUPS
+void kthread_associate_blkcg(struct cgroup_subsys_state *css);
+struct cgroup_subsys_state *kthread_blkcg(void);
+#else
+static inline void kthread_associate_blkcg(struct cgroup_subsys_state *css) { }
+static inline struct cgroup_subsys_state *kthread_blkcg(void)
+{
+	return NULL;
+}
+#endif
 #endif /* _LINUX_KTHREAD_H */
diff --git a/kernel/kthread.c b/kernel/kthread.c
index 1c19edf824272db47a48730478af5f3b582bbf67..b011ea08967fc44bcc55ef475a3e15e39d6ef2ee 100644
--- a/kernel/kthread.c
+++ b/kernel/kthread.c
@@ -20,7 +20,6 @@
 #include <linux/freezer.h>
 #include <linux/ptrace.h>
 #include <linux/uaccess.h>
-#include <linux/cgroup.h>
 #include <trace/events/sched.h>
 
 static DEFINE_SPINLOCK(kthread_create_lock);
@@ -47,6 +46,9 @@ struct kthread {
 	void *data;
 	struct completion parked;
 	struct completion exited;
+#ifdef CONFIG_CGROUPS
+	struct cgroup_subsys_state *blkcg_css;
+#endif
 };
 
 enum KTHREAD_BITS {
@@ -74,11 +76,17 @@ static inline struct kthread *to_kthread(struct task_struct *k)
 
 void free_kthread_struct(struct task_struct *k)
 {
+	struct kthread *kthread;
+
 	/*
 	 * Can be NULL if this kthread was created by kernel_thread()
 	 * or if kmalloc() in kthread() failed.
 	 */
-	kfree(to_kthread(k));
+	kthread = to_kthread(k);
+#ifdef CONFIG_CGROUPS
+	WARN_ON_ONCE(kthread && kthread->blkcg_css);
+#endif
+	kfree(kthread);
 }
 
 /**
@@ -216,6 +224,9 @@ static int kthread(void *_create)
 	self->data = data;
 	init_completion(&self->exited);
 	init_completion(&self->parked);
+#ifdef CONFIG_CGROUPS
+	self->blkcg_css = NULL;
+#endif
 	current->vfork_done = &self->exited;
 
 	/* OK, tell user we're spawned, wait for stop or wakeup */
@@ -1154,3 +1165,54 @@ void kthread_destroy_worker(struct kthread_worker *worker)
 	kfree(worker);
 }
 EXPORT_SYMBOL(kthread_destroy_worker);
+
+#ifdef CONFIG_CGROUPS
+/**
+ * kthread_associate_blkcg - associate blkcg to current kthread
+ * @css: the cgroup info
+ *
+ * Current thread must be a kthread. The thread is running jobs on behalf of
+ * other threads. In some cases, we expect the jobs attach cgroup info of
+ * original threads instead of that of current thread. This function stores
+ * original thread's cgroup info in current kthread context for later
+ * retrieval.
+ */
+void kthread_associate_blkcg(struct cgroup_subsys_state *css)
+{
+	struct kthread *kthread;
+
+	if (!(current->flags & PF_KTHREAD))
+		return;
+	kthread = to_kthread(current);
+	if (!kthread)
+		return;
+
+	if (kthread->blkcg_css) {
+		css_put(kthread->blkcg_css);
+		kthread->blkcg_css = NULL;
+	}
+	if (css) {
+		css_get(css);
+		kthread->blkcg_css = css;
+	}
+}
+EXPORT_SYMBOL(kthread_associate_blkcg);
+
+/**
+ * kthread_blkcg - get associated blkcg css of current kthread
+ *
+ * Current thread must be a kthread.
+ */
+struct cgroup_subsys_state *kthread_blkcg(void)
+{
+	struct kthread *kthread;
+
+	if (current->flags & PF_KTHREAD) {
+		kthread = to_kthread(current);
+		if (kthread)
+			return kthread->blkcg_css;
+	}
+	return NULL;
+}
+EXPORT_SYMBOL(kthread_blkcg);
+#endif