diff --git a/lib/rhashtable.c b/lib/rhashtable.c
index 81edf1ab38ab8aa1fabc3df799bf2ac8b2234e8f..9427b5766134cb139ef385b27f92f6027fecceca 100644
--- a/lib/rhashtable.c
+++ b/lib/rhashtable.c
@@ -727,6 +727,7 @@ int rhashtable_walk_start_check(struct rhashtable_iter *iter)
 	__acquires(RCU)
 {
 	struct rhashtable *ht = iter->ht;
+	bool rhlist = ht->rhlist;
 
 	rcu_read_lock();
 
@@ -735,13 +736,52 @@ int rhashtable_walk_start_check(struct rhashtable_iter *iter)
 		list_del(&iter->walker.list);
 	spin_unlock(&ht->lock);
 
-	if (!iter->walker.tbl && !iter->end_of_table) {
+	if (iter->end_of_table)
+		return 0;
+	if (!iter->walker.tbl) {
 		iter->walker.tbl = rht_dereference_rcu(ht->tbl, ht);
 		iter->slot = 0;
 		iter->skip = 0;
 		return -EAGAIN;
 	}
 
+	if (iter->p && !rhlist) {
+		/*
+		 * We need to validate that 'p' is still in the table, and
+		 * if so, update 'skip'
+		 */
+		struct rhash_head *p;
+		int skip = 0;
+		rht_for_each_rcu(p, iter->walker.tbl, iter->slot) {
+			skip++;
+			if (p == iter->p) {
+				iter->skip = skip;
+				goto found;
+			}
+		}
+		iter->p = NULL;
+	} else if (iter->p && rhlist) {
+		/* Need to validate that 'list' is still in the table, and
+		 * if so, update 'skip' and 'p'.
+		 */
+		struct rhash_head *p;
+		struct rhlist_head *list;
+		int skip = 0;
+		rht_for_each_rcu(p, iter->walker.tbl, iter->slot) {
+			for (list = container_of(p, struct rhlist_head, rhead);
+			     list;
+			     list = rcu_dereference(list->next)) {
+				skip++;
+				if (list == iter->list) {
+					iter->p = p;
+					skip = skip;
+					goto found;
+				}
+			}
+		}
+		iter->p = NULL;
+	}
+found:
 	return 0;
 }
 EXPORT_SYMBOL_GPL(rhashtable_walk_start_check);
@@ -917,8 +957,6 @@ void rhashtable_walk_stop(struct rhashtable_iter *iter)
 		iter->walker.tbl = NULL;
 	spin_unlock(&ht->lock);
 
-	iter->p = NULL;
-
 out:
 	rcu_read_unlock();
 }