提交 3dc88342 编写于 作者: M Markus Kliegl

conv shift op: change to CamelCase

上级 92b0c699
...@@ -22,7 +22,7 @@ using framework::Tensor; ...@@ -22,7 +22,7 @@ using framework::Tensor;
namespace { namespace {
inline int div_up(int x, int y) { return (x + y - 1) / y; } inline int DivUp(int x, int y) { return (x + y - 1) / y; }
// Some notes on the design: // Some notes on the design:
// //
...@@ -33,7 +33,7 @@ inline int div_up(int x, int y) { return (x + y - 1) / y; } ...@@ -33,7 +33,7 @@ inline int div_up(int x, int y) { return (x + y - 1) / y; }
// y is fairly small. For large y, it would probably be more efficient // y is fairly small. For large y, it would probably be more efficient
// to also tile across y. // to also tile across y.
template <typename T> template <typename T>
__global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, __global__ void ConvShiftForward(const T *x, const T *y, T *out, int x_width,
int y_width, int y_half_width, int y_width, int y_half_width,
int batch_size) { int batch_size) {
extern __shared__ T mem[]; extern __shared__ T mem[];
...@@ -79,7 +79,7 @@ __global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width, ...@@ -79,7 +79,7 @@ __global__ void conv_shift_forward(const T *x, const T *y, T *out, int x_width,
// Compute x gradient - initial naive implementation with atomic add. // Compute x gradient - initial naive implementation with atomic add.
template <typename T> template <typename T>
__global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, __global__ void ConvShiftGradX(const T *dout, const T *y, T *dx, int x_width,
int y_width, int y_half_width, int batch_size) { int y_width, int y_half_width, int batch_size) {
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
int j = blockIdx.y; // y index int j = blockIdx.y; // y index
...@@ -94,7 +94,7 @@ __global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width, ...@@ -94,7 +94,7 @@ __global__ void conv_shift_dx(const T *dout, const T *y, T *dx, int x_width,
// Compute y gradient - initial naive implementation with atomic add. // Compute y gradient - initial naive implementation with atomic add.
template <typename T> template <typename T>
__global__ void conv_shift_dy(const T *x, const T *dout, T *dy, int x_width, __global__ void ConvShiftDy(const T *x, const T *dout, T *dy, int x_width,
int y_width, int y_half_width, int batch_size) { int y_width, int y_half_width, int batch_size) {
int i = blockIdx.x * blockDim.x + threadIdx.x; // x index int i = blockIdx.x * blockDim.x + threadIdx.x; // x index
int j = blockIdx.y; // y index int j = blockIdx.y; // y index
...@@ -125,14 +125,14 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { ...@@ -125,14 +125,14 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
int y_half_width = (y_width - 1) / 2; int y_half_width = (y_width - 1) / 2;
const int x_per_block = 256; const int x_per_block = 256;
int num_x_blocks = div_up(x_width, x_per_block); int num_x_blocks = DivUp(x_width, x_per_block);
int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T); int mem_per_block = (x_per_block + 2 * y_width) * sizeof(T);
dim3 grid_dim(num_x_blocks, batch_size); dim3 grid_dim(num_x_blocks, batch_size);
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
conv_shift_forward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>( ConvShiftForward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>(
x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size);
} }
}; };
...@@ -160,20 +160,20 @@ class ConvShiftGradKernel<platform::GPUPlace, T> ...@@ -160,20 +160,20 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
const int x_per_block = 256; const int x_per_block = 256;
int num_x_blocks = div_up(x_width, x_per_block); int num_x_blocks = DivUp(x_width, x_per_block);
dim3 grid_dim(num_x_blocks, y_width, batch_size); dim3 grid_dim(num_x_blocks, y_width, batch_size);
if (dX) { if (dX) {
T *dx_data = dX->mutable_data<T>(context.GetPlace()); T *dx_data = dX->mutable_data<T>(context.GetPlace());
cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream); cudaMemsetAsync(dx_data, 0, dX->numel() * sizeof(T), stream);
conv_shift_dx<T><<<grid_dim, x_per_block, 0, stream>>>( ConvShiftGradX<T><<<grid_dim, x_per_block, 0, stream>>>(
dout_data, y_data, dx_data, x_width, y_width, y_half_width, dout_data, y_data, dx_data, x_width, y_width, y_half_width,
batch_size); batch_size);
} }
if (dY) { if (dY) {
T *dy_data = dY->mutable_data<T>(context.GetPlace()); T *dy_data = dY->mutable_data<T>(context.GetPlace());
cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream); cudaMemsetAsync(dy_data, 0, dY->numel() * sizeof(T), stream);
conv_shift_dy<T><<<grid_dim, x_per_block, 0, stream>>>( ConvShiftDy<T><<<grid_dim, x_per_block, 0, stream>>>(
x_data, dout_data, dy_data, x_width, y_width, y_half_width, x_data, dout_data, dy_data, x_width, y_width, y_half_width,
batch_size); batch_size);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册