Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
effb70f4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
effb70f4
编写于
9月 26, 2021
作者:
C
crystal
提交者:
GitHub
9月 26, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[cherry-pick]CPU forward calculation replaces Eigen with Lapack (#35916) (#36091)
cherry-pick #35916,CPU前向计算将Eigen替换为Lapack,修改linalg暴露规则
上级
14cdcde7
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
217 addition
and
179 deletion
+217
-179
paddle/fluid/operators/eigh_op.cc
paddle/fluid/operators/eigh_op.cc
+8
-9
paddle/fluid/operators/eigh_op.cu
paddle/fluid/operators/eigh_op.cu
+8
-9
paddle/fluid/operators/eigh_op.h
paddle/fluid/operators/eigh_op.h
+4
-3
paddle/fluid/operators/math/eigen_values_vectors.h
paddle/fluid/operators/math/eigen_values_vectors.h
+122
-151
paddle/fluid/operators/math/lapack_function.cc
paddle/fluid/operators/math/lapack_function.cc
+43
-0
paddle/fluid/operators/math/lapack_function.h
paddle/fluid/operators/math/lapack_function.h
+9
-4
paddle/fluid/platform/dynload/lapack.h
paddle/fluid/platform/dynload/lapack.h
+21
-0
python/paddle/__init__.py
python/paddle/__init__.py
+0
-1
python/paddle/tensor/linalg.py
python/paddle/tensor/linalg.py
+2
-2
未找到文件。
paddle/fluid/operators/eigh_op.cc
浏览文件 @
effb70f4
...
@@ -147,18 +147,17 @@ REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker,
...
@@ -147,18 +147,17 @@ REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker,
REGISTER_OPERATOR
(
eigh_grad
,
ops
::
EighGradOp
);
REGISTER_OPERATOR
(
eigh_grad
,
ops
::
EighGradOp
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
eigh
,
ops
::
EighKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
,
float
>
,
eigh
,
ops
::
EighKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
EighKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
,
double
>
,
ops
::
EighKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
EighKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
,
ops
::
EighKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
EighKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
,
ops
::
EighKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_CPU_KERNEL
(
REGISTER_OP_CPU_KERNEL
(
eigh_grad
,
eigh_grad
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
,
float
>
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
,
double
>
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
,
paddle
::
platform
::
complex
<
float
>>
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
);
paddle/fluid/operators/eigh_op.cu
浏览文件 @
effb70f4
...
@@ -16,18 +16,17 @@ limitations under the License. */
...
@@ -16,18 +16,17 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
eigh
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
float
>
,
eigh
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
double
>
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex
<
float
>>
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_CUDA_KERNEL
(
REGISTER_OP_CUDA_KERNEL
(
eigh_grad
,
eigh_grad
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
float
>
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
double
>
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
paddle
::
platform
::
complex
<
float
>>
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
ops
::
EighGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
complex
<
double
>>
);
paddle
::
platform
::
complex
<
double
>>
);
paddle/fluid/operators/eigh_op.h
浏览文件 @
effb70f4
...
@@ -22,7 +22,7 @@ namespace operators {
...
@@ -22,7 +22,7 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
ValueType
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
EighKernel
:
public
framework
::
OpKernel
<
T
>
{
class
EighKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -31,15 +31,16 @@ class EighKernel : public framework::OpKernel<T> {
...
@@ -31,15 +31,16 @@ class EighKernel : public framework::OpKernel<T> {
auto
output_v
=
ctx
.
Output
<
Tensor
>
(
"Eigenvectors"
);
auto
output_v
=
ctx
.
Output
<
Tensor
>
(
"Eigenvectors"
);
std
::
string
lower
=
ctx
.
Attr
<
std
::
string
>
(
"UPLO"
);
std
::
string
lower
=
ctx
.
Attr
<
std
::
string
>
(
"UPLO"
);
bool
is_lower
=
(
lower
==
"L"
);
bool
is_lower
=
(
lower
==
"L"
);
math
::
MatrixEighFunctor
<
DeviceContext
,
ValueType
,
T
>
functor
;
math
::
MatrixEighFunctor
<
DeviceContext
,
T
>
functor
;
functor
(
ctx
,
*
input
,
output_w
,
output_v
,
is_lower
,
true
);
functor
(
ctx
,
*
input
,
output_w
,
output_v
,
is_lower
,
true
);
}
}
};
};
template
<
typename
DeviceContext
,
typename
ValueType
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
EighGradKernel
:
public
framework
::
OpKernel
<
T
>
{
class
EighGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
ValueType
=
math
::
Real
<
T
>
;
auto
&
x_grad
=
*
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
&
x_grad
=
*
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
x_grad
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
x_grad
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
output_w
=
*
ctx
.
Input
<
Tensor
>
(
"Eigenvalues"
);
auto
&
output_w
=
*
ctx
.
Input
<
Tensor
>
(
"Eigenvalues"
);
...
...
paddle/fluid/operators/math/eigen_values_vectors.h
浏览文件 @
effb70f4
...
@@ -14,8 +14,8 @@
...
@@ -14,8 +14,8 @@
#pragma once
#pragma once
#include "Eigen/Core"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/math/lapack_function.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/svd_helper.h"
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
...
@@ -25,84 +25,6 @@ namespace paddle {
...
@@ -25,84 +25,6 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
InputMatrixMap
=
Eigen
::
Map
<
const
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
OutputMatrixMap
=
Eigen
::
Map
<
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
;
template
<
typename
ValueType
>
inline
void
ComputeFloatEigenvaluesAndVectors
(
ValueType
*
x_data
,
ValueType
*
eigenvalues_data
,
ValueType
*
eigenvectors_data
,
int
batches
,
int
rows
,
int
cols
,
bool
has_vectors
)
{
int
stride
=
rows
*
cols
;
for
(
int
i
=
0
;
i
<
batches
;
i
++
)
{
auto
m
=
InputMatrixMap
<
ValueType
>
(
x_data
+
i
*
stride
,
rows
,
cols
);
auto
eigenvalues
=
OutputMatrixMap
<
ValueType
>
(
eigenvalues_data
+
i
*
rows
,
1
,
rows
);
auto
eigenvectors
=
OutputMatrixMap
<
ValueType
>
(
eigenvectors_data
+
i
*
stride
,
rows
,
cols
);
Eigen
::
SelfAdjointEigenSolver
<
Eigen
::
Matrix
<
ValueType
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
eigen_solver
(
m
,
has_vectors
?
Eigen
::
ComputeEigenvectors
:
Eigen
::
EigenvaluesOnly
);
PADDLE_ENFORCE_EQ
(
eigen_solver
.
info
(),
Eigen
::
Success
,
platform
::
errors
::
InvalidArgument
(
"Self Adjoint Eigen decomposition is not successful. "
"The %d-th input matrice might not be not be positive definite."
,
i
));
eigenvalues
=
eigen_solver
.
eigenvalues
().
transpose
();
if
(
has_vectors
)
{
eigenvectors
=
eigen_solver
.
eigenvectors
();
}
}
}
template
<
typename
T
,
typename
ValueType
>
inline
void
ComputeComplexEigenvaluesAndVectors
(
T
*
x_data
,
ValueType
*
eigenvalues_data
,
T
*
eigenvectors_data
,
int
batches
,
int
rows
,
int
cols
,
bool
has_vectors
)
{
using
Complex
=
std
::
complex
<
ValueType
>
;
Complex
*
input
=
reinterpret_cast
<
Complex
*>
(
x_data
);
Complex
*
eigenvectors_data_
=
reinterpret_cast
<
Complex
*>
(
eigenvectors_data
);
int
stride
=
rows
*
cols
;
for
(
int
i
=
0
;
i
<
batches
;
i
++
)
{
auto
m
=
InputMatrixMap
<
Complex
>
(
input
+
i
*
stride
,
rows
,
cols
);
auto
eigenvalues
=
OutputMatrixMap
<
ValueType
>
(
eigenvalues_data
+
i
*
rows
,
1
,
rows
);
auto
eigenvectors
=
OutputMatrixMap
<
Complex
>
(
eigenvectors_data_
+
i
*
stride
,
rows
,
cols
);
Eigen
::
SelfAdjointEigenSolver
<
Eigen
::
Matrix
<
Complex
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
,
Eigen
::
RowMajor
>>
eigen_solver
(
m
,
has_vectors
?
Eigen
::
ComputeEigenvectors
:
Eigen
::
EigenvaluesOnly
);
PADDLE_ENFORCE_EQ
(
eigen_solver
.
info
(),
Eigen
::
Success
,
platform
::
errors
::
InvalidArgument
(
"Self Adjoint Eigen decomposition is not successful. "
"The %d-th input matrice might not be not be positive definite."
,
i
));
eigenvalues
=
eigen_solver
.
eigenvalues
().
transpose
();
if
(
has_vectors
)
{
eigenvectors
=
eigen_solver
.
eigenvectors
();
}
}
}
inline
int64_t
GetBatchSize
(
framework
::
DDim
dims
)
{
inline
int64_t
GetBatchSize
(
framework
::
DDim
dims
)
{
int64_t
batch_size
=
1
;
int64_t
batch_size
=
1
;
auto
dim_size
=
dims
.
size
();
auto
dim_size
=
dims
.
size
();
...
@@ -112,7 +34,20 @@ inline int64_t GetBatchSize(framework::DDim dims) {
...
@@ -112,7 +34,20 @@ inline int64_t GetBatchSize(framework::DDim dims) {
return
batch_size
;
return
batch_size
;
}
}
template
<
typename
DeviceContext
,
typename
ValueType
,
typename
T
>
static
void
CheckEighResult
(
const
int
batch
,
const
int
info
)
{
PADDLE_ENFORCE_LE
(
info
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"For batch [%d]: the [%d] off-diagonal elements of an intermediate"
"tridiagonal form did not converge to zero"
,
batch
,
info
));
PADDLE_ENFORCE_GE
(
info
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"For batch [%d]: the [%d] argument had an illegal value"
,
batch
,
info
));
}
template
<
typename
DeviceContext
,
typename
T
>
struct
MatrixEighFunctor
{
struct
MatrixEighFunctor
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
input
,
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
input
,
Tensor
*
eigen_values
,
Tensor
*
eigen_vectors
,
bool
is_lower
,
Tensor
*
eigen_values
,
Tensor
*
eigen_vectors
,
bool
is_lower
,
...
@@ -122,43 +57,84 @@ struct MatrixEighFunctor {
...
@@ -122,43 +57,84 @@ struct MatrixEighFunctor {
// Calculates the eigenvalues and eigenvectors of Hermitian or real
// Calculates the eigenvalues and eigenvectors of Hermitian or real
// symmetric matrices, and uses the variable has_vectors to
// symmetric matrices, and uses the variable has_vectors to
// control whether to return the eigenvectors.
// control whether to return the eigenvectors.
template
<
typename
ValueType
,
typename
T
>
template
<
typename
T
>
struct
MatrixEighFunctor
<
platform
::
CPUDeviceContext
,
ValueType
,
T
>
{
struct
MatrixEighFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
public:
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
input
,
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
input
,
Tensor
*
eigen_values
,
Tensor
*
eigen_vectors
,
bool
is_lower
,
Tensor
*
eigen_values
,
Tensor
*
eigen_vectors
,
bool
is_lower
,
bool
has_vectors
)
{
bool
has_vectors
)
{
auto
dims
=
input
.
dims
()
;
using
ValueType
=
math
::
Real
<
T
>
;
auto
output_value_dim
=
eigen_values
->
dims
(
);
auto
*
out_value
=
eigen_values
->
mutable_data
<
ValueType
>
(
ctx
.
GetPlace
()
);
int64_t
batch_size
=
1
;
int
dim_size
=
dims
.
size
();
for
(
int64_t
i
=
0
;
i
<
dim_size
-
2
;
i
++
)
{
batch_size
*=
dims
[
i
];
}
auto
dito
=
auto
dito
=
DeviceIndependenceTensorOperations
<
platform
::
CPUDeviceContext
,
T
>
(
ctx
);
math
::
DeviceIndependenceTensorOperations
<
platform
::
CPUDeviceContext
,
T
>
(
Tensor
input_tensor
;
ctx
);
TensorCopy
(
input
,
ctx
.
GetPlace
(),
&
input_tensor
);
if
(
!
is_lower
)
{
input_tensor
=
dito
.
Transpose
(
input
);
}
int
rows
=
dims
[
dims
.
size
()
-
2
];
auto
*
value_data
=
Tensor
input_trans
;
eigen_values
->
mutable_data
<
ValueType
>
(
output_value_dim
,
ctx
.
GetPlace
());
// lapack is a column-major storge, transpose make the input to
// have a continuous memory layout
input_trans
=
dito
.
Transpose
(
input
);
auto
*
input_vector
=
input_trans
.
data
<
T
>
();
if
(
framework
::
IsComplexType
(
input_tensor
.
type
()))
{
auto
dims
=
input
.
dims
();
auto
*
x_data
=
input_tensor
.
data
<
T
>
();
int
dim_size
=
dims
.
size
();
auto
*
vector_data
=
eigen_vectors
->
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
int64_t
batch_size
=
GetBatchSize
(
dims
);
ComputeComplexEigenvaluesAndVectors
<
T
,
ValueType
>
(
x_data
,
value_data
,
vector_data
,
batch_size
,
rows
,
rows
,
has_vectors
);
int
vector_stride
=
dims
[
dim_size
-
1
]
*
dims
[
dim_size
-
2
];
}
else
{
int
values_stride
=
dims
[
dim_size
-
1
];
auto
*
x_data
=
input_tensor
.
data
<
ValueType
>
();
char
uplo
=
is_lower
?
'L'
:
'U'
;
auto
*
vector_data
=
char
jobz
=
has_vectors
?
'V'
:
'N'
;
eigen_vectors
->
mutable_data
<
ValueType
>
(
dims
,
ctx
.
GetPlace
());
auto
n
=
dims
[
dim_size
-
1
];
ComputeFloatEigenvaluesAndVectors
<
ValueType
>
(
auto
lda
=
std
::
max
<
int64_t
>
(
1
,
n
);
x_data
,
value_data
,
vector_data
,
batch_size
,
rows
,
rows
,
has_vectors
);
// if work = -1, it means that you need to use the lapack function to query
// the optimal value
int
lwork
=
-
1
;
// The length of the array work
int
lrwork
=
-
1
;
// The dimension of the array rwork,rwork is REAL array
int
liwork
=
-
1
;
// The dimension of the array iwork
int
iwork_opt
=
-
1
;
// The optimal length of the array liwork
T
lwork_opt
=
static_cast
<
T
>
(
-
1
);
// The optimal length of the array work
ValueType
rwork_opt
=
static_cast
<
ValueType
>
(
-
1
);
// The optimal length of the array rwork
int
info
=
0
;
// Call lapackEigh to get the optimal size of work data
math
::
lapackEigh
<
T
,
ValueType
>
(
jobz
,
uplo
,
n
,
input_vector
,
lda
,
out_value
,
&
lwork_opt
,
lwork
,
&
rwork_opt
,
lrwork
,
&
iwork_opt
,
liwork
,
&
info
);
lwork
=
std
::
max
<
int
>
(
1
,
static_cast
<
int
>
(
lwork_opt
));
liwork
=
std
::
max
<
int
>
(
1
,
iwork_opt
);
Tensor
rwork_tensor
;
ValueType
*
rwork_data
=
nullptr
;
// complex type
if
(
framework
::
IsComplexType
(
input
.
type
()))
{
lrwork
=
std
::
max
<
int
>
(
1
,
static_cast
<
int
>
(
rwork_opt
));
rwork_data
=
rwork_tensor
.
mutable_data
<
ValueType
>
(
framework
::
make_ddim
({
lrwork
}),
ctx
.
GetPlace
());
}
Tensor
iwork_tensor
,
work_tensor
;
auto
*
iwork_data
=
iwork_tensor
.
mutable_data
<
int
>
(
framework
::
make_ddim
({
liwork
}),
ctx
.
GetPlace
());
auto
*
work_data
=
work_tensor
.
mutable_data
<
T
>
(
framework
::
make_ddim
({
lwork
}),
ctx
.
GetPlace
());
for
(
auto
i
=
0
;
i
<
batch_size
;
i
++
)
{
auto
*
value_data
=
out_value
+
i
*
values_stride
;
auto
*
input_data
=
input_vector
+
i
*
vector_stride
;
math
::
lapackEigh
<
T
,
Real
<
T
>>
(
jobz
,
uplo
,
n
,
input_data
,
lda
,
value_data
,
work_data
,
lwork
,
rwork_data
,
lrwork
,
iwork_data
,
liwork
,
&
info
);
CheckEighResult
(
i
,
info
);
}
if
(
has_vectors
)
{
PADDLE_ENFORCE_NOT_NULL
(
eigen_vectors
,
platform
::
errors
::
InvalidArgument
(
"When has_vectors is true,"
"the eigenvectors needs to be calculated, "
"so the eigenvectors must be provided."
));
input_trans
=
dito
.
Transpose
(
input_trans
);
eigen_vectors
->
ShareDataWith
(
input_trans
);
}
}
}
}
};
};
...
@@ -168,15 +144,22 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, ValueType, T> {
...
@@ -168,15 +144,22 @@ struct MatrixEighFunctor<platform::CPUDeviceContext, ValueType, T> {
// Calculates the eigenvalues and eigenvectors of Hermitian or real
// Calculates the eigenvalues and eigenvectors of Hermitian or real
// symmetric matrices on GPU, and uses the variable has_vectors
// symmetric matrices on GPU, and uses the variable has_vectors
// to control whether to return the eigenvectors.
// to control whether to return the eigenvectors.
template
<
typename
ValueType
,
typename
T
>
template
<
typename
T
>
struct
MatrixEighFunctor
<
platform
::
CUDADeviceContext
,
ValueType
,
T
>
{
struct
MatrixEighFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
public:
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
input
,
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
input
,
Tensor
*
eigen_values
,
Tensor
*
eigen_vectors
,
bool
is_lower
,
Tensor
*
eigen_values
,
Tensor
*
eigen_vectors
,
bool
is_lower
,
bool
has_vectors
)
{
bool
has_vectors
)
{
using
ValueType
=
math
::
Real
<
T
>
;
auto
*
out_value
=
eigen_values
->
mutable_data
<
ValueType
>
(
ctx
.
GetPlace
());
auto
*
out_value
=
eigen_values
->
mutable_data
<
ValueType
>
(
ctx
.
GetPlace
());
auto
*
out_vector
=
eigen_vectors
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
dito
=
math
::
DeviceIndependenceTensorOperations
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
);
Tensor
input_trans
;
input_trans
=
dito
.
Transpose
(
input
);
auto
*
input_vector
=
input_trans
.
data
<
T
>
();
auto
&
dims
=
input
.
dims
();
auto
&
dims
=
input
.
dims
();
int
dim_size
=
dims
.
size
();
int
dim_size
=
dims
.
size
();
int64_t
batch_size
=
GetBatchSize
(
dims
);
int64_t
batch_size
=
GetBatchSize
(
dims
);
...
@@ -190,14 +173,6 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
...
@@ -190,14 +173,6 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
int
lda
=
std
::
max
<
int
>
(
1
,
n
);
int
lda
=
std
::
max
<
int
>
(
1
,
n
);
auto
vector_stride
=
dims
[
dim_size
-
1
]
*
dims
[
dim_size
-
2
];
auto
vector_stride
=
dims
[
dim_size
-
1
]
*
dims
[
dim_size
-
2
];
auto
values_stride
=
dims
[
dim_size
-
1
];
auto
values_stride
=
dims
[
dim_size
-
1
];
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
dito
=
math
::
DeviceIndependenceTensorOperations
<
platform
::
CUDADeviceContext
,
T
>
(
ctx
);
Tensor
output_v_var_trans
=
dito
.
Transpose
(
input
);
TensorCopy
(
output_v_var_trans
,
ctx
.
GetPlace
(),
eigen_vectors
);
int
lwork
=
0
;
int
lwork
=
0
;
auto
info
=
memory
::
Alloc
(
dev_ctx
,
sizeof
(
int
)
*
batch_size
);
auto
info
=
memory
::
Alloc
(
dev_ctx
,
sizeof
(
int
)
*
batch_size
);
auto
*
info_ptr
=
reinterpret_cast
<
int
*>
(
info
->
ptr
());
auto
*
info_ptr
=
reinterpret_cast
<
int
*>
(
info
->
ptr
());
...
@@ -205,10 +180,8 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
...
@@ -205,10 +180,8 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
// When the input type is float32, and the feature value input dimension is
// When the input type is float32, and the feature value input dimension is
// greater than or equal to [*,32,32] and less than or equal to
// greater than or equal to [*,32,32] and less than or equal to
// [*,512,512], Syevj has better performance.
// [*,512,512], Syevj has better performance.
bool
use_syevj
=
bool
use_syevj
=
(
input
.
type
()
==
framework
::
proto
::
VarType
::
FP32
&&
(
eigen_vectors
->
type
()
==
framework
::
proto
::
VarType
::
FP32
&&
values_stride
>=
32
&&
values_stride
<=
512
);
values_stride
>=
32
&&
values_stride
<=
512
);
syevjInfo_t
syevj_params
;
syevjInfo_t
syevj_params
;
if
(
use_syevj
)
{
if
(
use_syevj
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
...
@@ -216,52 +189,52 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
...
@@ -216,52 +189,52 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cusolverDnSsyevj_bufferSize
(
platform
::
dynload
::
cusolverDnSsyevj_bufferSize
(
dev_ctx
.
cusolver_dn_handle
(),
jobz
,
uplo
,
n
,
dev_ctx
.
cusolver_dn_handle
(),
jobz
,
uplo
,
n
,
reinterpret_cast
<
const
float
*>
(
o
ut_vector
),
lda
,
reinterpret_cast
<
const
float
*>
(
inp
ut_vector
),
lda
,
reinterpret_cast
<
const
float
*>
(
out_value
),
&
lwork
,
reinterpret_cast
<
const
float
*>
(
out_value
),
&
lwork
,
syevj_params
));
syevj_params
));
}
else
{
}
else
{
EvdBuffer
(
dev_ctx
.
cusolver_dn_handle
(),
jobz
,
uplo
,
n
,
o
ut_vector
,
lda
,
EvdBuffer
(
dev_ctx
.
cusolver_dn_handle
(),
jobz
,
uplo
,
n
,
inp
ut_vector
,
lda
,
out_value
,
&
lwork
);
out_value
,
&
lwork
);
}
}
auto
work
=
memory
::
Alloc
(
dev_ctx
,
sizeof
(
T
)
*
lwork
);
auto
work
=
memory
::
Alloc
(
dev_ctx
,
sizeof
(
T
)
*
lwork
);
auto
*
work_ptr
=
reinterpret_cast
<
T
*>
(
work
->
ptr
());
auto
*
work_ptr
=
reinterpret_cast
<
T
*>
(
work
->
ptr
());
for
(
auto
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
auto
i
=
0
;
i
<
batch_size
;
i
++
)
{
auto
vector_data
=
o
ut_vector
+
i
*
vector_stride
;
auto
*
input_data
=
inp
ut_vector
+
i
*
vector_stride
;
auto
value_data
=
out_value
+
i
*
values_stride
;
auto
*
value_data
=
out_value
+
i
*
values_stride
;
auto
handle
=
dev_ctx
.
cusolver_dn_handle
();
auto
handle
=
dev_ctx
.
cusolver_dn_handle
();
if
(
use_syevj
)
{
if
(
use_syevj
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cusolverDnSsyevj
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cusolverDnSsyevj
(
handle
,
jobz
,
uplo
,
n
,
reinterpret_cast
<
float
*>
(
vector
_data
),
lda
,
handle
,
jobz
,
uplo
,
n
,
reinterpret_cast
<
float
*>
(
input
_data
),
lda
,
reinterpret_cast
<
float
*>
(
value_data
),
reinterpret_cast
<
float
*>
(
value_data
),
reinterpret_cast
<
float
*>
(
work_ptr
),
lwork
,
info_ptr
,
reinterpret_cast
<
float
*>
(
work_ptr
),
lwork
,
info_ptr
,
syevj_params
));
syevj_params
));
}
else
{
}
else
{
Evd
(
handle
,
jobz
,
uplo
,
n
,
vector_data
,
lda
,
value_data
,
work_ptr
,
Evd
(
handle
,
jobz
,
uplo
,
n
,
input_data
,
lda
,
value_data
,
work_ptr
,
lwork
,
lwork
,
info_ptr
);
info_ptr
);
}
}
int
error_info
;
int
error_info
=
0
;
memory
::
Copy
(
platform
::
CPUPlace
(),
&
error_info
,
memory
::
Copy
(
platform
::
CPUPlace
(),
&
error_info
,
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
dev_ctx
.
GetPlace
()),
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
dev_ctx
.
GetPlace
()),
info_ptr
,
sizeof
(
int
),
dev_ctx
.
stream
());
info_ptr
,
sizeof
(
int
),
dev_ctx
.
stream
());
PADDLE_ENFORCE_EQ
(
CheckEighResult
(
i
,
error_info
);
error_info
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"For batch [%d]: the [%d] argument had an illegal value"
,
i
,
error_info
));
}
}
if
(
use_syevj
)
{
if
(
use_syevj
)
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
platform
::
dynload
::
cusolverDnDestroySyevjInfo
(
syevj_params
));
platform
::
dynload
::
cusolverDnDestroySyevjInfo
(
syevj_params
));
}
}
if
(
has_vectors
)
{
if
(
has_vectors
)
{
*
eigen_vectors
=
dito
.
Transpose
(
*
eigen_vectors
);
PADDLE_ENFORCE_NOT_NULL
(
eigen_vectors
,
platform
::
errors
::
InvalidArgument
(
"When has_vectors is true,"
"the eigenvectors needs to be calculated,"
"so the eigenvectors must be provided."
));
input_trans
=
dito
.
Transpose
(
input_trans
);
eigen_vectors
->
ShareDataWith
(
input_trans
);
}
}
}
}
using
ValueType
=
math
::
Real
<
T
>
;
inline
void
EvdBuffer
(
cusolverDnHandle_t
handle
,
cusolverEigMode_t
jobz
,
inline
void
EvdBuffer
(
cusolverDnHandle_t
handle
,
cusolverEigMode_t
jobz
,
cublasFillMode_t
uplo
,
int
n
,
const
T
*
A
,
int
lda
,
cublasFillMode_t
uplo
,
int
n
,
const
T
*
A
,
int
lda
,
const
ValueType
*
W
,
int
*
lwork
)
const
;
const
ValueType
*
W
,
int
*
lwork
)
const
;
...
@@ -271,15 +244,14 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
...
@@ -271,15 +244,14 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
T
*
work
,
int
lwork
,
int
*
devInfo
)
const
;
T
*
work
,
int
lwork
,
int
*
devInfo
)
const
;
};
};
#define FUNC_WITH_TYPES(m)
\
#define FUNC_WITH_TYPES(m) \
m(float,
float, Ssy, float) m(double, double, Dsy, double)
\
m(float,
Ssy, float) m(double, Dsy, double)
\
m(
float,
paddle::platform::complex<float>, Che, cuComplex) \
m(paddle::platform::complex<float>, Che, cuComplex) \
m(
double,
paddle::platform::complex<double>, Zhe, cuDoubleComplex)
m(paddle::platform::complex<double>, Zhe, cuDoubleComplex)
#define EVDBUFFER_INSTANCE(
ValueType, T, C, CastType)
\
#define EVDBUFFER_INSTANCE(
T, C, CastType)
\
template <> \
template <> \
inline void \
inline void MatrixEighFunctor<platform::CUDADeviceContext, T>::EvdBuffer( \
MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T>::EvdBuffer( \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \
cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \
int *lwork) const { \
int *lwork) const { \
...
@@ -291,10 +263,9 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
...
@@ -291,10 +263,9 @@ struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
FUNC_WITH_TYPES
(
EVDBUFFER_INSTANCE
);
FUNC_WITH_TYPES
(
EVDBUFFER_INSTANCE
);
#define EVD_INSTANCE(
ValueType, T, C, CastType)
\
#define EVD_INSTANCE(
T, C, CastType)
\
template <> \
template <> \
inline void \
inline void MatrixEighFunctor<platform::CUDADeviceContext, T>::Evd( \
MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T>::Evd( \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \
cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \
int lwork, int *devInfo) const { \
int lwork, int *devInfo) const { \
...
...
paddle/fluid/operators/math/lapack_function.cc
浏览文件 @
effb70f4
...
@@ -31,6 +31,49 @@ void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
...
@@ -31,6 +31,49 @@ void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
platform
::
dynload
::
sgetrf_
(
&
m
,
&
n
,
a
,
&
lda
,
ipiv
,
info
);
platform
::
dynload
::
sgetrf_
(
&
m
,
&
n
,
a
,
&
lda
,
ipiv
,
info
);
}
}
// eigh
template
<
>
void
lapackEigh
<
float
>
(
char
jobz
,
char
uplo
,
int
n
,
float
*
a
,
int
lda
,
float
*
w
,
float
*
work
,
int
lwork
,
float
*
rwork
,
int
lrwork
,
int
*
iwork
,
int
liwork
,
int
*
info
)
{
(
void
)
rwork
;
// unused
(
void
)
lrwork
;
// unused
platform
::
dynload
::
ssyevd_
(
&
jobz
,
&
uplo
,
&
n
,
a
,
&
lda
,
w
,
work
,
&
lwork
,
iwork
,
&
liwork
,
info
);
}
template
<
>
void
lapackEigh
<
double
>
(
char
jobz
,
char
uplo
,
int
n
,
double
*
a
,
int
lda
,
double
*
w
,
double
*
work
,
int
lwork
,
double
*
rwork
,
int
lrwork
,
int
*
iwork
,
int
liwork
,
int
*
info
)
{
(
void
)
rwork
;
// unused
(
void
)
lrwork
;
// unused
platform
::
dynload
::
dsyevd_
(
&
jobz
,
&
uplo
,
&
n
,
a
,
&
lda
,
w
,
work
,
&
lwork
,
iwork
,
&
liwork
,
info
);
}
template
<
>
void
lapackEigh
<
platform
::
complex
<
float
>
,
float
>
(
char
jobz
,
char
uplo
,
int
n
,
platform
::
complex
<
float
>
*
a
,
int
lda
,
float
*
w
,
platform
::
complex
<
float
>
*
work
,
int
lwork
,
float
*
rwork
,
int
lrwork
,
int
*
iwork
,
int
liwork
,
int
*
info
)
{
platform
::
dynload
::
cheevd_
(
&
jobz
,
&
uplo
,
&
n
,
reinterpret_cast
<
std
::
complex
<
float
>
*>
(
a
),
&
lda
,
w
,
reinterpret_cast
<
std
::
complex
<
float
>
*>
(
work
),
&
lwork
,
rwork
,
&
lrwork
,
iwork
,
&
liwork
,
info
);
}
template
<
>
void
lapackEigh
<
platform
::
complex
<
double
>
,
double
>
(
char
jobz
,
char
uplo
,
int
n
,
platform
::
complex
<
double
>
*
a
,
int
lda
,
double
*
w
,
platform
::
complex
<
double
>
*
work
,
int
lwork
,
double
*
rwork
,
int
lrwork
,
int
*
iwork
,
int
liwork
,
int
*
info
)
{
platform
::
dynload
::
zheevd_
(
&
jobz
,
&
uplo
,
&
n
,
reinterpret_cast
<
std
::
complex
<
double
>
*>
(
a
),
&
lda
,
w
,
reinterpret_cast
<
std
::
complex
<
double
>
*>
(
work
),
&
lwork
,
rwork
,
&
lrwork
,
iwork
,
&
liwork
,
info
);
}
// Eig
// Eig
template
<
>
template
<
>
void
lapackEig
<
double
>
(
char
jobvl
,
char
jobvr
,
int
n
,
double
*
a
,
int
lda
,
void
lapackEig
<
double
>
(
char
jobvl
,
char
jobvr
,
int
n
,
double
*
a
,
int
lda
,
...
...
paddle/fluid/operators/math/lapack_function.h
浏览文件 @
effb70f4
...
@@ -20,12 +20,17 @@ namespace math {
...
@@ -20,12 +20,17 @@ namespace math {
// LU (for example)
// LU (for example)
template
<
typename
T
>
template
<
typename
T
>
void
lapackLu
(
int
m
,
int
n
,
T
*
a
,
int
lda
,
int
*
ipiv
,
int
*
info
);
void
lapackLu
(
int
m
,
int
n
,
T
*
a
,
int
lda
,
int
*
ipiv
,
int
*
info
);
template
<
typename
T
,
typename
ValueType
=
T
>
void
lapackEigh
(
char
jobz
,
char
uplo
,
int
n
,
T
*
a
,
int
lda
,
ValueType
*
w
,
T
*
work
,
int
lwork
,
ValueType
*
rwork
,
int
lrwork
,
int
*
iwork
,
int
liwork
,
int
*
info
);
template
<
typename
T1
,
typename
T2
=
T1
>
template
<
typename
T1
,
typename
T2
=
T1
>
void
lapackEig
(
char
jobvl
,
char
jobvr
,
int
n
,
T1
*
a
,
int
lda
,
T1
*
w
,
T1
*
vl
,
void
lapackEig
(
char
jobvl
,
char
jobvr
,
int
n
,
T1
*
a
,
int
lda
,
T1
*
w
,
T1
*
vl
,
int
ldvl
,
T1
*
vr
,
int
ldvr
,
T1
*
work
,
int
lwork
,
T2
*
rwork
,
int
ldvl
,
T1
*
vr
,
int
ldvr
,
T1
*
work
,
int
lwork
,
T2
*
rwork
,
int
*
info
);
int
*
info
);
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/platform/dynload/lapack.h
浏览文件 @
effb70f4
...
@@ -16,6 +16,7 @@ limitations under the License. */
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <complex>
#include <complex>
#include <mutex>
#include <mutex>
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/port.h"
...
@@ -28,6 +29,22 @@ extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv,
...
@@ -28,6 +29,22 @@ extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv,
extern
"C"
void
sgetrf_
(
int
*
m
,
int
*
n
,
float
*
a
,
int
*
lda
,
int
*
ipiv
,
extern
"C"
void
sgetrf_
(
int
*
m
,
int
*
n
,
float
*
a
,
int
*
lda
,
int
*
ipiv
,
int
*
info
);
int
*
info
);
// evd
extern
"C"
void
zheevd_
(
char
*
jobz
,
char
*
uplo
,
int
*
n
,
std
::
complex
<
double
>
*
a
,
int
*
lda
,
double
*
w
,
std
::
complex
<
double
>
*
work
,
int
*
lwork
,
double
*
rwork
,
int
*
lrwork
,
int
*
iwork
,
int
*
liwork
,
int
*
info
);
extern
"C"
void
cheevd_
(
char
*
jobz
,
char
*
uplo
,
int
*
n
,
std
::
complex
<
float
>
*
a
,
int
*
lda
,
float
*
w
,
std
::
complex
<
float
>
*
work
,
int
*
lwork
,
float
*
rwork
,
int
*
lrwork
,
int
*
iwork
,
int
*
liwork
,
int
*
info
);
extern
"C"
void
dsyevd_
(
char
*
jobz
,
char
*
uplo
,
int
*
n
,
double
*
a
,
int
*
lda
,
double
*
w
,
double
*
work
,
int
*
lwork
,
int
*
iwork
,
int
*
liwork
,
int
*
info
);
extern
"C"
void
ssyevd_
(
char
*
jobz
,
char
*
uplo
,
int
*
n
,
float
*
a
,
int
*
lda
,
float
*
w
,
float
*
work
,
int
*
lwork
,
int
*
iwork
,
int
*
liwork
,
int
*
info
);
// geev
// geev
extern
"C"
void
dgeev_
(
char
*
jobvl
,
char
*
jobvr
,
int
*
n
,
double
*
a
,
int
*
lda
,
extern
"C"
void
dgeev_
(
char
*
jobvl
,
char
*
jobvr
,
int
*
n
,
double
*
a
,
int
*
lda
,
double
*
wr
,
double
*
wi
,
double
*
vl
,
int
*
ldvl
,
double
*
wr
,
double
*
wi
,
double
*
vl
,
int
*
ldvl
,
...
@@ -81,6 +98,10 @@ extern void *lapack_dso_handle;
...
@@ -81,6 +98,10 @@ extern void *lapack_dso_handle;
#define LAPACK_ROUTINE_EACH(__macro) \
#define LAPACK_ROUTINE_EACH(__macro) \
__macro(dgetrf_); \
__macro(dgetrf_); \
__macro(sgetrf_); \
__macro(sgetrf_); \
__macro(zheevd_); \
__macro(cheevd_); \
__macro(dsyevd_); \
__macro(ssyevd_); \
__macro(dgeev_); \
__macro(dgeev_); \
__macro(sgeev_); \
__macro(sgeev_); \
__macro(zgeev_); \
__macro(zgeev_); \
...
...
python/paddle/__init__.py
浏览文件 @
effb70f4
...
@@ -106,7 +106,6 @@ from .tensor.linalg import slogdet # noqa: F401
...
@@ -106,7 +106,6 @@ from .tensor.linalg import slogdet # noqa: F401
from
.tensor.linalg
import
multi_dot
# noqa: F401
from
.tensor.linalg
import
multi_dot
# noqa: F401
from
.tensor.linalg
import
matrix_power
# noqa: F401
from
.tensor.linalg
import
matrix_power
# noqa: F401
from
.tensor.linalg
import
svd
# noqa: F401
from
.tensor.linalg
import
svd
# noqa: F401
from
.tensor.linalg
import
eigh
# noqa: F401
from
.tensor.linalg
import
pinv
# noqa: F401
from
.tensor.linalg
import
pinv
# noqa: F401
from
.tensor.linalg
import
solve
# noqa: F401
from
.tensor.linalg
import
solve
# noqa: F401
from
.tensor.logic
import
equal
# noqa: F401
from
.tensor.logic
import
equal
# noqa: F401
...
...
python/paddle/tensor/linalg.py
浏览文件 @
effb70f4
...
@@ -1759,7 +1759,7 @@ def eigh(x, UPLO='L', name=None):
...
@@ -1759,7 +1759,7 @@ def eigh(x, UPLO='L', name=None):
x_data = np.array([[1, -2j], [2j, 5]])
x_data = np.array([[1, -2j], [2j, 5]])
x = paddle.to_tensor(x_data)
x = paddle.to_tensor(x_data)
out_value, out_vector = paddle.eigh(x, UPLO='L')
out_value, out_vector = paddle.
linalg.
eigh(x, UPLO='L')
print(out_value)
print(out_value)
#[0.17157288, 5.82842712]
#[0.17157288, 5.82842712]
print(out_vector)
print(out_vector)
...
@@ -1780,7 +1780,7 @@ def eigh(x, UPLO='L', name=None):
...
@@ -1780,7 +1780,7 @@ def eigh(x, UPLO='L', name=None):
raise
ValueError
(
raise
ValueError
(
"The input matrix must be batches of square matrices. But received x's dimention: {}"
.
"The input matrix must be batches of square matrices. But received x's dimention: {}"
.
format
(
x_shape
))
format
(
x_shape
))
if
UPLO
is
not
'L'
and
UPLO
is
not
'U'
:
if
UPLO
!=
'L'
and
UPLO
!=
'U'
:
raise
ValueError
(
raise
ValueError
(
"UPLO must be L or U. But received UPLO is: {}"
.
format
(
UPLO
))
"UPLO must be L or U. But received UPLO is: {}"
.
format
(
UPLO
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录