提交 0a7516d1 编写于 作者: C chengduoZH

fix col2vol vol2col kernel

上级 43f6cdc8
......@@ -593,21 +593,28 @@ void hl_matrix_rotate(
CHECK_SYNC("hl_matrix_rotate failed");
}
__global__ void keMatrixVol2Col(
int num_kernels, real*dataSrc, real* dataDst,
int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
int depth_col, int height_col, int width_col){
for (int index = blockIdx.x * blockDim.x + threadIdx.x;
index < num_kernels;
index += blockDim.x * gridDim.x){
__global__ void keMatrixVol2Col(int num_kernels,
real* dataSrc,
real* dataDst,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
int depth_col,
int height_col,
int width_col) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
int w_out = index % width_col;
int h_out = (index / width_col ) % height_col;
int h_out = (index / width_col) % height_col;
int d_out = (index / width_col / height_col) % depth_col;
int channel_in = index / width_col / height_col / depth_col;
int channel_out = channel_in * filterD * filterH * filterW;
......@@ -615,7 +622,9 @@ __global__ void keMatrixVol2Col(
int h_in = h_out * strideH - paddingH;
int d_in = d_out * strideD - paddingD;
dataDst += ((channel_out * depth_col + d_out) * height_col + h_out) * width_col + w_out;
dataDst +=
((channel_out * depth_col + d_out) * height_col + h_out) * width_col +
w_out;
dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
for (int k = 0; k < filterD; ++k) {
for (int i = 0; i < filterH; ++i) {
......@@ -623,8 +632,10 @@ __global__ void keMatrixVol2Col(
int d = d_in + k;
int h = h_in + i;
int w = w_in + j;
*dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && w < width ) ?
dataSrc[(k * height + i) * width + j] : 0;
*dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 &&
w < width)
? dataSrc[(k * height + i) * width + j]
: 0;
dataDst += depth_col * height_col * width_col;
}
}
......@@ -633,11 +644,20 @@ __global__ void keMatrixVol2Col(
}
void hl_matrix_vol2Col(real* dataSrc,
int channels, int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW, real* dataDst){
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real* dataDst) {
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
......@@ -646,34 +666,55 @@ void hl_matrix_vol2Col(real* dataSrc,
const int threads = 512;
const int blocks = DIVUP(num_kernels, threads);
keMatrixVol2Col<<< blocks, threads >>>(
num_kernels, dataSrc, dataDst,
depth, height, width,
filterD, filterH, filterW,
strideD, strideH, strideW,
paddingD, paddingH, paddingW,
depth_col, height_col, width_col);
keMatrixVol2Col<<<blocks, threads>>>(num_kernels,
dataSrc,
dataDst,
depth,
height,
width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
depth_col,
height_col,
width_col);
CHECK_SYNC("hl_matrix_vol2Col failed");
}
__global__ void keMatrixCol2Vol(
int num_kernels, real*dataDst, real* dataSrc,
int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
int depth_col, int height_col, int width_col,
real alpha, real beta){
for (int index = blockIdx.x * blockDim.x + threadIdx.x;
index < num_kernels;
__global__ void keMatrixCol2Vol(int num_kernels,
real* dataDst,
real* dataSrc,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
int depth_col,
int height_col,
int width_col,
real alpha,
real beta) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
real val = 0;
real srcVal = 0;
real dstVal = dataDst[index];
int w = index % width + paddingW;
int h = (index / width) % height + paddingH;
int d = (index / width / height) % depth + paddingD;
int c = index / (width * height * depth);
int c = index / width / height / depth;
// compute the start and end of the output
int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1;
int w_col_end = min(w / strideW + 1, width_col);
......@@ -682,32 +723,45 @@ __global__ void keMatrixCol2Vol(
int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1;
int d_col_end = min(d / strideD + 1, depth_col);
int offset = (c * filterD * filterW * filterH + \
d * filterW * filterH + h * filterW + w) * depth_col * height_col * width_col;
int offset = (c * filterD * filterW * filterH + d * filterW * filterH +
h * filterW + w) *
depth_col * height_col * width_col;
int coeff_d_col = (1 - strideD * filterW * filterH * depth_col) * height_col * width_col;
int coeff_h_col = (1 - strideH * filterW * depth_col * height_col) * width_col;
int coeff_d_col =
(1 - strideD * filterW * filterH * depth_col) * height_col * width_col;
int coeff_h_col =
(1 - strideH * filterW * depth_col * height_col) * width_col;
int coeff_w_col = (1 - strideW * depth_col * height_col * width_col);
for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
val += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col + w_col * coeff_w_col];
srcVal += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col +
w_col * coeff_w_col];
}
}
}
dataDst[index] = val;
dataDst[index] = alpha * srcVal + beta * dstVal;
}
}
void hl_matrix_col2Vol(real* dataDst,
int channels, int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
int channels,
int depth,
int height,
int width,
int filterD,
int filterH,
int filterW,
int strideD,
int strideH,
int strideW,
int paddingD,
int paddingH,
int paddingW,
real* dataSrc,
real alpha, real beta){
real alpha,
real beta) {
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
......@@ -716,14 +770,26 @@ void hl_matrix_col2Vol(real* dataDst,
const int threads = 512;
const int blocks = DIVUP(num_kernels, threads);
keMatrixCol2Vol<<< blocks, threads >>>(
num_kernels, dataDst, dataSrc,
depth, height, width,
filterD, filterH, filterW,
strideD, strideH, strideW,
paddingD, paddingH, paddingW,
depth_col, height_col, width_col,
alpha, beta);
keMatrixCol2Vol<<<blocks, threads>>>(num_kernels,
dataDst,
dataSrc,
depth,
height,
width,
filterD,
filterH,
filterW,
strideD,
strideH,
strideW,
paddingD,
paddingH,
paddingW,
depth_col,
height_col,
width_col,
alpha,
beta);
CHECK_SYNC("hl_matrix_col2Vol failed");
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册