diff --git a/drivers/md/bcache/writeback.c b/drivers/md/bcache/writeback.c
index b23f88d9f18cf1caa863b1b4ed56c37901296ece..b9346cd9cda192bdf9142bfa462dc0ab74383707 100644
--- a/drivers/md/bcache/writeback.c
+++ b/drivers/md/bcache/writeback.c
@@ -323,6 +323,10 @@ void bcache_dev_sectors_dirty_add(struct cache_set *c, unsigned inode,
 
 static bool dirty_pred(struct keybuf *buf, struct bkey *k)
 {
+	struct cached_dev *dc = container_of(buf, struct cached_dev, writeback_keys);
+
+	BUG_ON(KEY_INODE(k) != dc->disk.id);
+
 	return KEY_DIRTY(k);
 }
 
@@ -372,11 +376,24 @@ static void refill_full_stripes(struct cached_dev *dc)
 	}
 }
 
+/*
+ * Returns true if we scanned the entire disk
+ */
 static bool refill_dirty(struct cached_dev *dc)
 {
 	struct keybuf *buf = &dc->writeback_keys;
+	struct bkey start = KEY(dc->disk.id, 0, 0);
 	struct bkey end = KEY(dc->disk.id, MAX_KEY_OFFSET, 0);
-	bool searched_from_start = false;
+	struct bkey start_pos;
+
+	/*
+	 * make sure keybuf pos is inside the range for this disk - at bringup
+	 * we might not be attached yet so this disk's inode nr isn't
+	 * initialized then
+	 */
+	if (bkey_cmp(&buf->last_scanned, &start) < 0 ||
+	    bkey_cmp(&buf->last_scanned, &end) > 0)
+		buf->last_scanned = start;
 
 	if (dc->partial_stripes_expensive) {
 		refill_full_stripes(dc);
@@ -384,14 +401,20 @@ static bool refill_dirty(struct cached_dev *dc)
 			return false;
 	}
 
-	if (bkey_cmp(&buf->last_scanned, &end) >= 0) {
-		buf->last_scanned = KEY(dc->disk.id, 0, 0);
-		searched_from_start = true;
-	}
-
+	start_pos = buf->last_scanned;
 	bch_refill_keybuf(dc->disk.c, buf, &end, dirty_pred);
 
-	return bkey_cmp(&buf->last_scanned, &end) >= 0 && searched_from_start;
+	if (bkey_cmp(&buf->last_scanned, &end) < 0)
+		return false;
+
+	/*
+	 * If we get to the end start scanning again from the beginning, and
+	 * only scan up to where we initially started scanning from:
+	 */
+	buf->last_scanned = start;
+	bch_refill_keybuf(dc->disk.c, buf, &start_pos, dirty_pred);
+
+	return bkey_cmp(&buf->last_scanned, &start_pos) >= 0;
 }
 
 static int bch_writeback_thread(void *arg)