未验证 提交 fcfaa104 编写于 作者: M ming1753 提交者: GitHub

(modified) fc support fp16 (#44540)

上级 3b0aa75e
......@@ -36,6 +36,24 @@ struct FcTypeTraits<double> {
typedef double4 Type;
};
#if defined(PADDLE_WITH_CUDA)
#include <cuda_fp16.h>
template <>
struct FcTypeTraits<float16> {
typedef half2 Type;
};
#else
struct float16_4 {
float16 x, y, z, w;
};
template <>
struct FcTypeTraits<float16> {
typedef float16_4 Type;
};
#endif
template <typename T, bool DoRelu>
__global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -109,14 +127,6 @@ void AddReluKernel(
}
#if defined(PADDLE_WITH_CUDA)
#include <cuda_fp16.h>
template <>
struct FcTypeTraits<float16> {
typedef half2 Type;
};
template <bool DoRelu>
__global__ void bias_relu_v2(const int num,
const half2* bias,
......@@ -200,46 +210,11 @@ void AddReluKernel(cudaStream_t stream,
}
#else
struct float16_4 {
float16 x, y, z, w;
};
template <>
struct FcTypeTraits<float16> {
typedef float16_4 Type;
};
template <bool DoRelu>
__global__ void bias_relu_v4(const int num,
const float16_4* bias,
float16_4* data,
int K) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int bias_idx = tid % K;
const float16_4 bias_ptr = bias[bias_idx];
const float16_4 in_ptr = data[tid];
float16_4 packed_val;
packed_val.x = in_ptr.x + bias_ptr.x;
packed_val.y = in_ptr.y + bias_ptr.y;
packed_val.z = in_ptr.z + bias_ptr.z;
packed_val.w = in_ptr.w + bias_ptr.w;
if (DoRelu) {
packed_val.x = fmaxf(0.f, packed_val.x);
packed_val.y = fmaxf(0.f, packed_val.y);
packed_val.z = fmaxf(0.f, packed_val.z);
packed_val.w = fmaxf(0.f, packed_val.w);
}
data[tid] = packed_val;
}
}
template <bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N,
const float16* bias,
float16* data) {
int offset = blockIdx.x * N;
for (int i = threadIdx.x; i < N; i += BlockDim) {
float16 temp;
temp = data[offset + i] + bias[i];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册