From fcfaa10478e907179a5187420784ac56cba37d1e Mon Sep 17 00:00:00 2001 From: ming1753 <61511741+ming1753@users.noreply.github.com> Date: Fri, 22 Jul 2022 15:50:27 +0800 Subject: [PATCH] (modified) fc support fp16 (#44540) --- paddle/phi/kernels/funcs/fc_functor.cu | 61 ++++++++------------------ 1 file changed, 18 insertions(+), 43 deletions(-) diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu index 1f2db558329..d0bd7567c7d 100644 --- a/paddle/phi/kernels/funcs/fc_functor.cu +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -36,6 +36,24 @@ struct FcTypeTraits { typedef double4 Type; }; +#if defined(PADDLE_WITH_CUDA) +#include + +template <> +struct FcTypeTraits { + typedef half2 Type; +}; +#else +struct float16_4 { + float16 x, y, z, w; +}; + +template <> +struct FcTypeTraits { + typedef float16_4 Type; +}; +#endif + template __global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) { int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -109,14 +127,6 @@ void AddReluKernel( } #if defined(PADDLE_WITH_CUDA) - -#include - -template <> -struct FcTypeTraits { - typedef half2 Type; -}; - template __global__ void bias_relu_v2(const int num, const half2* bias, @@ -200,46 +210,11 @@ void AddReluKernel(cudaStream_t stream, } #else - -struct float16_4 { - float16 x, y, z, w; -}; -template <> -struct FcTypeTraits { - typedef float16_4 Type; -}; - -template -__global__ void bias_relu_v4(const int num, - const float16_4* bias, - float16_4* data, - int K) { - int tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < num) { - int bias_idx = tid % K; - const float16_4 bias_ptr = bias[bias_idx]; - const float16_4 in_ptr = data[tid]; - float16_4 packed_val; - packed_val.x = in_ptr.x + bias_ptr.x; - packed_val.y = in_ptr.y + bias_ptr.y; - packed_val.z = in_ptr.z + bias_ptr.z; - packed_val.w = in_ptr.w + bias_ptr.w; - if (DoRelu) { - packed_val.x = fmaxf(0.f, packed_val.x); - packed_val.y = fmaxf(0.f, packed_val.y); - packed_val.z = fmaxf(0.f, packed_val.z); - packed_val.w = fmaxf(0.f, packed_val.w); - } - data[tid] = packed_val; - } -} - template __global__ void InplaceAddReluKernel(const int N, const float16* bias, float16* data) { int offset = blockIdx.x * N; - for (int i = threadIdx.x; i < N; i += BlockDim) { float16 temp; temp = data[offset + i] + bias[i]; -- GitLab