Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
cfb86c4e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
cfb86c4e
编写于
8月 13, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add vol2col and col2vol cuda kernel
上级
0f3a3e98
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
208 addition
and
0 deletion
+208
-0
paddle/cuda/include/hl_matrix.h
paddle/cuda/include/hl_matrix.h
+58
-0
paddle/cuda/include/stub/hl_matrix_stub.h
paddle/cuda/include/stub/hl_matrix_stub.h
+15
-0
paddle/cuda/src/hl_cuda_matrix.cu
paddle/cuda/src/hl_cuda_matrix.cu
+135
-0
未找到文件。
paddle/cuda/include/hl_matrix.h
浏览文件 @
cfb86c4e
...
...
@@ -224,4 +224,62 @@ extern void hl_matrix_collect_shared_bias(real* B_d,
extern
void
hl_matrix_rotate
(
real
*
mat
,
real
*
matRot
,
int
dimM
,
int
dimN
,
bool
clockWise
);
/**
* @brief Matrix vol2Col: Convert 3D volume into col matrix
*
* @param[in] matSrc input matrix.
* @param[in] channel channel of matSrc.
* @param[in] depth depth of matSrc.
* @param[in] height height of matSrc.
* @param[in] width width of matSrc.
* @param[in] filterD depth of filter.
* @param[in] filterH height of filter.
* @param[in] filterW width of filter.
* @param[in] strideD stride in the depth.
* @param[in] strideH stride in the height.
* @param[in] strideW stride in the width.
* @param[in] paddingD padding in the depth.
* @param[in] paddingH padding in the height.
* @param[in] paddingW padding in the width.
* @param[out] matDst output matrix.
*
*/
extern
void
hl_matrix_vol2Col
(
real
*
matSrc
,
int
channel
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
matDst
);
/**
* @brief Matrix col2Vol: Convert col matrix into 3D volume
*
* @param[out] matDst output matrix.
* @param[in] channel channel of matDst.
* @param[in] depth depth of matDst.
* @param[in] height height of matDst.
* @param[in] width width of matDst.
* @param[in] filterD depth of filter.
* @param[in] filterH height of filter.
* @param[in] filterW width of filter.
* @param[in] strideD stride in the depth.
* @param[in] strideH stride in the height.
* @param[in] strideW stride in the width.
* @param[in] paddingD padding in the depth.
* @param[in] paddingH padding in the height.
* @param[in] paddingW padding in the width.
* @param[in] matSrc input matrix.
* @param[in] beta input
* @param[in] alpha input
*
*/
extern
void
hl_matrix_col2Vol
(
real
*
matDst
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
matSrc
,
real
alpha
,
real
beta
);
#endif
/* HL_MATRIX_H_ */
paddle/cuda/include/stub/hl_matrix_stub.h
浏览文件 @
cfb86c4e
...
...
@@ -99,4 +99,19 @@ inline void hl_matrix_collect_shared_bias(real* B_d,
inline
void
hl_matrix_rotate
(
real
*
mat
,
real
*
matRot
,
int
dimM
,
int
dimN
,
bool
clockWise
)
{}
inline
void
hl_matrix_vol2Col
(
real
*
data
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
data_col
)
{}
inline
void
hl_matrix_col2Vol
(
real
*
data
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
data_Im
,
real
alpha
,
real
beta
)
{}
#endif // HL_MATRIX_STUB_H_
paddle/cuda/src/hl_cuda_matrix.cu
浏览文件 @
cfb86c4e
...
...
@@ -592,3 +592,138 @@ void hl_matrix_rotate(
mat
,
matRot
,
dimM
,
dimN
,
clockWise
);
CHECK_SYNC
(
"hl_matrix_rotate failed"
);
}
__global__
void
keMatrixVol2Col
(
int
num_kernels
,
real
*
dataSrc
,
real
*
dataDst
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
int
depth_col
,
int
height_col
,
int
width_col
){
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
num_kernels
;
index
+=
blockDim
.
x
*
gridDim
.
x
){
int
w_out
=
index
%
width_col
;
int
h_out
=
(
index
/
width_col
)
%
height_col
;
int
d_out
=
(
index
/
width_col
/
height_col
)
%
depth_col
;
int
channel_in
=
index
/
width_col
/
height_col
/
depth_col
;
int
channel_out
=
channel_in
*
filterD
*
filterH
*
filterW
;
int
w_in
=
w_out
*
strideW
-
paddingW
;
int
h_in
=
h_out
*
strideH
-
paddingH
;
int
d_in
=
d_out
*
strideD
-
paddingD
;
dataDst
+=
((
channel_out
*
depth_col
+
d_out
)
*
height_col
+
h_out
)
*
width_col
+
w_out
;
dataSrc
+=
((
channel_in
*
depth
+
d_in
)
*
height
+
h_in
)
*
width
+
w_in
;
for
(
int
k
=
0
;
k
<
filterD
;
++
k
)
{
for
(
int
i
=
0
;
i
<
filterH
;
++
i
)
{
for
(
int
j
=
0
;
j
<
filterW
;
++
j
)
{
int
d
=
d_in
+
k
;
int
h
=
h_in
+
i
;
int
w
=
w_in
+
j
;
*
dataDst
=
(
d
>=
0
&&
d
<
depth
&&
h
>=
0
&&
h
<
height
&&
w
>=
0
&&
w
<
width
)
?
dataSrc
[(
k
*
height
+
i
)
*
width
+
j
]
:
0
;
dataDst
+=
depth_col
*
height_col
*
width_col
;
}
}
}
}
}
void
hl_matrix_vol2Col
(
real
*
dataSrc
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
dataDst
){
int
depth_col
=
(
depth
+
2
*
paddingD
-
filterD
)
/
strideD
+
1
;
int
height_col
=
(
height
+
2
*
paddingH
-
filterH
)
/
strideH
+
1
;
int
width_col
=
(
width
+
2
*
paddingW
-
filterW
)
/
strideW
+
1
;
int
num_kernels
=
channels
*
depth_col
*
height_col
*
width_col
;
const
int
threads
=
512
;
const
int
blocks
=
DIVUP
(
num_kernels
,
threads
);
keMatrixVol2Col
<<<
blocks
,
threads
>>>
(
num_kernels
,
dataSrc
,
dataDst
,
depth
,
height
,
width
,
filterD
,
filterH
,
filterW
,
strideD
,
strideH
,
strideW
,
paddingD
,
paddingH
,
paddingW
,
depth_col
,
height_col
,
width_col
);
CHECK_SYNC
(
"hl_matrix_vol2Col failed"
);
}
__global__
void
keMatrixCol2Vol
(
int
num_kernels
,
real
*
dataDst
,
real
*
dataSrc
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
int
depth_col
,
int
height_col
,
int
width_col
,
real
alpha
,
real
beta
){
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
num_kernels
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
real
val
=
0
;
int
w
=
index
%
width
+
paddingW
;
int
h
=
(
index
/
width
)
%
height
+
paddingH
;
int
d
=
(
index
/
width
/
height
)
%
depth
+
paddingD
;
int
c
=
index
/
(
width
*
height
*
depth
);
// compute the start and end of the output
int
w_col_start
=
(
w
<
filterW
)
?
0
:
(
w
-
filterW
)
/
strideW
+
1
;
int
w_col_end
=
min
(
w
/
strideW
+
1
,
width_col
);
int
h_col_start
=
(
h
<
filterH
)
?
0
:
(
h
-
filterH
)
/
strideH
+
1
;
int
h_col_end
=
min
(
h
/
strideH
+
1
,
height_col
);
int
d_col_start
=
(
d
<
filterD
)
?
0
:
(
d
-
filterD
)
/
strideD
+
1
;
int
d_col_end
=
min
(
d
/
strideD
+
1
,
depth_col
);
int
offset
=
(
c
*
filterD
*
filterW
*
filterH
+
\
d
*
filterW
*
filterH
+
h
*
filterW
+
w
)
*
depth_col
*
height_col
*
width_col
;
int
coeff_d_col
=
(
1
-
strideD
*
filterW
*
filterH
*
depth_col
)
*
height_col
*
width_col
;
int
coeff_h_col
=
(
1
-
strideH
*
filterW
*
depth_col
*
height_col
)
*
width_col
;
int
coeff_w_col
=
(
1
-
strideW
*
depth_col
*
height_col
*
width_col
);
for
(
int
d_col
=
d_col_start
;
d_col
<
d_col_end
;
++
d_col
)
{
for
(
int
h_col
=
h_col_start
;
h_col
<
h_col_end
;
++
h_col
)
{
for
(
int
w_col
=
w_col_start
;
w_col
<
w_col_end
;
++
w_col
)
{
val
+=
dataSrc
[
offset
+
d_col
*
coeff_d_col
+
h_col
*
coeff_h_col
+
w_col
*
coeff_w_col
];
}
}
}
dataDst
[
index
]
=
val
;
}
}
void
hl_matrix_col2Vol
(
real
*
dataDst
,
int
channels
,
int
depth
,
int
height
,
int
width
,
int
filterD
,
int
filterH
,
int
filterW
,
int
strideD
,
int
strideH
,
int
strideW
,
int
paddingD
,
int
paddingH
,
int
paddingW
,
real
*
dataSrc
,
real
alpha
,
real
beta
){
int
depth_col
=
(
depth
+
2
*
paddingD
-
filterD
)
/
strideD
+
1
;
int
height_col
=
(
height
+
2
*
paddingH
-
filterH
)
/
strideH
+
1
;
int
width_col
=
(
width
+
2
*
paddingW
-
filterW
)
/
strideW
+
1
;
int
num_kernels
=
channels
*
depth
*
height
*
width
;
const
int
threads
=
512
;
const
int
blocks
=
DIVUP
(
num_kernels
,
threads
);
keMatrixCol2Vol
<<<
blocks
,
threads
>>>
(
num_kernels
,
dataDst
,
dataSrc
,
depth
,
height
,
width
,
filterD
,
filterH
,
filterW
,
strideD
,
strideH
,
strideW
,
paddingD
,
paddingH
,
paddingW
,
depth_col
,
height_col
,
width_col
,
alpha
,
beta
);
CHECK_SYNC
(
"hl_matrix_col2Vol failed"
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录