未验证 提交 cdbfeff4 编写于 作者: W wuyefeilin 提交者: GitHub

[PHI] Move eigvalsh op to phi (#44559)

* mv eigvalsh op
上级 15ce2c1b
......@@ -12,7 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigvalsh_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -22,43 +26,6 @@ using framework::Tensor;
class EigvalshOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigvalsh");
OP_INOUT_CHECK(
ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues", "Eigvalsh");
auto input_dim = ctx->GetInputDim("X");
auto rank = input_dim.size();
PADDLE_ENFORCE_GE(rank,
2,
platform::errors::InvalidArgument(
"The Input(X) should have at least 2 dimensions."
"But received a %d dimension tensor.",
rank));
PADDLE_ENFORCE_EQ(
input_dim[rank - 2],
input_dim[rank - 1],
platform::errors::InvalidArgument(
"Eigvalsh op is designed for square matrix, consequently"
"inner-most 2 dimensions of Input(X) should be symmetric."
"But received X's shape[-2] = %d and shape[-1] = %d.",
input_dim[rank - 2],
input_dim[rank - 1]));
std::vector<int64_t> values_dim;
for (auto i = 0; i < rank - 1; i++) {
values_dim.emplace_back(input_dim[i]);
}
ctx->SetOutputDim("Eigenvalues", phi::make_ddim(values_dim));
if (ctx->HasOutput("Eigenvectors")) {
ctx->SetOutputDim("Eigenvectors", input_dim);
}
}
};
class EigvalshOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -100,20 +67,6 @@ class EigvalshGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(
ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", "EigvalshGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")),
"Input",
"Eigenvalues@GRAD",
"EigvalshGrad");
auto dims = ctx->GetInputDim("Eigenvectors");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -144,30 +97,19 @@ class EigvalshGradOpMaker : public framework::SingleGradOpMaker<T> {
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(eigvalsh,
EigvalshInferShapeFunctor,
PD_INFER_META(phi::EigvalshInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(eigvalsh_grad,
EigvalshGradInferShapeFunctor,
PD_INFER_META(phi::EigvalshGradInferMeta));
REGISTER_OPERATOR(eigvalsh,
ops::EigvalshOp,
ops::EigvalshOpMaker,
ops::EigvalshGradOpMaker<paddle::framework::OpDesc>,
ops::EigvalshGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(eigvalsh_grad, ops::EigvalshGradOp);
REGISTER_OP_CPU_KERNEL(eigvalsh,
ops::EigvalshKernel<phi::CPUContext, float, float>,
ops::EigvalshKernel<phi::CPUContext, double, double>,
ops::EigvalshKernel<phi::CPUContext,
float,
paddle::platform::complex<float>>,
ops::EigvalshKernel<phi::CPUContext,
double,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
eigvalsh_grad,
ops::EigvalshGradKernel<phi::CPUContext, float, float>,
ops::EigvalshGradKernel<phi::CPUContext, double, double>,
ops::EigvalshGradKernel<phi::CPUContext,
float,
paddle::platform::complex<float>>,
ops::EigvalshGradKernel<phi::CPUContext,
double,
paddle::platform::complex<double>>);
ops::EigvalshGradOpMaker<paddle::imperative::OpBase>,
EigvalshInferShapeFunctor);
REGISTER_OPERATOR(eigvalsh_grad,
ops::EigvalshGradOp,
EigvalshGradInferShapeFunctor);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/eigen_values_vectors.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T,
int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename ValueType, typename T>
class EigvalshKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto input = ctx.Input<Tensor>("X");
auto output_w = ctx.Output<Tensor>("Eigenvalues");
std::string lower = ctx.Attr<std::string>("UPLO");
bool is_lower = (lower == "L");
bool is_test = ctx.Attr<bool>("is_test");
math::MatrixEighFunctor<DeviceContext, T> functor;
if (is_test) {
functor(ctx, *input, output_w, nullptr, is_lower, false);
} else {
auto output_v = ctx.Output<Tensor>("Eigenvectors");
functor(ctx, *input, output_w, output_v, is_lower, true);
}
}
};
template <typename DeviceContext, typename ValueType, typename T>
class EigvalshGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto& output_v = *ctx.Input<Tensor>("Eigenvectors");
auto& output_w_grad =
*ctx.Input<Tensor>(framework::GradVarName("Eigenvalues"));
auto dito =
math::DeviceIndependenceTensorOperations<DeviceContext, T, ValueType>(
ctx);
auto tV = dito.Transpose(dito.Conj(output_v));
// compute elementwise multiply of output_v and output_w_grad
x_grad.mutable_data<T>(output_v.dims(), ctx.GetPlace());
auto output_v_vector = EigenVector<T>::Flatten(output_v);
auto output_w_grad_vector = EigenVector<ValueType>::Flatten(output_w_grad);
auto result_vector = EigenVector<T>::Flatten(x_grad);
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
std::vector<int> broadcast_factor;
broadcast_factor.push_back(output_v.dims().at(output_v.dims().size() - 1));
result_vector.device(place) =
output_v_vector * output_w_grad_vector.broadcast(broadcast_factor);
x_grad = dito.Matmul(x_grad, tV);
}
};
} // namespace operators
} // namespace paddle
......@@ -682,6 +682,15 @@
kernel :
func : eigvals
- api : eigvalsh
args : (Tensor x, str uplo, bool is_test)
output : Tensor(eigenvalues), Tensor(eigenvectors)
infer_meta :
func : EigvalshInferMeta
kernel :
func : eigvalsh
backward : eigvalsh_grad
- api : einsum
args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
......
......@@ -667,6 +667,18 @@
data_transform:
skip_transform : out_w, out_w_grad
- backward_api : eigvalsh_grad
forward : eigvalsh (Tensor x, str uplo, bool is_test) -> Tensor(eigenvalues), Tensor(eigenvectors)
args : (Tensor eigenvectors, Tensor eigenvalues_grad, str uplo, bool is_test)
output : Tensor(x_grad)
infer_meta :
func : EigvalshGradInferMeta
kernel :
func : eigvalsh_grad
data_type : eigenvectors
data_transform :
skip_transform : eigenvalues_grad
- backward_api : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape)
args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation)
......
......@@ -263,6 +263,18 @@ void EigGradInferMeta(const MetaTensor& out_w,
}
}
void EigvalshGradInferMeta(const MetaTensor& out_v,
const MetaTensor& out_w_grad,
const std::string& uplo,
bool is_test,
MetaTensor* x_grad) {
auto dims = out_v.dims();
if (x_grad != nullptr) {
x_grad->set_dims(dims);
x_grad->set_dtype(out_v.dtype());
}
}
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -126,6 +126,12 @@ void EigGradInferMeta(const MetaTensor& out_w,
const MetaTensor& dout_v,
MetaTensor* dx);
void EigvalshGradInferMeta(const MetaTensor& out_v,
const MetaTensor& out_w_grad,
const std::string& uplo,
bool is_test,
MetaTensor* x_grad);
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -622,6 +622,46 @@ void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) {
out->set_dtype(out_dtype);
}
void EigvalshInferMeta(const MetaTensor& x,
const std::string& uplo,
bool is_test,
MetaTensor* out_w,
MetaTensor* out_v) {
auto input_dim = x.dims();
auto rank = input_dim.size();
PADDLE_ENFORCE_GE(
rank,
2,
errors::InvalidArgument("The Input(X) should have at least 2 dimensions."
"But received a %d dimension tensor.",
rank));
PADDLE_ENFORCE_EQ(
input_dim[rank - 2],
input_dim[rank - 1],
errors::InvalidArgument(
"Eigvalsh op is designed for square matrix, consequently"
"inner-most 2 dimensions of Input(X) should be symmetric."
"But received X's shape[-2] = %d and shape[-1] = %d.",
input_dim[rank - 2],
input_dim[rank - 1]));
std::vector<int64_t> values_dim;
for (auto i = 0; i < rank - 1; i++) {
values_dim.emplace_back(input_dim[i]);
}
if (out_w != nullptr) {
out_w->set_dims(phi::make_ddim(values_dim));
out_w->set_dtype(dtype::ToReal(x.dtype()));
}
if (out_v != nullptr) {
out_v->set_dims(input_dim);
out_v->set_dtype(x.dtype());
}
}
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out) {
......
......@@ -103,6 +103,12 @@ void EigvalsInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void EigvalshInferMeta(const MetaTensor& x,
const std::string& uplo,
bool is_test,
MetaTensor* out_w,
MetaTensor* out_v);
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/eigvalsh_grad_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/eigvalsh_grad_kernel_impl.h"
PD_REGISTER_KERNEL(eigvalsh_grad,
CPU,
ALL_LAYOUT,
phi::EigvalshGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/eigvalsh_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/eigvalsh_kernel_impl.h"
PD_REGISTER_KERNEL(eigvalsh,
CPU,
ALL_LAYOUT,
phi::EigvalshKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void EigvalshGradKernel(const Context& dev_ctx,
const DenseTensor& out_v,
const DenseTensor& out_w_grad,
const std::string& uplo,
bool is_test,
DenseTensor* x_grad);
} // namespace phi
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,27 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigvalsh_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(eigvalsh,
ops::EigvalshKernel<phi::GPUContext, float, float>,
ops::EigvalshKernel<phi::GPUContext, double, double>,
ops::EigvalshKernel<phi::GPUContext,
float,
paddle::platform::complex<float>>,
ops::EigvalshKernel<phi::GPUContext,
double,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
eigvalsh_grad,
ops::EigvalshGradKernel<phi::GPUContext, float, float>,
ops::EigvalshGradKernel<phi::GPUContext, double, double>,
ops::EigvalshGradKernel<phi::GPUContext,
float,
paddle::platform::complex<float>>,
ops::EigvalshGradKernel<phi::GPUContext,
double,
paddle::platform::complex<double>>);
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void EigvalshKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& uplo,
bool is_test,
DenseTensor* out_w,
DenseTensor* out_v);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/eigvalsh_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/eigvalsh_grad_kernel_impl.h"
PD_REGISTER_KERNEL(eigvalsh_grad,
GPU,
ALL_LAYOUT,
phi::EigvalshGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/eigvalsh_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/eigvalsh_kernel_impl.h"
PD_REGISTER_KERNEL(eigvalsh, // cuda_only
GPU,
ALL_LAYOUT,
phi::EigvalshKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#pragma once
#include "paddle/phi/kernels/eigvalsh_grad_kernel.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
template <typename T, typename Context>
void EigvalshGradKernel(const Context& dev_ctx,
const DenseTensor& out_v,
const DenseTensor& out_w_grad,
const std::string& uplo,
bool is_test,
DenseTensor* x_grad) {
auto tV = phi::TransposeLast2Dim<T>(dev_ctx, phi::Conj<T>(dev_ctx, out_v));
x_grad->Resize(out_v.dims());
dev_ctx.template Alloc<T>(x_grad);
auto output_v_vector = EigenVector<T>::Flatten(out_v);
auto output_w_grad_vector =
EigenVector<phi::dtype::Real<T>>::Flatten(out_w_grad);
auto result_vector = EigenVector<T>::Flatten(*x_grad);
auto& place = *dev_ctx.eigen_device();
std::vector<int> broadcast_factor;
broadcast_factor.push_back(out_v.dims().at(out_v.dims().size() - 1));
result_vector.device(place) =
output_v_vector * output_w_grad_vector.broadcast(broadcast_factor);
*x_grad = phi::Matmul<T>(dev_ctx, *x_grad, tV);
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/kernels/eigvalsh_kernel.h"
#include "paddle/phi/kernels/funcs/values_vectors_functor.h"
namespace phi {
template <typename T, typename Context>
void EigvalshKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& uplo,
bool is_test,
DenseTensor* out_w,
DenseTensor* out_v) {
bool is_lower = (uplo == "L");
phi::funcs::MatrixEighFunctor<Context, T> functor;
if (is_test) {
functor(dev_ctx, x, out_w, nullptr, is_lower, false);
} else {
functor(dev_ctx, x, out_w, out_v, is_lower, true);
}
}
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EigvalshOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"eigvalsh", {"X"}, {"UPLO", "is_test"}, {"Eigenvalues", "Eigenvectors"});
}
KernelSignature EigvalshGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("eigvalsh_grad",
{"Eigenvectors", "Eigenvalues@GRAD"},
{"UPLO", "is_test"},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(eigvalsh, phi::EigvalshOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(eigvalsh_grad, phi::EigvalshGradOpArgumentMapping);
......@@ -51,6 +51,8 @@ class TestEigvalshOp(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "eigvalsh"
self.python_api = paddle.linalg.eigvalsh
self.python_out_sig = ['Eigenvalues']
self.init_input()
self.init_config()
np.random.seed(123)
......@@ -69,10 +71,10 @@ class TestEigvalshOp(OpTest):
def test_check_output(self):
# Vectors in posetive or negative is equivalent
self.check_output(no_check_set=['Eigenvectors'])
self.check_output(no_check_set=['Eigenvectors'], check_eager=True)
def test_grad(self):
self.check_grad(["X"], ["Eigenvalues"])
self.check_grad(["X"], ["Eigenvalues"], check_eager=True)
class TestEigvalshUPLOCase(TestEigvalshOp):
......
......@@ -3052,7 +3052,11 @@ def eigvalsh(x, UPLO='L', name=None):
print(out_value)
#[0.17157288, 5.82842712]
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
values, _ = _C_ops.final_state_eigvalsh(x, UPLO, x.stop_gradient)
return values
elif paddle.in_dynamic_mode():
is_test = x.stop_gradient
values, _ = _C_ops.eigvalsh(x, 'UPLO', UPLO, 'is_test', is_test)
return values
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册