Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
be04fbff
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
be04fbff
编写于
7月 20, 2018
作者:
T
tensor-tang
提交者:
GitHub
7月 20, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #12233 from tensor-tang/refine/mkl/gemm
add option split mkl gemm
上级
7219676e
fc2b5788
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
98 addition
and
17 deletion
+98
-17
CMakeLists.txt
CMakeLists.txt
+6
-0
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+38
-17
paddle/fluid/operators/math/math_function_test.cc
paddle/fluid/operators/math/math_function_test.cc
+54
-0
未找到文件。
CMakeLists.txt
浏览文件 @
be04fbff
...
@@ -136,6 +136,12 @@ else()
...
@@ -136,6 +136,12 @@ else()
set
(
THIRD_PARTY_BUILD_TYPE Release
)
set
(
THIRD_PARTY_BUILD_TYPE Release
)
endif
()
endif
()
if
(
WITH_MKL
)
option
(
MKL_SPLIT_GEMM
"PaddlePaddle MKL gemm would split to small ones"
OFF
)
if
(
MKL_SPLIT_GEMM
)
add_definitions
(
-DPADDLE_MKL_SPLIT_GEMM
)
endif
()
endif
()
set
(
WITH_MKLML
${
WITH_MKL
}
)
set
(
WITH_MKLML
${
WITH_MKL
}
)
if
(
NOT DEFINED WITH_MKLDNN
)
if
(
NOT DEFINED WITH_MKLDNN
)
if
(
WITH_MKL AND AVX2_FOUND
)
if
(
WITH_MKL AND AVX2_FOUND
)
...
...
paddle/fluid/operators/math/blas_impl.h
浏览文件 @
be04fbff
...
@@ -37,6 +37,7 @@ struct CBlas<float> {
...
@@ -37,6 +37,7 @@ struct CBlas<float> {
libxsmm_sgemm
(
args
...);
libxsmm_sgemm
(
args
...);
}
}
#endif
#endif
template
<
typename
...
ARGS
>
template
<
typename
...
ARGS
>
static
void
AXPY
(
ARGS
...
args
)
{
static
void
AXPY
(
ARGS
...
args
)
{
platform
::
dynload
::
cblas_saxpy
(
args
...);
platform
::
dynload
::
cblas_saxpy
(
args
...);
...
@@ -76,6 +77,7 @@ struct CBlas<double> {
...
@@ -76,6 +77,7 @@ struct CBlas<double> {
libxsmm_dgemm
(
args
...);
libxsmm_dgemm
(
args
...);
}
}
#endif
#endif
template
<
typename
...
ARGS
>
template
<
typename
...
ARGS
>
static
void
AXPY
(
ARGS
...
args
)
{
static
void
AXPY
(
ARGS
...
args
)
{
platform
::
dynload
::
cblas_daxpy
(
args
...);
platform
::
dynload
::
cblas_daxpy
(
args
...);
...
@@ -150,6 +152,7 @@ struct CBlas<double> {
...
@@ -150,6 +152,7 @@ struct CBlas<double> {
}
}
};
};
#endif
#endif
template
<
>
template
<
>
struct
CBlas
<
platform
::
float16
>
{
struct
CBlas
<
platform
::
float16
>
{
static
void
GEMM
(...)
{
PADDLE_THROW
(
"float16 GEMM not supported on CPU"
);
}
static
void
GEMM
(...)
{
PADDLE_THROW
(
"float16 GEMM not supported on CPU"
);
}
...
@@ -190,30 +193,48 @@ inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
...
@@ -190,30 +193,48 @@ inline bool UseXSMM<platform::float16>(const int &m, const int &n, const int &k,
return
false
;
return
false
;
}
}
template
<
>
template
<
typename
T
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM
(
CBLAS_TRANSPOSE
transA
,
inline
void
GEMM_WARP
(
CBLAS_ORDER
order
,
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
T
alpha
,
int
N
,
int
K
,
T
alpha
,
const
T
*
A
,
const
T
*
A
,
int
lda
,
const
T
*
B
,
int
ldb
,
T
beta
,
T
*
C
,
const
T
*
B
,
T
beta
,
T
*
C
)
const
{
int
ldc
)
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
#ifdef PADDLE_WITH_LIBXSMM
#ifdef PADDLE_WITH_LIBXSMM
if
(
UseXSMM
(
M
,
N
,
K
,
transA
!=
CblasNoTrans
,
transB
!=
CblasNoTrans
,
alpha
,
if
(
UseXSMM
<
T
>
(
M
,
N
,
K
,
transA
!=
CblasNoTrans
,
transB
!=
CblasNoTrans
,
alpha
,
beta
))
{
beta
))
{
// Note: SMM use ColMajor
// Note: SMM use ColMajor
const
char
transa
=
'N'
;
const
char
transa
=
'N'
;
const
char
transb
=
'N'
;
const
char
transb
=
'N'
;
CBlas
<
T
>::
SMM_GEMM
(
&
transa
,
&
transb
,
&
N
,
&
M
,
&
K
,
&
alpha
,
B
,
&
ldb
,
A
,
&
lda
,
CBlas
<
T
>::
SMM_GEMM
(
&
transa
,
&
transb
,
&
N
,
&
M
,
&
K
,
&
alpha
,
B
,
&
ldb
,
A
,
&
lda
,
&
beta
,
C
,
&
ldc
);
&
beta
,
C
,
&
ldc
);
}
else
{
return
;
}
#endif
#endif
CBlas
<
T
>::
GEMM
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
#ifdef PADDLE_MKL_SPLIT_GEMM
#ifdef PADDLE_WITH_LIBXSMM
constexpr
int
bs
=
2
;
if
(
M
%
bs
==
0
&&
transA
==
CblasNoTrans
&&
transB
==
CblasNoTrans
)
{
for
(
int
off
=
0
;
off
<
M
;
off
+=
bs
)
{
CBlas
<
T
>::
GEMM
(
CblasRowMajor
,
CblasNoTrans
,
CblasNoTrans
,
bs
,
N
,
K
,
alpha
,
A
+
off
*
lda
,
lda
,
B
,
ldb
,
beta
,
C
+
off
*
ldb
,
ldc
);
}
return
;
}
}
#endif
#endif
CBlas
<
T
>::
GEMM
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CPUDeviceContext
>::
GEMM
(
CBLAS_TRANSPOSE
transA
,
CBLAS_TRANSPOSE
transB
,
int
M
,
int
N
,
int
K
,
T
alpha
,
const
T
*
A
,
const
T
*
B
,
T
beta
,
T
*
C
)
const
{
int
lda
=
(
transA
==
CblasNoTrans
)
?
K
:
M
;
int
ldb
=
(
transB
==
CblasNoTrans
)
?
N
:
K
;
int
ldc
=
N
;
GEMM_WARP
<
T
>
(
CblasRowMajor
,
transA
,
transB
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
}
template
<
>
template
<
>
...
@@ -222,7 +243,7 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
...
@@ -222,7 +243,7 @@ void Blas<platform::CPUDeviceContext>::GEMM(bool transA, bool transB, int M,
int
N
,
int
K
,
T
alpha
,
const
T
*
A
,
int
N
,
int
K
,
T
alpha
,
const
T
*
A
,
int
lda
,
const
T
*
B
,
int
ldb
,
int
lda
,
const
T
*
B
,
int
ldb
,
T
beta
,
T
*
C
,
int
ldc
)
const
{
T
beta
,
T
*
C
,
int
ldc
)
const
{
CBlas
<
T
>::
GEMM
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
GEMM_WARP
<
T
>
(
CblasRowMajor
,
transA
==
false
?
CblasNoTrans
:
CblasTrans
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
transB
==
false
?
CblasNoTrans
:
CblasTrans
,
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
}
}
...
...
paddle/fluid/operators/math/math_function_test.cc
浏览文件 @
be04fbff
...
@@ -228,3 +228,57 @@ TEST(math_funciton, set_constant) {
...
@@ -228,3 +228,57 @@ TEST(math_funciton, set_constant) {
}
}
delete
ctx
;
delete
ctx
;
}
}
template
<
typename
T
>
void
GemmWarpTest
(
int
m
,
int
n
,
int
k
,
T
alpha
,
T
beta
)
{
paddle
::
framework
::
Tensor
mat_a
;
paddle
::
framework
::
Tensor
mat_b
;
paddle
::
framework
::
Tensor
mat_c_ref
;
paddle
::
framework
::
Tensor
mat_c_mkl
;
auto
*
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
T
*
A
=
mat_a
.
mutable_data
<
T
>
({
m
,
k
},
*
cpu_place
);
T
*
B
=
mat_b
.
mutable_data
<
T
>
({
k
,
n
},
*
cpu_place
);
T
*
CREF
=
mat_c_ref
.
mutable_data
<
T
>
({
m
,
n
},
*
cpu_place
);
T
*
CMKL
=
mat_c_mkl
.
mutable_data
<
T
>
({
m
,
n
},
*
cpu_place
);
ASSERT_EQ
(
mat_c_mkl
.
numel
(),
mat_c_ref
.
numel
());
for
(
int
i
=
0
;
i
<
mat_a
.
numel
();
++
i
)
{
A
[
i
]
=
static_cast
<
T
>
(
i
);
}
for
(
int
i
=
0
;
i
<
mat_b
.
numel
();
++
i
)
{
B
[
i
]
=
static_cast
<
T
>
(
i
+
1
);
}
for
(
int
i
=
0
;
i
<
mat_c_ref
.
numel
();
++
i
)
{
CREF
[
i
]
=
static_cast
<
T
>
(
i
+
2
);
CMKL
[
i
]
=
CREF
[
i
];
}
// this would call gemm_warp
paddle
::
platform
::
CPUDeviceContext
context
(
*
cpu_place
);
GetBlas
<
T
>
(
context
).
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
m
,
n
,
k
,
alpha
,
A
,
B
,
beta
,
CREF
);
// lda,ldb,ldc follow RowMajor
int
lda
=
k
;
int
ldb
=
n
;
int
ldc
=
n
;
paddle
::
operators
::
math
::
CBlas
<
T
>::
GEMM
(
CblasRowMajor
,
CblasNoTrans
,
CblasNoTrans
,
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
CMKL
,
ldc
);
for
(
int
i
=
0
;
i
<
mat_c_mkl
.
numel
();
++
i
)
{
EXPECT_FLOAT_EQ
(
CREF
[
i
],
CMKL
[
i
]);
}
}
TEST
(
math_function
,
gemm_warp
)
{
GemmWarpTest
<
float
>
(
3
,
2
,
5
,
1.
f
,
0.
f
);
GemmWarpTest
<
float
>
(
3
,
2
,
5
,
2.
f
,
1.
f
);
GemmWarpTest
<
float
>
(
8
,
5
,
6
,
1.
f
,
0.
f
);
GemmWarpTest
<
float
>
(
8
,
5
,
6
,
2.
f
,
1.
f
);
GemmWarpTest
<
double
>
(
3
,
2
,
5
,
1.0
,
0.0
);
GemmWarpTest
<
double
>
(
3
,
2
,
5
,
2.0
,
1.0
);
GemmWarpTest
<
double
>
(
8
,
5
,
6
,
1.0
,
0.0
);
GemmWarpTest
<
double
>
(
8
,
5
,
6
,
2.0
,
1.0
);
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录