未验证 提交 95474815 编写于 作者: R Ruibiao Chen 提交者: GitHub

Move eigvals OP to PHI (#44183)

* Move eigvals OP to PHI

* Fix CI errors

* Fix CI errors
上级 0a5d625b
...@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/eigvals_op.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,59 +37,17 @@ class EigvalsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -36,59 +37,17 @@ class EigvalsOpMaker : public framework::OpProtoAndCheckerMaker {
class EigvalsOp : public framework::OperatorWithKernel { class EigvalsOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigvals");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Eigvals");
DDim x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(),
2,
platform::errors::InvalidArgument(
"The dimensions of Input(X) for Eigvals operator "
"should be at least 2, "
"but received X's dimension = %d, X's shape = [%s].",
x_dims.size(),
x_dims));
if (ctx->IsRuntime() || !phi::contain_unknown_dim(x_dims)) {
int last_dim = x_dims.size() - 1;
PADDLE_ENFORCE_EQ(x_dims[last_dim],
x_dims[last_dim - 1],
platform::errors::InvalidArgument(
"The last two dimensions of Input(X) for Eigvals "
"operator should be equal, "
"but received X's shape = [%s].",
x_dims));
}
auto output_dims = vectorize(x_dims);
output_dims.resize(x_dims.size() - 1);
ctx->SetOutputDim("Out", phi::make_ddim(output_dims));
}
}; };
class EigvalsOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const {
auto input_dtype = ctx->GetInputDataType("X");
auto output_dtype = framework::IsComplexType(input_dtype)
? input_dtype
: framework::ToComplexType(input_dtype);
ctx->SetOutputDataType("Out", output_dtype);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(eigvals,
EigvalsInferShapeFunctor,
PD_INFER_META(phi::EigvalsInferMeta));
REGISTER_OPERATOR(eigvals, REGISTER_OPERATOR(eigvals,
ops::EigvalsOp, ops::EigvalsOp,
ops::EigvalsOpMaker, ops::EigvalsOpMaker,
ops::EigvalsOpVarTypeInference); EigvalsInferShapeFunctor);
REGISTER_OP_CPU_KERNEL(
eigvals,
ops::EigvalsKernel<phi::CPUContext, float>,
ops::EigvalsKernel<phi::CPUContext, double>,
ops::EigvalsKernel<phi::CPUContext, paddle::platform::complex<float>>,
ops::EigvalsKernel<phi::CPUContext, paddle::platform::complex<double>>);
...@@ -536,6 +536,14 @@ ...@@ -536,6 +536,14 @@
func : eigh func : eigh
backward : eigh_grad backward : eigh_grad
- api : eigvals
args : (Tensor x)
output : Tensor
infer_meta :
func : EigvalsInferMeta
kernel :
func : eigvals
- api : einsum - api : einsum
args : (Tensor[] x, str equation) args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()} output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
......
...@@ -80,4 +80,21 @@ inline void VisitDataTypeTiny(phi::DataType type, Visitor visitor) { ...@@ -80,4 +80,21 @@ inline void VisitDataTypeTiny(phi::DataType type, Visitor visitor) {
"Not supported phi::DataType(%d) as data type.", static_cast<int>(type))); "Not supported phi::DataType(%d) as data type.", static_cast<int>(type)));
} }
inline bool IsComplexType(const DataType& type) {
return (type == DataType::COMPLEX64 || type == DataType::COMPLEX128);
}
inline DataType ToComplexType(const DataType& type) {
switch (type) {
case DataType::FLOAT32:
return DataType::COMPLEX64;
case DataType::FLOAT64:
return DataType::COMPLEX128;
default:
PADDLE_THROW(errors::Unimplemented(
"Can not transform data type (%s) to complex type, now only support "
"float32 and float64 real value.",
type));
}
}
} // namespace phi } // namespace phi
...@@ -399,6 +399,39 @@ void EighInferMeta(const MetaTensor& x, ...@@ -399,6 +399,39 @@ void EighInferMeta(const MetaTensor& x,
out_v->set_dims(input_dim); out_v->set_dims(input_dim);
} }
void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) {
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(x_dims.size(),
2,
errors::InvalidArgument(
"The dimensions of Input(X) for Eigvals operator "
"should be at least 2, "
"but received X's dimension = %d, X's shape = [%s].",
x_dims.size(),
x_dims));
if (config.is_runtime || !phi::contain_unknown_dim(x_dims)) {
int last_dim = x_dims.size() - 1;
PADDLE_ENFORCE_EQ(x_dims[last_dim],
x_dims[last_dim - 1],
errors::InvalidArgument(
"The last two dimensions of Input(X) for Eigvals "
"operator should be equal, "
"but received X's shape = [%s].",
x_dims));
}
auto out_dims = vectorize(x_dims);
out_dims.resize(x_dims.size() - 1);
const DataType& x_dtype = x.dtype();
const DataType& out_dtype =
IsComplexType(x_dtype) ? x_dtype : ToComplexType(x_dtype);
out->set_dims(make_ddim(out_dims));
out->set_dtype(out_dtype);
}
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation, const std::string& equation,
MetaTensor* out, MetaTensor* out,
......
...@@ -80,6 +80,10 @@ void EighInferMeta(const MetaTensor& x, ...@@ -80,6 +80,10 @@ void EighInferMeta(const MetaTensor& x,
MetaTensor* out_w, MetaTensor* out_w,
MetaTensor* out_v); MetaTensor* out_v);
void EigvalsInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs, void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation, const std::string& equation,
MetaTensor* out, MetaTensor* out,
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -12,23 +12,17 @@ ...@@ -12,23 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #include "paddle/phi/kernels/eigvals_kernel.h"
#include <string> #include "paddle/phi/backends/cpu/cpu_context.h"
#include <vector> #include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/phi/core/utils/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" #include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
namespace paddle { namespace phi {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename T, typename enable = void> template <typename T, typename enable = void>
struct PaddleComplex; struct PaddleComplex;
...@@ -37,79 +31,60 @@ template <typename T> ...@@ -37,79 +31,60 @@ template <typename T>
struct PaddleComplex< struct PaddleComplex<
T, T,
typename std::enable_if<std::is_floating_point<T>::value>::type> { typename std::enable_if<std::is_floating_point<T>::value>::type> {
using type = paddle::platform::complex<T>; using type = dtype::complex<T>;
}; };
template <typename T> template <typename T>
struct PaddleComplex< struct PaddleComplex<
T, T,
typename std::enable_if< typename std::enable_if<
std::is_same<T, platform::complex<float>>::value || std::is_same<T, dtype::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type> { std::is_same<T, dtype::complex<double>>::value>::type> {
using type = T; using type = T;
}; };
template <typename T> template <typename T>
using PaddleCType = typename PaddleComplex<T>::type; using PaddleCType = typename PaddleComplex<T>::type;
template <typename T> template <typename T>
using Real = typename phi::dtype::Real<T>; using Real = typename dtype::Real<T>;
static void SpiltBatchSquareMatrix(const Tensor& input,
std::vector<Tensor>* output) {
DDim input_dims = input.dims();
int last_dim = input_dims.size() - 1;
int n_dim = input_dims[last_dim];
DDim flattened_input_dims, flattened_output_dims;
if (input_dims.size() > 2) {
flattened_input_dims =
phi::flatten_to_3d(input_dims, last_dim - 1, last_dim);
} else {
flattened_input_dims = phi::make_ddim({1, n_dim, n_dim});
}
Tensor flattened_input; inline void CheckLapackEigResult(const int info, const std::string& name) {
flattened_input.ShareDataWith(input); PADDLE_ENFORCE_LE(
flattened_input.Resize(flattened_input_dims); info,
(*output) = flattened_input.Split(1, 0); 0,
} errors::PreconditionNotMet("The QR algorithm failed to compute all the "
"eigenvalues in function %s.",
static void CheckLapackEigResult(const int info, const std::string& name) { name.c_str()));
PADDLE_ENFORCE_LE(info,
0,
platform::errors::PreconditionNotMet(
"The QR algorithm failed to compute all the "
"eigenvalues in function %s.",
name.c_str()));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
info, info,
0, 0,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The %d-th argument has an illegal value in function %s.", "The %d-th argument has an illegal value in function %s.",
-info, -info,
name.c_str())); name.c_str()));
} }
template <typename DeviceContext, typename T> template <typename T, typename Context>
static typename std::enable_if<std::is_floating_point<T>::value>::type typename std::enable_if<std::is_floating_point<T>::value>::type LapackEigvals(
LapackEigvals(const framework::ExecutionContext& ctx, const Context& ctx,
const Tensor& input, const DenseTensor& input,
Tensor* output, DenseTensor* output,
Tensor* work, DenseTensor* work,
Tensor* rwork /*unused*/) { DenseTensor* rwork /*unused*/) {
Tensor a; // will be overwritten when lapackEig exit DenseTensor a; // will be overwritten when lapackEig exit
framework::TensorCopy(input, input.place(), &a); Copy(ctx, input, input.place(), /*blocking=*/true, &a);
Tensor w; DenseTensor w;
int64_t n_dim = input.dims()[1]; int64_t n_dim = input.dims()[1];
auto* w_data = w.Resize(make_ddim({n_dim << 1}));
w.mutable_data<T>(phi::make_ddim({n_dim << 1}), ctx.GetPlace()); T* w_data = ctx.template Alloc<T>(&w);
int64_t work_mem = work->memory_size(); int64_t work_mem = work->memory_size();
int64_t required_work_mem = 3 * n_dim * sizeof(T); int64_t required_work_mem = 3 * n_dim * sizeof(T);
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
work_mem, work_mem,
3 * n_dim * sizeof(T), 3 * n_dim * sizeof(T),
platform::errors::InvalidArgument( errors::InvalidArgument(
"The memory size of the work tensor in LapackEigvals function " "The memory size of the work tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, " "should be at least %" PRId64 " bytes, "
"but received work\'s memory size = %" PRId64 " bytes.", "but received work\'s memory size = %" PRId64 " bytes.",
...@@ -132,30 +107,28 @@ LapackEigvals(const framework::ExecutionContext& ctx, ...@@ -132,30 +107,28 @@ LapackEigvals(const framework::ExecutionContext& ctx,
static_cast<T*>(NULL), static_cast<T*>(NULL),
&info); &info);
std::string name = "framework::platform::dynload::dgeev_"; std::string name = "phi::backend::dynload::dgeev_";
if (framework::TransToProtoVarType(input.dtype()) == if (input.dtype() == DataType::FLOAT64) {
framework::proto::VarType::FP64) { name = "phi::backend::dynload::sgeev_";
name = "framework::platform::dynload::sgeev_";
} }
CheckLapackEigResult(info, name); CheckLapackEigResult(info, name);
platform::ForRange<DeviceContext> for_range( funcs::ForRange<Context> for_range(ctx, n_dim);
ctx.template device_context<DeviceContext>(), n_dim); funcs::RealImagToComplexFunctor<PaddleCType<T>> functor(
phi::funcs::RealImagToComplexFunctor<PaddleCType<T>> functor(
w_data, w_data + n_dim, output->template data<PaddleCType<T>>(), n_dim); w_data, w_data + n_dim, output->template data<PaddleCType<T>>(), n_dim);
for_range(functor); for_range(functor);
} }
template <typename DeviceContext, typename T> template <typename T, typename Context>
typename std::enable_if<std::is_same<T, platform::complex<float>>::value || typename std::enable_if<std::is_same<T, dtype::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type std::is_same<T, dtype::complex<double>>::value>::type
LapackEigvals(const framework::ExecutionContext& ctx, LapackEigvals(const Context& ctx,
const Tensor& input, const DenseTensor& input,
Tensor* output, DenseTensor* output,
Tensor* work, DenseTensor* work,
Tensor* rwork) { DenseTensor* rwork) {
Tensor a; // will be overwritten when lapackEig exit DenseTensor a; // will be overwritten when lapackEig exit
framework::TensorCopy(input, input.place(), &a); Copy(ctx, input, input.place(), /*blocking=*/true, &a);
int64_t work_mem = work->memory_size(); int64_t work_mem = work->memory_size();
int64_t n_dim = input.dims()[1]; int64_t n_dim = input.dims()[1];
...@@ -163,7 +136,7 @@ LapackEigvals(const framework::ExecutionContext& ctx, ...@@ -163,7 +136,7 @@ LapackEigvals(const framework::ExecutionContext& ctx,
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
work_mem, work_mem,
3 * n_dim * sizeof(T), 3 * n_dim * sizeof(T),
platform::errors::InvalidArgument( errors::InvalidArgument(
"The memory size of the work tensor in LapackEigvals function " "The memory size of the work tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, " "should be at least %" PRId64 " bytes, "
"but received work\'s memory size = %" PRId64 " bytes.", "but received work\'s memory size = %" PRId64 " bytes.",
...@@ -171,11 +144,11 @@ LapackEigvals(const framework::ExecutionContext& ctx, ...@@ -171,11 +144,11 @@ LapackEigvals(const framework::ExecutionContext& ctx,
work_mem)); work_mem));
int64_t rwork_mem = rwork->memory_size(); int64_t rwork_mem = rwork->memory_size();
int64_t required_rwork_mem = (n_dim << 1) * sizeof(phi::dtype::Real<T>); int64_t required_rwork_mem = (n_dim << 1) * sizeof(dtype::Real<T>);
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
rwork_mem, rwork_mem,
required_rwork_mem, required_rwork_mem,
platform::errors::InvalidArgument( errors::InvalidArgument(
"The memory size of the rwork tensor in LapackEigvals function " "The memory size of the rwork tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, " "should be at least %" PRId64 " bytes, "
"but received rwork\'s memory size = %" PRId64 " bytes.", "but received rwork\'s memory size = %" PRId64 " bytes.",
...@@ -183,7 +156,7 @@ LapackEigvals(const framework::ExecutionContext& ctx, ...@@ -183,7 +156,7 @@ LapackEigvals(const framework::ExecutionContext& ctx,
rwork_mem)); rwork_mem));
int info = 0; int info = 0;
phi::funcs::lapackEig<T, phi::dtype::Real<T>>( phi::funcs::lapackEig<T, dtype::Real<T>>(
'N', 'N',
'N', 'N',
static_cast<int>(n_dim), static_cast<int>(n_dim),
...@@ -196,78 +169,92 @@ LapackEigvals(const framework::ExecutionContext& ctx, ...@@ -196,78 +169,92 @@ LapackEigvals(const framework::ExecutionContext& ctx,
1, 1,
work->template data<T>(), work->template data<T>(),
static_cast<int>(work_mem / sizeof(T)), static_cast<int>(work_mem / sizeof(T)),
rwork->template data<phi::dtype::Real<T>>(), rwork->template data<dtype::Real<T>>(),
&info); &info);
std::string name = "framework::platform::dynload::cgeev_"; std::string name = "phi::backend::dynload::cgeev_";
if (framework::TransToProtoVarType(input.dtype()) == if (input.dtype() == DataType::COMPLEX128) {
framework::proto::VarType::COMPLEX64) { name = "phi::backend::dynload::zgeev_";
name = "framework::platform::dynload::zgeev_";
} }
CheckLapackEigResult(info, name); CheckLapackEigResult(info, name);
} }
template <typename DeviceContext, typename T> void SpiltBatchSquareMatrix(const DenseTensor& input,
class EigvalsKernel : public framework::OpKernel<T> { std::vector<DenseTensor>* output) {
public: DDim input_dims = input.dims();
void Compute(const framework::ExecutionContext& ctx) const override { int last_dim = input_dims.size() - 1;
const Tensor* input = ctx.Input<Tensor>("X"); int n_dim = input_dims[last_dim];
Tensor* output = ctx.Output<Tensor>("Out");
output->mutable_data<PaddleCType<T>>(ctx.GetPlace()); DDim flattened_input_dims, flattened_output_dims;
if (input_dims.size() > 2) {
std::vector<Tensor> input_matrices; flattened_input_dims =
SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices); phi::flatten_to_3d(input_dims, last_dim - 1, last_dim);
} else {
int64_t n_dim = input_matrices[0].dims()[1]; flattened_input_dims = phi::make_ddim({1, n_dim, n_dim});
int64_t n_batch = input_matrices.size();
DDim output_dims = output->dims();
output->Resize(phi::make_ddim({n_batch, n_dim}));
std::vector<Tensor> output_vectors = output->Split(1, 0);
// query workspace size
T qwork;
int info;
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
'N',
'N',
static_cast<int>(n_dim),
input_matrices[0].template data<T>(),
static_cast<int>(n_dim),
NULL,
NULL,
1,
NULL,
1,
&qwork,
-1,
static_cast<phi::dtype::Real<T>*>(NULL),
&info);
int64_t lwork = static_cast<int64_t>(qwork);
Tensor work, rwork;
try {
work.mutable_data<T>(phi::make_ddim({lwork}), ctx.GetPlace());
} catch (memory::allocation::BadAlloc&) {
LOG(WARNING) << "Failed to allocate Lapack workspace with the optimal "
<< "memory size = " << lwork * sizeof(T) << " bytes, "
<< "try reallocating a smaller workspace with the minimum "
<< "required size = " << 3 * n_dim * sizeof(T) << " bytes, "
<< "this may lead to bad performance.";
lwork = 3 * n_dim;
work.mutable_data<T>(phi::make_ddim({lwork}), ctx.GetPlace());
}
if (framework::IsComplexType(
framework::TransToProtoVarType(input->dtype()))) {
rwork.mutable_data<phi::dtype::Real<T>>(phi::make_ddim({n_dim << 1}),
ctx.GetPlace());
}
for (int64_t i = 0; i < n_batch; ++i) {
LapackEigvals<DeviceContext, T>(
ctx, input_matrices[i], &output_vectors[i], &work, &rwork);
}
output->Resize(output_dims);
} }
};
} // namespace operators DenseTensor flattened_input;
} // namespace paddle flattened_input.ShareDataWith(input);
flattened_input.Resize(flattened_input_dims);
(*output) = flattened_input.Split(1, 0);
}
template <typename T, typename Context>
void EigvalsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<PaddleCType<T>>(out);
std::vector<DenseTensor> x_matrices;
SpiltBatchSquareMatrix(x, /*->*/ &x_matrices);
int64_t n_dim = x_matrices[0].dims()[1];
int64_t n_batch = x_matrices.size();
DDim out_dims = out->dims();
out->Resize(make_ddim({n_batch, n_dim}));
std::vector<DenseTensor> out_vectors = out->Split(1, 0);
// query workspace size
T qwork;
int info;
funcs::lapackEig<T, dtype::Real<T>>('N',
'N',
static_cast<int>(n_dim),
x_matrices[0].template data<T>(),
static_cast<int>(n_dim),
NULL,
NULL,
1,
NULL,
1,
&qwork,
-1,
static_cast<dtype::Real<T>*>(NULL),
&info);
int64_t lwork = static_cast<int64_t>(qwork);
DenseTensor work, rwork;
work.Resize(make_ddim({lwork}));
ctx.template Alloc<T>(&work);
if (IsComplexType(x.dtype())) {
rwork.Resize(make_ddim({n_dim << 1}));
ctx.template Alloc<dtype::Real<T>>(&rwork);
}
for (int64_t i = 0; i < n_batch; ++i) {
LapackEigvals<T, Context>(
ctx, x_matrices[i], &out_vectors[i], &work, &rwork);
}
out->Resize(out_dims);
}
} // namespace phi
PD_REGISTER_KERNEL(eigvals,
CPU,
ALL_LAYOUT,
phi::EigvalsKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void EigvalsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EigvalsOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("eigvals", {"X"}, {}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(eigvals, phi::EigvalsOpArgumentMapping);
...@@ -2339,7 +2339,9 @@ def eigvals(x, name=None): ...@@ -2339,7 +2339,9 @@ def eigvals(x, name=None):
"The last two dimensions of Input(x) should be equal, but received x's shape = {}" "The last two dimensions of Input(x) should be equal, but received x's shape = {}"
.format(x_shape)) .format(x_shape))
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_eigvals(x)
elif paddle.in_dynamic_mode():
return _C_ops.eigvals(x) return _C_ops.eigvals(x)
helper = LayerHelper('eigvals', **locals()) helper = LayerHelper('eigvals', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册