未验证 提交 dc78f3ca 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #5558 from mkliegl/conv_shift_fix_camel_case

conv shift op: change to CamelCase & fix bug
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/conv_shift_op.h" #include "paddle/operators/conv_shift_op.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/cuda_helper.h" #include "paddle/platform/cuda_helper.h"
namespace paddle { namespace paddle {
...@@ -22,7 +23,7 @@ using framework::Tensor; ...@@ -22,7 +23,7 @@ using framework::Tensor;
namespace { namespace {
inline int div_up(int x, int y) { return (x + y - 1) / y; } inline int DivUp(int x, int y) { return (x + y - 1) / y; }
// Some notes on the design: // Some notes on the design:
// //
...@@ -33,9 +34,9 @@ inline int div_up(int x, int y) { return (x + y - 1) / y; } ...@@ -33,9 +34,9 @@ inline int div_up(int x, int y) { return (x + y - 1) / y; }
// y is fairly small. For large y, it would probably be more efficient // y is fairly small. For large y, it would probably be more efficient
// to also tile across y. // to also tile across y.
template <typename T> template <typename T>
__global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, __global__ void ConvShiftForward(const T *x, const T *y, int x_width,
int y_width, int y_half_width, int y_width, int y_half_width, int batch_size,
int batch_size) { T *out) {
extern __shared__ T mem[]; extern __shared__ T mem[];
int tx = threadIdx.x; int tx = threadIdx.x;
...@@ -62,25 +63,26 @@ __global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, ...@@ -62,25 +63,26 @@ __global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width,
if (tx < num_x) { if (tx < num_x) {
int load_i = (i - y_half_width + x_width) % x_width; int load_i = (i - y_half_width + x_width) % x_width;
sx[tx] = x[k * x_width + load_i]; sx[tx] = x[k * x_width + load_i];
} else {
return;
} }
__syncthreads(); __syncthreads();
// Compute dot product of sx[tx:tx + y_width] and sy. if (tx < num_x) {
T sum = 0; // Compute dot product of sx[tx:tx + y_width] and sy.
for (int j = 0; j < y_width; ++j) { T sum = 0;
sum += sx[tx + j] * sy[j]; for (int j = 0; j < y_width; ++j) {
} sum += sx[tx + j] * sy[j];
}
// Save to out[k, i]. // Save to out[k, i].
out[k * x_width + i] = sum; out[k * x_width + i] = sum;
}
} }
// Compute x gradient - initial naive implementation with atomic add. // Compute x gradient - initial naive implementation with atomic add.
template <typename T> template <typename T>
__global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, __global__ void ConvShiftGradX(const T *dout, const T *y, int x_width,
int y_width, int y_half_width, int batch_size) { int y_width, int y_half_width, int batch_size,
T *dx) {
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
int j = blockIdx.y; // y index int j = blockIdx.y; // y index
int k = blockIdx.z; // batch index int k = blockIdx.z; // batch index
...@@ -94,8 +96,8 @@ __global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, ...@@ -94,8 +96,8 @@ __global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width,
// Compute y gradient - initial naive implementation with atomic add. // Compute y gradient - initial naive implementation with atomic add.
template <typename T> template <typename T>
__global__ void conv_shift_dy(const T *x, const T *dout, T *dy, int x_width, __global__ void ConvShiftDy(const T *x, const T *dout, int x_width, int y_width,
int y_width, int y_half_width, int batch_size) { int y_half_width, int batch_size, T *dy) {
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
int j = blockIdx.y; // y index int j = blockIdx.y; // y index
int k = blockIdx.z; // batch index int k = blockIdx.z; // batch index
...@@ -125,15 +127,15 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { ...@@ -125,15 +127,15 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
int y_half_width = (y_width - 1) / 2; int y_half_width = (y_width - 1) / 2;
const int x_per_block = 256; const int x_per_block = 256;
int num_x_blocks = div_up(x_width, x_per_block); int num_x_blocks = DivUp(x_width, x_per_block);
int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T); int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T);
dim3 grid_dim(num_x_blocks, batch_size); dim3 grid_dim(num_x_blocks, batch_size);
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
conv_shift_forward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>( ConvShiftForward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>(
x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); x_data, y_data, x_width, y_width, y_half_width, batch_size, out_data);
} }
}; };
...@@ -157,25 +159,26 @@ class ConvShiftGradKernel<platform::GPUPlace, T> ...@@ -157,25 +159,26 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
int y_width = Y->dims()[1]; int y_width = Y->dims()[1];
int y_half_width = (y_width - 1) / 2; int y_half_width = (y_width - 1) / 2;
auto stream = context.cuda_device_context().stream(); auto &device_ctx = context.cuda_device_context();
math::SetConstant<platform::GPUPlace, T> zero;
const int x_per_block = 256; const int x_per_block = 256;
int num_x_blocks = div_up(x_width, x_per_block); int num_x_blocks = DivUp(x_width, x_per_block);
dim3 grid_dim(num_x_blocks, y_width, batch_size); dim3 grid_dim(num_x_blocks, y_width, batch_size);
if (dX) { if (dX) {
T *dx_data = dX->mutable_data<T>(context.GetPlace()); T *dx_data = dX->mutable_data<T>(context.GetPlace());
cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream); zero(device_ctx, dX, static_cast<T>(0.0));
conv_shift_dx<T><<<grid_dim, x_per_block, 0, stream>>>( ConvShiftGradX<T><<<grid_dim, x_per_block, 0, device_ctx.stream()>>>(
dout_data, y_data, dx_data, x_width, y_width, y_half_width, dout_data, y_data, x_width, y_width, y_half_width, batch_size,
batch_size); dx_data);
} }
if (dY) { if (dY) {
T *dy_data = dY->mutable_data<T>(context.GetPlace()); T *dy_data = dY->mutable_data<T>(context.GetPlace());
cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream); zero(device_ctx, dY, static_cast<T>(0.0));
conv_shift_dy<T><<<grid_dim, x_per_block, 0, stream>>>( ConvShiftDy<T><<<grid_dim, x_per_block, 0, device_ctx.stream()>>>(
x_data, dout_data, dy_data, x_width, y_width, y_half_width, x_data, dout_data, x_width, y_width, y_half_width, batch_size,
batch_size); dy_data);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册