diff --git a/arch/x86/include/asm/pmem.h b/arch/x86/include/asm/pmem.h index 643eba42d6206aa0fbcb57baa150606269577523..2c1ebeb4d7376db6350b7266048a1d37d800ac86 100644 --- a/arch/x86/include/asm/pmem.h +++ b/arch/x86/include/asm/pmem.h @@ -46,10 +46,7 @@ static inline void arch_memcpy_to_pmem(void *dst, const void *src, size_t n) static inline int arch_memcpy_from_pmem(void *dst, const void *src, size_t n) { - if (static_cpu_has(X86_FEATURE_MCE_RECOVERY)) - return memcpy_mcsafe(dst, src, n); - memcpy(dst, src, n); - return 0; + return memcpy_mcsafe(dst, src, n); } /** diff --git a/arch/x86/include/asm/string_64.h b/arch/x86/include/asm/string_64.h index 877a1dfbf7707e5a2297cbbef8b65c55212f8e94..a164862d77e3b4573e7a264c8712fac06ef7a830 100644 --- a/arch/x86/include/asm/string_64.h +++ b/arch/x86/include/asm/string_64.h @@ -79,6 +79,7 @@ int strcmp(const char *cs, const char *ct); #define memset(s, c, n) __memset(s, c, n) #endif +__must_check int memcpy_mcsafe_unrolled(void *dst, const void *src, size_t cnt); DECLARE_STATIC_KEY_FALSE(mcsafe_key); /** @@ -89,10 +90,23 @@ DECLARE_STATIC_KEY_FALSE(mcsafe_key); * @cnt: number of bytes to copy * * Low level memory copy function that catches machine checks + * We only call into the "safe" function on systems that can + * actually do machine check recovery. Everyone else can just + * use memcpy(). * * Return 0 for success, -EFAULT for fail */ -int memcpy_mcsafe(void *dst, const void *src, size_t cnt); +static __always_inline __must_check int +memcpy_mcsafe(void *dst, const void *src, size_t cnt) +{ +#ifdef CONFIG_X86_MCE + if (static_branch_unlikely(&mcsafe_key)) + return memcpy_mcsafe_unrolled(dst, src, cnt); + else +#endif + memcpy(dst, src, cnt); + return 0; +} #endif /* __KERNEL__ */ diff --git a/arch/x86/kernel/x8664_ksyms_64.c b/arch/x86/kernel/x8664_ksyms_64.c index 95e49f6e4fc303a9340199588cbd54983cb13149..b2cee3d19477688a982a1ca7baaa7c4fefc311b2 100644 --- a/arch/x86/kernel/x8664_ksyms_64.c +++ b/arch/x86/kernel/x8664_ksyms_64.c @@ -38,7 +38,7 @@ EXPORT_SYMBOL(__copy_user_nocache); EXPORT_SYMBOL(_copy_from_user); EXPORT_SYMBOL(_copy_to_user); -EXPORT_SYMBOL_GPL(memcpy_mcsafe); +EXPORT_SYMBOL_GPL(memcpy_mcsafe_unrolled); EXPORT_SYMBOL(copy_page); EXPORT_SYMBOL(clear_page); diff --git a/arch/x86/lib/memcpy_64.S b/arch/x86/lib/memcpy_64.S index 2ec0b0abbfaa876fb242b71061b70cdb7dc9db20..49e6ebac7e73e33b0a03327cb65c95a29afc1c67 100644 --- a/arch/x86/lib/memcpy_64.S +++ b/arch/x86/lib/memcpy_64.S @@ -181,11 +181,11 @@ ENDPROC(memcpy_orig) #ifndef CONFIG_UML /* - * memcpy_mcsafe - memory copy with machine check exception handling + * memcpy_mcsafe_unrolled - memory copy with machine check exception handling * Note that we only catch machine checks when reading the source addresses. * Writes to target are posted and don't generate machine checks. */ -ENTRY(memcpy_mcsafe) +ENTRY(memcpy_mcsafe_unrolled) cmpl $8, %edx /* Less than 8 bytes? Go to byte copy loop */ jb .L_no_whole_words @@ -273,7 +273,7 @@ ENTRY(memcpy_mcsafe) .L_done_memcpy_trap: xorq %rax, %rax ret -ENDPROC(memcpy_mcsafe) +ENDPROC(memcpy_mcsafe_unrolled) .section .fixup, "ax" /* Return -EFAULT for any failure */