diff --git a/mm/vmscan.c b/mm/vmscan.c
index a01897fdfdac3bbfa28bafd50c07614907c0bb2d..88d740db3216bc0bdc97983ff94863997df82aa6 100644
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -3194,11 +3194,13 @@ unsigned long try_to_free_pages(struct zonelist *zonelist, int order,
 	if (throttle_direct_reclaim(sc.gfp_mask, zonelist, nodemask))
 		return 1;
 
+	current->reclaim_state = &sc.reclaim_state;
 	trace_mm_vmscan_direct_reclaim_begin(order, sc.gfp_mask);
 
 	nr_reclaimed = do_try_to_free_pages(zonelist, &sc);
 
 	trace_mm_vmscan_direct_reclaim_end(nr_reclaimed);
+	current->reclaim_state = NULL;
 
 	return nr_reclaimed;
 }
@@ -3221,6 +3223,7 @@ unsigned long mem_cgroup_shrink_node(struct mem_cgroup *memcg,
 	};
 	unsigned long lru_pages;
 
+	current->reclaim_state = &sc.reclaim_state;
 	sc.gfp_mask = (gfp_mask & GFP_RECLAIM_MASK) |
 			(GFP_HIGHUSER_MOVABLE & ~GFP_RECLAIM_MASK);
 
@@ -3238,7 +3241,9 @@ unsigned long mem_cgroup_shrink_node(struct mem_cgroup *memcg,
 
 	trace_mm_vmscan_memcg_softlimit_reclaim_end(sc.nr_reclaimed);
 
+	current->reclaim_state = NULL;
 	*nr_scanned = sc.nr_scanned;
+
 	return sc.nr_reclaimed;
 }
 
@@ -3265,6 +3270,7 @@ unsigned long try_to_free_mem_cgroup_pages(struct mem_cgroup *memcg,
 		.may_shrinkslab = 1,
 	};
 
+	current->reclaim_state = &sc.reclaim_state;
 	/*
 	 * Unlike direct reclaim via alloc_pages(), memcg's reclaim doesn't
 	 * take care of from where we get pages. So the node where we start the
@@ -3285,6 +3291,7 @@ unsigned long try_to_free_mem_cgroup_pages(struct mem_cgroup *memcg,
 	psi_memstall_leave(&pflags);
 
 	trace_mm_vmscan_memcg_reclaim_end(nr_reclaimed);
+	current->reclaim_state = NULL;
 
 	return nr_reclaimed;
 }