diff --git a/include/linux/nospec.h b/include/linux/nospec.h index b99bced39ac2fa1b91cea50a8301336c3cbff133..fbc98e2c8228d0b8ca17e4af9be1cb1cddaccd27 100644 --- a/include/linux/nospec.h +++ b/include/linux/nospec.h @@ -19,20 +19,6 @@ static inline unsigned long array_index_mask_nospec(unsigned long index, unsigned long size) { - /* - * Warn developers about inappropriate array_index_nospec() usage. - * - * Even if the CPU speculates past the WARN_ONCE branch, the - * sign bit of @index is taken into account when generating the - * mask. - * - * This warning is compiled out when the compiler can infer that - * @index and @size are less than LONG_MAX. - */ - if (WARN_ONCE(index > LONG_MAX || size > LONG_MAX, - "array_index_nospec() limited to range of [0, LONG_MAX]\n")) - return 0; - /* * Always calculate and emit the mask even if the compiler * thinks the mask is not needed. The compiler does not take @@ -43,6 +29,26 @@ static inline unsigned long array_index_mask_nospec(unsigned long index, } #endif +/* + * Warn developers about inappropriate array_index_nospec() usage. + * + * Even if the CPU speculates past the WARN_ONCE branch, the + * sign bit of @index is taken into account when generating the + * mask. + * + * This warning is compiled out when the compiler can infer that + * @index and @size are less than LONG_MAX. + */ +#define array_index_mask_nospec_check(index, size) \ +({ \ + if (WARN_ONCE(index > LONG_MAX || size > LONG_MAX, \ + "array_index_nospec() limited to range of [0, LONG_MAX]\n")) \ + _mask = 0; \ + else \ + _mask = array_index_mask_nospec(index, size); \ + _mask; \ +}) + /* * array_index_nospec - sanitize an array index after a bounds check * @@ -61,7 +67,7 @@ static inline unsigned long array_index_mask_nospec(unsigned long index, ({ \ typeof(index) _i = (index); \ typeof(size) _s = (size); \ - unsigned long _mask = array_index_mask_nospec(_i, _s); \ + unsigned long _mask = array_index_mask_nospec_check(_i, _s); \ \ BUILD_BUG_ON(sizeof(_i) > sizeof(long)); \ BUILD_BUG_ON(sizeof(_s) > sizeof(long)); \