Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
bafd8dec
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bafd8dec
编写于
6月 24, 2022
作者:
X
xiongkun
提交者:
GitHub
6月 24, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
change svd_cpu_kernel from Eigen to Lapack, speed up the compile from 120s -> 20s (#43784)
上级
23036031
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
133 addition
and
26 deletion
+133
-26
paddle/fluid/operators/svd_helper.h
paddle/fluid/operators/svd_helper.h
+28
-25
paddle/fluid/operators/svd_op.h
paddle/fluid/operators/svd_op.h
+21
-1
paddle/phi/backends/dynload/lapack.h
paddle/phi/backends/dynload/lapack.h
+30
-0
paddle/phi/kernels/funcs/lapack/lapack_function.cc
paddle/phi/kernels/funcs/lapack/lapack_function.cc
+38
-0
paddle/phi/kernels/funcs/lapack/lapack_function.h
paddle/phi/kernels/funcs/lapack/lapack_function.h
+16
-0
未找到文件。
paddle/fluid/operators/svd_helper.h
浏览文件 @
bafd8dec
...
...
@@ -30,6 +30,7 @@
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
...
...
@@ -44,40 +45,42 @@ template <typename T, int MajorType = Eigen::RowMajor,
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
T
>
void
Eigen
Svd
(
const
T
*
X
,
T
*
U
,
T
*
VH
,
T
*
S
,
int
rows
,
int
cols
,
void
Lapack
Svd
(
const
T
*
X
,
T
*
U
,
T
*
VH
,
T
*
S
,
int
rows
,
int
cols
,
int
full
=
false
)
{
auto
flag
=
Eigen
::
DecompositionOptions
::
ComputeThinU
|
Eigen
::
DecompositionOptions
::
ComputeThinV
;
if
(
full
)
{
flag
=
Eigen
::
DecompositionOptions
::
ComputeFullU
|
Eigen
::
DecompositionOptions
::
ComputeFullV
;
}
Eigen
::
BDCSVD
<
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
svd
(
2
,
2
,
flag
);
/*NOTE(xiongkun03) Eigen::Matrix API need non-const pointer.*/
T
*
input
=
const_cast
<
T
*>
(
X
);
auto
m
=
Eigen
::
Map
<
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
(
input
,
rows
,
cols
);
svd
.
compute
(
m
);
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>
V_trans
=
svd
.
matrixV
().
transpose
();
memcpy
(
U
,
svd
.
matrixU
().
data
(),
svd
.
matrixU
().
size
()
*
sizeof
(
T
));
memcpy
(
VH
,
V_trans
.
data
(),
V_trans
.
size
()
*
sizeof
(
T
));
memcpy
(
S
,
svd
.
singularValues
().
data
(),
svd
.
singularValues
().
size
()
*
sizeof
(
T
));
char
jobz
=
full
?
'A'
:
'S'
;
int
mx
=
std
::
max
(
rows
,
cols
);
int
mn
=
std
::
min
(
rows
,
cols
);
T
*
a
=
const_cast
<
T
*>
(
X
);
int
lda
=
rows
;
int
ldu
=
rows
;
int
ldvt
=
full
?
cols
:
mn
;
int
lwork
=
full
?
(
4
*
mn
*
mn
+
6
*
mn
+
mx
)
:
(
4
*
mn
*
mn
+
7
*
mn
);
std
::
vector
<
T
>
work
(
lwork
);
std
::
vector
<
int
>
iwork
(
8
*
mn
);
int
info
;
phi
::
funcs
::
lapackSvd
<
T
>
(
jobz
,
rows
,
cols
,
a
,
lda
,
S
,
U
,
ldu
,
VH
,
ldvt
,
work
.
data
(),
lwork
,
iwork
.
data
(),
&
info
);
if
(
info
<
0
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"This %s-th argument has an illegal value"
,
info
));
}
if
(
info
>
0
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"DBDSDC/SBDSDC did not converge, updating process failed. May be you "
"passes a invalid matrix."
));
}
}
template
<
typename
T
>
void
BatchSvd
(
const
T
*
X
,
T
*
U
,
T
*
VH
,
T
*
S
,
int
rows
,
int
cols
,
int
batches
,
int
full
=
false
)
{
// NOTE: this function is row major, because this function called the lapack.
int
stride
=
rows
*
cols
;
int
k
=
std
::
min
(
rows
,
cols
);
int
stride_u
=
full
?
rows
*
rows
:
k
*
rows
;
int
stride_v
=
full
?
cols
*
cols
:
k
*
cols
;
for
(
int
i
=
0
;
i
<
batches
;
++
i
)
{
Eigen
Svd
<
T
>
(
X
+
i
*
stride
,
U
+
i
*
stride_u
,
VH
+
i
*
stride_v
,
S
+
i
*
k
,
Lapack
Svd
<
T
>
(
X
+
i
*
stride
,
U
+
i
*
stride_u
,
VH
+
i
*
stride_v
,
S
+
i
*
k
,
rows
,
cols
,
full
);
}
return
;
...
...
paddle/fluid/operators/svd_op.h
浏览文件 @
bafd8dec
...
...
@@ -21,6 +21,7 @@
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -39,7 +40,12 @@ class SvdCPUKernel : public framework::OpKernel<T> {
/*Create Tensors and output, set the dim ...*/
auto
numel
=
x
->
numel
();
auto
*
x_data
=
x
->
data
<
T
>
();
auto
&
orig_dev_ctx
=
context
.
template
device_context
<
platform
::
CPUDeviceContext
>();
auto
&
dev_ctx
=
static_cast
<
const
typename
framework
::
ConvertToPhiContext
<
platform
::
CPUDeviceContext
>::
TYPE
&>
(
orig_dev_ctx
);
Tensor
trans_x
=
::
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
*
x
);
auto
*
x_data
=
trans_x
.
data
<
T
>
();
auto
x_dims
=
x
->
dims
();
int
rows
=
x_dims
[
x_dims
.
size
()
-
2
];
int
cols
=
x_dims
[
x_dims
.
size
()
-
1
];
...
...
@@ -57,6 +63,20 @@ class SvdCPUKernel : public framework::OpKernel<T> {
context
.
GetPlace
(),
size_t
(
batches
*
k
*
sizeof
(
phi
::
dtype
::
Real
<
T
>
)));
/*SVD Use the Eigen Library*/
math
::
BatchSvd
<
T
>
(
x_data
,
U_out
,
VH_out
,
S_out
,
rows
,
cols
,
batches
,
full
);
/* let C[m, n] as a col major matrix with m rows and n cols.
* let R[m, n] is row major matrix with m rows and n cols.
* then we have: R[m,n] = C[m, n].resize((n,m)).tranpose_last_two()
* */
auto
col_major_to_row_major
=
[
&
dev_ctx
](
Tensor
*
out
)
{
auto
origin_dim
=
out
->
dims
();
int64_t
&
x
=
origin_dim
[
origin_dim
.
size
()
-
1
];
int64_t
&
y
=
origin_dim
[
origin_dim
.
size
()
-
2
];
std
::
swap
(
x
,
y
);
out
->
Resize
(
origin_dim
);
return
::
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
*
out
);
};
*
U
=
col_major_to_row_major
(
U
);
*
VH
=
col_major_to_row_major
(
VH
);
}
};
...
...
paddle/phi/backends/dynload/lapack.h
浏览文件 @
bafd8dec
...
...
@@ -280,6 +280,34 @@ extern "C" void spotrs_(char *uplo,
float
*
b
,
int
*
ldb
,
int
*
info
);
extern
"C"
void
dgesdd_
(
char
*
,
int
*
,
int
*
,
double
*
,
int
*
,
double
*
,
double
*
,
int
*
,
double
*
,
int
*
,
double
*
,
int
*
,
int
*
,
int
*
);
extern
"C"
void
sgesdd_
(
char
*
,
int
*
,
int
*
,
float
*
,
int
*
,
float
*
,
float
*
,
int
*
,
float
*
,
int
*
,
float
*
,
int
*
,
int
*
,
int
*
);
namespace
phi
{
namespace
dynload
{
...
...
@@ -328,6 +356,8 @@ extern void *lapack_dso_handle;
__macro(sgelsy_); \
__macro(dgelss_); \
__macro(sgelss_); \
__macro(sgesdd_); \
__macro(dgesdd_); \
__macro(zpotrs_); \
__macro(cpotrs_); \
__macro(dpotrs_); \
...
...
paddle/phi/kernels/funcs/lapack/lapack_function.cc
浏览文件 @
bafd8dec
...
...
@@ -499,5 +499,43 @@ void lapackCholeskySolve<float>(char uplo,
dynload
::
spotrs_
(
&
uplo
,
&
n
,
&
nrhs
,
a
,
&
lda
,
b
,
&
ldb
,
info
);
}
template
<
>
void
lapackSvd
<
double
>
(
char
jobz
,
int
m
,
int
n
,
double
*
a
,
int
lda
,
double
*
s
,
double
*
u
,
int
ldu
,
double
*
vt
,
int
ldvt
,
double
*
work
,
int
lwork
,
int
*
iwork
,
int
*
info
)
{
dynload
::
dgesdd_
(
&
jobz
,
&
m
,
&
n
,
a
,
&
lda
,
s
,
u
,
&
ldu
,
vt
,
&
ldvt
,
work
,
&
lwork
,
iwork
,
info
);
}
template
<
>
void
lapackSvd
<
float
>
(
char
jobz
,
int
m
,
int
n
,
float
*
a
,
int
lda
,
float
*
s
,
float
*
u
,
int
ldu
,
float
*
vt
,
int
ldvt
,
float
*
work
,
int
lwork
,
int
*
iwork
,
int
*
info
)
{
dynload
::
sgesdd_
(
&
jobz
,
&
m
,
&
n
,
a
,
&
lda
,
s
,
u
,
&
ldu
,
vt
,
&
ldvt
,
work
,
&
lwork
,
iwork
,
info
);
}
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/funcs/lapack/lapack_function.h
浏览文件 @
bafd8dec
...
...
@@ -120,6 +120,22 @@ void lapackGelss(int m,
T2
*
rwork
,
int
*
info
);
template
<
typename
T
>
void
lapackSvd
(
char
jobz
,
int
m
,
int
n
,
T
*
a
,
int
lda
,
T
*
s
,
T
*
u
,
int
ldu
,
T
*
vt
,
int
ldvt
,
T
*
work
,
int
lwork
,
int
*
iwork
,
int
*
info
);
template
<
typename
T
>
void
lapackCholeskySolve
(
char
uplo
,
int
n
,
int
nrhs
,
T
*
a
,
int
lda
,
T
*
b
,
int
ldb
,
int
*
info
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录