diff --git a/block/blk-cgroup.c b/block/blk-cgroup.c
index 527524134693005624d508b6720f8dfeefbcc47b..d4efb293b8e13a72d11d88b6b35b7648e9cd8d1f 100644
--- a/block/blk-cgroup.c
+++ b/block/blk-cgroup.c
@@ -28,6 +28,7 @@
 #include <linux/ctype.h>
 #include <linux/blk-cgroup.h>
 #include <linux/tracehook.h>
+#include <linux/psi.h>
 #include "blk.h"
 
 #define MAX_KEY_LEN 100
@@ -1674,6 +1675,7 @@ static void blkcg_scale_delay(struct blkcg_gq *blkg, u64 now)
  */
 static void blkcg_maybe_throttle_blkg(struct blkcg_gq *blkg, bool use_memdelay)
 {
+	unsigned long pflags;
 	u64 now = ktime_to_ns(ktime_get());
 	u64 exp;
 	u64 delay_nsec = 0;
@@ -1700,11 +1702,8 @@ static void blkcg_maybe_throttle_blkg(struct blkcg_gq *blkg, bool use_memdelay)
 	 */
 	delay_nsec = min_t(u64, delay_nsec, 250 * NSEC_PER_MSEC);
 
-	/*
-	 * TODO: the use_memdelay flag is going to be for the upcoming psi stuff
-	 * that hasn't landed upstream yet.  Once that stuff is in place we need
-	 * to do a psi_memstall_enter/leave if memdelay is set.
-	 */
+	if (use_memdelay)
+		psi_memstall_enter(&pflags);
 
 	exp = ktime_add_ns(now, delay_nsec);
 	tok = io_schedule_prepare();
@@ -1714,6 +1713,9 @@ static void blkcg_maybe_throttle_blkg(struct blkcg_gq *blkg, bool use_memdelay)
 			break;
 	} while (!fatal_signal_pending(current));
 	io_schedule_finish(tok);
+
+	if (use_memdelay)
+		psi_memstall_leave(&pflags);
 }
 
 /**