diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 555a5df3713091bb5c02de9b0c38cf255f83ef0a..67acbd7e511795ac1f59a5799e2d5a498797c442 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -136,7 +136,7 @@ const std::vector kTRTSubgraphPasses({ const std::vector kDlnneSubgraphPasses({ "is_test_pass", // - "delete_dropout_op_pass" // + "delete_dropout_op_pass", // "simplify_with_basic_ops_pass", // "conv_bn_fuse_pass", // "depthwise_conv_bn_fuse_pass", // @@ -158,7 +158,10 @@ const std::vector kGpuLowerPrecisionPasses{ "conv_eltwiseadd_bn_fuse_pass", "conv_elementwise_add_act_fuse_pass", "conv_elementwise_add2_act_fuse_pass", - "conv_elementwise_add_fuse_pass"}; + "conv_elementwise_add_fuse_pass", + "gpu_cpu_map_matmul_v2_to_mul_pass", // + "gpu_cpu_map_matmul_v2_to_matmul_pass", // + "fc_fuse_pass"}; const std::vector kTrtLowerPrecisionPasses{ // "conv_bn_fuse_pass", diff --git a/paddle/fluid/operators/fc_op.cu.cc b/paddle/fluid/operators/fc_op.cu.cc index b5f260e60511f1e4d6ef6f4b1f5d88f55eae1581..4147903551d5e54802075b38f45ac67b9132173c 100644 --- a/paddle/fluid/operators/fc_op.cu.cc +++ b/paddle/fluid/operators/fc_op.cu.cc @@ -17,5 +17,6 @@ limitations under the License. */ namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( fc, + ops::FCOpKernel, ops::FCOpKernel, ops::FCOpKernel); diff --git a/paddle/phi/kernels/funcs/fc_functor.cu b/paddle/phi/kernels/funcs/fc_functor.cu index b441ad581793dfbdca3a8441c8821ee6b1ce1ca8..1f2db5583295a408b5c68a7aa8d91db1026cd615 100644 --- a/paddle/phi/kernels/funcs/fc_functor.cu +++ b/paddle/phi/kernels/funcs/fc_functor.cu @@ -21,6 +21,8 @@ limitations under the License. */ namespace phi { namespace funcs { +using float16 = phi::dtype::float16; + template struct FcTypeTraits; @@ -75,6 +77,216 @@ __global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) { } } +template +void AddReluKernel( + gpuStream_t stream, const int M, const int N, T* Y, const T* B, bool relu) { + if (N % 4 == 0) { + const int threads = 256; + const int num = M * N / 4; + const int blocks = (num + threads - 1) / threads; + typedef typename FcTypeTraits::Type trans_type; + auto* bias_ptr_v4 = reinterpret_cast(B); + auto* data_ptr_v4 = reinterpret_cast(Y); + if (relu) { + bias_relu_v4<<>>( + num, bias_ptr_v4, data_ptr_v4, N / 4); + } else { + bias_relu_v4<<>>( + num, bias_ptr_v4, data_ptr_v4, N / 4); + } + } else { + const int threads = 256; + const int blocks = M; + + if (relu) { + InplaceAddReluKernel + <<>>(N, B, Y); + } else { + InplaceAddReluKernel + <<>>(N, B, Y); + } + } +} + +#if defined(PADDLE_WITH_CUDA) + +#include + +template <> +struct FcTypeTraits { + typedef half2 Type; +}; + +template +__global__ void bias_relu_v2(const int num, + const half2* bias, + half2* data, + int K) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int bias_idx = tid % K; + const half2 bias_ptr = bias[bias_idx]; + const half2 in_ptr = data[tid]; + half2 packed_val = __hadd2(bias_ptr, in_ptr); + if (DoRelu) { +#if __CUDA_ARCH__ >= 800 + packed_val = __hmax2(__half2(0, 0), packed_val); +#else + packed_val = __hmul2(__hgt2(__half2(0, 0), packed_val), packed_val); +#endif + } + data[tid] = packed_val; + } +} + +template +__global__ void InplaceAddReluKernel(const int N, + const half* bias, + half* data) { + int offset = blockIdx.x * N; + for (int i = threadIdx.x; i < N; i += BlockDim) { + half temp; +#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350 + temp = __ldg(data + offset + i) + __ldg(bias + i); +#else + temp = data[offset + i] + bias[i]; +#endif + if (DoRelu) { +#if __CUDA_ARCH__ >= 800 + data[offset + i] = __hmax(0, temp); +#else + data[offset + i] = __hmul(__hgt(temp, 0), temp); +#endif + } else { + data[offset + i] = temp; + } + } +} + +template <> +void AddReluKernel(cudaStream_t stream, + const int M, + const int N, + float16* Y, + const float16* B, + bool relu) { + if (N % 2 == 0) { + const int threads = 256; + const int num = M * N / 2; + const int blocks = (num + threads - 1) / threads; + typedef typename FcTypeTraits::Type trans_type; + auto* bias_ptr_v2 = reinterpret_cast(B); + auto* data_ptr_v2 = reinterpret_cast(Y); + if (relu) { + bias_relu_v2<<>>( + num, bias_ptr_v2, data_ptr_v2, N / 2); + } else { + bias_relu_v2<<>>( + num, bias_ptr_v2, data_ptr_v2, N / 2); + } + } else { + const int threads = 256; + const int blocks = M; + auto* halfB = reinterpret_cast(B); + auto* halfY = reinterpret_cast(Y); + if (relu) { + InplaceAddReluKernel + <<>>(N, halfB, halfY); + } else { + InplaceAddReluKernel + <<>>(N, halfB, halfY); + } + } +} + +#else + +struct float16_4 { + float16 x, y, z, w; +}; +template <> +struct FcTypeTraits { + typedef float16_4 Type; +}; + +template +__global__ void bias_relu_v4(const int num, + const float16_4* bias, + float16_4* data, + int K) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num) { + int bias_idx = tid % K; + const float16_4 bias_ptr = bias[bias_idx]; + const float16_4 in_ptr = data[tid]; + float16_4 packed_val; + packed_val.x = in_ptr.x + bias_ptr.x; + packed_val.y = in_ptr.y + bias_ptr.y; + packed_val.z = in_ptr.z + bias_ptr.z; + packed_val.w = in_ptr.w + bias_ptr.w; + if (DoRelu) { + packed_val.x = fmaxf(0.f, packed_val.x); + packed_val.y = fmaxf(0.f, packed_val.y); + packed_val.z = fmaxf(0.f, packed_val.z); + packed_val.w = fmaxf(0.f, packed_val.w); + } + data[tid] = packed_val; + } +} + +template +__global__ void InplaceAddReluKernel(const int N, + const float16* bias, + float16* data) { + int offset = blockIdx.x * N; + + for (int i = threadIdx.x; i < N; i += BlockDim) { + float16 temp; + temp = data[offset + i] + bias[i]; + if (DoRelu) { + data[offset + i] = fmaxf(0.f, temp); + } else { + data[offset + i] = temp; + } + } +} + +template <> +void AddReluKernel(gpuStream_t stream, + const int M, + const int N, + float16* Y, + const float16* B, + bool relu) { + if (N % 4 == 0) { + const int threads = 256; + const int num = M * N / 4; + const int blocks = (num + threads - 1) / threads; + typedef typename FcTypeTraits::Type trans_type; + auto* bias_ptr_v4 = reinterpret_cast(B); + auto* data_ptr_v4 = reinterpret_cast(Y); + if (relu) { + bias_relu_v4<<>>( + num, bias_ptr_v4, data_ptr_v4, N / 4); + } else { + bias_relu_v4<<>>( + num, bias_ptr_v4, data_ptr_v4, N / 4); + } + } else { + const int threads = 256; + const int blocks = M; + + if (relu) { + InplaceAddReluKernel + <<>>(N, B, Y); + } else { + InplaceAddReluKernel + <<>>(N, B, Y); + } + } +} +#endif + template void FCFunctor::operator()(const DeviceContext& context, const int M, @@ -109,36 +321,14 @@ void FCFunctor::operator()(const DeviceContext& context, } // M * N - if (N % 4 == 0) { - const int threads = 256; - const int num = M * N / 4; - const int blocks = (num + threads - 1) / threads; - typedef typename FcTypeTraits::Type trans_type; - auto* bias_ptr_v4 = reinterpret_cast(B); - auto* data_ptr_v4 = reinterpret_cast(Y); - if (relu) { - bias_relu_v4<<>>( - num, bias_ptr_v4, data_ptr_v4, N / 4); - } else { - bias_relu_v4<<>>( - num, bias_ptr_v4, data_ptr_v4, N / 4); - } - } else { - const int threads = 256; - const int blocks = M; - if (relu) { - InplaceAddReluKernel - <<>>(N, B, Y); - } else { - InplaceAddReluKernel - <<>>(N, B, Y); - } - } + AddReluKernel(context.stream(), M, N, Y, B, relu); } +template class FCFunctor; template class FCFunctor; template class FCFunctor; +template class FCFunctor; template class FCFunctor; template class FCFunctor;