diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 1bab87a444d781ffc66e494a13cf1cdacc66d90f..ce60f4c38843f7897244be1b9d7962736cf602e4 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -1136,7 +1136,8 @@ struct kvm_x86_ops {
 	int (*get_tdp_level)(struct kvm_vcpu *vcpu);
 	u64 (*get_mt_mask)(struct kvm_vcpu *vcpu, gfn_t gfn, bool is_mmio);
 
-	void (*load_mmu_pgd)(struct kvm_vcpu *vcpu, unsigned long cr3);
+	void (*load_mmu_pgd)(struct kvm_vcpu *vcpu, unsigned long pgd,
+			     int pgd_level);
 
 	bool (*has_wbinvd_exit)(void);
 
diff --git a/arch/x86/kvm/mmu.h b/arch/x86/kvm/mmu.h
index 9f6554613babc978373ad60ec62182a0415cca38..5efc6081ca138ed601e457041ba2cb12f96f6970 100644
--- a/arch/x86/kvm/mmu.h
+++ b/arch/x86/kvm/mmu.h
@@ -90,9 +90,13 @@ static inline unsigned long kvm_get_active_pcid(struct kvm_vcpu *vcpu)
 
 static inline void kvm_mmu_load_pgd(struct kvm_vcpu *vcpu)
 {
-	if (VALID_PAGE(vcpu->arch.mmu->root_hpa))
-		kvm_x86_ops.load_mmu_pgd(vcpu, vcpu->arch.mmu->root_hpa |
-					       kvm_get_active_pcid(vcpu));
+	u64 root_hpa = vcpu->arch.mmu->root_hpa;
+
+	if (!VALID_PAGE(root_hpa))
+		return;
+
+	kvm_x86_ops.load_mmu_pgd(vcpu, root_hpa | kvm_get_active_pcid(vcpu),
+				 vcpu->arch.mmu->shadow_root_level);
 }
 
 int kvm_tdp_page_fault(struct kvm_vcpu *vcpu, gpa_t gpa, u32 error_code,
diff --git a/arch/x86/kvm/svm/svm.c b/arch/x86/kvm/svm/svm.c
index 783330d0e7b889678b6a86e933b4d746e6539f39..c70d7dd3330612a2e272a72f8deab9eea586b873 100644
--- a/arch/x86/kvm/svm/svm.c
+++ b/arch/x86/kvm/svm/svm.c
@@ -3541,7 +3541,8 @@ static __no_kcsan fastpath_t svm_vcpu_run(struct kvm_vcpu *vcpu)
 	return exit_fastpath;
 }
 
-static void svm_load_mmu_pgd(struct kvm_vcpu *vcpu, unsigned long root)
+static void svm_load_mmu_pgd(struct kvm_vcpu *vcpu, unsigned long root,
+			     int root_level)
 {
 	struct vcpu_svm *svm = to_svm(vcpu);
 	unsigned long cr3;
diff --git a/arch/x86/kvm/vmx/nested.c b/arch/x86/kvm/vmx/nested.c
index 4d561edf6f9ca6b1edf1e07f65603e0a313d4b21..e405e754b592be58c911bbe9cd16839f150cabc2 100644
--- a/arch/x86/kvm/vmx/nested.c
+++ b/arch/x86/kvm/vmx/nested.c
@@ -2162,7 +2162,8 @@ static void prepare_vmcs02_constant_state(struct vcpu_vmx *vmx)
 	 * consistency checks.
 	 */
 	if (enable_ept && nested_early_check)
-		vmcs_write64(EPT_POINTER, construct_eptp(&vmx->vcpu, 0));
+		vmcs_write64(EPT_POINTER,
+			     construct_eptp(&vmx->vcpu, 0, PT64_ROOT_4LEVEL));
 
 	/* All VMFUNCs are currently emulated through L0 vmexits.  */
 	if (cpu_has_vmx_vmfunc())
diff --git a/arch/x86/kvm/vmx/vmx.c b/arch/x86/kvm/vmx/vmx.c
index 791baa73e5786ba60100b282f7c99fa752d9b0ee..244053cff0a3a9d7357938688348645b4b10997f 100644
--- a/arch/x86/kvm/vmx/vmx.c
+++ b/arch/x86/kvm/vmx/vmx.c
@@ -2933,14 +2933,16 @@ static void vmx_flush_tlb_all(struct kvm_vcpu *vcpu)
 
 static void vmx_flush_tlb_current(struct kvm_vcpu *vcpu)
 {
-	u64 root_hpa = vcpu->arch.mmu->root_hpa;
+	struct kvm_mmu *mmu = vcpu->arch.mmu;
+	u64 root_hpa = mmu->root_hpa;
 
 	/* No flush required if the current context is invalid. */
 	if (!VALID_PAGE(root_hpa))
 		return;
 
 	if (enable_ept)
-		ept_sync_context(construct_eptp(vcpu, root_hpa));
+		ept_sync_context(construct_eptp(vcpu, root_hpa,
+						mmu->shadow_root_level));
 	else if (!is_guest_mode(vcpu))
 		vpid_sync_context(to_vmx(vcpu)->vpid);
 	else
@@ -3078,11 +3080,12 @@ static int get_ept_level(struct kvm_vcpu *vcpu)
 	return vmx_get_tdp_level(vcpu);
 }
 
-u64 construct_eptp(struct kvm_vcpu *vcpu, unsigned long root_hpa)
+u64 construct_eptp(struct kvm_vcpu *vcpu, unsigned long root_hpa,
+		   int root_level)
 {
 	u64 eptp = VMX_EPTP_MT_WB;
 
-	eptp |= (get_ept_level(vcpu) == 5) ? VMX_EPTP_PWL_5 : VMX_EPTP_PWL_4;
+	eptp |= (root_level == 5) ? VMX_EPTP_PWL_5 : VMX_EPTP_PWL_4;
 
 	if (enable_ept_ad_bits &&
 	    (!is_guest_mode(vcpu) || nested_ept_ad_enabled(vcpu)))
@@ -3092,7 +3095,8 @@ u64 construct_eptp(struct kvm_vcpu *vcpu, unsigned long root_hpa)
 	return eptp;
 }
 
-static void vmx_load_mmu_pgd(struct kvm_vcpu *vcpu, unsigned long pgd)
+static void vmx_load_mmu_pgd(struct kvm_vcpu *vcpu, unsigned long pgd,
+			     int pgd_level)
 {
 	struct kvm *kvm = vcpu->kvm;
 	bool update_guest_cr3 = true;
@@ -3100,7 +3104,9 @@ static void vmx_load_mmu_pgd(struct kvm_vcpu *vcpu, unsigned long pgd)
 	u64 eptp;
 
 	if (enable_ept) {
-		eptp = construct_eptp(vcpu, pgd);
+		WARN_ON(pgd_level != get_ept_level(vcpu));
+
+		eptp = construct_eptp(vcpu, pgd, pgd_level);
 		vmcs_write64(EPT_POINTER, eptp);
 
 		if (kvm_x86_ops.tlb_remote_flush) {
diff --git a/arch/x86/kvm/vmx/vmx.h b/arch/x86/kvm/vmx/vmx.h
index 3c55433ac1b213eadb3a7a0c267b0aac3c353134..26175a4759fa5f8eef34b443dcb3fd6415b4f6b6 100644
--- a/arch/x86/kvm/vmx/vmx.h
+++ b/arch/x86/kvm/vmx/vmx.h
@@ -341,7 +341,8 @@ void set_cr4_guest_host_mask(struct vcpu_vmx *vmx);
 void ept_save_pdptrs(struct kvm_vcpu *vcpu);
 void vmx_get_segment(struct kvm_vcpu *vcpu, struct kvm_segment *var, int seg);
 void vmx_set_segment(struct kvm_vcpu *vcpu, struct kvm_segment *var, int seg);
-u64 construct_eptp(struct kvm_vcpu *vcpu, unsigned long root_hpa);
+u64 construct_eptp(struct kvm_vcpu *vcpu, unsigned long root_hpa,
+		   int root_level);
 void update_exception_bitmap(struct kvm_vcpu *vcpu);
 void vmx_update_msr_bitmap(struct kvm_vcpu *vcpu);
 bool vmx_nmi_blocked(struct kvm_vcpu *vcpu);