未验证 提交 792443ef 编写于 作者: Z zhaoyuchen2018 提交者: GitHub

Refine elementwise kernel. (#16952)

* Refine elementwise kernel.

Add a simple cuda kernel if grad x and y both exist
Use 2D block cuda kernel to do broadcast.

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* refine code.

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>

* refine code.

test=develop
Signed-off-by: Nzhaoyuchen <zhaoyuchen01@baidu.com>
上级 d8af44a5
......@@ -14,9 +14,67 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/platform/float16.h"
#define TILE_SIZE 512
namespace ops = paddle::operators;
namespace plat = paddle::platform;
namespace paddle {
namespace operators {
template <typename T>
static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y,
const T* out,
const T* dout,
int64_t size, T* dx,
T* dy) {
int col = blockIdx.x * blockDim.x + threadIdx.x;
while (col < size) {
T o = dout[col];
dx[col] = y[col] * o;
dy[col] = x[col] * o;
col += blockDim.x * gridDim.x;
}
}
template <typename T>
class ElementwiseMulGradKernel<plat::CUDADeviceContext, T>
: public ElemwiseGradKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx);
using Tensor = framework::Tensor;
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* out = dout; // out is not necessary
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
if (x->dims() == y->dims() && dx && dy) {
dim3 block_size = dim3(TILE_SIZE, 1);
auto size = x->numel();
dim3 gird_size = dim3((size + TILE_SIZE - 1) / TILE_SIZE, 1);
SimpleElemwiseMulGradCUDAKernel<T><<<
gird_size, block_size, 0,
ctx.template device_context<plat::CUDADeviceContext>().stream()>>>(
x->data<T>(), y->data<T>(), out->data<T>(), dout->data<T>(), size,
dx->mutable_data<T>(ctx.GetPlace()),
dy->mutable_data<T>(ctx.GetPlace()));
return;
} else {
ElemwiseGradCompute<plat::CUDADeviceContext, T, MulGradDX<T>,
MulGradDY<T>>(ctx, *x, *y, *out, *dout, axis, dx, dy,
MulGradDX<T>(), MulGradDY<T>());
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
elementwise_mul, ops::ElementwiseMulKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
......
......@@ -351,14 +351,65 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
}
}
#define BLOCK_X 32
#define BLOCK_Y 32
// suppose use 2D block is fast because more parallel
// and memory coalesced
template <typename T, typename DX_OP, typename DY_OP>
static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
const T *x, const T *y, const T *out, const T *dout, int h, int w,
DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) {
__shared__ T sdata[BLOCK_Y][BLOCK_X + 1];
T val(0);
size_t width_stride = gridDim.x * blockDim.x;
size_t idx = threadIdx.x + blockDim.x * blockIdx.x;
size_t full_width =
(w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0);
size_t full_height =
(h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0);
for (int m = idx; m < full_width; m += width_stride) {
sdata[threadIdx.y][threadIdx.x] = 0;
for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) {
int x_offset = n * w + m;
if (dx && m < w && n < h) {
dx[x_offset] = dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
}
if (dy) {
if (m < w && n < h) {
T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]);
sdata[threadIdx.y][threadIdx.x] += val;
}
__syncthreads();
}
}
if (dy) {
T my_val = sdata[threadIdx.x][threadIdx.y];
for (int i = warpSize >> 1; i > 0; i >>= 1)
my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i);
__syncthreads();
if ((threadIdx.x == 0)) {
sdata[0][threadIdx.y] = my_val;
}
__syncthreads();
if (threadIdx.y == 0 && m < w) {
dy[m] = sdata[0][threadIdx.x];
}
}
}
}
template <typename T, typename DX_OP, typename DY_OP>
static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x,
const T *y, const T *out, const T *dout,
int h, int w, DX_OP dx_op, DY_OP dy_op,
T *dx, T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
// suppose perfoemance improves with h increased.
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
x, y, out, dout, h, w, dx_op, dy_op, dx, dy);
}
......@@ -619,7 +670,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
auto y_dims_untrimed = y->dims();
PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(),
"Rank of first input must >= rank of second input.");
if (x_dims == y_dims_untrimed) {
functor.Run();
return;
......@@ -1559,7 +1609,8 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
// z = f1(f2(x, y))
if (bcast_y) { // Y should be broadcast.
// In this case,
// for 'f2(y)', the shape of intermediate_out should be equal to the shape
// for 'f2(y)', the shape of intermediate_out should be equal to the
// shape
// of Y.
// for 'f2(x, y)', the shape of intermediate_out should be equal to the
// shape of Out.
......@@ -1571,7 +1622,8 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
intermediate_out);
} else {
// In this case,
// for 'f2(y)', the shape of intermediate_out should be equal to the shape
// for 'f2(y)', the shape of intermediate_out should be equal to the
// shape
// of Out.
// for 'f2(x, y)', the shape of intermediate_out should be equal to the
// shape of Out.
......
......@@ -63,7 +63,8 @@ inline static int RoundToPowerOfTwo(int dim) {
template <typename T>
__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
int delta, int width = 32) {
int delta,
int width = warpSize) {
#if CUDA_VERSION < 9000
return __shfl_down(val, delta, width);
#else
......@@ -71,6 +72,16 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
#endif
}
template <typename T>
__forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, T val,
int width = warpSize) {
#if CUDA_VERSION < 9000
return __shfl_xor(val, width);
#else
return __shfl_xor_sync(mask, val, width);
#endif
}
// CUDA 9.0 have native compatible float16 shfl_down
#if CUDA_VERSION < 9000
template <>
......@@ -80,6 +91,11 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
return float16(
__shfl_down(static_cast<half>(val), static_cast<unsigned>(delta), width));
}
template <>
__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
float16 val, int width) {
return float16(__shfl_xor(static_cast<half>(val), width));
}
#else
template <>
__forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
......@@ -88,6 +104,11 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask,
return float16(__shfl_down_sync(mask, static_cast<half>(val),
static_cast<unsigned>(delta), width));
}
template <>
__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask,
float16 val, int width) {
return float16(__shfl_xor_sync(mask, static_cast<half>(val), width));
}
#endif
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册