Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b52039bd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b52039bd
编写于
9月 30, 2016
作者:
H
hedaoyuan
提交者:
qingqing01
9月 30, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
some bug fix for sparse matrix (#133)
* some bug fix for sparse matrix * a minor bug fix
上级
0276f15a
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
98 addition
and
88 deletion
+98
-88
paddle/cuda/include/hl_sparse.h
paddle/cuda/include/hl_sparse.h
+4
-0
paddle/cuda/src/hl_cuda_sparse.cu
paddle/cuda/src/hl_cuda_sparse.cu
+35
-72
paddle/cuda/src/hl_cuda_sparse.cuh
paddle/cuda/src/hl_cuda_sparse.cuh
+37
-11
paddle/gserver/layers/Layer.cpp
paddle/gserver/layers/Layer.cpp
+22
-5
未找到文件。
paddle/cuda/include/hl_sparse.h
浏览文件 @
b52039bd
...
...
@@ -223,6 +223,7 @@ extern void hl_matrix_csc2dense(hl_sparse_matrix_s A_d,
* @param[in] dimK width of op(A) & height of op(B)
* @param[in] alpha scalar used for multiplication.
* @param[in] beta scalar used for multiplication.
* If beta is zero, C does not have to be a valid input.
*
* @note transb is not support HPPL_OP_T.
*
...
...
@@ -251,6 +252,7 @@ extern void hl_matrix_csr_mul_dense(hl_sparse_matrix_s A_d,
* @param[in] dimK width of op(A) & height of op(B)
* @param[in] alpha scalar used for multiplication.
* @param[in] beta scalar used for multiplication.
* If beta is zero, C does not have to be a valid input.
*
* @note transb is not support HPPL_OP_T.
*
...
...
@@ -275,6 +277,7 @@ extern void hl_matrix_csc_mul_dense(hl_sparse_matrix_s A_d,
* @param[in] dimK width of op(A) & height of op(B)
* @param[in] alpha scalar used for multiplication.
* @param[in] beta scalar used for multiplication.
* If beta is zero, C does not have to be a valid input.
*
* @note transa is not support HPPL_OP_T.
*
...
...
@@ -327,6 +330,7 @@ extern void hl_sparse_matrix_mul(real* A_d, hl_trans_op_t transa,
* @param[in] dimK width of op(A) & height of op(B)
* @param[in] alpha scalar used for multiplication.
* @param[in] beta scalar used for multiplication.
* If beta is zero, C does not have to be a valid input.
*
*
* @note transa is not support HPPL_OP_T.
...
...
paddle/cuda/src/hl_cuda_sparse.cu
浏览文件 @
b52039bd
...
...
@@ -562,6 +562,22 @@ void hl_memcpy_sparse_matrix(hl_sparse_matrix_s dst,
}
}
/**
* Calculate beta * C, if beta is zero, C does not have to be a valid input.
*/
static
void
_beta_mul_c
(
real
*
c
,
int
dimM
,
int
dimN
,
real
beta
)
{
if
(
beta
==
0.0
)
{
hl_gpu_apply_unary_op
(
unary
::
Zero
<
real
>
(),
c
,
dimM
,
dimN
,
dimN
);
}
else
{
if
(
beta
!=
1.0
){
hl_gpu_apply_unary_op
(
unary
::
mul_scalar
<
real
>
(
beta
),
c
,
dimM
,
dimN
,
dimN
);
}
}
return
;
}
void
hl_matrix_csr_mul_dense
(
hl_sparse_matrix_s
A_d
,
hl_trans_op_t
transa
,
real
*
B_d
,
hl_trans_op_t
transb
,
real
*
C_d
,
...
...
@@ -580,16 +596,9 @@ void hl_matrix_csr_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
}
if
(
A_d
->
nnz
==
0
)
{
if
(
beta
!=
1.0
)
{
hl_gpu_apply_unary_op
(
unary
::
mul_scalar
<
real
>
(
beta
),
C_d
,
dimM
,
dimN
,
dimN
);
}
else
{
_beta_mul_c
(
C_d
,
dimM
,
dimN
,
beta
);
return
;
}
}
/* nnz != 0 */
hl_csr_matrix
A_d2
=
(
hl_csr_matrix
)(
A_d
->
matrix
);
...
...
@@ -633,13 +642,7 @@ void hl_matrix_csr_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
beta
);
}
}
else
if
(
HPPL_OP_T
==
transa
)
{
if
(
beta
!=
1.0
)
{
hl_gpu_apply_unary_op
(
unary
::
mul_scalar
<
real
>
(
beta
),
C_d
,
dimM
,
dimN
,
dimN
);
}
_beta_mul_c
(
C_d
,
dimM
,
dimN
,
beta
);
int
blocksX
=
(
dimN
+
CU_CSC_MUL_DENSE_BLOCK_N
-
1
)
/
CU_CSC_MUL_DENSE_BLOCK_N
;
...
...
@@ -699,16 +702,9 @@ void hl_matrix_dense_mul_csc(real *A_d, hl_trans_op_t transa,
<<
"matrix format error!"
;
if
(
B_d
->
nnz
==
0
)
{
if
(
beta
!=
1.0
)
{
hl_gpu_apply_unary_op
(
unary
::
mul_scalar
<
real
>
(
beta
),
C_d
,
dimM
,
dimN
,
dimN
);
}
else
{
_beta_mul_c
(
C_d
,
dimM
,
dimN
,
beta
);
return
;
}
}
/* nnz != 0 */
hl_csc_matrix
B_d2
=
(
hl_csc_matrix
)(
B_d
->
matrix
);
...
...
@@ -750,13 +746,7 @@ void hl_matrix_dense_mul_csc(real *A_d, hl_trans_op_t transa,
beta
);
}
}
else
if
(
transb
==
HPPL_OP_T
)
{
if
(
beta
!=
1.0
)
{
hl_gpu_apply_unary_op
(
unary
::
mul_scalar
<
real
>
(
beta
),
C_d
,
dimM
,
dimN
,
dimN
);
}
_beta_mul_c
(
C_d
,
dimM
,
dimN
,
beta
);
int
blocksX
=
1
+
(
dimK
-
1
)
/
CU_DM_CSR_THREAD_X
;
int
blocksY
=
1
+
(
dimM
-
1
)
/
CU_DM_CSR_BLOCK_M
;
dim3
threads
(
CU_DM_CSR_THREAD_X
,
CU_DM_CSR_THREAD_Y
);
...
...
@@ -813,16 +803,9 @@ void hl_matrix_dense_mul_csr(real *A_d, hl_trans_op_t transa,
<<
"matrix format error!"
;
if
(
B_d
->
nnz
==
0
)
{
if
(
beta
!=
1.0
)
{
hl_gpu_apply_unary_op
(
unary
::
mul_scalar
<
real
>
(
beta
),
C_d
,
dimM
,
dimN
,
dimN
);
}
else
{
_beta_mul_c
(
C_d
,
dimM
,
dimN
,
beta
);
return
;
}
}
/* nnz != 0 */
hl_csr_matrix
B_d2
=
(
hl_csr_matrix
)(
B_d
->
matrix
);
...
...
@@ -833,14 +816,7 @@ void hl_matrix_dense_mul_csr(real *A_d, hl_trans_op_t transa,
}
if
(
transb
==
HPPL_OP_N
)
{
if
(
beta
!=
1.0
)
{
hl_gpu_apply_unary_op
(
unary
::
mul_scalar
<
real
>
(
beta
),
C_d
,
dimM
,
dimN
,
dimN
);
}
_beta_mul_c
(
C_d
,
dimM
,
dimN
,
beta
);
int
blocksX
=
1
+
(
dimK
-
1
)
/
CU_DM_CSR_THREAD_X
;
int
blocksY
=
1
+
(
dimM
-
1
)
/
CU_DM_CSR_BLOCK_M
;
dim3
threads
(
CU_DM_CSR_THREAD_X
,
CU_DM_CSR_THREAD_Y
);
...
...
@@ -925,16 +901,9 @@ void hl_matrix_csc_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
}
if
(
A_d
->
nnz
==
0
)
{
if
(
beta
!=
1.0
)
{
hl_gpu_apply_unary_op
(
unary
::
mul_scalar
<
real
>
(
beta
),
C_d
,
dimM
,
dimN
,
dimN
);
}
else
{
_beta_mul_c
(
C_d
,
dimM
,
dimN
,
beta
);
return
;
}
}
/* nnz != 0 */
hl_csc_matrix
A_d2
=
(
hl_csc_matrix
)(
A_d
->
matrix
);
...
...
@@ -945,13 +914,7 @@ void hl_matrix_csc_mul_dense(hl_sparse_matrix_s A_d, hl_trans_op_t transa,
}
if
(
HPPL_OP_N
==
transa
)
{
if
(
beta
!=
1.0
)
{
hl_gpu_apply_unary_op
(
unary
::
mul_scalar
<
real
>
(
beta
),
C_d
,
dimM
,
dimN
,
dimN
);
}
_beta_mul_c
(
C_d
,
dimM
,
dimN
,
beta
);
int
blocksX
=
(
dimN
+
CU_CSC_MUL_DENSE_BLOCK_N
-
1
)
/
CU_CSC_MUL_DENSE_BLOCK_N
;
int
blocksY
=
(
dimK
+
CU_CSC_MUL_DENSE_BLOCK_K
-
1
)
/
CU_CSC_MUL_DENSE_BLOCK_K
;
...
...
@@ -1113,7 +1076,7 @@ void hl_sparse_matrix_mul(real *A_d, hl_trans_op_t transa,
CHECK
(
!
transA
)
<<
"Not supported A is trans and B is not trans!"
;
dim3
block
(
CU_BLOCK_SIZE
,
1
);
int
avgNnzPerRow
=
C_d
2
->
nnz_s
/
dimM
;
int
avgNnzPerRow
=
C_d
->
nnz
/
dimM
;
avgNnzPerRow
=
avgNnzPerRow
>
0
?
avgNnzPerRow
:
1
;
int
gridx
=
DIVUP
(
avgNnzPerRow
,
CU_BLOCK_SIZE
);
dim3
grid
(
gridx
,
dimM
);
...
...
@@ -1242,9 +1205,9 @@ void hl_matrix_csr_column_sum(real* A_d, hl_sparse_matrix_s B_d,
LOG
(
FATAL
)
<<
"parameter B is null!"
;
}
if
(
B_d
2
->
nnz_s
==
0
)
return
;
if
(
B_d
->
nnz
==
0
)
return
;
int
nnz
=
B_d
2
->
nnz_s
;
int
nnz
=
B_d
->
nnz
;
int
block
=
512
;
int
grid
=
DIVUP
(
nnz
,
512
);
KeSMatrixCsrColumnSum
<<<
grid
,
block
,
0
,
STREAM_DEFAULT
>>>
(
...
...
@@ -1273,9 +1236,9 @@ void hl_matrix_csr_add_bias(hl_sparse_matrix_s A_d, real* B_d,
LOG
(
FATAL
)
<<
"parameter A_d is null!"
;
}
if
(
A_d
2
->
nnz_s
==
0
)
return
;
if
(
A_d
->
nnz
==
0
)
return
;
int
nnz
=
A_d
2
->
nnz_s
;
int
nnz
=
A_d
->
nnz
;
int
block
=
512
;
int
grid
=
DIVUP
(
nnz
,
512
);
KeSMatrixCsrAddBias
<<<
grid
,
block
,
0
,
STREAM_DEFAULT
>>>
(
...
...
@@ -1308,9 +1271,9 @@ void hl_matrix_csr_add_dense(hl_sparse_matrix_s A_d, real* B_d, int dimM,
LOG
(
FATAL
)
<<
"parameter A_d is null!"
;
}
if
(
A_d
2
->
nnz_s
==
0
)
return
;
if
(
A_d
->
nnz
==
0
)
return
;
int
gridX
=
DIVUP
((
A_d
2
->
nnz_s
/
dimM
),
512
);
int
gridX
=
DIVUP
((
A_d
->
nnz
/
dimM
),
512
);
gridX
=
gridX
>
0
?
gridX
:
1
;
dim3
block
(
512
,
1
);
dim3
grid
(
gridX
,
dimM
);
...
...
paddle/cuda/src/hl_cuda_sparse.cuh
浏览文件 @
b52039bd
...
...
@@ -85,6 +85,15 @@ __global__ void KeSMatrixCsc2Dense(real * csc_val,
C_d
[
row
*
dimN
+
col
]
=
sum
;
}
__device__
__forceinline__
void
_calculate_c
(
real
&
c
,
real
sum
)
{
c
=
sum
;
}
__device__
__forceinline__
void
_calculate_c
(
real
&
c
,
real
sum
,
real
beta
)
{
c
=
sum
+
beta
*
c
;
}
#define CU_CSRMM_N 4
#define CU_CSRMM_THREAD_X 32
#define CU_CSRMM_THREAD_Y 32
...
...
@@ -191,13 +200,21 @@ __global__ void KeSMatrixCsrMulDense(real *C_d,
}
C_d
+=
__mul24
(
index_m
,
dimN
);
#pragma unroll
if
(
beta
==
0.0
)
{
for
(
int
n
=
0
;
n
<
CU_CSRMM_N
;
n
++
)
{
if
(
index_n
<
dimN
)
{
_calculate_c
(
C_d
[
index_n
],
alpha
*
sum
[
n
]);
index_n
+=
CU_CSRMM_THREAD_X
;
}
}
}
else
{
for
(
int
n
=
0
;
n
<
CU_CSRMM_N
;
n
++
)
{
if
(
index_n
<
dimN
)
{
C_d
[
index_n
]
=
alpha
*
sum
[
n
]
+
beta
*
C_d
[
index_n
]
;
_calculate_c
(
C_d
[
index_n
],
alpha
*
sum
[
n
],
beta
)
;
index_n
+=
CU_CSRMM_THREAD_X
;
}
}
}
}
#define CU_CSC_MUL_DENSE_THREAD_N 1
...
...
@@ -544,14 +561,23 @@ TEMP_TEST:
int
index_m_c
=
ibx
+
idy
;
int
index_n_c
=
blockIdx
.
y
*
CU_CSCMM_BLOCK_N_BEST
+
idx
;
C_d
+=
index_n_c
+
__mul24
(
index_m_c
,
dimN
);
#pragma unroll
if
(
beta
==
0.0
)
{
for
(
int
m
=
0
;
m
<
CU_CSCMM_THREAD_M_BEST
;
m
++
)
{
if
(
index_m_c
<
dimM
&&
index_n_c
<
dimN
)
{
C_d
[
0
]
=
A_s
[
idy
+
m
*
32
][
idx
]
+
beta
*
C_d
[
0
]
;
_calculate_c
(
C_d
[
0
],
A_s
[
idy
+
m
*
32
][
idx
])
;
}
index_m_c
+=
32
;
C_d
+=
dimN
*
32
;
}
}
else
{
for
(
int
m
=
0
;
m
<
CU_CSCMM_THREAD_M_BEST
;
m
++
)
{
if
(
index_m_c
<
dimM
&&
index_n_c
<
dimN
)
{
_calculate_c
(
C_d
[
0
],
A_s
[
idy
+
m
*
32
][
idx
],
beta
);
}
index_m_c
+=
32
;
C_d
+=
dimN
*
32
;
}
}
}
#define CU_DM_CSR_THREAD_X 32
...
...
paddle/gserver/layers/Layer.cpp
浏览文件 @
b52039bd
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/utils/Util.h"
#include "paddle/utils/Logging.h"
#include "paddle/math/SparseMatrix.h"
#include "AddtoLayer.h"
#include "CosSimLayer.h"
...
...
@@ -290,14 +291,30 @@ void Layer::showOutputStats() {
<<
" is 0, skip to show the statistics"
;
return
;
}
real
mean
=
out
->
getSum
()
/
out
->
getElementCnt
();
MatrixPtr
outSquare
=
out
->
clone
();
outSquare
->
copyFrom
(
*
out
);
MatrixPtr
outSquare
;
if
(
dynamic_cast
<
GpuSparseMatrix
*>
(
out
.
get
()))
{
GpuSparseMatrix
*
tmp
=
dynamic_cast
<
GpuSparseMatrix
*>
(
out
.
get
());
outSquare
=
std
::
make_shared
<
CpuSparseMatrix
>
(
tmp
->
getHeight
(),
tmp
->
getWidth
(),
tmp
->
getElementCnt
(),
tmp
->
getValueType
(),
tmp
->
getFormat
());
}
else
{
outSquare
=
out
->
clone
();
}
outSquare
->
copyFrom
(
*
out
,
HPPL_STREAM_DEFAULT
);
hl_stream_synchronize
(
HPPL_STREAM_DEFAULT
);
real
mean
=
outSquare
->
getSum
()
/
out
->
getElementCnt
();
real
min
;
real
max
;
if
(
dynamic_cast
<
CpuSparseMatrix
*>
(
outSquare
.
get
()))
{
auto
tmpMat
=
dynamic_cast
<
CpuSparseMatrix
*>
(
outSquare
.
get
());
min
=
tmpMat
->
getMin
();
max
=
tmpMat
->
getMax
();
tmpMat
->
square
();
LOG
(
INFO
)
<<
"show statistics of [none zero values] in sparse matrix"
;
}
else
{
min
=
outSquare
->
getMin
();
max
=
outSquare
->
getMax
();
outSquare
->
square
();
}
real
std
=
(
outSquare
->
getSum
()
/
outSquare
->
getElementCnt
())
-
mean
*
mean
;
...
...
@@ -306,8 +323,8 @@ void Layer::showOutputStats() {
<<
", "
<<
"std="
<<
std
<<
", "
<<
"min="
<<
out
->
getMin
()
<<
", "
<<
"max="
<<
out
->
getMax
()
;
<<
"min="
<<
min
<<
", "
<<
"max="
<<
max
;
}
void
Layer
::
forwardActivation
()
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录