From d0b601c4a8219eef669b7e530a047bf898cf4cdc Mon Sep 17 00:00:00 2001 From: Markus Kliegl Date: Wed, 15 Nov 2017 00:57:43 +0000 Subject: [PATCH] address PR feedback --- paddle/operators/conv_shift_op.cu | 37 +++++++++++++++++-------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/paddle/operators/conv_shift_op.cu b/paddle/operators/conv_shift_op.cu index 2a157f457a6..95e13c38a8d 100644 --- a/paddle/operators/conv_shift_op.cu +++ b/paddle/operators/conv_shift_op.cu @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/conv_shift_op.h" +#include "paddle/operators/math/math_function.h" #include "paddle/platform/cuda_helper.h" namespace paddle { @@ -33,9 +34,9 @@ inline int DivUp(int x, int y) { return (x + y - 1) / y; } // y is fairly small. For large y, it would probably be more efficient // to also tile across y. template -__global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width, - int y_width, int y_half_width, - int batch_size) { +__global__ void ConvShiftForward(const T *x, const T *y, int x_width, + int y_width, int y_half_width, int batch_size, + T *out) { extern __shared__ T mem[]; int tx = threadIdx.x; @@ -79,8 +80,9 @@ __global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width, // Compute x gradient - initial naive implementation with atomic add. template -__global__ void ConvShiftGradX(const T *dout, const T *y, T *dx, int x_width, - int y_width, int y_half_width, int batch_size) { +__global__ void ConvShiftGradX(const T *dout, const T *y, int x_width, + int y_width, int y_half_width, int batch_size, + T *dx) { int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int j = blockIdx.y; // y index int k = blockIdx.z; // batch index @@ -94,8 +96,8 @@ __global__ void ConvShiftGradX(const T *dout, const T *y, T *dx, int x_width, // Compute y gradient - initial naive implementation with atomic add. template -__global__ void ConvShiftDy(const T *x, const T *dout, T *dy, int x_width, - int y_width, int y_half_width, int batch_size) { +__global__ void ConvShiftDy(const T *x, const T *dout, int x_width, int y_width, + int y_half_width, int batch_size, T *dy) { int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int j = blockIdx.y; // y index int k = blockIdx.z; // batch index @@ -133,7 +135,7 @@ class ConvShiftKernel : public framework::OpKernel { auto stream = context.cuda_device_context().stream(); ConvShiftForward<<>>( - x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); + x_data, y_data, x_width, y_width, y_half_width, batch_size, out_data); } }; @@ -157,7 +159,8 @@ class ConvShiftGradKernel int y_width = Y->dims()[1]; int y_half_width = (y_width - 1) / 2; - auto stream = context.cuda_device_context().stream(); + auto &device_ctx = context.cuda_device_context(); + math::SetConstant zero; const int x_per_block = 256; int num_x_blocks = DivUp(x_width, x_per_block); @@ -165,17 +168,17 @@ class ConvShiftGradKernel if (dX) { T *dx_data = dX->mutable_data(context.GetPlace()); - cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream); - ConvShiftGradX<<>>( - dout_data, y_data, dx_data, x_width, y_width, y_half_width, - batch_size); + zero(device_ctx, dX, static_cast(0.0)); + ConvShiftGradX<<>>( + dout_data, y_data, x_width, y_width, y_half_width, batch_size, + dx_data); } if (dY) { T *dy_data = dY->mutable_data(context.GetPlace()); - cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream); - ConvShiftDy<<>>( - x_data, dout_data, dy_data, x_width, y_width, y_half_width, - batch_size); + zero(device_ctx, dY, static_cast(0.0)); + ConvShiftDy<<>>( + x_data, dout_data, x_width, y_width, y_half_width, batch_size, + dy_data); } } }; -- GitLab