diff --git a/mm/slab_common.c b/mm/slab_common.c
index ec832904f4084140a36e1b50850608ddd071d966..e8519d5f3a32357e89251c2243ef2ae4b90d726f 100644
--- a/mm/slab_common.c
+++ b/mm/slab_common.c
@@ -308,7 +308,6 @@ kmem_cache_create_usercopy(const char *name,
 	int err;
 
 	get_online_cpus();
-	get_online_mems();
 
 	mutex_lock(&slab_mutex);
 
@@ -358,7 +357,6 @@ kmem_cache_create_usercopy(const char *name,
 out_unlock:
 	mutex_unlock(&slab_mutex);
 
-	put_online_mems();
 	put_online_cpus();
 
 	if (err) {
@@ -485,7 +483,6 @@ void kmem_cache_destroy(struct kmem_cache *s)
 		return;
 
 	get_online_cpus();
-	get_online_mems();
 
 	mutex_lock(&slab_mutex);
 
@@ -502,7 +499,6 @@ void kmem_cache_destroy(struct kmem_cache *s)
 out_unlock:
 	mutex_unlock(&slab_mutex);
 
-	put_online_mems();
 	put_online_cpus();
 }
 EXPORT_SYMBOL(kmem_cache_destroy);
@@ -521,10 +517,10 @@ int kmem_cache_shrink(struct kmem_cache *cachep)
 	int ret;
 
 	get_online_cpus();
-	get_online_mems();
+
 	kasan_cache_shrink(cachep);
 	ret = __kmem_cache_shrink(cachep);
-	put_online_mems();
+
 	put_online_cpus();
 	return ret;
 }
diff --git a/mm/slub.c b/mm/slub.c
index 259fd20dfe3bc28738e77033b71cf8cee503f0ef..f8e9bbf0f762c50d9eb33d97ade70617bd290009 100644
--- a/mm/slub.c
+++ b/mm/slub.c
@@ -236,6 +236,14 @@ static inline void stat(const struct kmem_cache *s, enum stat_item si)
 #endif
 }
 
+/*
+ * Tracks for which NUMA nodes we have kmem_cache_nodes allocated.
+ * Corresponds to node_state[N_NORMAL_MEMORY], but can temporarily
+ * differ during memory hotplug/hotremove operations.
+ * Protected by slab_mutex.
+ */
+static nodemask_t slab_nodes;
+
 /********************************************************************
  * 			Core slab cache functions
  *******************************************************************/
@@ -2670,7 +2678,7 @@ static void *___slab_alloc(struct kmem_cache *s, gfp_t gfpflags, int node,
 		 * ignore the node constraint
 		 */
 		if (unlikely(node != NUMA_NO_NODE &&
-			     !node_state(node, N_NORMAL_MEMORY)))
+			     !node_isset(node, slab_nodes)))
 			node = NUMA_NO_NODE;
 		goto new_slab;
 	}
@@ -2681,7 +2689,7 @@ static void *___slab_alloc(struct kmem_cache *s, gfp_t gfpflags, int node,
 		 * same as above but node_match() being false already
 		 * implies node != NUMA_NO_NODE
 		 */
-		if (!node_state(node, N_NORMAL_MEMORY)) {
+		if (!node_isset(node, slab_nodes)) {
 			node = NUMA_NO_NODE;
 			goto redo;
 		} else {
@@ -3569,7 +3577,7 @@ static int init_kmem_cache_nodes(struct kmem_cache *s)
 {
 	int node;
 
-	for_each_node_state(node, N_NORMAL_MEMORY) {
+	for_each_node_mask(node, slab_nodes) {
 		struct kmem_cache_node *n;
 
 		if (slab_state == DOWN) {
@@ -4217,6 +4225,7 @@ static void slab_mem_offline_callback(void *arg)
 		return;
 
 	mutex_lock(&slab_mutex);
+	node_clear(offline_node, slab_nodes);
 	/*
 	 * We no longer free kmem_cache_node structures here, as it would be
 	 * racy with all get_node() users, and infeasible to protect them with
@@ -4266,6 +4275,11 @@ static int slab_mem_going_online_callback(void *arg)
 		init_kmem_cache_node(n);
 		s->node[nid] = n;
 	}
+	/*
+	 * Any cache created after this point will also have kmem_cache_node
+	 * initialized for the new node.
+	 */
+	node_set(nid, slab_nodes);
 out:
 	mutex_unlock(&slab_mutex);
 	return ret;
@@ -4346,6 +4360,7 @@ void __init kmem_cache_init(void)
 {
 	static __initdata struct kmem_cache boot_kmem_cache,
 		boot_kmem_cache_node;
+	int node;
 
 	if (debug_guardpage_minorder())
 		slub_max_order = 0;
@@ -4353,6 +4368,13 @@ void __init kmem_cache_init(void)
 	kmem_cache_node = &boot_kmem_cache_node;
 	kmem_cache = &boot_kmem_cache;
 
+	/*
+	 * Initialize the nodemask for which we will allocate per node
+	 * structures. Here we don't need taking slab_mutex yet.
+	 */
+	for_each_node_state(node, N_NORMAL_MEMORY)
+		node_set(node, slab_nodes);
+
 	create_boot_cache(kmem_cache_node, "kmem_cache_node",
 		sizeof(struct kmem_cache_node), SLAB_HWCACHE_ALIGN, 0, 0);