diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu index d50bec2f635e7cf6b0507d1c5235b3afdb50e7b1..901551f964bd4f81bfe25be5c744b40d9ba1de3f 100644 --- a/paddle/phi/kernels/funcs/fc_functor.cu +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -36,6 +36,14 @@ struct FcTypeTraits { typedef double4 Type; }; +#if defined(PADDLE_WITH_CUDA) +#include + +template <> +struct FcTypeTraits { + typedef half2 Type; +}; +#else struct float16_4 { float16 x, y, z, w; }; @@ -44,6 +52,7 @@ template <> struct FcTypeTraits { typedef float16_4 Type; }; +#endif template __global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) { @@ -117,12 +126,109 @@ void AddReluKernel( } } +#if defined(PADDLE_WITH_CUDA) +template +__global__ void bias_relu_v2(const int num, + const half2* bias, + half2* data, + int K) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid < num) { + int bias_idx = tid % K; + const half2 bias_ptr = bias[bias_idx]; + const half2 in_ptr = data[tid]; + half2 packed_val; +#if __CUDA_ARCH__ >= 530 + packed_val = __hadd2(bias_ptr, in_ptr); +#else + packed_val.x = __hadd(bias_ptr.x, in_ptr.x); + packed_val.y = __hadd(bias_ptr.y, in_ptr.y); +#endif + if (DoRelu) { +#if __CUDA_ARCH__ >= 800 + packed_val = __hmax2(__half2(0, 0), packed_val); +#elif __CUDA_ARCH__ >= 530 + packed_val = __hmul2(__hgt2(__half2(0, 0), packed_val), packed_val); +#else + packed_val.x = static_cast(static_cast(packed_val.x) > 0) * + static_cast(packed_val.x); + packed_val.y = static_cast(static_cast(packed_val.y) > 0) * + static_cast(packed_val.y); +#endif + } + data[tid] = packed_val; + } +} + +template +__global__ void InplaceAddReluKernel(const int N, + const half* bias, + half* data) { + int offset = blockIdx.x * N; + for (int i = threadIdx.x; i < N; i += BlockDim) { + half temp; +#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350 + temp = __hadd(__ldg(data + offset + i), __ldg(bias + i)); +#else + temp = __hadd(data[offset + i], bias[i]); +#endif + if (DoRelu) { +#if __CUDA_ARCH__ >= 800 + data[offset + i] = __hmax(0, temp); +#elif __CUDA_ARCH__ >= 530 + data[offset + i] = __hmul(__hgt(temp, 0), temp); +#else + data[offset + i] = static_cast(static_cast(temp) > 0) * + static_cast(temp); +#endif + } else { + data[offset + i] = temp; + } + } +} + +template <> +void AddReluKernel(cudaStream_t stream, + const int M, + const int N, + float16* Y, + const float16* B, + bool relu) { + if (N % 2 == 0) { + const int threads = 256; + const int num = M * N / 2; + const int blocks = (num + threads - 1) / threads; + typedef typename FcTypeTraits::Type trans_type; + auto* bias_ptr_v2 = reinterpret_cast(B); + auto* data_ptr_v2 = reinterpret_cast(Y); + if (relu) { + bias_relu_v2<<>>( + num, bias_ptr_v2, data_ptr_v2, N / 2); + } else { + bias_relu_v2<<>>( + num, bias_ptr_v2, data_ptr_v2, N / 2); + } + } else { + const int threads = 256; + const int blocks = M; + auto* halfB = reinterpret_cast(B); + auto* halfY = reinterpret_cast(Y); + if (relu) { + InplaceAddReluKernel + <<>>(N, halfB, halfY); + } else { + InplaceAddReluKernel + <<>>(N, halfB, halfY); + } + } +} +#else template __global__ void InplaceAddReluKernel(const int N, const float16* bias, float16* data) { int offset = blockIdx.x * N; - for (int i = threadIdx.x; i < N; i += BlockDim) { float16 temp; temp = data[offset + i] + bias[i]; @@ -168,6 +274,7 @@ void AddReluKernel(gpuStream_t stream, } } } +#endif template void FCFunctor::operator()(const DeviceContext& context,