diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index f6d3505c8d18c48de7142b916ca1de2da04b9022..cfef959693355ca45008cccd348a66d7b0f2de4a 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -504,14 +504,16 @@ static bool spte_has_volatile_bits(u64 spte)
 	return true;
 }
 
-static bool spte_is_bit_cleared(u64 old_spte, u64 new_spte, u64 bit_mask)
+static bool is_accessed_spte(u64 spte)
 {
-	return (old_spte & bit_mask) && !(new_spte & bit_mask);
+	return shadow_accessed_mask ? spte & shadow_accessed_mask
+				    : true;
 }
 
-static bool spte_is_bit_changed(u64 old_spte, u64 new_spte, u64 bit_mask)
+static bool is_dirty_spte(u64 spte)
 {
-	return (old_spte & bit_mask) != (new_spte & bit_mask);
+	return shadow_dirty_mask ? spte & shadow_dirty_mask
+				 : spte & PT_WRITABLE_MASK;
 }
 
 /* Rules for using mmu_spte_set:
@@ -534,17 +536,19 @@ static void mmu_spte_set(u64 *sptep, u64 new_spte)
  * will find a read-only spte, even though the writable spte
  * might be cached on a CPU's TLB, the return value indicates this
  * case.
+ *
+ * Returns true if the TLB needs to be flushed
  */
 static bool mmu_spte_update(u64 *sptep, u64 new_spte)
 {
 	u64 old_spte = *sptep;
-	bool ret = false;
+	bool flush = false;
 
 	WARN_ON(!is_shadow_present_pte(new_spte));
 
 	if (!is_shadow_present_pte(old_spte)) {
 		mmu_spte_set(sptep, new_spte);
-		return ret;
+		return flush;
 	}
 
 	if (!spte_has_volatile_bits(old_spte))
@@ -552,6 +556,8 @@ static bool mmu_spte_update(u64 *sptep, u64 new_spte)
 	else
 		old_spte = __update_clear_spte_slow(sptep, new_spte);
 
+	WARN_ON(spte_to_pfn(old_spte) != spte_to_pfn(new_spte));
+
 	/*
 	 * For the spte updated out of mmu-lock is safe, since
 	 * we always atomically update it, see the comments in
@@ -559,38 +565,31 @@ static bool mmu_spte_update(u64 *sptep, u64 new_spte)
 	 */
 	if (spte_can_locklessly_be_made_writable(old_spte) &&
 	      !is_writable_pte(new_spte))
-		ret = true;
-
-	if (!shadow_accessed_mask) {
-		/*
-		 * We don't set page dirty when dropping non-writable spte.
-		 * So do it now if the new spte is becoming non-writable.
-		 */
-		if (ret)
-			kvm_set_pfn_dirty(spte_to_pfn(old_spte));
-		return ret;
-	}
+		flush = true;
 
 	/*
-	 * Flush TLB when accessed/dirty bits are changed in the page tables,
+	 * Flush TLB when accessed/dirty states are changed in the page tables,
 	 * to guarantee consistency between TLB and page tables.
 	 */
-	if (spte_is_bit_changed(old_spte, new_spte,
-                                shadow_accessed_mask | shadow_dirty_mask))
-		ret = true;
 
-	if (spte_is_bit_cleared(old_spte, new_spte, shadow_accessed_mask))
+	if (is_accessed_spte(old_spte) && !is_accessed_spte(new_spte)) {
+		flush = true;
 		kvm_set_pfn_accessed(spte_to_pfn(old_spte));
-	if (spte_is_bit_cleared(old_spte, new_spte, shadow_dirty_mask))
+	}
+
+	if (is_dirty_spte(old_spte) && !is_dirty_spte(new_spte)) {
+		flush = true;
 		kvm_set_pfn_dirty(spte_to_pfn(old_spte));
+	}
 
-	return ret;
+	return flush;
 }
 
 /*
  * Rules for using mmu_spte_clear_track_bits:
  * It sets the sptep from present to nonpresent, and track the
  * state bits, it is used to clear the last level sptep.
+ * Returns non-zero if the PTE was previously valid.
  */
 static int mmu_spte_clear_track_bits(u64 *sptep)
 {
@@ -614,11 +613,12 @@ static int mmu_spte_clear_track_bits(u64 *sptep)
 	 */
 	WARN_ON(!kvm_is_reserved_pfn(pfn) && !page_count(pfn_to_page(pfn)));
 
-	if (!shadow_accessed_mask || old_spte & shadow_accessed_mask)
+	if (is_accessed_spte(old_spte))
 		kvm_set_pfn_accessed(pfn);
-	if (old_spte & (shadow_dirty_mask ? shadow_dirty_mask :
-					    PT_WRITABLE_MASK))
+
+	if (is_dirty_spte(old_spte))
 		kvm_set_pfn_dirty(pfn);
+
 	return 1;
 }
 
@@ -1616,7 +1616,6 @@ static int kvm_test_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
 {
 	u64 *sptep;
 	struct rmap_iterator iter;
-	int young = 0;
 
 	/*
 	 * If there's no access bit in the secondary pte set by the
@@ -1626,14 +1625,11 @@ static int kvm_test_age_rmapp(struct kvm *kvm, struct kvm_rmap_head *rmap_head,
 	if (!shadow_accessed_mask)
 		goto out;
 
-	for_each_rmap_spte(rmap_head, &iter, sptep) {
-		if (*sptep & shadow_accessed_mask) {
-			young = 1;
-			break;
-		}
-	}
+	for_each_rmap_spte(rmap_head, &iter, sptep)
+		if (is_accessed_spte(*sptep))
+			return 1;
 out:
-	return young;
+	return 0;
 }
 
 #define RMAP_RECYCLE_THRESHOLD 1000