未验证 提交 4acc87be 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize the perf of SameDimsAdd CUDA Kernel (#31872)

上级 980227f9
......@@ -24,7 +24,10 @@ namespace paddle {
namespace operators {
template <typename T>
struct SameDimsElemwiseAdd<platform::CUDADeviceContext, T> {
struct SameDimsElemwiseAdd<
platform::CUDADeviceContext, T,
typename std::enable_if<!std::is_same<T, platform::float16>::value &&
!std::is_same<T, float>::value>::type> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
......@@ -36,38 +39,68 @@ struct SameDimsElemwiseAdd<platform::CUDADeviceContext, T> {
}
};
template <>
struct SameDimsElemwiseAdd<platform::CUDADeviceContext, platform::float16> {
template <typename T>
struct SameDimsElemwiseAdd<
platform::CUDADeviceContext, T,
typename std::enable_if<std::is_same<T, platform::float16>::value ||
std::is_same<T, float>::value>::type> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
int vec_size = sizeof(float4) / sizeof(T);
dim3 grid_size =
dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
const half* x2 =
reinterpret_cast<const half*>(x->data<platform::float16>());
const half* y2 =
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseAddCUDAKernel<<<
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>().stream()>>>(
x2, y2, z2, size);
if (std::is_same<T, float>::value) {
SameDimsElemwiseAddCUDAKernel<<<
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>()
.stream()>>>(x->data<float>(), y->data<float>(), z->data<float>(),
size);
} else {
const half* x2 =
reinterpret_cast<const half*>(x->data<platform::float16>());
const half* y2 =
reinterpret_cast<const half*>(y->data<platform::float16>());
half* z2 = reinterpret_cast<half*>(z->data<platform::float16>());
SameDimsElemwiseAddCUDAKernel<<<
grid_size, block_size, 0,
ctx.template device_context<platform::CUDADeviceContext>()
.stream()>>>(x2, y2, z2, size);
}
}
};
template <typename T>
static __global__ void SimpleElemwiseAddGradCUDAKernel(const T* dout,
int64_t size, T* dx,
T* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
static __global__ void SimpleElemwiseAddGradCUDAKernel(
const T* __restrict__ dout, int size, int vec_size, T* dx, T* dy) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
int loop = size / vec_size;
int remainder = size % vec_size;
const float4* dout_vec = reinterpret_cast<const float4*>(dout);
float4* dx_vec = reinterpret_cast<float4*>(dx);
float4* dy_vec = reinterpret_cast<float4*>(dy);
float4 tmp_loop;
for (int i = tid; i < loop; i += stride) {
tmp_loop = dout_vec[i];
dx_vec[i] = tmp_loop;
dy_vec[i] = tmp_loop;
}
while (col < size) {
dx[col] = dout[col];
dy[col] = dout[col];
col += blockDim.x * gridDim.x;
if (tid == loop && remainder != 0) {
T tmp_rem;
while (remainder) {
int idx = size - remainder;
remainder--;
tmp_rem = dout[idx];
dx[idx] = tmp_rem;
dy[idx] = tmp_rem;
}
}
}
......@@ -79,14 +112,17 @@ elementwise_add_grad(const framework::ExecutionContext& ctx,
const framework::Tensor* out,
const framework::Tensor* dout, framework::Tensor* dx,
framework::Tensor* dy) {
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
auto size = x->numel();
int vec_size = max(static_cast<int>(sizeof(float4) / sizeof(T)), 1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
dim3 grid_size =
dim3((size + PADDLE_CUDA_THREAD_SIZE - 1) / PADDLE_CUDA_THREAD_SIZE, 1);
dim3(((size + vec_size - 1) / vec_size + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
SimpleElemwiseAddGradCUDAKernel<
T><<<grid_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
dout->data<T>(), size, dx->mutable_data<T>(ctx.GetPlace()),
dout->data<T>(), size, vec_size, dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
}
......
......@@ -43,7 +43,7 @@ struct SameDimsElemwiseDiv<platform::CUDADeviceContext, platform::float16> {
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) /
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
......
......@@ -43,7 +43,7 @@ struct SameDimsElemwiseMul<platform::CUDADeviceContext, platform::float16> {
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) /
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
......
......@@ -18,7 +18,11 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/hostdevice.h"
#ifdef __HIPCC__
#define PADDLE_CUDA_THREAD_SIZE 256
#else
#define PADDLE_CUDA_THREAD_SIZE 512
#endif
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
......@@ -158,32 +162,62 @@ inline DEVICE half2 half2_div(const half2& a, const half2& b) {
#endif
}
#define DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Func, expr, FP16Function) \
template <typename T> \
__global__ void SameDimsElemwise##Func##CUDAKernel(const T* x, const T* y, \
T* z, int64_t size) { \
int col = blockIdx.x * blockDim.x + threadIdx.x; \
while (col < size) { \
z[col] = x[col] expr y[col]; \
col += blockDim.x * gridDim.x; \
} \
} \
template <> \
inline __global__ void SameDimsElemwise##Func##CUDAKernel<half>( \
const half* x, const half* y, half* z, int64_t size) { \
int start = threadIdx.x + blockDim.x * blockIdx.x; \
int stride = blockDim.x * gridDim.x; \
int n2 = size / 2; \
const half2* x2 = reinterpret_cast<const half2*>(x); \
const half2* y2 = reinterpret_cast<const half2*>(y); \
half2* z2 = reinterpret_cast<half2*>(z); \
for (int i = start; i < n2; i += stride) { \
z2[i] = FP16Function(x2[i], y2[i]); \
} \
if (start == 0 && (size % 2)) { \
z[size - 1] = __float2half(__half2float(x[size - 1]) \
expr __half2float(y[size - 1])); \
} \
#define DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Func, expr, FP16Function) \
inline __global__ void SameDimsElemwise##Func##CUDAKernel( \
const float* __restrict__ x, const float* __restrict__ y, float* z, \
int64_t size) { \
int tid = blockIdx.x * blockDim.x + threadIdx.x; \
int stride = gridDim.x * blockDim.x; \
int loop = size / 4; \
int remainder = size % 4; \
const float4* x_vec = reinterpret_cast<const float4*>(x); \
const float4* y_vec = reinterpret_cast<const float4*>(y); \
float4* z_vec = reinterpret_cast<float4*>(z); \
float4 x_f4, y_f4; \
for (int i = tid; i < loop; i += stride) { \
x_f4 = x_vec[i]; \
y_f4 = y_vec[i]; \
z_vec[i] = make_float4(x_f4.x expr y_f4.x, x_f4.y expr y_f4.y, \
x_f4.z expr y_f4.z, x_f4.w expr y_f4.w); \
} \
if (tid == loop && remainder != 0) { \
while (remainder) { \
int idx = size - remainder; \
remainder--; \
z[idx] = x[idx] expr y[idx]; \
} \
} \
} \
inline __global__ void SameDimsElemwise##Func##CUDAKernel( \
const half* __restrict__ x, const half* __restrict__ y, half* z, \
int64_t size) { \
int tid = blockIdx.x * blockDim.x + threadIdx.x; \
int stride = gridDim.x * blockDim.x; \
int loop = size / 8; \
int remainder = size % 8; \
const float4* x_vec = reinterpret_cast<const float4*>(x); \
const float4* y_vec = reinterpret_cast<const float4*>(y); \
float4* z_vec = reinterpret_cast<float4*>(z); \
float4 x_h8, y_h8, z_h8; \
for (int i = tid; i < loop; i += stride) { \
x_h8 = x_vec[i]; \
y_h8 = y_vec[i]; \
half2* x_h2 = reinterpret_cast<half2*>(&x_h8); \
half2* y_h2 = reinterpret_cast<half2*>(&y_h8); \
half2* z_h2 = reinterpret_cast<half2*>(&z_h8); \
z_h2[0] = FP16Function(x_h2[0], y_h2[0]); \
z_h2[1] = FP16Function(x_h2[1], y_h2[1]); \
z_h2[2] = FP16Function(x_h2[2], y_h2[2]); \
z_h2[3] = FP16Function(x_h2[3], y_h2[3]); \
z_vec[i] = z_h8; \
} \
if (tid == loop && remainder != 0) { \
while (remainder) { \
int idx = size - remainder; \
remainder--; \
z[idx] = __float2half(__half2float(x[idx]) expr __half2float(y[idx])); \
} \
} \
}
DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Add, +, half2_add)
DEFINE_SIMPLE_CUDA_BINARY_KERNEL(Sub, -, half2_sub)
......
......@@ -43,7 +43,7 @@ struct SameDimsElemwiseSub<platform::CUDADeviceContext, platform::float16> {
const framework::Tensor* x, const framework::Tensor* y,
framework::Tensor* z) {
auto size = x->numel();
dim3 grid_size = dim3(((size + 1) / 2 + PADDLE_CUDA_THREAD_SIZE - 1) /
dim3 grid_size = dim3(((size + 7) / 8 + PADDLE_CUDA_THREAD_SIZE - 1) /
PADDLE_CUDA_THREAD_SIZE,
1);
dim3 block_size = dim3(PADDLE_CUDA_THREAD_SIZE, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册