diff --git a/include/asm-x86/paravirt.h b/include/asm-x86/paravirt.h
index d9957ece49b578b75abc2cc1c023155bf9db9b0e..c56b17a5eba07bc6f42db05e62a25a7ad844fde0 100644
--- a/include/asm-x86/paravirt.h
+++ b/include/asm-x86/paravirt.h
@@ -1086,17 +1086,19 @@ static inline pmdval_t pmd_val(pmd_t pmd)
 
 	return ret;
 }
-#endif	/* PAGETABLE_LEVELS >= 3 */
-
-#ifdef CONFIG_X86_PAE
 
-static inline void set_pud(pud_t *pudp, pud_t pudval)
+static inline void set_pud(pud_t *pudp, pud_t pud)
 {
-	PVOP_VCALL3(pv_mmu_ops.set_pud, pudp,
-		    pudval.pgd.pgd, pudval.pgd.pgd >> 32);
-}
+	pudval_t val = native_pud_val(pud);
 
-#endif	/* CONFIG_X86_PAE */
+	if (sizeof(pudval_t) > sizeof(long))
+		PVOP_VCALL3(pv_mmu_ops.set_pud, pudp,
+			    val, (u64)val >> 32);
+	else
+		PVOP_VCALL2(pv_mmu_ops.set_pud, pudp,
+			    val);
+}
+#endif	/* PAGETABLE_LEVELS >= 3 */
 
 /* Lazy mode for batching updates / context switch */
 enum paravirt_lazy_mode {