Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
da441363
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
da441363
编写于
9月 18, 2021
作者:
C
crystal
提交者:
GitHub
9月 18, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
FixEighOP; Unified MatrixEighFunctor function (#35812)
上级
a1b6ae26
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
68 addition
and
116 deletion
+68
-116
paddle/fluid/operators/eigh_op.cc
paddle/fluid/operators/eigh_op.cc
+5
-8
paddle/fluid/operators/eigh_op.cu
paddle/fluid/operators/eigh_op.cu
+6
-26
paddle/fluid/operators/eigh_op.h
paddle/fluid/operators/eigh_op.h
+13
-20
paddle/fluid/operators/math/eigen_values_vectors.h
paddle/fluid/operators/math/eigen_values_vectors.h
+18
-16
paddle/fluid/operators/svd_helper.h
paddle/fluid/operators/svd_helper.h
+24
-44
python/paddle/fluid/tests/unittests/test_eigh_op.py
python/paddle/fluid/tests/unittests/test_eigh_op.py
+2
-2
未找到文件。
paddle/fluid/operators/eigh_op.cc
浏览文件 @
da441363
...
...
@@ -47,12 +47,9 @@ class EighOp : public framework::OperatorWithKernel {
input_dim
[
rank
-
2
],
input_dim
[
rank
-
1
]));
std
::
vector
<
int64_t
>
values_dim
;
if
(
rank
>
2
)
{
for
(
auto
i
=
0
;
i
<
rank
-
1
;
i
++
)
{
values_dim
.
emplace_back
(
input_dim
[
i
]);
}
}
else
{
values_dim
=
{
input_dim
[
1
]};
for
(
auto
i
=
0
;
i
<
rank
-
1
;
i
++
)
{
values_dim
.
emplace_back
(
input_dim
[
i
]);
}
ctx
->
SetOutputDim
(
"Eigenvalues"
,
framework
::
make_ddim
(
values_dim
));
...
...
@@ -99,9 +96,9 @@ class EighGradOp : public framework::OperatorWithKernel {
"EighGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Eigenvectors"
),
"Input"
,
"Eigenvectors"
,
"EighGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
s
(
framework
::
GradVarName
(
"Eigenvalues"
)),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Eigenvalues"
)),
"Input"
,
"Eigenvalues@GRAD"
,
"EighGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
s
(
framework
::
GradVarName
(
"Eigenvectors"
)),
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Eigenvectors"
)),
"Input"
,
"Eigenvectors@GRAD"
,
"EighGrad"
);
auto
dims
=
ctx
->
GetInputDim
(
"Eigenvectors"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
...
...
paddle/fluid/operators/eigh_op.cu
浏览文件 @
da441363
...
...
@@ -14,34 +14,14 @@ limitations under the License. */
#include "paddle/fluid/operators/eigh_op.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
ValueType
,
typename
T
>
class
EighGPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_var
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
output_w_var
=
ctx
.
Output
<
Tensor
>
(
"Eigenvalues"
);
auto
output_v_var
=
ctx
.
Output
<
Tensor
>
(
"Eigenvectors"
);
std
::
string
lower
=
ctx
.
Attr
<
std
::
string
>
(
"UPLO"
);
bool
is_lower
=
(
lower
==
"L"
);
math
::
MatrixEighFunctor
<
ValueType
,
T
>
functor
;
functor
(
ctx
,
*
input_var
,
output_w_var
,
output_v_var
,
is_lower
,
true
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
eigh
,
ops
::
EighGPUKernel
<
float
,
float
>
,
ops
::
EighGPUKernel
<
double
,
double
>
,
ops
::
EighGPUKernel
<
float
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
EighGPUKernel
<
double
,
paddle
::
platform
::
complex
<
double
>>
);
eigh
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
float
>
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
double
>
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
,
paddle
::
platform
::
complex
<
float
>>
,
ops
::
EighKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
,
paddle
::
platform
::
complex
<
double
>>
);
REGISTER_OP_CUDA_KERNEL
(
eigh_grad
,
...
...
paddle/fluid/operators/eigh_op.h
浏览文件 @
da441363
...
...
@@ -22,24 +22,17 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenVector
=
framework
::
EigenVector
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
DeviceContext
,
typename
ValueType
,
typename
T
>
class
EighKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input
_var
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
output_w
_var
=
ctx
.
Output
<
Tensor
>
(
"Eigenvalues"
);
auto
output_v
_var
=
ctx
.
Output
<
Tensor
>
(
"Eigenvectors"
);
auto
input
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
output_w
=
ctx
.
Output
<
Tensor
>
(
"Eigenvalues"
);
auto
output_v
=
ctx
.
Output
<
Tensor
>
(
"Eigenvectors"
);
std
::
string
lower
=
ctx
.
Attr
<
std
::
string
>
(
"UPLO"
);
bool
is_lower
=
(
lower
==
"L"
);
math
::
MatrixEighFunctor
CPU
<
DeviceContext
,
ValueType
,
T
>
functor
;
functor
(
ctx
,
*
input
_var
,
output_w_var
,
output_v_var
,
is_lower
,
true
);
math
::
MatrixEighFunctor
<
DeviceContext
,
ValueType
,
T
>
functor
;
functor
(
ctx
,
*
input
,
output_w
,
output_v
,
is_lower
,
true
);
}
};
...
...
@@ -49,30 +42,30 @@ class EighGradKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
x_grad
=
*
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
x_grad
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
output_w
_var
=
*
ctx
.
Input
<
Tensor
>
(
"Eigenvalues"
);
auto
&
output_v
_var
=
*
ctx
.
Input
<
Tensor
>
(
"Eigenvectors"
);
auto
&
output_w
=
*
ctx
.
Input
<
Tensor
>
(
"Eigenvalues"
);
auto
&
output_v
=
*
ctx
.
Input
<
Tensor
>
(
"Eigenvectors"
);
auto
&
output_w_grad
=
*
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Eigenvalues"
));
auto
&
output_v_grad
=
*
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Eigenvectors"
));
auto
&
dims
=
output_v
_var
.
dims
();
auto
&
dims
=
output_v
.
dims
();
const
int
m
=
dims
[
dims
.
size
()
-
1
];
auto
dito
=
math
::
DeviceIndependenceTensorOperations
<
DeviceContext
,
T
,
ValueType
>
(
ctx
);
auto
tV
=
dito
.
Transpose
(
dito
.
Conj
(
output_v
_var
));
auto
W
=
dito
.
Sub_
(
dito
.
Unsqueeze
(
output_w_var
,
-
2
),
dito
.
Unsqueeze
(
output_w_var
,
-
1
));
auto
tV
=
dito
.
Transpose
(
dito
.
Conj
(
output_v
));
auto
W
=
dito
.
template
Sub
<
ValueType
>(
dito
.
Unsqueeze
(
output_w
,
-
2
),
dito
.
Unsqueeze
(
output_w
,
-
1
));
Tensor
result
=
dito
.
Matmul
(
tV
,
output_v_grad
);
result
.
mutable_data
<
T
>
(
dims
,
ctx
.
GetPlace
());
std
::
vector
<
int
>
out_shape
=
framework
::
vectorize
<
int
>
(
dims
);
auto
constant
=
dito
.
Fill
(
out_shape
,
0.5
);
result
=
dito
.
Sub
(
result
,
dito
.
Conj
(
dito
.
Transpose
(
result
)));
result
=
dito
.
Mul
(
result
,
constant
);
result
=
dito
.
Div
_
(
result
,
W
);
result
=
dito
.
Div
(
result
,
W
);
result
=
dito
.
DiagFill
(
m
,
m
,
m
,
0
,
output_w_grad
,
result
);
x_grad
=
dito
.
Matmul
(
output_v
_var
,
dito
.
Matmul
(
result
,
tV
));
x_grad
=
dito
.
Matmul
(
output_v
,
dito
.
Matmul
(
result
,
tV
));
}
};
...
...
paddle/fluid/operators/math/eigen_values_vectors.h
浏览文件 @
da441363
...
...
@@ -16,7 +16,6 @@
#include "Eigen/Core"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/svd_helper.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cusolver.h"
...
...
@@ -26,10 +25,6 @@ namespace paddle {
namespace
operators
{
namespace
math
{
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
InputMatrixMap
=
Eigen
::
Map
<
...
...
@@ -67,7 +62,7 @@ inline void ComputeFloatEigenvaluesAndVectors(ValueType *x_data,
eigenvalues
=
eigen_solver
.
eigenvalues
().
transpose
();
if
(
has_vectors
)
{
eigenvectors
=
eigen_solver
.
eigenvectors
()
.
transpose
()
;
eigenvectors
=
eigen_solver
.
eigenvectors
();
}
}
}
...
...
@@ -103,7 +98,7 @@ inline void ComputeComplexEigenvaluesAndVectors(T *x_data,
eigenvalues
=
eigen_solver
.
eigenvalues
().
transpose
();
if
(
has_vectors
)
{
eigenvectors
=
eigen_solver
.
eigenvectors
()
.
transpose
()
;
eigenvectors
=
eigen_solver
.
eigenvectors
();
}
}
}
...
...
@@ -117,11 +112,18 @@ inline int64_t GetBatchSize(framework::DDim dims) {
return
batch_size
;
}
template
<
typename
DeviceContext
,
typename
ValueType
,
typename
T
>
struct
MatrixEighFunctor
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
input
,
Tensor
*
eigen_values
,
Tensor
*
eigen_vectors
,
bool
is_lower
,
bool
has_vectors
);
};
// Calculates the eigenvalues and eigenvectors of Hermitian or real
// symmetric matrices, and uses the variable has_vectors to
// control whether to return the eigenvectors.
template
<
typename
DeviceContext
,
typename
ValueType
,
typename
T
>
struct
MatrixEighFunctor
CPU
{
template
<
typename
ValueType
,
typename
T
>
struct
MatrixEighFunctor
<
platform
::
CPUDeviceContext
,
ValueType
,
T
>
{
public:
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
input
,
Tensor
*
eigen_values
,
Tensor
*
eigen_vectors
,
bool
is_lower
,
...
...
@@ -134,7 +136,8 @@ struct MatrixEighFunctorCPU {
for
(
int64_t
i
=
0
;
i
<
dim_size
-
2
;
i
++
)
{
batch_size
*=
dims
[
i
];
}
auto
dito
=
DeviceIndependenceTensorOperations
<
DeviceContext
,
T
>
(
ctx
);
auto
dito
=
DeviceIndependenceTensorOperations
<
platform
::
CPUDeviceContext
,
T
>
(
ctx
);
Tensor
input_tensor
;
TensorCopy
(
input
,
ctx
.
GetPlace
(),
&
input_tensor
);
if
(
!
is_lower
)
{
...
...
@@ -157,9 +160,6 @@ struct MatrixEighFunctorCPU {
ComputeFloatEigenvaluesAndVectors
<
ValueType
>
(
x_data
,
value_data
,
vector_data
,
batch_size
,
rows
,
rows
,
has_vectors
);
}
if
(
has_vectors
)
{
*
eigen_vectors
=
dito
.
Transpose
(
*
eigen_vectors
);
}
}
};
...
...
@@ -169,7 +169,7 @@ struct MatrixEighFunctorCPU {
// symmetric matrices on GPU, and uses the variable has_vectors
// to control whether to return the eigenvectors.
template
<
typename
ValueType
,
typename
T
>
struct
MatrixEighFunctor
{
struct
MatrixEighFunctor
<
platform
::
CUDADeviceContext
,
ValueType
,
T
>
{
public:
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
Tensor
&
input
,
Tensor
*
eigen_values
,
Tensor
*
eigen_vectors
,
bool
is_lower
,
...
...
@@ -278,7 +278,8 @@ struct MatrixEighFunctor {
#define EVDBUFFER_INSTANCE(ValueType, T, C, CastType) \
template <> \
inline void MatrixEighFunctor<ValueType, T>::EvdBuffer( \
inline void \
MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T>::EvdBuffer( \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \
int *lwork) const { \
...
...
@@ -292,7 +293,8 @@ FUNC_WITH_TYPES(EVDBUFFER_INSTANCE);
#define EVD_INSTANCE(ValueType, T, C, CastType) \
template <> \
inline void MatrixEighFunctor<ValueType, T>::Evd( \
inline void \
MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T>::Evd( \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \
int lwork, int *devInfo) const { \
...
...
paddle/fluid/operators/svd_helper.h
浏览文件 @
da441363
...
...
@@ -289,10 +289,20 @@ struct DeviceIndependenceTensorOperations {
framework
::
Tensor
Div
(
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
)
{
framework
::
Tensor
ret
;
std
::
vector
<
int
>
out_shape
=
GetBroadcastShape
({
&
x
,
&
y
});
ret
.
Resize
(
framework
::
make_ddim
(
out_shape
));
ElementwiseComputeEx
<
DivFunctor
<
T
>
,
DeviceContext
,
T
>
(
context
,
&
x
,
&
y
,
-
1
,
DivFunctor
<
T
>
(),
&
ret
);
if
(
x
.
type
()
!=
y
.
type
())
{
ret
.
mutable_data
<
T
>
(
x
.
dims
(),
context
.
GetPlace
());
auto
x_vector
=
EigenVector
<
T
>::
Flatten
(
x
);
auto
y_vector
=
EigenVector
<
ValueType
>::
Flatten
(
y
);
auto
out_vector
=
EigenVector
<
T
>::
Flatten
(
ret
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
out_vector
.
device
(
place
)
=
x_vector
/
y_vector
;
}
else
{
std
::
vector
<
int
>
out_shape
=
GetBroadcastShape
({
&
x
,
&
y
});
ret
.
Resize
(
framework
::
make_ddim
(
out_shape
));
ElementwiseComputeEx
<
DivFunctor
<
T
>
,
DeviceContext
,
T
>
(
context
,
&
x
,
&
y
,
-
1
,
DivFunctor
<
T
>
(),
&
ret
);
}
return
ret
;
}
framework
::
Tensor
Add
(
const
framework
::
Tensor
&
x
,
...
...
@@ -330,7 +340,8 @@ struct DeviceIndependenceTensorOperations {
NameInTensorMap
inputs
({{
"X"
,
{
&
x
}}});
return
CreateOpRunAndReturnTensor
(
"reduce_max"
,
inputs
,
attrs
,
out_dim
);
}
// Support float and complex type subtraction,the default is T type
template
<
typename
InT
=
T
>
framework
::
Tensor
Sub
(
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
)
{
framework
::
Tensor
ret
;
...
...
@@ -340,18 +351,18 @@ struct DeviceIndependenceTensorOperations {
#if defined(__NVCC__) || defined(__HIPCC__)
// For GPU, there is no need to define XxxInverseFunctor and call
// ElementwiseComputeEx in two branches.
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
context
,
&
x
,
&
y
,
-
1
,
SubFunctor
<
T
>
(),
&
ret
);
ElementwiseComputeEx
<
SubFunctor
<
InT
>
,
DeviceContext
,
In
T
>
(
context
,
&
x
,
&
y
,
-
1
,
SubFunctor
<
In
T
>
(),
&
ret
);
#endif
}
else
{
if
(
x
.
dims
().
size
()
>=
y
.
dims
().
size
())
{
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
context
,
&
x
,
&
y
,
-
1
,
SubFunctor
<
T
>
(),
&
ret
);
ElementwiseComputeEx
<
SubFunctor
<
InT
>
,
DeviceContext
,
In
T
>
(
context
,
&
x
,
&
y
,
-
1
,
SubFunctor
<
In
T
>
(),
&
ret
);
}
else
{
ElementwiseComputeEx
<
InverseSubFunctor
<
T
>
,
DeviceContext
,
T
>
(
// This is copyed from elementwise_sub, which means we
// need reverse will xrank < yrank
context
,
&
x
,
&
y
,
-
1
,
InverseSubFunctor
<
T
>
(),
&
ret
);
// This is copyed from elementwise_sub, which means we
// need reverse will xrank < yrank
ElementwiseComputeEx
<
InverseSubFunctor
<
InT
>
,
DeviceContext
,
InT
>
(
context
,
&
x
,
&
y
,
-
1
,
InverseSubFunctor
<
In
T
>
(),
&
ret
);
}
}
return
ret
;
...
...
@@ -461,37 +472,6 @@ struct DeviceIndependenceTensorOperations {
return
out
;
}
// Support x and y are different data types
Tensor
Div_
(
const
Tensor
&
x
,
const
Tensor
&
y
)
{
Tensor
out
;
out
.
mutable_data
<
T
>
(
x
.
dims
(),
context
.
GetPlace
());
auto
x_vector
=
EigenVector
<
T
>::
Flatten
(
x
);
auto
y_vector
=
EigenVector
<
ValueType
>::
Flatten
(
y
);
auto
out_vector
=
EigenVector
<
T
>::
Flatten
(
out
);
auto
&
place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
out_vector
.
device
(
place
)
=
x_vector
/
y_vector
;
return
out
;
}
framework
::
Tensor
Sub_
(
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
)
{
framework
::
Tensor
ret
;
std
::
vector
<
int
>
out_shape
=
GetBroadcastShape
({
&
x
,
&
y
});
ret
.
Resize
(
framework
::
make_ddim
(
out_shape
));
if
(
x
.
dims
().
size
()
>=
y
.
dims
().
size
())
{
ElementwiseComputeEx
<
SubFunctor
<
ValueType
>
,
DeviceContext
,
ValueType
>
(
context
,
&
x
,
&
y
,
-
1
,
SubFunctor
<
ValueType
>
(),
&
ret
);
}
else
{
ElementwiseComputeEx
<
InverseSubFunctor
<
ValueType
>
,
DeviceContext
,
ValueType
>
(
// This is copyed from elementwise_sub, which means we
// need reverse will xrank < yrank
context
,
&
x
,
&
y
,
-
1
,
InverseSubFunctor
<
ValueType
>
(),
&
ret
);
}
return
ret
;
}
private:
const
framework
::
ExecutionContext
&
context
;
BlasT
<
DeviceContext
,
T
>
GetBlas
()
{
...
...
python/paddle/fluid/tests/unittests/test_eigh_op.py
浏览文件 @
da441363
...
...
@@ -140,7 +140,7 @@ class TestEighAPI(unittest.TestCase):
self
.
check_static_complex_result
()
def
test_in_dynamic_mode
(
self
):
paddle
.
disable_static
(
self
.
place
)
paddle
.
disable_static
()
input_real_data
=
paddle
.
to_tensor
(
self
.
real_data
)
expected_w
,
expected_v
=
np
.
linalg
.
eigh
(
self
.
real_data
)
actual_w
,
actual_v
=
paddle
.
linalg
.
eigh
(
input_real_data
)
...
...
@@ -152,7 +152,7 @@ class TestEighAPI(unittest.TestCase):
self
.
compare_result
(
actual_w
,
actual_v
.
numpy
(),
expected_w
,
expected_v
)
def
test_eigh_grad
(
self
):
paddle
.
disable_static
(
self
.
place
)
paddle
.
disable_static
()
x
=
paddle
.
to_tensor
(
self
.
complex_data
,
stop_gradient
=
False
)
w
,
v
=
paddle
.
linalg
.
eigh
(
x
)
(
w
.
sum
()
+
paddle
.
abs
(
v
).
sum
()).
backward
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录