diff --git a/paddle/fluid/operators/top_k_function_cuda.h b/paddle/fluid/operators/top_k_function_cuda.h index a3da3d572e5e75387ae617d2fdad70e95b2e761c..848ab1cb556e098b4df001ab9a9c751082dc72b9 100644 --- a/paddle/fluid/operators/top_k_function_cuda.h +++ b/paddle/fluid/operators/top_k_function_cuda.h @@ -361,7 +361,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, } /*---------------------------Radix TopK Begin------------------*/ -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000 constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS) constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS constexpr int RADIX_MASK = (RADIX_SIZE - 1); @@ -479,15 +479,25 @@ struct RadixTypeConfig { typedef uint32_t RadixType; static inline __device__ RadixType Convert(platform::float16 v) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) half v_h = v.to_half(); RadixType x = __half_as_ushort(v_h); RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; return (v_h == v_h) ? (x ^ mask) : 0xffff; +#else + assert(false); + return 0u; +#endif } static inline __device__ platform::float16 Deconvert(RadixType v) { +#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; return static_cast(__ushort_as_half(v ^ mask)); +#else + assert(false); + return static_cast(0); +#endif } };