Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
82cd8d21
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
82cd8d21
编写于
6月 28, 2022
作者:
W
WangZhen
提交者:
GitHub
6月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Speed up matrix_rank_tol_kernel.cc compile time (#43856)
上级
6d436f6e
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
49 addition
and
47 deletion
+49
-47
paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc
paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc
+49
-47
未找到文件。
paddle/phi/kernels/cpu/matrix_rank_tol_kernel.cc
浏览文件 @
82cd8d21
...
@@ -14,67 +14,66 @@
...
@@ -14,67 +14,66 @@
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
#include <Eigen/Dense>
#include <Eigen/SVD>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/abs_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/values_vectors_functor.h"
#include "paddle/phi/kernels/impl/matrix_rank_kernel_impl.h"
#include "paddle/phi/kernels/impl/matrix_rank_kernel_impl.h"
#include "paddle/phi/kernels/reduce_max_kernel.h"
#include "paddle/phi/kernels/reduce_max_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace
phi
{
namespace
phi
{
template
<
typename
T
>
template
<
typename
T
>
void
BatchEigenvalues
(
const
T
*
x_data
,
void
LapackSVD
(
const
T
*
x_data
,
T
*
eigenvalues_data
,
int
rows
,
int
cols
)
{
T
*
eigenvalues_data
,
char
jobz
=
'N'
;
int
batches
,
int
mx
=
std
::
max
(
rows
,
cols
);
int
rows
,
int
mn
=
std
::
min
(
rows
,
cols
);
int
cols
,
T
*
a
=
const_cast
<
T
*>
(
x_data
);
int
k
)
{
int
lda
=
rows
;
// Eigen::Matrix API need non-const pointer.
int
lwork
=
3
*
mn
+
std
::
max
(
mx
,
7
*
mn
);
T
*
input
=
const_cast
<
T
*>
(
x_data
);
std
::
vector
<
T
>
work
(
lwork
);
int
stride
=
rows
*
cols
;
std
::
vector
<
int
>
iwork
(
8
*
mn
);
for
(
int
i
=
0
;
i
<
batches
;
i
++
)
{
int
info
;
auto
m
=
Eigen
::
Map
<
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
(
phi
::
funcs
::
lapackSvd
<
T
>
(
jobz
,
input
+
i
*
stride
,
rows
,
rows
);
rows
,
Eigen
::
SelfAdjointEigenSolver
<
cols
,
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
a
,
eigen_solver
(
m
);
lda
,
auto
eigenvalues
=
eigen_solver
.
eigenvalues
().
cwiseAbs
();
eigenvalues_data
,
for
(
int
j
=
0
;
j
<
k
;
j
++
)
{
nullptr
,
*
(
eigenvalues_data
+
i
*
k
+
j
)
=
eigenvalues
[
j
];
1
,
}
nullptr
,
1
,
work
.
data
(),
lwork
,
iwork
.
data
(),
&
info
);
if
(
info
<
0
)
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"This %s-th argument has an illegal value"
,
info
));
}
if
(
info
>
0
)
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"DBDSDC/SBDSDC did not converge, updating process failed. May be you "
"passes a invalid matrix."
));
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
void
BatchSVD
(
const
T
*
x_data
,
void
BatchSVD
(
T
*
eigenvalues_data
,
const
T
*
x_data
,
T
*
eigenvalues_data
,
int
batches
,
int
rows
,
int
cols
)
{
int
batches
,
int
rows
,
int
cols
,
int
k
)
{
// Eigen::Matrix API need non-const pointer.
T
*
input
=
const_cast
<
T
*>
(
x_data
);
int
stride
=
rows
*
cols
;
int
stride
=
rows
*
cols
;
Eigen
::
BDCSVD
<
int
k
=
std
::
min
(
rows
,
cols
);
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
for
(
int
i
=
0
;
i
<
batches
;
++
i
)
{
svd
;
LapackSVD
<
T
>
(
x_data
+
i
*
stride
,
eigenvalues_data
+
i
*
k
,
rows
,
cols
);
for
(
int
i
=
0
;
i
<
batches
;
i
++
)
{
auto
m
=
Eigen
::
Map
<
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
(
input
+
i
*
stride
,
rows
,
cols
);
svd
.
compute
(
m
);
auto
res_s
=
svd
.
singularValues
();
for
(
int
j
=
0
;
j
<
k
;
j
++
)
{
eigenvalues_data
[
i
*
k
+
j
]
=
res_s
[
j
];
}
}
}
}
}
...
@@ -85,7 +84,6 @@ void MatrixRankTolKernel(const Context& dev_ctx,
...
@@ -85,7 +84,6 @@ void MatrixRankTolKernel(const Context& dev_ctx,
bool
use_default_tol
,
bool
use_default_tol
,
bool
hermitian
,
bool
hermitian
,
DenseTensor
*
out
)
{
DenseTensor
*
out
)
{
auto
*
x_data
=
x
.
data
<
T
>
();
dev_ctx
.
template
Alloc
<
int64_t
>(
out
);
dev_ctx
.
template
Alloc
<
int64_t
>(
out
);
auto
dim_x
=
x
.
dims
();
auto
dim_x
=
x
.
dims
();
auto
dim_out
=
out
->
dims
();
auto
dim_out
=
out
->
dims
();
...
@@ -106,9 +104,13 @@ void MatrixRankTolKernel(const Context& dev_ctx,
...
@@ -106,9 +104,13 @@ void MatrixRankTolKernel(const Context& dev_ctx,
auto
*
eigenvalue_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
eigenvalue_tensor
);
auto
*
eigenvalue_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
eigenvalue_tensor
);
if
(
hermitian
)
{
if
(
hermitian
)
{
BatchEigenvalues
<
T
>
(
x_data
,
eigenvalue_data
,
batches
,
rows
,
cols
,
k
);
phi
::
funcs
::
MatrixEighFunctor
<
Context
,
T
>
functor
;
functor
(
dev_ctx
,
x
,
&
eigenvalue_tensor
,
nullptr
,
true
,
false
);
phi
::
AbsKernel
<
T
,
Context
>
(
dev_ctx
,
eigenvalue_tensor
,
&
eigenvalue_tensor
);
}
else
{
}
else
{
BatchSVD
<
T
>
(
x_data
,
eigenvalue_data
,
batches
,
rows
,
cols
,
k
);
DenseTensor
trans_x
=
phi
::
TransposeLast2Dim
<
T
>
(
dev_ctx
,
x
);
auto
*
x_data
=
trans_x
.
data
<
T
>
();
BatchSVD
<
T
>
(
x_data
,
eigenvalue_data
,
batches
,
rows
,
cols
);
}
}
DenseTensor
max_eigenvalue_tensor
;
DenseTensor
max_eigenvalue_tensor
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录