Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
95474815
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
95474815
编写于
7月 13, 2022
作者:
R
Ruibiao Chen
提交者:
GitHub
7月 13, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move eigvals OP to PHI (#44183)
* Move eigvals OP to PHI * Fix CI errors * Fix CI errors
上级
0a5d625b
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
383 addition
and
50 deletion
+383
-50
paddle/fluid/operators/eigvals_op.cc
paddle/fluid/operators/eigvals_op.cc
+8
-49
paddle/phi/api/yaml/legacy_api.yaml
paddle/phi/api/yaml/legacy_api.yaml
+8
-0
paddle/phi/core/utils/data_type.h
paddle/phi/core/utils/data_type.h
+17
-0
paddle/phi/infermeta/unary.cc
paddle/phi/infermeta/unary.cc
+33
-0
paddle/phi/infermeta/unary.h
paddle/phi/infermeta/unary.h
+4
-0
paddle/phi/kernels/cpu/eigvals_kernel.cc
paddle/phi/kernels/cpu/eigvals_kernel.cc
+260
-0
paddle/phi/kernels/eigvals_kernel.h
paddle/phi/kernels/eigvals_kernel.h
+25
-0
paddle/phi/ops/compat/eigvals_sig.cc
paddle/phi/ops/compat/eigvals_sig.cc
+25
-0
python/paddle/tensor/linalg.py
python/paddle/tensor/linalg.py
+3
-1
未找到文件。
paddle/fluid/operators/eigvals_op.cc
浏览文件 @
95474815
...
...
@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigvals_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -36,59 +37,17 @@ class EigvalsOpMaker : public framework::OpProtoAndCheckerMaker {
class
EigvalsOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"Eigvals"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"Eigvals"
);
DDim
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The dimensions of Input(X) for Eigvals operator "
"should be at least 2, "
"but received X's dimension = %d, X's shape = [%s]."
,
x_dims
.
size
(),
x_dims
));
if
(
ctx
->
IsRuntime
()
||
!
phi
::
contain_unknown_dim
(
x_dims
))
{
int
last_dim
=
x_dims
.
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
x_dims
[
last_dim
],
x_dims
[
last_dim
-
1
],
platform
::
errors
::
InvalidArgument
(
"The last two dimensions of Input(X) for Eigvals "
"operator should be equal, "
"but received X's shape = [%s]."
,
x_dims
));
}
auto
output_dims
=
vectorize
(
x_dims
);
output_dims
.
resize
(
x_dims
.
size
()
-
1
);
ctx
->
SetOutputDim
(
"Out"
,
phi
::
make_ddim
(
output_dims
));
}
};
class
EigvalsOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
{
auto
input_dtype
=
ctx
->
GetInputDataType
(
"X"
);
auto
output_dtype
=
framework
::
IsComplexType
(
input_dtype
)
?
input_dtype
:
framework
::
ToComplexType
(
input_dtype
);
ctx
->
SetOutputDataType
(
"Out"
,
output_dtype
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
DECLARE_INFER_SHAPE_FUNCTOR
(
eigvals
,
EigvalsInferShapeFunctor
,
PD_INFER_META
(
phi
::
EigvalsInferMeta
));
REGISTER_OPERATOR
(
eigvals
,
ops
::
EigvalsOp
,
ops
::
EigvalsOpMaker
,
ops
::
EigvalsOpVarTypeInference
);
REGISTER_OP_CPU_KERNEL
(
eigvals
,
ops
::
EigvalsKernel
<
phi
::
CPUContext
,
float
>
,
ops
::
EigvalsKernel
<
phi
::
CPUContext
,
double
>
,
ops
::
EigvalsKernel
<
phi
::
CPUContext
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
EigvalsKernel
<
phi
::
CPUContext
,
paddle
::
platform
::
complex
<
double
>>
);
EigvalsInferShapeFunctor
);
paddle/phi/api/yaml/legacy_api.yaml
浏览文件 @
95474815
...
...
@@ -536,6 +536,14 @@
func
:
eigh
backward
:
eigh_grad
-
api
:
eigvals
args
:
(Tensor x)
output
:
Tensor
infer_meta
:
func
:
EigvalsInferMeta
kernel
:
func
:
eigvals
-
api
:
einsum
args
:
(Tensor[] x, str equation)
output
:
Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
...
...
paddle/phi/core/utils/data_type.h
浏览文件 @
95474815
...
...
@@ -80,4 +80,21 @@ inline void VisitDataTypeTiny(phi::DataType type, Visitor visitor) {
"Not supported phi::DataType(%d) as data type."
,
static_cast
<
int
>
(
type
)));
}
inline
bool
IsComplexType
(
const
DataType
&
type
)
{
return
(
type
==
DataType
::
COMPLEX64
||
type
==
DataType
::
COMPLEX128
);
}
inline
DataType
ToComplexType
(
const
DataType
&
type
)
{
switch
(
type
)
{
case
DataType
::
FLOAT32
:
return
DataType
::
COMPLEX64
;
case
DataType
::
FLOAT64
:
return
DataType
::
COMPLEX128
;
default:
PADDLE_THROW
(
errors
::
Unimplemented
(
"Can not transform data type (%s) to complex type, now only support "
"float32 and float64 real value."
,
type
));
}
}
}
// namespace phi
paddle/phi/infermeta/unary.cc
浏览文件 @
95474815
...
...
@@ -399,6 +399,39 @@ void EighInferMeta(const MetaTensor& x,
out_v
->
set_dims
(
input_dim
);
}
void
EigvalsInferMeta
(
const
MetaTensor
&
x
,
MetaTensor
*
out
,
MetaConfig
config
)
{
auto
x_dims
=
x
.
dims
();
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
2
,
errors
::
InvalidArgument
(
"The dimensions of Input(X) for Eigvals operator "
"should be at least 2, "
"but received X's dimension = %d, X's shape = [%s]."
,
x_dims
.
size
(),
x_dims
));
if
(
config
.
is_runtime
||
!
phi
::
contain_unknown_dim
(
x_dims
))
{
int
last_dim
=
x_dims
.
size
()
-
1
;
PADDLE_ENFORCE_EQ
(
x_dims
[
last_dim
],
x_dims
[
last_dim
-
1
],
errors
::
InvalidArgument
(
"The last two dimensions of Input(X) for Eigvals "
"operator should be equal, "
"but received X's shape = [%s]."
,
x_dims
));
}
auto
out_dims
=
vectorize
(
x_dims
);
out_dims
.
resize
(
x_dims
.
size
()
-
1
);
const
DataType
&
x_dtype
=
x
.
dtype
();
const
DataType
&
out_dtype
=
IsComplexType
(
x_dtype
)
?
x_dtype
:
ToComplexType
(
x_dtype
);
out
->
set_dims
(
make_ddim
(
out_dims
));
out
->
set_dtype
(
out_dtype
);
}
void
EinsumInferMeta
(
const
std
::
vector
<
const
MetaTensor
*>&
inputs
,
const
std
::
string
&
equation
,
MetaTensor
*
out
,
...
...
paddle/phi/infermeta/unary.h
浏览文件 @
95474815
...
...
@@ -80,6 +80,10 @@ void EighInferMeta(const MetaTensor& x,
MetaTensor
*
out_w
,
MetaTensor
*
out_v
);
void
EigvalsInferMeta
(
const
MetaTensor
&
x
,
MetaTensor
*
out
,
MetaConfig
config
=
MetaConfig
());
void
EinsumInferMeta
(
const
std
::
vector
<
const
MetaTensor
*>&
inputs
,
const
std
::
string
&
equation
,
MetaTensor
*
out
,
...
...
paddle/
fluid/operators/eigvals_op.h
→
paddle/
phi/kernels/cpu/eigvals_kernel.cc
浏览文件 @
95474815
// Copyright (c) 202
1
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -12,23 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#
pragma once
#
include "paddle/phi/kernels/eigvals_kernel.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
DDim
=
framework
::
DDim
;
namespace
phi
{
template
<
typename
T
,
typename
enable
=
void
>
struct
PaddleComplex
;
...
...
@@ -37,79 +31,60 @@ template <typename T>
struct
PaddleComplex
<
T
,
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
>::
type
>
{
using
type
=
paddle
::
platform
::
complex
<
T
>
;
using
type
=
dtype
::
complex
<
T
>
;
};
template
<
typename
T
>
struct
PaddleComplex
<
T
,
typename
std
::
enable_if
<
std
::
is_same
<
T
,
platform
::
complex
<
float
>>::
value
||
std
::
is_same
<
T
,
platform
::
complex
<
double
>>::
value
>::
type
>
{
std
::
is_same
<
T
,
dtype
::
complex
<
float
>>::
value
||
std
::
is_same
<
T
,
dtype
::
complex
<
double
>>::
value
>::
type
>
{
using
type
=
T
;
};
template
<
typename
T
>
using
PaddleCType
=
typename
PaddleComplex
<
T
>::
type
;
template
<
typename
T
>
using
Real
=
typename
phi
::
dtype
::
Real
<
T
>
;
static
void
SpiltBatchSquareMatrix
(
const
Tensor
&
input
,
std
::
vector
<
Tensor
>*
output
)
{
DDim
input_dims
=
input
.
dims
();
int
last_dim
=
input_dims
.
size
()
-
1
;
int
n_dim
=
input_dims
[
last_dim
];
DDim
flattened_input_dims
,
flattened_output_dims
;
if
(
input_dims
.
size
()
>
2
)
{
flattened_input_dims
=
phi
::
flatten_to_3d
(
input_dims
,
last_dim
-
1
,
last_dim
);
}
else
{
flattened_input_dims
=
phi
::
make_ddim
({
1
,
n_dim
,
n_dim
});
}
Tensor
flattened_input
;
flattened_input
.
ShareDataWith
(
input
);
flattened_input
.
Resize
(
flattened_input_dims
);
(
*
output
)
=
flattened_input
.
Split
(
1
,
0
);
}
using
Real
=
typename
dtype
::
Real
<
T
>
;
static
void
CheckLapackEigResult
(
const
int
info
,
const
std
::
string
&
name
)
{
PADDLE_ENFORCE_LE
(
info
,
inline
void
CheckLapackEigResult
(
const
int
info
,
const
std
::
string
&
name
)
{
PADDLE_ENFORCE_LE
(
info
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"The QR algorithm failed to compute all the "
errors
::
PreconditionNotMet
(
"The QR algorithm failed to compute all the "
"eigenvalues in function %s."
,
name
.
c_str
()));
PADDLE_ENFORCE_GE
(
info
,
0
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The %d-th argument has an illegal value in function %s."
,
-
info
,
name
.
c_str
()));
}
template
<
typename
DeviceContext
,
typename
T
>
static
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
>::
type
LapackEigvals
(
const
framework
::
Execution
Context
&
ctx
,
const
Tensor
&
input
,
Tensor
*
output
,
Tensor
*
work
,
Tensor
*
rwork
/*unused*/
)
{
Tensor
a
;
// will be overwritten when lapackEig exit
framework
::
TensorCopy
(
input
,
input
.
place
()
,
&
a
);
Tensor
w
;
template
<
typename
T
,
typename
Context
>
typename
std
::
enable_if
<
std
::
is_floating_point
<
T
>::
value
>::
type
LapackEigvals
(
const
Context
&
ctx
,
const
Dense
Tensor
&
input
,
Dense
Tensor
*
output
,
Dense
Tensor
*
work
,
Dense
Tensor
*
rwork
/*unused*/
)
{
Dense
Tensor
a
;
// will be overwritten when lapackEig exit
Copy
(
ctx
,
input
,
input
.
place
(),
/*blocking=*/
true
,
&
a
);
Dense
Tensor
w
;
int64_t
n_dim
=
input
.
dims
()[
1
];
auto
*
w_data
=
w
.
mutable_data
<
T
>
(
phi
::
make_ddim
({
n_dim
<<
1
}),
ctx
.
GetPlace
()
);
w
.
Resize
(
make_ddim
({
n_dim
<<
1
}));
T
*
w_data
=
ctx
.
template
Alloc
<
T
>(
&
w
);
int64_t
work_mem
=
work
->
memory_size
();
int64_t
required_work_mem
=
3
*
n_dim
*
sizeof
(
T
);
PADDLE_ENFORCE_GE
(
work_mem
,
3
*
n_dim
*
sizeof
(
T
),
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The memory size of the work tensor in LapackEigvals function "
"should be at least %"
PRId64
" bytes, "
"but received work
\'
s memory size = %"
PRId64
" bytes."
,
...
...
@@ -132,30 +107,28 @@ LapackEigvals(const framework::ExecutionContext& ctx,
static_cast
<
T
*>
(
NULL
),
&
info
);
std
::
string
name
=
"framework::platform::dynload::dgeev_"
;
if
(
framework
::
TransToProtoVarType
(
input
.
dtype
())
==
framework
::
proto
::
VarType
::
FP64
)
{
name
=
"framework::platform::dynload::sgeev_"
;
std
::
string
name
=
"phi::backend::dynload::dgeev_"
;
if
(
input
.
dtype
()
==
DataType
::
FLOAT64
)
{
name
=
"phi::backend::dynload::sgeev_"
;
}
CheckLapackEigResult
(
info
,
name
);
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx
.
template
device_context
<
DeviceContext
>(),
n_dim
);
phi
::
funcs
::
RealImagToComplexFunctor
<
PaddleCType
<
T
>>
functor
(
funcs
::
ForRange
<
Context
>
for_range
(
ctx
,
n_dim
);
funcs
::
RealImagToComplexFunctor
<
PaddleCType
<
T
>>
functor
(
w_data
,
w_data
+
n_dim
,
output
->
template
data
<
PaddleCType
<
T
>
>
(),
n_dim
);
for_range
(
functor
);
}
template
<
typename
DeviceContext
,
typename
T
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
platform
::
complex
<
float
>>::
value
||
std
::
is_same
<
T
,
platform
::
complex
<
double
>>::
value
>::
type
LapackEigvals
(
const
framework
::
Execution
Context
&
ctx
,
const
Tensor
&
input
,
Tensor
*
output
,
Tensor
*
work
,
Tensor
*
rwork
)
{
Tensor
a
;
// will be overwritten when lapackEig exit
framework
::
TensorCopy
(
input
,
input
.
place
()
,
&
a
);
template
<
typename
T
,
typename
Context
>
typename
std
::
enable_if
<
std
::
is_same
<
T
,
dtype
::
complex
<
float
>>::
value
||
std
::
is_same
<
T
,
dtype
::
complex
<
double
>>::
value
>::
type
LapackEigvals
(
const
Context
&
ctx
,
const
Dense
Tensor
&
input
,
Dense
Tensor
*
output
,
Dense
Tensor
*
work
,
Dense
Tensor
*
rwork
)
{
Dense
Tensor
a
;
// will be overwritten when lapackEig exit
Copy
(
ctx
,
input
,
input
.
place
(),
/*blocking=*/
true
,
&
a
);
int64_t
work_mem
=
work
->
memory_size
();
int64_t
n_dim
=
input
.
dims
()[
1
];
...
...
@@ -163,7 +136,7 @@ LapackEigvals(const framework::ExecutionContext& ctx,
PADDLE_ENFORCE_GE
(
work_mem
,
3
*
n_dim
*
sizeof
(
T
),
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The memory size of the work tensor in LapackEigvals function "
"should be at least %"
PRId64
" bytes, "
"but received work
\'
s memory size = %"
PRId64
" bytes."
,
...
...
@@ -171,11 +144,11 @@ LapackEigvals(const framework::ExecutionContext& ctx,
work_mem
));
int64_t
rwork_mem
=
rwork
->
memory_size
();
int64_t
required_rwork_mem
=
(
n_dim
<<
1
)
*
sizeof
(
phi
::
dtype
::
Real
<
T
>
);
int64_t
required_rwork_mem
=
(
n_dim
<<
1
)
*
sizeof
(
dtype
::
Real
<
T
>
);
PADDLE_ENFORCE_GE
(
rwork_mem
,
required_rwork_mem
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"The memory size of the rwork tensor in LapackEigvals function "
"should be at least %"
PRId64
" bytes, "
"but received rwork
\'
s memory size = %"
PRId64
" bytes."
,
...
...
@@ -183,7 +156,7 @@ LapackEigvals(const framework::ExecutionContext& ctx,
rwork_mem
));
int
info
=
0
;
phi
::
funcs
::
lapackEig
<
T
,
phi
::
dtype
::
Real
<
T
>>
(
phi
::
funcs
::
lapackEig
<
T
,
dtype
::
Real
<
T
>>
(
'N'
,
'N'
,
static_cast
<
int
>
(
n_dim
),
...
...
@@ -196,42 +169,56 @@ LapackEigvals(const framework::ExecutionContext& ctx,
1
,
work
->
template
data
<
T
>(),
static_cast
<
int
>
(
work_mem
/
sizeof
(
T
)),
rwork
->
template
data
<
phi
::
dtype
::
Real
<
T
>
>
(),
rwork
->
template
data
<
dtype
::
Real
<
T
>
>
(),
&
info
);
std
::
string
name
=
"framework::platform::dynload::cgeev_"
;
if
(
framework
::
TransToProtoVarType
(
input
.
dtype
())
==
framework
::
proto
::
VarType
::
COMPLEX64
)
{
name
=
"framework::platform::dynload::zgeev_"
;
std
::
string
name
=
"phi::backend::dynload::cgeev_"
;
if
(
input
.
dtype
()
==
DataType
::
COMPLEX128
)
{
name
=
"phi::backend::dynload::zgeev_"
;
}
CheckLapackEigResult
(
info
,
name
);
}
template
<
typename
DeviceContext
,
typename
T
>
class
EigvalsKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
Tensor
*
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
Tensor
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
output
->
mutable_data
<
PaddleCType
<
T
>>
(
ctx
.
GetPlace
());
void
SpiltBatchSquareMatrix
(
const
DenseTensor
&
input
,
std
::
vector
<
DenseTensor
>*
output
)
{
DDim
input_dims
=
input
.
dims
();
int
last_dim
=
input_dims
.
size
()
-
1
;
int
n_dim
=
input_dims
[
last_dim
];
std
::
vector
<
Tensor
>
input_matrices
;
SpiltBatchSquareMatrix
(
*
input
,
/*->*/
&
input_matrices
);
DDim
flattened_input_dims
,
flattened_output_dims
;
if
(
input_dims
.
size
()
>
2
)
{
flattened_input_dims
=
phi
::
flatten_to_3d
(
input_dims
,
last_dim
-
1
,
last_dim
);
}
else
{
flattened_input_dims
=
phi
::
make_ddim
({
1
,
n_dim
,
n_dim
});
}
DenseTensor
flattened_input
;
flattened_input
.
ShareDataWith
(
input
);
flattened_input
.
Resize
(
flattened_input_dims
);
(
*
output
)
=
flattened_input
.
Split
(
1
,
0
);
}
int64_t
n_dim
=
input_matrices
[
0
].
dims
()[
1
];
int64_t
n_batch
=
input_matrices
.
size
();
DDim
output_dims
=
output
->
dims
();
output
->
Resize
(
phi
::
make_ddim
({
n_batch
,
n_dim
}));
std
::
vector
<
Tensor
>
output_vectors
=
output
->
Split
(
1
,
0
);
template
<
typename
T
,
typename
Context
>
void
EigvalsKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
)
{
ctx
.
template
Alloc
<
PaddleCType
<
T
>
>
(
out
);
std
::
vector
<
DenseTensor
>
x_matrices
;
SpiltBatchSquareMatrix
(
x
,
/*->*/
&
x_matrices
);
int64_t
n_dim
=
x_matrices
[
0
].
dims
()[
1
];
int64_t
n_batch
=
x_matrices
.
size
();
DDim
out_dims
=
out
->
dims
();
out
->
Resize
(
make_ddim
({
n_batch
,
n_dim
}));
std
::
vector
<
DenseTensor
>
out_vectors
=
out
->
Split
(
1
,
0
);
// query workspace size
T
qwork
;
int
info
;
phi
::
funcs
::
lapackEig
<
T
,
phi
::
dtype
::
Real
<
T
>>
(
'N'
,
funcs
::
lapackEig
<
T
,
dtype
::
Real
<
T
>>
(
'N'
,
'N'
,
static_cast
<
int
>
(
n_dim
),
input
_matrices
[
0
].
template
data
<
T
>(),
x
_matrices
[
0
].
template
data
<
T
>(),
static_cast
<
int
>
(
n_dim
),
NULL
,
NULL
,
...
...
@@ -240,34 +227,34 @@ class EigvalsKernel : public framework::OpKernel<T> {
1
,
&
qwork
,
-
1
,
static_cast
<
phi
::
dtype
::
Real
<
T
>*>
(
NULL
),
static_cast
<
dtype
::
Real
<
T
>*>
(
NULL
),
&
info
);
int64_t
lwork
=
static_cast
<
int64_t
>
(
qwork
);
Tensor
work
,
rwork
;
try
{
work
.
mutable_data
<
T
>
(
phi
::
make_ddim
({
lwork
}),
ctx
.
GetPlace
());
}
catch
(
memory
::
allocation
::
BadAlloc
&
)
{
LOG
(
WARNING
)
<<
"Failed to allocate Lapack workspace with the optimal "
<<
"memory size = "
<<
lwork
*
sizeof
(
T
)
<<
" bytes, "
<<
"try reallocating a smaller workspace with the minimum "
<<
"required size = "
<<
3
*
n_dim
*
sizeof
(
T
)
<<
" bytes, "
<<
"this may lead to bad performance."
;
lwork
=
3
*
n_dim
;
work
.
mutable_data
<
T
>
(
phi
::
make_ddim
({
lwork
}),
ctx
.
GetPlace
());
}
if
(
framework
::
IsComplexType
(
framework
::
TransToProtoVarType
(
input
->
dtype
())))
{
rwork
.
mutable_data
<
phi
::
dtype
::
Real
<
T
>>
(
phi
::
make_ddim
({
n_dim
<<
1
}),
ctx
.
GetPlace
());
DenseTensor
work
,
rwork
;
work
.
Resize
(
make_ddim
({
lwork
}));
ctx
.
template
Alloc
<
T
>(
&
work
);
if
(
IsComplexType
(
x
.
dtype
()))
{
rwork
.
Resize
(
make_ddim
({
n_dim
<<
1
}));
ctx
.
template
Alloc
<
dtype
::
Real
<
T
>
>
(
&
rwork
);
}
for
(
int64_t
i
=
0
;
i
<
n_batch
;
++
i
)
{
LapackEigvals
<
DeviceContext
,
T
>
(
ctx
,
input_matrices
[
i
],
&
outp
ut_vectors
[
i
],
&
work
,
&
rwork
);
LapackEigvals
<
T
,
Context
>
(
ctx
,
x_matrices
[
i
],
&
o
ut_vectors
[
i
],
&
work
,
&
rwork
);
}
output
->
Resize
(
output_dims
);
}
};
}
// namespace operators
}
// namespace paddle
out
->
Resize
(
out_dims
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
eigvals
,
CPU
,
ALL_LAYOUT
,
phi
::
EigvalsKernel
,
float
,
double
,
phi
::
dtype
::
complex
<
float
>
,
phi
::
dtype
::
complex
<
double
>
)
{}
paddle/phi/kernels/eigvals_kernel.h
0 → 100644
浏览文件 @
95474815
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
EigvalsKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/ops/compat/eigvals_sig.cc
0 → 100644
浏览文件 @
95474815
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
EigvalsOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"eigvals"
,
{
"X"
},
{},
{
"Out"
});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
eigvals
,
phi
::
EigvalsOpArgumentMapping
);
python/paddle/tensor/linalg.py
浏览文件 @
95474815
...
...
@@ -2339,7 +2339,9 @@ def eigvals(x, name=None):
"The last two dimensions of Input(x) should be equal, but received x's shape = {}"
.
format
(
x_shape
))
if
paddle
.
in_dynamic_mode
():
if
in_dygraph_mode
():
return
_C_ops
.
final_state_eigvals
(
x
)
elif
paddle
.
in_dynamic_mode
():
return
_C_ops
.
eigvals
(
x
)
helper
=
LayerHelper
(
'eigvals'
,
**
locals
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录