提交 0a7516d1 编写于 作者: C chengduoZH

fix col2vol vol2col kernel

上级 43f6cdc8
...@@ -593,21 +593,28 @@ void hl_matrix_rotate( ...@@ -593,21 +593,28 @@ void hl_matrix_rotate(
CHECK_SYNC("hl_matrix_rotate failed"); CHECK_SYNC("hl_matrix_rotate failed");
} }
__global__ void keMatrixVol2Col(int num_kernels,
__global__ void keMatrixVol2Col( real* dataSrc,
int num_kernels, real*dataSrc, real* dataDst, real* dataDst,
int depth, int height, int width, int depth,
int filterD, int filterH, int filterW, int height,
int strideD, int strideH, int strideW, int width,
int paddingD, int paddingH, int paddingW, int filterD,
int depth_col, int height_col, int width_col){ int filterH,
int filterW,
for (int index = blockIdx.x * blockDim.x + threadIdx.x; int strideD,
index < num_kernels; int strideH,
index += blockDim.x * gridDim.x){ int strideW,
int paddingD,
int paddingH,
int paddingW,
int depth_col,
int height_col,
int width_col) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
int w_out = index % width_col; int w_out = index % width_col;
int h_out = (index / width_col ) % height_col; int h_out = (index / width_col) % height_col;
int d_out = (index / width_col / height_col) % depth_col; int d_out = (index / width_col / height_col) % depth_col;
int channel_in = index / width_col / height_col / depth_col; int channel_in = index / width_col / height_col / depth_col;
int channel_out = channel_in * filterD * filterH * filterW; int channel_out = channel_in * filterD * filterH * filterW;
...@@ -615,7 +622,9 @@ __global__ void keMatrixVol2Col( ...@@ -615,7 +622,9 @@ __global__ void keMatrixVol2Col(
int h_in = h_out * strideH - paddingH; int h_in = h_out * strideH - paddingH;
int d_in = d_out * strideD - paddingD; int d_in = d_out * strideD - paddingD;
dataDst += ((channel_out * depth_col + d_out) * height_col + h_out) * width_col + w_out; dataDst +=
((channel_out * depth_col + d_out) * height_col + h_out) * width_col +
w_out;
dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in; dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
for (int k = 0; k < filterD; ++k) { for (int k = 0; k < filterD; ++k) {
for (int i = 0; i < filterH; ++i) { for (int i = 0; i < filterH; ++i) {
...@@ -623,8 +632,10 @@ __global__ void keMatrixVol2Col( ...@@ -623,8 +632,10 @@ __global__ void keMatrixVol2Col(
int d = d_in + k; int d = d_in + k;
int h = h_in + i; int h = h_in + i;
int w = w_in + j; int w = w_in + j;
*dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && w < width ) ? *dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
dataSrc[(k * height + i) * width + j] : 0; w < width)
? dataSrc[(k * height + i) * width + j]
: 0;
dataDst += depth_col * height_col * width_col; dataDst += depth_col * height_col * width_col;
} }
} }
...@@ -633,11 +644,20 @@ __global__ void keMatrixVol2Col( ...@@ -633,11 +644,20 @@ __global__ void keMatrixVol2Col(
} }
void hl_matrix_vol2Col(real* dataSrc, void hl_matrix_vol2Col(real* dataSrc,
int channels, int depth, int height, int width, int channels,
int filterD, int filterH, int filterW, int depth,
int strideD, int strideH, int strideW, int height,
int paddingD, int paddingH, int paddingW, real* dataDst){ int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real* dataDst) {
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1; int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
int height_col = (height + 2 * paddingH - filterH) / strideH + 1; int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
int width_col = (width + 2 * paddingW - filterW) / strideW + 1; int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
...@@ -646,34 +666,55 @@ void hl_matrix_vol2Col(real* dataSrc, ...@@ -646,34 +666,55 @@ void hl_matrix_vol2Col(real* dataSrc,
const int threads = 512; const int threads = 512;
const int blocks = DIVUP(num_kernels, threads); const int blocks = DIVUP(num_kernels, threads);
keMatrixVol2Col<<< blocks, threads >>>( keMatrixVol2Col<<<blocks, threads>>>(num_kernels,
num_kernels, dataSrc, dataDst, dataSrc,
depth, height, width, dataDst,
filterD, filterH, filterW, depth,
strideD, strideH, strideW, height,
paddingD, paddingH, paddingW, width,
depth_col, height_col, width_col); filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
depth_col,
height_col,
width_col);
CHECK_SYNC("hl_matrix_vol2Col failed"); CHECK_SYNC("hl_matrix_vol2Col failed");
} }
__global__ void keMatrixCol2Vol( __global__ void keMatrixCol2Vol(int num_kernels,
int num_kernels, real*dataDst, real* dataSrc, real* dataDst,
int depth, int height, int width, real* dataSrc,
int filterD, int filterH, int filterW, int depth,
int strideD, int strideH, int strideW, int height,
int paddingD, int paddingH, int paddingW, int width,
int depth_col, int height_col, int width_col, int filterD,
real alpha, real beta){ int filterH,
int filterW,
for (int index = blockIdx.x * blockDim.x + threadIdx.x; int strideD,
index < num_kernels; int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
int depth_col,
int height_col,
int width_col,
real alpha,
real beta) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
real srcVal = 0;
real val = 0; real dstVal = dataDst[index];
int w = index % width + paddingW; int w = index % width + paddingW;
int h = (index / width) % height + paddingH; int h = (index / width) % height + paddingH;
int d = (index / width / height) % depth + paddingD; int d = (index / width / height) % depth + paddingD;
int c = index / (width * height * depth); int c = index / width / height / depth;
// compute the start and end of the output // compute the start and end of the output
int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1; int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1;
int w_col_end = min(w / strideW + 1, width_col); int w_col_end = min(w / strideW + 1, width_col);
...@@ -682,32 +723,45 @@ __global__ void keMatrixCol2Vol( ...@@ -682,32 +723,45 @@ __global__ void keMatrixCol2Vol(
int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1; int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1;
int d_col_end = min(d / strideD + 1, depth_col); int d_col_end = min(d / strideD + 1, depth_col);
int offset = (c * filterD * filterW * filterH + \ int offset = (c * filterD * filterW * filterH + d * filterW * filterH +
d * filterW * filterH + h * filterW + w) * depth_col * height_col * width_col; h * filterW + w) *
depth_col * height_col * width_col;
int coeff_d_col = (1 - strideD * filterW * filterH * depth_col) * height_col * width_col; int coeff_d_col =
int coeff_h_col = (1 - strideH * filterW * depth_col * height_col) * width_col; (1 - strideD * filterW * filterH * depth_col) * height_col * width_col;
int coeff_h_col =
(1 - strideH * filterW * depth_col * height_col) * width_col;
int coeff_w_col = (1 - strideW * depth_col * height_col * width_col); int coeff_w_col = (1 - strideW * depth_col * height_col * width_col);
for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
val += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col + w_col * coeff_w_col]; srcVal += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col +
w_col * coeff_w_col];
} }
} }
} }
dataDst[index] = val; dataDst[index] = alpha * srcVal + beta * dstVal;
} }
} }
void hl_matrix_col2Vol(real* dataDst, void hl_matrix_col2Vol(real* dataDst,
int channels, int depth, int height, int width, int channels,
int filterD, int filterH, int filterW, int depth,
int strideD, int strideH, int strideW, int height,
int paddingD, int paddingH, int paddingW, int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real* dataSrc, real* dataSrc,
real alpha, real beta){ real alpha,
real beta) {
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1; int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
int height_col = (height + 2 * paddingH - filterH) / strideH + 1; int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
int width_col = (width + 2 * paddingW - filterW) / strideW + 1; int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
...@@ -716,14 +770,26 @@ void hl_matrix_col2Vol(real* dataDst, ...@@ -716,14 +770,26 @@ void hl_matrix_col2Vol(real* dataDst,
const int threads = 512; const int threads = 512;
const int blocks = DIVUP(num_kernels, threads); const int blocks = DIVUP(num_kernels, threads);
keMatrixCol2Vol<<< blocks, threads >>>( keMatrixCol2Vol<<<blocks, threads>>>(num_kernels,
num_kernels, dataDst, dataSrc, dataDst,
depth, height, width, dataSrc,
filterD, filterH, filterW, depth,
strideD, strideH, strideW, height,
paddingD, paddingH, paddingW, width,
depth_col, height_col, width_col, filterD,
alpha, beta); filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
depth_col,
height_col,
width_col,
alpha,
beta);
CHECK_SYNC("hl_matrix_col2Vol failed"); CHECK_SYNC("hl_matrix_col2Vol failed");
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册