未验证 提交 e77d1cac 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Optimize bias_add reluv2 in half2 (#49048)

* optimize bias_add reluv2 in half2

* Add annotation

* refine code format
上级 a5ce60b8
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"
......@@ -127,37 +128,54 @@ void AddReluKernel(
}
#if defined(PADDLE_WITH_CUDA)
template <bool DoRelu>
__global__ void bias_relu_v2(const int num,
template <bool DoRelu, int Half2VecSize>
__global__ void bias_relu_v4_half2(const int num,
const half2* bias,
half2* data,
int K) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
using LoadT = phi::AlignedVector<half2, Half2VecSize>;
LoadT data_vec;
LoadT bias_vec;
const int32_t global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int32_t grid_stride = gridDim.x * blockDim.x;
if (tid < num) {
int bias_idx = tid % K;
const half2 bias_ptr = bias[bias_idx];
const half2 in_ptr = data[tid];
half2 packed_val;
for (int32_t linear_idx = global_thread_idx * Half2VecSize; linear_idx < num;
linear_idx += grid_stride * Half2VecSize) {
phi::Load<half2, Half2VecSize>(&data[linear_idx], &data_vec);
const int bias_idx = linear_idx % K;
phi::Load<half2, Half2VecSize>(&bias[bias_idx], &bias_vec);
#pragma unroll
for (int unroll_idx = 0; unroll_idx < Half2VecSize; unroll_idx++) {
// Do biasAdd
#if __CUDA_ARCH__ >= 530
packed_val = __hadd2(bias_ptr, in_ptr);
data_vec[unroll_idx] =
__hadd2(data_vec[unroll_idx], bias_vec[unroll_idx]);
#else
packed_val.x = __hadd(bias_ptr.x, in_ptr.x);
packed_val.y = __hadd(bias_ptr.y, in_ptr.y);
data_vec[unroll_idx].x =
__hadd(data_vec[unroll_idx].x, bias_vec[unroll_idx].x);
data_vec[unroll_idx].y =
__hadd(data_vec[unroll_idx].y, bias_vec[unroll_idx].y);
#endif
// Do relu
if (DoRelu) {
#if __CUDA_ARCH__ >= 800
packed_val = __hmax2(__half2(0, 0), packed_val);
data_vec[unroll_idx] = __hmax2(__half2(0, 0), data_vec[unroll_idx]);
#elif __CUDA_ARCH__ >= 530
packed_val = __hmul2(__hgt2(__half2(0, 0), packed_val), packed_val);
data_vec[unroll_idx] = __hmul2(
__hgt2(__half2(0, 0), data_vec[unroll_idx]), data_vec[unroll_idx]);
#else
packed_val.x = static_cast<int>(static_cast<float>(packed_val.x) > 0) *
static_cast<float>(packed_val.x);
packed_val.y = static_cast<int>(static_cast<float>(packed_val.y) > 0) *
static_cast<float>(packed_val.y);
data_vec[unroll_idx].x =
static_cast<int>(static_cast<float>(data_vec[unroll_idx].x) > 0) *
static_cast<float>(data_vec[unroll_idx].x);
data_vec[unroll_idx].y =
static_cast<int>(static_cast<float>(data_vec[unroll_idx].y) > 0) *
static_cast<float>(data_vec[unroll_idx].y);
#endif
}
data[tid] = packed_val;
}
phi::Store<half2, Half2VecSize>(data_vec, &data[linear_idx]);
}
}
......@@ -188,27 +206,62 @@ __global__ void InplaceAddReluKernel(const int N,
}
}
template <>
void AddReluKernel(cudaStream_t stream,
const int M,
const int N,
/**
* brief: Launch BiasAddReluKernel with relu or not.
**/
template <int Half2VecSize>
void LaunchBiasAddReluHalf2Kernel(cudaStream_t stream,
const int32_t rows,
const int32_t cols,
float16* Y,
const float16* B,
bool relu) {
if (N % 2 == 0) {
const int threads = 256;
const int num = M * N / 2;
const int blocks = (num + threads - 1) / threads;
const int vec_num = rows * cols / (Half2VecSize * 2);
const int half2_num = rows * cols / 2;
const int blocks = (vec_num + threads - 1) / threads;
// Here reinterpret_cast to half2 type.
typedef typename FcTypeTraits<float16>::Type trans_type;
auto* bias_ptr_v2 = reinterpret_cast<const trans_type*>(B);
auto* data_ptr_v2 = reinterpret_cast<trans_type*>(Y);
auto* bias_half2_ptr = reinterpret_cast<const trans_type*>(B);
auto* data_half2_ptr = reinterpret_cast<trans_type*>(Y);
if (relu) {
bias_relu_v2<true><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v2, data_ptr_v2, N / 2);
bias_relu_v4_half2<true, Half2VecSize><<<blocks, threads, 0, stream>>>(
half2_num, bias_half2_ptr, data_half2_ptr, cols / 2);
} else {
bias_relu_v4_half2<false, Half2VecSize><<<blocks, threads, 0, stream>>>(
half2_num, bias_half2_ptr, data_half2_ptr, cols / 2);
}
}
/**
* brief: Dispatch BiasAddReluKernel half2 type with 8 / 4 / 2 vecsize.
**/
void DispatchBiasAddReluKernelHalf2VecSize(cudaStream_t stream,
const int32_t rows,
const int32_t cols,
float16* Y,
const float16* B,
bool relu) {
// Half Max Vecsize is 128 / 16 = 8, since we use half2 type, here
// Half2VecSize need divide 2.
if (cols % 8 == 0) {
LaunchBiasAddReluHalf2Kernel<4>(stream, rows, cols, Y, B, relu);
} else if (cols % 4 == 0) {
LaunchBiasAddReluHalf2Kernel<2>(stream, rows, cols, Y, B, relu);
} else {
bias_relu_v2<false><<<blocks, threads, 0, stream>>>(
num, bias_ptr_v2, data_ptr_v2, N / 2);
LaunchBiasAddReluHalf2Kernel<1>(stream, rows, cols, Y, B, relu);
}
}
template <>
void AddReluKernel(cudaStream_t stream,
const int M,
const int N,
float16* Y,
const float16* B,
bool relu) {
if (N % 2 == 0) {
DispatchBiasAddReluKernelHalf2VecSize(stream, M, N, Y, B, relu);
} else {
const int threads = 256;
const int blocks = M;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册