未验证 提交 3e1280ea 编写于 作者: M ming1753 提交者: GitHub

Fc fp16 (#44505)

* fc support fp16

* add a ‘,’ on paddle_pass_builder.cc

* fc support fp16 on non-cuda.
上级 185a900f
......@@ -136,7 +136,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
const std::vector<std::string> kDlnneSubgraphPasses({
"is_test_pass", //
"delete_dropout_op_pass" //
"delete_dropout_op_pass", //
"simplify_with_basic_ops_pass", //
"conv_bn_fuse_pass", //
"depthwise_conv_bn_fuse_pass", //
......@@ -158,7 +158,10 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"conv_eltwiseadd_bn_fuse_pass",
"conv_elementwise_add_act_fuse_pass",
"conv_elementwise_add2_act_fuse_pass",
"conv_elementwise_add_fuse_pass"};
"conv_elementwise_add_fuse_pass",
"gpu_cpu_map_matmul_v2_to_mul_pass", //
"gpu_cpu_map_matmul_v2_to_matmul_pass", //
"fc_fuse_pass"};
const std::vector<std::string> kTrtLowerPrecisionPasses{
// "conv_bn_fuse_pass",
......
......@@ -17,5 +17,6 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fc,
ops::FCOpKernel<paddle::platform::CUDADeviceContext, phi::dtype::float16>,
ops::FCOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::FCOpKernel<paddle::platform::CUDADeviceContext, double>);
......@@ -21,6 +21,8 @@ limitations under the License. */
namespace phi {
namespace funcs {
using float16 = phi::dtype::float16;
template <typename T>
struct FcTypeTraits;
......@@ -75,6 +77,216 @@ __global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) {
}
}
template <typename T>
void AddReluKernel(
gpuStream_t stream, const int M, const int N, T* Y, const T* B, bool relu) {
if (N % 4 == 0) {
const int threads = 256;
const int num = M * N / 4;
const int blocks = (num + threads - 1) / threads;
typedef typename FcTypeTraits<T>::Type trans_type;
auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
if (relu) {
bias_relu_v4<trans_type, true><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
} else {
bias_relu_v4<trans_type, false><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
}
} else {
const int threads = 256;
const int blocks = M;
if (relu) {
InplaceAddReluKernel<T, true, threads>
<<<blocks, threads, 0, stream>>>(N, B, Y);
} else {
InplaceAddReluKernel<T, false, threads>
<<<blocks, threads, 0, stream>>>(N, B, Y);
}
}
}
#if defined(PADDLE_WITH_CUDA)
#include <cuda_fp16.h>
template <>
struct FcTypeTraits<float16> {
typedef half2 Type;
};
template <bool DoRelu>
__global__ void bias_relu_v2(const int num,
const half2* bias,
half2* data,
int K) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int bias_idx = tid % K;
const half2 bias_ptr = bias[bias_idx];
const half2 in_ptr = data[tid];
half2 packed_val = __hadd2(bias_ptr, in_ptr);
if (DoRelu) {
#if __CUDA_ARCH__ >= 800
packed_val = __hmax2(__half2(0, 0), packed_val);
#else
packed_val = __hmul2(__hgt2(__half2(0, 0), packed_val), packed_val);
#endif
}
data[tid] = packed_val;
}
}
template <bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N,
const half* bias,
half* data) {
int offset = blockIdx.x * N;
for (int i = threadIdx.x; i < N; i += BlockDim) {
half temp;
#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350
temp = __ldg(data + offset + i) + __ldg(bias + i);
#else
temp = data[offset + i] + bias[i];
#endif
if (DoRelu) {
#if __CUDA_ARCH__ >= 800
data[offset + i] = __hmax(0, temp);
#else
data[offset + i] = __hmul(__hgt(temp, 0), temp);
#endif
} else {
data[offset + i] = temp;
}
}
}
template <>
void AddReluKernel(cudaStream_t stream,
const int M,
const int N,
float16* Y,
const float16* B,
bool relu) {
if (N % 2 == 0) {
const int threads = 256;
const int num = M * N / 2;
const int blocks = (num + threads - 1) / threads;
typedef typename FcTypeTraits<float16>::Type trans_type;
auto* bias_ptr_v2 = reinterpret_cast<const trans_type*>(B);
auto* data_ptr_v2 = reinterpret_cast<trans_type*>(Y);
if (relu) {
bias_relu_v2<true><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v2, data_ptr_v2, N / 2);
} else {
bias_relu_v2<false><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v2, data_ptr_v2, N / 2);
}
} else {
const int threads = 256;
const int blocks = M;
auto* halfB = reinterpret_cast<const half*>(B);
auto* halfY = reinterpret_cast<half*>(Y);
if (relu) {
InplaceAddReluKernel<true, threads>
<<<blocks, threads, 0, stream>>>(N, halfB, halfY);
} else {
InplaceAddReluKernel<false, threads>
<<<blocks, threads, 0, stream>>>(N, halfB, halfY);
}
}
}
#else
struct float16_4 {
float16 x, y, z, w;
};
template <>
struct FcTypeTraits<float16> {
typedef float16_4 Type;
};
template <bool DoRelu>
__global__ void bias_relu_v4(const int num,
const float16_4* bias,
float16_4* data,
int K) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < num) {
int bias_idx = tid % K;
const float16_4 bias_ptr = bias[bias_idx];
const float16_4 in_ptr = data[tid];
float16_4 packed_val;
packed_val.x = in_ptr.x + bias_ptr.x;
packed_val.y = in_ptr.y + bias_ptr.y;
packed_val.z = in_ptr.z + bias_ptr.z;
packed_val.w = in_ptr.w + bias_ptr.w;
if (DoRelu) {
packed_val.x = fmaxf(0.f, packed_val.x);
packed_val.y = fmaxf(0.f, packed_val.y);
packed_val.z = fmaxf(0.f, packed_val.z);
packed_val.w = fmaxf(0.f, packed_val.w);
}
data[tid] = packed_val;
}
}
template <bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N,
const float16* bias,
float16* data) {
int offset = blockIdx.x * N;
for (int i = threadIdx.x; i < N; i += BlockDim) {
float16 temp;
temp = data[offset + i] + bias[i];
if (DoRelu) {
data[offset + i] = fmaxf(0.f, temp);
} else {
data[offset + i] = temp;
}
}
}
template <>
void AddReluKernel(gpuStream_t stream,
const int M,
const int N,
float16* Y,
const float16* B,
bool relu) {
if (N % 4 == 0) {
const int threads = 256;
const int num = M * N / 4;
const int blocks = (num + threads - 1) / threads;
typedef typename FcTypeTraits<float16>::Type trans_type;
auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
if (relu) {
bias_relu_v4<trans_type, true><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
} else {
bias_relu_v4<trans_type, false><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
}
} else {
const int threads = 256;
const int blocks = M;
if (relu) {
InplaceAddReluKernel<true, threads>
<<<blocks, threads, 0, stream>>>(N, B, Y);
} else {
InplaceAddReluKernel<false, threads>
<<<blocks, threads, 0, stream>>>(N, B, Y);
}
}
}
#endif
template <typename DeviceContext, typename T>
void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
const int M,
......@@ -109,36 +321,14 @@ void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
}
// M * N
if (N % 4 == 0) {
const int threads = 256;
const int num = M * N / 4;
const int blocks = (num + threads - 1) / threads;
typedef typename FcTypeTraits<T>::Type trans_type;
auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
if (relu) {
bias_relu_v4<trans_type, true><<<blocks, threads, 0, context.stream()>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
} else {
bias_relu_v4<trans_type, false><<<blocks, threads, 0, context.stream()>>>(
num, bias_ptr_v4, data_ptr_v4, N / 4);
}
} else {
const int threads = 256;
const int blocks = M;
if (relu) {
InplaceAddReluKernel<T, true, threads>
<<<blocks, threads, 0, context.stream()>>>(N, B, Y);
} else {
InplaceAddReluKernel<T, false, threads>
<<<blocks, threads, 0, context.stream()>>>(N, B, Y);
}
}
AddReluKernel(context.stream(), M, N, Y, B, relu);
}
template class FCFunctor<paddle::platform::CUDADeviceContext, float16>;
template class FCFunctor<paddle::platform::CUDADeviceContext, float>;
template class FCFunctor<paddle::platform::CUDADeviceContext, double>;
template class FCFunctor<GPUContext, float16>;
template class FCFunctor<GPUContext, float>;
template class FCFunctor<GPUContext, double>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册