未验证 提交 dea24544 编写于 作者: Z Zhang Zheng 提交者: GitHub

Restrict compilation conditions of optimized topk kernel (#41153)

* Restrict compilation conditions of optimized topk kernel

* fix
上级 23a69bc7
...@@ -361,7 +361,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, ...@@ -361,7 +361,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
} }
/*---------------------------Radix TopK Begin------------------*/ /*---------------------------Radix TopK Begin------------------*/
#if defined(PADDLE_WITH_CUDA) #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000
constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS) constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS)
constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS
constexpr int RADIX_MASK = (RADIX_SIZE - 1); constexpr int RADIX_MASK = (RADIX_SIZE - 1);
...@@ -479,15 +479,25 @@ struct RadixTypeConfig<platform::float16> { ...@@ -479,15 +479,25 @@ struct RadixTypeConfig<platform::float16> {
typedef uint32_t RadixType; typedef uint32_t RadixType;
static inline __device__ RadixType Convert(platform::float16 v) { static inline __device__ RadixType Convert(platform::float16 v) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
half v_h = v.to_half(); half v_h = v.to_half();
RadixType x = __half_as_ushort(v_h); RadixType x = __half_as_ushort(v_h);
RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
return (v_h == v_h) ? (x ^ mask) : 0xffff; return (v_h == v_h) ? (x ^ mask) : 0xffff;
#else
assert(false);
return 0u;
#endif
} }
static inline __device__ platform::float16 Deconvert(RadixType v) { static inline __device__ platform::float16 Deconvert(RadixType v) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
return static_cast<platform::float16>(__ushort_as_half(v ^ mask)); return static_cast<platform::float16>(__ushort_as_half(v ^ mask));
#else
assert(false);
return static_cast<platform::float16>(0);
#endif
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册