diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 50b2322b17bdba44f8c5c1dd4a9f0b2160f6a7d8..e36cc8f9f28d0ed3d3693e0a38d8bb17fa4ba25d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -14,9 +14,67 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_mul_op.h" #include "paddle/fluid/platform/float16.h" +#define TILE_SIZE 512 namespace ops = paddle::operators; namespace plat = paddle::platform; +namespace paddle { +namespace operators { + +template +static __global__ void SimpleElemwiseMulGradCUDAKernel(const T* x, const T* y, + const T* out, + const T* dout, + int64_t size, T* dx, + T* dy) { + int col = blockIdx.x * blockDim.x + threadIdx.x; + + while (col < size) { + T o = dout[col]; + dx[col] = y[col] * o; + dy[col] = x[col] * o; + col += blockDim.x * gridDim.x; + } +} + +template +class ElementwiseMulGradKernel + : public ElemwiseGradKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ElemwiseGradKernel::Compute(ctx); + using Tensor = framework::Tensor; + + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* out = dout; // out is not necessary + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + + if (x->dims() == y->dims() && dx && dy) { + dim3 block_size = dim3(TILE_SIZE, 1); + auto size = x->numel(); + dim3 gird_size = dim3((size + TILE_SIZE - 1) / TILE_SIZE, 1); + SimpleElemwiseMulGradCUDAKernel<<< + gird_size, block_size, 0, + ctx.template device_context().stream()>>>( + x->data(), y->data(), out->data(), dout->data(), size, + dx->mutable_data(ctx.GetPlace()), + dy->mutable_data(ctx.GetPlace())); + return; + } else { + ElemwiseGradCompute, + MulGradDY>(ctx, *x, *y, *out, *dout, axis, dx, dy, + MulGradDX(), MulGradDY()); + } + } +}; + +} // namespace operators +} // namespace paddle + REGISTER_OP_CUDA_KERNEL( elementwise_mul, ops::ElementwiseMulKernel, ops::ElementwiseMulKernel, diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index cb8a4e7e1502e7e6ceb48e51452c2c7ab8313972..2e91ec84848b0f491dca0a271d9326e3c37632ea 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -351,14 +351,65 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel( } } +#define BLOCK_X 32 +#define BLOCK_Y 32 + +// suppose use 2D block is fast because more parallel +// and memory coalesced +template +static __global__ void FastElemwiseGradBroadcast1CUDAKernel( + const T *x, const T *y, const T *out, const T *dout, int h, int w, + DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { + __shared__ T sdata[BLOCK_Y][BLOCK_X + 1]; + + T val(0); + size_t width_stride = gridDim.x * blockDim.x; + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + size_t full_width = + (w & (~((uint64_t)(BLOCK_X - 1)))) + ((w & (BLOCK_X - 1)) ? BLOCK_X : 0); + size_t full_height = + (h & (~((uint64_t)(BLOCK_Y - 1)))) + ((h & (BLOCK_Y - 1)) ? BLOCK_Y : 0); + + for (int m = idx; m < full_width; m += width_stride) { + sdata[threadIdx.y][threadIdx.x] = 0; + for (int n = threadIdx.y; n < full_height; n += BLOCK_Y) { + int x_offset = n * w + m; + if (dx && m < w && n < h) { + dx[x_offset] = dx_op(x[x_offset], y[m], out[x_offset], dout[x_offset]); + } + if (dy) { + if (m < w && n < h) { + T val = dy_op(x[x_offset], y[m], out[x_offset], dout[x_offset]); + sdata[threadIdx.y][threadIdx.x] += val; + } + __syncthreads(); + } + } + if (dy) { + T my_val = sdata[threadIdx.x][threadIdx.y]; + for (int i = warpSize >> 1; i > 0; i >>= 1) + my_val += platform::CudaShuffleXorSync(0xFFFFFFFF, my_val, i); + __syncthreads(); + if ((threadIdx.x == 0)) { + sdata[0][threadIdx.y] = my_val; + } + __syncthreads(); + if (threadIdx.y == 0 && m < w) { + dy[m] = sdata[0][threadIdx.x]; + } + } + } +} + template static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x, const T *y, const T *out, const T *dout, int h, int w, DX_OP dx_op, DY_OP dy_op, T *dx, T *dy) { - int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h); - int gird_size = w; - ElemwiseGradBroadcast1CUDAKernel<<>>( + // suppose perfoemance improves with h increased. + dim3 block_size = dim3(BLOCK_X, BLOCK_Y); + int grid_size = (w + BLOCK_X - 1) / BLOCK_X; + FastElemwiseGradBroadcast1CUDAKernel<<>>( x, y, out, dout, h, w, dx_op, dy_op, dx, dy); } @@ -619,7 +670,6 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx, auto y_dims_untrimed = y->dims(); PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(), "Rank of first input must >= rank of second input."); - if (x_dims == y_dims_untrimed) { functor.Run(); return; @@ -1559,7 +1609,8 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx, // z = f1(f2(x, y)) if (bcast_y) { // Y should be broadcast. // In this case, - // for 'f2(y)', the shape of intermediate_out should be equal to the shape + // for 'f2(y)', the shape of intermediate_out should be equal to the + // shape // of Y. // for 'f2(x, y)', the shape of intermediate_out should be equal to the // shape of Out. @@ -1571,7 +1622,8 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx, intermediate_out); } else { // In this case, - // for 'f2(y)', the shape of intermediate_out should be equal to the shape + // for 'f2(y)', the shape of intermediate_out should be equal to the + // shape // of Out. // for 'f2(x, y)', the shape of intermediate_out should be equal to the // shape of Out. diff --git a/paddle/fluid/platform/cuda_device_function.h b/paddle/fluid/platform/cuda_device_function.h index 31b6c38d613cf9df8fa7e8f6a8e1cfa310280968..202613244deb02c05c39ed18abaa18d79078db33 100644 --- a/paddle/fluid/platform/cuda_device_function.h +++ b/paddle/fluid/platform/cuda_device_function.h @@ -63,7 +63,8 @@ inline static int RoundToPowerOfTwo(int dim) { template __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val, - int delta, int width = 32) { + int delta, + int width = warpSize) { #if CUDA_VERSION < 9000 return __shfl_down(val, delta, width); #else @@ -71,6 +72,16 @@ __forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val, #endif } +template +__forceinline__ __device__ T CudaShuffleXorSync(unsigned mask, T val, + int width = warpSize) { +#if CUDA_VERSION < 9000 + return __shfl_xor(val, width); +#else + return __shfl_xor_sync(mask, val, width); +#endif +} + // CUDA 9.0 have native compatible float16 shfl_down #if CUDA_VERSION < 9000 template <> @@ -80,6 +91,11 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, return float16( __shfl_down(static_cast(val), static_cast(delta), width)); } +template <> +__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask, + float16 val, int width) { + return float16(__shfl_xor(static_cast(val), width)); +} #else template <> __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, @@ -88,6 +104,11 @@ __forceinline__ __device__ float16 CudaShuffleDownSync(unsigned mask, return float16(__shfl_down_sync(mask, static_cast(val), static_cast(delta), width)); } +template <> +__forceinline__ __device__ float16 CudaShuffleXorSync(unsigned mask, + float16 val, int width) { + return float16(__shfl_xor_sync(mask, static_cast(val), width)); +} #endif template