提交 cfb86c4e 编写于 作者: C chengduoZH

Add vol2col and col2vol cuda kernel

上级 0f3a3e98
......@@ -224,4 +224,62 @@ extern void hl_matrix_collect_shared_bias(real* B_d,
extern void hl_matrix_rotate(
real* mat, real* matRot, int dimM, int dimN, bool clockWise);
/**
* @brief Matrix vol2Col: Convert 3D volume into col matrix
*
* @param[in] matSrc input matrix.
* @param[in] channel channel of matSrc.
* @param[in] depth depth of matSrc.
* @param[in] height height of matSrc.
* @param[in] width width of matSrc.
* @param[in] filterD depth of filter.
* @param[in] filterH height of filter.
* @param[in] filterW width of filter.
* @param[in] strideD stride in the depth.
* @param[in] strideH stride in the height.
* @param[in] strideW stride in the width.
* @param[in] paddingD padding in the depth.
* @param[in] paddingH padding in the height.
* @param[in] paddingW padding in the width.
* @param[out] matDst output matrix.
*
*/
extern void hl_matrix_vol2Col(real* matSrc,
int channel, int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
real* matDst);
/**
* @brief Matrix col2Vol: Convert col matrix into 3D volume
*
* @param[out] matDst output matrix.
* @param[in] channel channel of matDst.
* @param[in] depth depth of matDst.
* @param[in] height height of matDst.
* @param[in] width width of matDst.
* @param[in] filterD depth of filter.
* @param[in] filterH height of filter.
* @param[in] filterW width of filter.
* @param[in] strideD stride in the depth.
* @param[in] strideH stride in the height.
* @param[in] strideW stride in the width.
* @param[in] paddingD padding in the depth.
* @param[in] paddingH padding in the height.
* @param[in] paddingW padding in the width.
* @param[in] matSrc input matrix.
* @param[in] beta input
* @param[in] alpha input
*
*/
extern void hl_matrix_col2Vol(real* matDst,
int channels, int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
real* matSrc,
real alpha, real beta);
#endif /* HL_MATRIX_H_ */
......@@ -99,4 +99,19 @@ inline void hl_matrix_collect_shared_bias(real* B_d,
inline void hl_matrix_rotate(
real* mat, real* matRot, int dimM, int dimN, bool clockWise) {}
inline void hl_matrix_vol2Col(real* data,
int channels, int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
real* data_col) {}
inline void hl_matrix_col2Vol(real* data,
int channels, int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
real* data_Im,
real alpha, real beta) {}
#endif // HL_MATRIX_STUB_H_
......@@ -592,3 +592,138 @@ void hl_matrix_rotate(
mat, matRot, dimM, dimN, clockWise);
CHECK_SYNC("hl_matrix_rotate failed");
}
__global__ void keMatrixVol2Col(
int num_kernels, real*dataSrc, real* dataDst,
int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
int depth_col, int height_col, int width_col){
for (int index = blockIdx.x * blockDim.x + threadIdx.x;
index < num_kernels;
index += blockDim.x * gridDim.x){
int w_out = index % width_col;
int h_out = (index / width_col ) % height_col;
int d_out = (index / width_col / height_col) % depth_col;
int channel_in = index / width_col / height_col / depth_col;
int channel_out = channel_in * filterD * filterH * filterW;
int w_in = w_out * strideW - paddingW;
int h_in = h_out * strideH - paddingH;
int d_in = d_out * strideD - paddingD;
dataDst += ((channel_out * depth_col + d_out) * height_col + h_out) * width_col + w_out;
dataSrc += ((channel_in * depth + d_in) * height + h_in) * width + w_in;
for (int k = 0; k < filterD; ++k) {
for (int i = 0; i < filterH; ++i) {
for (int j = 0; j < filterW; ++j) {
int d = d_in + k;
int h = h_in + i;
int w = w_in + j;
*dataDst = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && w < width ) ?
dataSrc[(k * height + i) * width + j] : 0;
dataDst += depth_col * height_col * width_col;
}
}
}
}
}
void hl_matrix_vol2Col(real* dataSrc,
int channels, int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW, real* dataDst){
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
int num_kernels = channels * depth_col * height_col * width_col;
const int threads = 512;
const int blocks = DIVUP(num_kernels, threads);
keMatrixVol2Col<<< blocks, threads >>>(
num_kernels, dataSrc, dataDst,
depth, height, width,
filterD, filterH, filterW,
strideD, strideH, strideW,
paddingD, paddingH, paddingW,
depth_col, height_col, width_col);
CHECK_SYNC("hl_matrix_vol2Col failed");
}
__global__ void keMatrixCol2Vol(
int num_kernels, real*dataDst, real* dataSrc,
int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
int depth_col, int height_col, int width_col,
real alpha, real beta){
for (int index = blockIdx.x * blockDim.x + threadIdx.x;
index < num_kernels;
index += blockDim.x * gridDim.x) {
real val = 0;
int w = index % width + paddingW;
int h = (index / width) % height + paddingH;
int d = (index / width / height) % depth + paddingD;
int c = index / (width * height * depth);
// compute the start and end of the output
int w_col_start = (w < filterW) ? 0 : (w - filterW) / strideW + 1;
int w_col_end = min(w / strideW + 1, width_col);
int h_col_start = (h < filterH) ? 0 : (h - filterH) / strideH + 1;
int h_col_end = min(h / strideH + 1, height_col);
int d_col_start = (d < filterD) ? 0 : (d - filterD) / strideD + 1;
int d_col_end = min(d / strideD + 1, depth_col);
int offset = (c * filterD * filterW * filterH + \
d * filterW * filterH + h * filterW + w) * depth_col * height_col * width_col;
int coeff_d_col = (1 - strideD * filterW * filterH * depth_col) * height_col * width_col;
int coeff_h_col = (1 - strideH * filterW * depth_col * height_col) * width_col;
int coeff_w_col = (1 - strideW * depth_col * height_col * width_col);
for (int d_col = d_col_start; d_col < d_col_end; ++d_col) {
for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
val += dataSrc[offset + d_col * coeff_d_col + h_col * coeff_h_col + w_col * coeff_w_col];
}
}
}
dataDst[index] = val;
}
}
void hl_matrix_col2Vol(real* dataDst,
int channels, int depth, int height, int width,
int filterD, int filterH, int filterW,
int strideD, int strideH, int strideW,
int paddingD, int paddingH, int paddingW,
real* dataSrc,
real alpha, real beta){
int depth_col = (depth + 2 * paddingD - filterD) / strideD + 1;
int height_col = (height + 2 * paddingH - filterH) / strideH + 1;
int width_col = (width + 2 * paddingW - filterW) / strideW + 1;
int num_kernels = channels * depth * height * width;
const int threads = 512;
const int blocks = DIVUP(num_kernels, threads);
keMatrixCol2Vol<<< blocks, threads >>>(
num_kernels, dataDst, dataSrc,
depth, height, width,
filterD, filterH, filterW,
strideD, strideH, strideW,
paddingD, paddingH, paddingW,
depth_col, height_col, width_col,
alpha, beta);
CHECK_SYNC("hl_matrix_col2Vol failed");
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册