Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
171eaff2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
171eaff2
编写于
1月 18, 2017
作者:
X
xutianbing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
clean the code a little bit.
上级
4751cc8f
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
166 addition
and
264 deletion
+166
-264
paddle/function/MulOp.cpp
paddle/function/MulOp.cpp
+37
-125
paddle/function/MulOpGpu.cu
paddle/function/MulOpGpu.cu
+123
-132
paddle/function/MulOpTest.cpp
paddle/function/MulOpTest.cpp
+6
-7
未找到文件。
paddle/function/MulOp.cpp
浏览文件 @
171eaff2
...
...
@@ -38,13 +38,6 @@ inline void vecAddTo(real* a, const real* b, real scaleB, size_t len) {
}
}
inline
void
colVecAddTo
(
real
*
a
,
const
real
*
b
,
size_t
len
,
size_t
aWidth
,
size_t
bWidth
)
{
for
(
unsigned
int
i
=
0
;
i
<
len
;
++
i
)
{
a
[
i
*
aWidth
]
+=
b
[
i
*
bWidth
];
}
}
inline
void
colVecAddTo
(
real
*
a
,
real
*
b
,
real
c
,
size_t
len
,
size_t
aWidth
,
size_t
bWidth
)
{
for
(
unsigned
int
i
=
0
;
i
<
len
;
++
i
)
{
...
...
@@ -336,140 +329,59 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
const
CpuSparseMatrix
&
b
,
real
scaleAB
,
real
scaleT
)
{
/// todo(tianbing), clean the code
CHECK
(
!
out
.
trans_
)
<<
"Not supported"
;
CHECK
(
!
a
.
isTransposed
())
<<
"Not supported"
;
CHECK
(
scaleT
==
0
||
scaleT
==
1
);
CHECK_EQ
(
scaleAB
,
static_cast
<
real
>
(
1.0
));
if
(
!
b
.
isTransposed
())
{
/// b is not Transpose
CHECK
(
b
.
getHeight
()
==
a
.
getWidth
()
&&
a
.
getHeight
()
==
out
.
getHeight
()
&&
b
.
getWidth
()
==
out
.
getWidth
());
}
else
{
CHECK
(
b
.
getHeight
()
==
out
.
getWidth
()
&&
a
.
getHeight
()
==
out
.
getHeight
()
&&
b
.
getWidth
()
==
a
.
getWidth
());
}
if
(
scaleT
==
0
)
{
out
.
zeroMem
();
}
real
*
A
=
const_cast
<
real
*>
(
a
.
getData
());
real
*
B
=
const_cast
<
real
*>
(
b
.
getValue
());
real
*
C
=
out
.
getData
();
int
*
rows
=
b
.
getRows
();
int
*
cols
=
b
.
getCols
();
if
(
scaleT
==
0
)
{
out
.
zeroMem
();
}
/// todo(tianbing), clean the code
/// b.getFormat() == SPARSE_CSC
if
(
b
.
getFormat
()
==
SPARSE_CSC
)
{
if
(
!
b
.
isTransposed
())
{
size_t
m
=
a
.
getWidth
();
CHECK_EQ
(
b
.
getHeight
(),
m
);
CHECK_EQ
(
a
.
getHeight
(),
out
.
height_
);
CHECK_EQ
(
b
.
getWidth
(),
out
.
width_
);
if
(
b
.
getValueType
()
==
NO_VALUE
)
{
for
(
size_t
j
=
0
;
j
<
b
.
getWidth
();
++
j
)
{
int
start
=
b
.
getColStartIdx
(
j
);
int
end
=
b
.
getColStartIdx
(
j
+
1
);
for
(
int
i
=
start
;
i
<
end
;
++
i
)
{
colVecAddTo
(
C
+
j
,
A
+
rows
[
i
],
out
.
height_
,
out
.
width_
,
a
.
getWidth
());
}
}
}
else
if
(
b
.
getValueType
()
==
FLOAT_VALUE
)
{
for
(
size_t
j
=
0
;
j
<
b
.
getWidth
();
++
j
)
{
int
start
=
b
.
getColStartIdx
(
j
);
int
end
=
b
.
getColStartIdx
(
j
+
1
);
for
(
int
i
=
start
;
i
<
end
;
++
i
)
{
colVecAddTo
(
C
+
j
,
A
+
rows
[
i
],
B
[
i
],
out
.
height_
,
out
.
width_
,
a
.
getWidth
());
}
}
}
}
else
/*if (b.isTransposed())*/
{
size_t
m
=
a
.
getWidth
();
CHECK_EQ
(
b
.
getHeight
(),
out
.
width_
);
CHECK_EQ
(
a
.
getHeight
(),
out
.
height_
);
CHECK_EQ
(
b
.
getWidth
(),
m
);
if
(
b
.
getValueType
()
==
NO_VALUE
)
{
for
(
size_t
i
=
0
;
i
<
b
.
getWidth
();
++
i
)
{
int
start
=
b
.
getColStartIdx
(
i
);
int
end
=
b
.
getColStartIdx
(
i
+
1
);
for
(
int
j
=
start
;
j
<
end
;
++
j
)
{
colVecAddTo
(
C
+
rows
[
j
],
A
+
i
,
out
.
height_
,
out
.
width_
,
a
.
getWidth
());
}
}
}
else
if
(
b
.
getValueType
()
==
FLOAT_VALUE
)
{
for
(
size_t
i
=
0
;
i
<
b
.
getWidth
();
++
i
)
{
int
start
=
b
.
getColStartIdx
(
i
);
int
end
=
b
.
getColStartIdx
(
i
+
1
);
for
(
int
j
=
start
;
j
<
end
;
++
j
)
{
colVecAddTo
(
C
+
rows
[
j
],
A
+
i
,
B
[
j
],
out
.
height_
,
out
.
width_
,
colVecAddTo
(
!
b
.
isTransposed
()
?
C
+
j
:
C
+
rows
[
i
],
!
b
.
isTransposed
()
?
A
+
rows
[
i
]
:
A
+
j
,
(
b
.
getValueType
()
==
NO_VALUE
)
?
(
real
)
1.0
:
B
[
i
],
out
.
getHeight
(),
out
.
getWidth
(),
a
.
getWidth
());
}
}
return
;
}
}
}
else
{
if
(
!
b
.
isTransposed
())
{
size_t
m
=
a
.
getWidth
();
CHECK_EQ
(
b
.
getHeight
(),
m
);
CHECK_EQ
(
a
.
getHeight
(),
out
.
height_
);
CHECK_EQ
(
b
.
getWidth
(),
out
.
width_
);
if
(
b
.
getValueType
()
==
NO_VALUE
)
{
/// b.getFormat() == SPARSE_CSR
if
(
b
.
getFormat
()
==
SPARSE_CSR
)
{
for
(
size_t
j
=
0
;
j
<
b
.
getHeight
();
++
j
)
{
int
start
=
b
.
getRowStartIdx
(
j
);
int
end
=
b
.
getRowStartIdx
(
j
+
1
);
for
(
int
i
=
start
;
i
<
end
;
++
i
)
{
colVecAddTo
(
C
+
cols
[
i
],
A
+
j
,
out
.
height_
,
out
.
width_
,
a
.
getWidth
());
}
}
}
else
if
(
b
.
getValueType
()
==
FLOAT_VALUE
)
{
for
(
size_t
j
=
0
;
j
<
b
.
getHeight
();
++
j
)
{
int
start
=
b
.
getRowStartIdx
(
j
);
int
end
=
b
.
getRowStartIdx
(
j
+
1
);
for
(
int
i
=
start
;
i
<
end
;
++
i
)
{
colVecAddTo
(
C
+
cols
[
i
],
A
+
j
,
B
[
i
],
out
.
height_
,
out
.
width_
,
a
.
getWidth
());
}
}
}
}
else
/*if (b.isTransposed())*/
{
size_t
m
=
a
.
getWidth
();
CHECK_EQ
(
b
.
getHeight
(),
out
.
width_
);
CHECK_EQ
(
a
.
getHeight
(),
out
.
height_
);
CHECK_EQ
(
b
.
getWidth
(),
m
);
if
(
b
.
getValueType
()
==
NO_VALUE
)
{
for
(
size_t
i
=
0
;
i
<
b
.
getHeight
();
++
i
)
{
int
start
=
b
.
getRowStartIdx
(
i
);
int
end
=
b
.
getRowStartIdx
(
i
+
1
);
for
(
int
j
=
start
;
j
<
end
;
++
j
)
{
colVecAddTo
(
C
+
i
,
A
+
cols
[
j
],
out
.
height_
,
out
.
width_
,
a
.
getWidth
());
}
}
}
else
if
(
b
.
getValueType
()
==
FLOAT_VALUE
)
{
for
(
size_t
i
=
0
;
i
<
b
.
getHeight
();
++
i
)
{
int
start
=
b
.
getRowStartIdx
(
i
);
int
end
=
b
.
getRowStartIdx
(
i
+
1
);
for
(
int
j
=
start
;
j
<
end
;
++
j
)
{
colVecAddTo
(
C
+
i
,
A
+
cols
[
j
],
B
[
j
],
out
.
height_
,
out
.
width_
,
colVecAddTo
(
!
b
.
isTransposed
()
?
C
+
cols
[
i
]
:
C
+
j
,
!
b
.
isTransposed
()
?
A
+
j
:
A
+
cols
[
i
],
(
b
.
getValueType
()
==
NO_VALUE
)
?
(
real
)
1.0
:
B
[
i
],
out
.
getHeight
(),
out
.
getWidth
(),
a
.
getWidth
());
}
}
}
}
return
;
}
}
...
...
paddle/function/MulOpGpu.cu
浏览文件 @
171eaff2
...
...
@@ -19,154 +19,147 @@ limitations under the License. */
namespace
paddle
{
/**
* out = scale
_t * out + scale_ab
* (a * b)
* out = scale
T * out + scaleAB
* (a * b)
* out : output matrix, M * N
*/
template
<
>
void
MulOp
<
DEVICE_TYPE_GPU
>
(
GpuMatrix
&
out
,
const
GpuMatrix
&
a
,
const
GpuMatrix
&
b
,
real
scale_ab
,
real
scale_t
)
{
CHECK
(
!
out
.
isTransposed
())
<<
"Not supported"
;
real
scaleAB
,
real
scaleT
)
{
CHECK
(
!
out
.
isTransposed
())
<<
"Transpose not supported for out matrix"
;
if
(
!
a
.
isTransposed
()
&&
!
b
.
isTransposed
())
{
/// a : M * K, b: K * N
CHECK_EQ
(
out
.
width_
,
b
.
width_
);
CHECK_EQ
(
out
.
height_
,
a
.
height_
);
CHECK_EQ
(
a
.
width_
,
b
.
height_
);
CHECK
(
out
.
getWidth
()
==
b
.
getWidth
()
&&
out
.
getHeight
()
==
a
.
getHeight
()
&&
a
.
getWidth
()
==
b
.
getHeight
()
);
}
else
if
(
a
.
isTransposed
()
&&
!
b
.
isTransposed
())
{
/// a : K * M, b : K * N
CHECK_EQ
(
out
.
width_
,
b
.
width_
);
CHECK_EQ
(
out
.
height_
,
a
.
width_
);
CHECK_EQ
(
a
.
height_
,
b
.
height_
);
CHECK
(
out
.
getWidth
()
==
b
.
getWidth
()
&&
out
.
getHeight
()
==
a
.
getWidth
()
&&
a
.
getHeight
()
==
b
.
getHeight
()
);
}
else
if
(
!
a
.
isTransposed
()
&&
b
.
isTransposed
())
{
/// a: M * K, b : N * K
CHECK_EQ
(
out
.
width_
,
b
.
height_
);
CHECK_EQ
(
out
.
height_
,
a
.
height_
);
CHECK_EQ
(
a
.
width_
,
b
.
width_
);
CHECK
(
out
.
getWidth
()
==
b
.
getHeight
()
&&
out
.
getHeight
()
==
a
.
getHeight
()
&&
a
.
getWidth
()
==
b
.
getWidth
()
);
}
else
{
LOG
(
FATAL
)
<<
"
Is not supported
"
;
LOG
(
FATAL
)
<<
"
Not support for both a and b are Transposed Matrices
"
;
}
real
*
a_data
=
a
.
data_
;
real
*
b_data
=
b
.
data_
;
real
*
out_data
=
out
.
data_
;
int
dim_m
=
out
.
getHeight
();
int
dim_n
=
out
.
getWidth
();
int
dim_k
=
!
a
.
isTransposed
()
?
a
.
width_
:
a
.
height_
;
int
lda
=
a
.
getStride
();
int
ldb
=
b
.
getStride
();
int
ldc
=
out
.
getStride
();
hl_trans_op_t
trans_a
=
!
a
.
isTransposed
()
?
HPPL_OP_N
:
HPPL_OP_T
;
hl_trans_op_t
trans_b
=
!
b
.
isTransposed
()
?
HPPL_OP_N
:
HPPL_OP_T
;
hl_matrix_mul
(
a_data
,
trans_a
,
b_data
,
trans_b
,
out_data
,
dim_m
,
dim_n
,
dim_k
,
scale_ab
,
scale_t
,
lda
,
ldb
,
ldc
);
real
*
aData
=
const_cast
<
real
*>
(
a
.
getData
());
real
*
bData
=
const_cast
<
real
*>
(
b
.
getData
());
real
*
outData
=
const_cast
<
real
*>
(
out
.
getData
());
hl_matrix_mul
(
aData
,
!
a
.
isTransposed
()
?
HPPL_OP_N
:
HPPL_OP_T
,
bData
,
!
b
.
isTransposed
()
?
HPPL_OP_N
:
HPPL_OP_T
,
outData
,
out
.
getHeight
(),
out
.
getWidth
(),
!
a
.
isTransposed
()
?
a
.
getWidth
()
:
a
.
getHeight
(),
scaleAB
,
scaleT
,
a
.
getStride
(),
b
.
getStride
(),
out
.
getStride
());
}
/**
* out = scale
_t * out + scale_ab
* (a * b)
* out = scale
T * out + scaleAB
* (a * b)
* out : M * N
*/
template
<
>
void
MulOp
<
DEVICE_TYPE_GPU
>
(
GpuMatrix
&
out
,
const
GpuSparseMatrix
&
a
,
const
GpuMatrix
&
b
,
real
scale
_ab
,
real
scale
_t
)
{
real
scale
AB
,
real
scale
T
)
{
CHECK
(
out
.
isContiguous
());
CHECK
(
b
.
isContiguous
());
CHECK
(
b
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
CHECK
(
!
out
.
trans_
&&
!
b
.
trans_
)
<<
"not supported"
;
if
(
!
a
.
trans_
)
{
CHECK
(
b
.
useGpu_
)
<<
"Matrix type are not equal"
;
CHECK
(
!
out
.
isTransposed
()
&&
!
b
.
isTransposed
()
)
<<
"not supported"
;
if
(
!
a
.
isTransposed
()
)
{
/// a: M * K, b: K * N
CHECK
(
out
.
width_
==
b
.
width_
&&
out
.
height_
==
a
.
height_
&&
a
.
width_
==
b
.
height_
)
<<
"Matrix dimensions are not equal"
;
CHECK
(
out
.
getWidth
()
==
b
.
getWidth
()
&&
out
.
getHeight
()
==
a
.
getHeight
()
&&
a
.
getWidth
()
==
b
.
getHeight
()
)
<<
"Matrix dimensions are not equal"
;
}
else
{
/// a: K * M, transpose, b: K * N
CHECK
(
out
.
width_
==
b
.
width_
&&
out
.
height_
==
a
.
width_
&&
a
.
height_
==
b
.
height_
)
<<
"Matrix dimensions are not equal"
;
CHECK
(
out
.
getWidth
()
==
b
.
getWidth
()
&&
out
.
getHeight
()
==
a
.
getWidth
()
&&
a
.
getHeight
()
==
b
.
getHeight
()
)
<<
"Matrix dimensions are not equal"
;
}
hl_trans_op_t
a
_trans
=
a
.
trans_
?
HPPL_OP_T
:
HPPL_OP_N
;
hl_sparse_matrix_s
a
_d
ata
=
a
.
sMatrix_
.
get
();
real
*
b
_data
=
b
.
data_
;
real
*
out
_data
=
out
.
data_
;
hl_matrix_csr_mul_dense
(
a
_d
ata
,
a
_t
rans
,
b
_d
ata
,
hl_trans_op_t
a
Trans
=
a
.
isTransposed
()
?
HPPL_OP_T
:
HPPL_OP_N
;
hl_sparse_matrix_s
a
D
ata
=
a
.
sMatrix_
.
get
();
real
*
b
Data
=
const_cast
<
real
*>
(
b
.
getData
())
;
real
*
out
Data
=
const_cast
<
real
*>
(
out
.
getData
())
;
hl_matrix_csr_mul_dense
(
a
D
ata
,
a
T
rans
,
b
D
ata
,
HPPL_OP_N
,
out
_d
ata
,
out
.
height_
,
out
.
width_
,
b
.
height_
,
scale
_ab
,
scale
_t
);
out
D
ata
,
out
.
getHeight
()
,
out
.
getWidth
()
,
b
.
getHeight
()
,
scale
AB
,
scale
T
);
}
/**
* out = scale
_t * out + scale_ab
* (a * b)
* out = scale
T * out + scaleAB
* (a * b)
* out : M * N
*/
template
<
>
void
MulOp
<
DEVICE_TYPE_GPU
>
(
GpuMatrix
&
out
,
const
GpuMatrix
&
a
,
const
GpuSparseMatrix
&
b
,
real
scale
_ab
,
real
scale
_t
)
{
real
scale
AB
,
real
scale
T
)
{
CHECK
(
out
.
isContiguous
());
CHECK
(
a
.
isContiguous
());
CHECK
(
a
.
useGpu_
==
true
)
<<
"Matrix type are not equal"
;
hl_sparse_matrix_s
b_data
=
b
.
sMatrix_
.
get
();
real
*
a_data
=
a
.
data_
;
real
*
out_data
=
out
.
data_
;
hl_trans_op_t
trans_b
=
b
.
trans_
?
HPPL_OP_T
:
HPPL_OP_N
;
if
(
!
b
.
trans_
)
{
CHECK
(
a
.
useGpu_
)
<<
"Matrix type are not equal"
;
if
(
!
b
.
isTransposed
())
{
/// a : M * K, b : K * N
CHECK
(
out
.
width_
==
b
.
width_
&&
out
.
height_
==
a
.
height_
&&
a
.
width_
==
b
.
height_
)
CHECK
(
out
.
getWidth
()
==
b
.
getWidth
()
&&
out
.
getHeight
()
==
a
.
getHeight
()
&&
a
.
getWidth
()
==
b
.
getHeight
())
<<
"Matrix dimensions are not equal"
;
}
else
{
/// a : M * K, b : N * K, transpose
CHECK
(
out
.
width_
==
b
.
height_
&&
out
.
height_
==
a
.
height_
&&
a
.
width_
==
b
.
width_
)
CHECK
(
out
.
getWidth
()
==
b
.
getHeight
()
&&
out
.
getHeight
()
==
a
.
getHeight
()
&&
a
.
getWidth
()
==
b
.
getWidth
())
<<
"Matrix dimensions are not equal"
;
}
hl_trans_op_t
bTrans
=
b
.
isTransposed
()
?
HPPL_OP_T
:
HPPL_OP_N
;
hl_sparse_matrix_s
bData
=
b
.
sMatrix_
.
get
();
real
*
aData
=
const_cast
<
real
*>
(
a
.
getData
());
real
*
outData
=
const_cast
<
real
*>
(
out
.
getData
());
if
(
b
.
format_
==
SPARSE_CSC
)
{
hl_matrix_dense_mul_csc
(
a
_d
ata
,
hl_matrix_dense_mul_csc
(
a
D
ata
,
HPPL_OP_N
,
b
_d
ata
,
trans_b
,
out
_d
ata
,
out
.
height_
,
out
.
width_
,
a
.
width_
,
scale
_ab
,
scale
_t
);
b
D
ata
,
bTrans
,
out
D
ata
,
out
.
getHeight
()
,
out
.
getWidth
()
,
a
.
getWidth
()
,
scale
AB
,
scale
T
);
}
else
{
hl_matrix_dense_mul_csr
(
a
_d
ata
,
hl_matrix_dense_mul_csr
(
a
D
ata
,
HPPL_OP_N
,
b
_d
ata
,
trans_b
,
out
_d
ata
,
out
.
height_
,
out
.
width_
,
a
.
width_
,
scale
_ab
,
scale
_t
);
b
D
ata
,
bTrans
,
out
D
ata
,
out
.
getHeight
()
,
out
.
getWidth
()
,
a
.
getWidth
()
,
scale
AB
,
scale
T
);
}
}
...
...
@@ -174,38 +167,36 @@ template <>
void
MulOp
<
DEVICE_TYPE_GPU
>
(
GpuSparseMatrix
&
out
,
const
GpuMatrix
&
a
,
const
GpuMatrix
&
b
,
real
scale_ab
,
real
scale_t
)
{
/// todo(tianbing), clean the code
CHECK
(
a
.
useGpu_
&&
b
.
useGpu_
)
<<
"type not match"
;
CHECK
(
!
out
.
trans_
)
<<
"trans not supported"
;
real
*
a_data
=
const_cast
<
real
*>
(
a
.
getData
());
real
*
b_data
=
const_cast
<
real
*>
(
b
.
getData
());
hl_sparse_matrix_s
out_data
=
out
.
sMatrix_
.
get
();
hl_trans_op_t
a_trans
=
a
.
trans_
?
HPPL_OP_T
:
HPPL_OP_N
;
hl_trans_op_t
b_trans
=
b
.
trans_
?
HPPL_OP_T
:
HPPL_OP_N
;
if
(
!
a
.
trans_
&&
!
b
.
trans_
)
{
CHECK
(
out
.
height_
==
a
.
getHeight
());
CHECK
(
out
.
width_
==
b
.
getWidth
());
CHECK
(
a
.
getWidth
()
==
b
.
getHeight
());
}
else
if
(
a
.
trans_
&&
!
b
.
trans_
)
{
CHECK
(
out
.
height_
==
a
.
getWidth
());
CHECK
(
out
.
width_
==
b
.
getWidth
());
CHECK
(
a
.
getHeight
()
==
b
.
getHeight
());
}
else
if
(
!
a
.
trans_
&&
b
.
trans_
)
{
CHECK
(
out
.
height_
==
a
.
getHeight
());
CHECK
(
out
.
width_
==
b
.
getHeight
());
CHECK
(
a
.
getWidth
()
==
b
.
getWidth
());
real
scaleAB
,
real
scaleT
)
{
CHECK
(
a
.
useGpu_
&&
b
.
useGpu_
)
<<
"matrix device type not match"
;
CHECK
(
!
out
.
isTransposed
())
<<
"Transpose is not supported for out matrix"
;
if
(
!
a
.
isTransposed
()
&&
!
b
.
isTransposed
())
{
CHECK
(
out
.
getHeight
()
==
a
.
getHeight
()
&&
out
.
getWidth
()
==
b
.
getWidth
()
&&
a
.
getWidth
()
==
b
.
getHeight
());
}
else
if
(
a
.
isTransposed
()
&&
!
b
.
isTransposed
())
{
CHECK
(
out
.
getHeight
()
==
a
.
getWidth
()
&&
out
.
getWidth
()
==
b
.
getWidth
()
&&
a
.
getHeight
()
==
b
.
getHeight
());
}
else
if
(
!
a
.
isTransposed
()
&&
b
.
isTransposed
())
{
CHECK
(
out
.
getHeight
()
==
a
.
getHeight
()
&&
out
.
getWidth
()
==
b
.
getHeight
()
&&
a
.
getWidth
()
==
b
.
getWidth
());
}
else
{
LOG
(
INFO
)
<<
"Not support
"
;
LOG
(
FATAL
)
<<
"Not support for both a and b are Transposed Matrices
"
;
}
int
dim_m
=
out
.
height_
;
int
dim_n
=
out
.
width_
;
int
dim_k
=
!
b
.
trans_
?
b
.
getHeight
()
:
b
.
getWidth
();
hl_sparse_matrix_mul
(
a_data
,
a_trans
,
b_data
,
b_trans
,
out_data
,
dim_m
,
dim_n
,
dim_k
,
scale_ab
,
scale_t
);
hl_trans_op_t
aTrans
=
a
.
isTransposed
()
?
HPPL_OP_T
:
HPPL_OP_N
;
hl_trans_op_t
bTrans
=
b
.
isTransposed
()
?
HPPL_OP_T
:
HPPL_OP_N
;
int
dimK
=
!
b
.
isTransposed
()
?
b
.
getHeight
()
:
b
.
getWidth
();
real
*
aData
=
const_cast
<
real
*>
(
a
.
getData
());
real
*
bData
=
const_cast
<
real
*>
(
b
.
getData
());
hl_sparse_matrix_s
outData
=
out
.
sMatrix_
.
get
();
hl_sparse_matrix_mul
(
aData
,
aTrans
,
bData
,
bTrans
,
outData
,
out
.
getHeight
(),
out
.
getWidth
(),
dimK
,
scaleAB
,
scaleT
);
}
}
// namespace paddle
paddle/function/MulOpTest.cpp
浏览文件 @
171eaff2
...
...
@@ -76,12 +76,12 @@ void testDDDMatrix(bool transa, bool transb, int dimM, int dimN, int dimK) {
TEST
(
Matrix
,
DDDMul
)
{
LOG
(
INFO
)
<<
"test for dense = dense * dense matrix"
;
for
(
auto
transa
:
{
false
,
true
})
{
for
(
auto
transb
:
{
false
,
true
})
{
for
(
auto
dimM
:
{
1
,
10
,
100
})
{
for
(
auto
dimN
:
{
1
,
10
})
{
for
(
auto
dimK
:
{
8
})
{
if
(
tr
ue
==
transa
&&
true
==
transb
)
{
for
(
const
auto
transa
:
{
false
,
true
})
{
for
(
const
auto
transb
:
{
false
,
true
})
{
for
(
const
auto
dimM
:
{
1
,
10
,
100
})
{
for
(
const
auto
dimN
:
{
1
,
10
})
{
for
(
const
auto
dimK
:
{
8
})
{
if
(
tr
ansa
&&
transb
)
{
continue
;
}
VLOG
(
3
)
<<
setiosflags
(
std
::
ios
::
left
)
<<
std
::
setfill
(
' '
)
...
...
@@ -89,7 +89,6 @@ TEST(Matrix, DDDMul) {
<<
" dimM="
<<
std
::
setw
(
5
)
<<
dimM
<<
" dimN="
<<
std
::
setw
(
5
)
<<
dimN
<<
" dimK="
<<
std
::
setw
(
5
)
<<
dimK
;
testDDDMatrix
(
transa
,
transb
,
dimM
,
dimN
,
dimK
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录