diff --git a/paddle/operators/conv_shift_op.cu b/paddle/operators/conv_shift_op.cu index 74ed1b0ed358afc4f1a4e6a0c322eb032029d551..95e13c38a8dd234f49393d2d4808607a447b0d4c 100644 --- a/paddle/operators/conv_shift_op.cu +++ b/paddle/operators/conv_shift_op.cu @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/conv_shift_op.h" +#include "paddle/operators/math/math_function.h" #include "paddle/platform/cuda_helper.h" namespace paddle { @@ -22,7 +23,7 @@ using framework::Tensor; namespace { -inline int div_up(int x, int y) { return (x + y - 1) / y; } +inline int DivUp(int x, int y) { return (x + y - 1) / y; } // Some notes on the design: // @@ -33,9 +34,9 @@ inline int div_up(int x, int y) { return (x + y - 1) / y; } // y is fairly small. For large y, it would probably be more efficient // to also tile across y. template -__global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, - int y_width, int y_half_width, - int batch_size) { +__global__ void ConvShiftForward(const T *x, const T *y, int x_width, + int y_width, int y_half_width, int batch_size, + T *out) { extern __shared__ T mem[]; int tx = threadIdx.x; @@ -62,25 +63,26 @@ __global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, if (tx < num_x) { int load_i = (i - y_half_width + x_width) % x_width; sx[tx] = x[k * x_width + load_i]; - } else { - return; } __syncthreads(); - // Compute dot product of sx[tx:tx + y_width] and sy. - T sum = 0; - for (int j = 0; j < y_width; ++j) { - sum += sx[tx + j] * sy[j]; - } + if (tx < num_x) { + // Compute dot product of sx[tx:tx + y_width] and sy. + T sum = 0; + for (int j = 0; j < y_width; ++j) { + sum += sx[tx + j] * sy[j]; + } - // Save to out[k, i]. - out[k * x_width + i] = sum; + // Save to out[k, i]. + out[k * x_width + i] = sum; + } } // Compute x gradient - initial naive implementation with atomic add. template -__global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, - int y_width, int y_half_width, int batch_size) { +__global__ void ConvShiftGradX(const T *dout, const T *y, int x_width, + int y_width, int y_half_width, int batch_size, + T *dx) { int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int j = blockIdx.y; // y index int k = blockIdx.z; // batch index @@ -94,8 +96,8 @@ __global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, // Compute y gradient - initial naive implementation with atomic add. template -__global__ void conv_shift_dy(const T *x, const T *dout, T *dy, int x_width, - int y_width, int y_half_width, int batch_size) { +__global__ void ConvShiftDy(const T *x, const T *dout, int x_width, int y_width, + int y_half_width, int batch_size, T *dy) { int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int j = blockIdx.y; // y index int k = blockIdx.z; // batch index @@ -125,15 +127,15 @@ class ConvShiftKernel : public framework::OpKernel { int y_half_width = (y_width - 1) / 2; const int x_per_block = 256; - int num_x_blocks = div_up(x_width, x_per_block); + int num_x_blocks = DivUp(x_width, x_per_block); int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T); dim3 grid_dim(num_x_blocks, batch_size); auto stream = context.cuda_device_context().stream(); - conv_shift_forward<<>>( - x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); + ConvShiftForward<<>>( + x_data, y_data, x_width, y_width, y_half_width, batch_size, out_data); } }; @@ -157,25 +159,26 @@ class ConvShiftGradKernel int y_width = Y->dims()[1]; int y_half_width = (y_width - 1) / 2; - auto stream = context.cuda_device_context().stream(); + auto &device_ctx = context.cuda_device_context(); + math::SetConstant zero; const int x_per_block = 256; - int num_x_blocks = div_up(x_width, x_per_block); + int num_x_blocks = DivUp(x_width, x_per_block); dim3 grid_dim(num_x_blocks, y_width, batch_size); if (dX) { T *dx_data = dX->mutable_data(context.GetPlace()); - cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream); - conv_shift_dx<<>>( - dout_data, y_data, dx_data, x_width, y_width, y_half_width, - batch_size); + zero(device_ctx, dX, static_cast(0.0)); + ConvShiftGradX<<>>( + dout_data, y_data, x_width, y_width, y_half_width, batch_size, + dx_data); } if (dY) { T *dy_data = dY->mutable_data(context.GetPlace()); - cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream); - conv_shift_dy<<>>( - x_data, dout_data, dy_data, x_width, y_width, y_half_width, - batch_size); + zero(device_ctx, dY, static_cast(0.0)); + ConvShiftDy<<>>( + x_data, dout_data, x_width, y_width, y_half_width, batch_size, + dy_data); } } };