未验证 提交 95474815 编写于 作者: R Ruibiao Chen 提交者: GitHub

Move eigvals OP to PHI (#44183)

* Move eigvals OP to PHI

* Fix CI errors

* Fix CI errors
上级 0a5d625b
......@@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigvals_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -36,59 +37,17 @@ class EigvalsOpMaker : public framework::OpProtoAndCheckerMaker {
class EigvalsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigvals");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Eigvals");
DDim x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_GE(x_dims.size(),
2,
platform::errors::InvalidArgument(
"The dimensions of Input(X) for Eigvals operator "
"should be at least 2, "
"but received X's dimension = %d, X's shape = [%s].",
x_dims.size(),
x_dims));
if (ctx->IsRuntime() || !phi::contain_unknown_dim(x_dims)) {
int last_dim = x_dims.size() - 1;
PADDLE_ENFORCE_EQ(x_dims[last_dim],
x_dims[last_dim - 1],
platform::errors::InvalidArgument(
"The last two dimensions of Input(X) for Eigvals "
"operator should be equal, "
"but received X's shape = [%s].",
x_dims));
}
auto output_dims = vectorize(x_dims);
output_dims.resize(x_dims.size() - 1);
ctx->SetOutputDim("Out", phi::make_ddim(output_dims));
}
};
class EigvalsOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext* ctx) const {
auto input_dtype = ctx->GetInputDataType("X");
auto output_dtype = framework::IsComplexType(input_dtype)
? input_dtype
: framework::ToComplexType(input_dtype);
ctx->SetOutputDataType("Out", output_dtype);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
DECLARE_INFER_SHAPE_FUNCTOR(eigvals,
EigvalsInferShapeFunctor,
PD_INFER_META(phi::EigvalsInferMeta));
REGISTER_OPERATOR(eigvals,
ops::EigvalsOp,
ops::EigvalsOpMaker,
ops::EigvalsOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(
eigvals,
ops::EigvalsKernel<phi::CPUContext, float>,
ops::EigvalsKernel<phi::CPUContext, double>,
ops::EigvalsKernel<phi::CPUContext, paddle::platform::complex<float>>,
ops::EigvalsKernel<phi::CPUContext, paddle::platform::complex<double>>);
EigvalsInferShapeFunctor);
......@@ -536,6 +536,14 @@
func : eigh
backward : eigh_grad
- api : eigvals
args : (Tensor x)
output : Tensor
infer_meta :
func : EigvalsInferMeta
kernel :
func : eigvals
- api : einsum
args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
......
......@@ -80,4 +80,21 @@ inline void VisitDataTypeTiny(phi::DataType type, Visitor visitor) {
"Not supported phi::DataType(%d) as data type.", static_cast<int>(type)));
}
inline bool IsComplexType(const DataType& type) {
return (type == DataType::COMPLEX64 || type == DataType::COMPLEX128);
}
inline DataType ToComplexType(const DataType& type) {
switch (type) {
case DataType::FLOAT32:
return DataType::COMPLEX64;
case DataType::FLOAT64:
return DataType::COMPLEX128;
default:
PADDLE_THROW(errors::Unimplemented(
"Can not transform data type (%s) to complex type, now only support "
"float32 and float64 real value.",
type));
}
}
} // namespace phi
......@@ -399,6 +399,39 @@ void EighInferMeta(const MetaTensor& x,
out_v->set_dims(input_dim);
}
void EigvalsInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config) {
auto x_dims = x.dims();
PADDLE_ENFORCE_GE(x_dims.size(),
2,
errors::InvalidArgument(
"The dimensions of Input(X) for Eigvals operator "
"should be at least 2, "
"but received X's dimension = %d, X's shape = [%s].",
x_dims.size(),
x_dims));
if (config.is_runtime || !phi::contain_unknown_dim(x_dims)) {
int last_dim = x_dims.size() - 1;
PADDLE_ENFORCE_EQ(x_dims[last_dim],
x_dims[last_dim - 1],
errors::InvalidArgument(
"The last two dimensions of Input(X) for Eigvals "
"operator should be equal, "
"but received X's shape = [%s].",
x_dims));
}
auto out_dims = vectorize(x_dims);
out_dims.resize(x_dims.size() - 1);
const DataType& x_dtype = x.dtype();
const DataType& out_dtype =
IsComplexType(x_dtype) ? x_dtype : ToComplexType(x_dtype);
out->set_dims(make_ddim(out_dims));
out->set_dtype(out_dtype);
}
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
......
......@@ -80,6 +80,10 @@ void EighInferMeta(const MetaTensor& x,
MetaTensor* out_w,
MetaTensor* out_v);
void EigvalsInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void EinsumInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::string& equation,
MetaTensor* out,
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -12,23 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/eigvals_kernel.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
namespace phi {
template <typename T, typename enable = void>
struct PaddleComplex;
......@@ -37,79 +31,60 @@ template <typename T>
struct PaddleComplex<
T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
using type = paddle::platform::complex<T>;
using type = dtype::complex<T>;
};
template <typename T>
struct PaddleComplex<
T,
typename std::enable_if<
std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type> {
std::is_same<T, dtype::complex<float>>::value ||
std::is_same<T, dtype::complex<double>>::value>::type> {
using type = T;
};
template <typename T>
using PaddleCType = typename PaddleComplex<T>::type;
template <typename T>
using Real = typename phi::dtype::Real<T>;
static void SpiltBatchSquareMatrix(const Tensor& input,
std::vector<Tensor>* output) {
DDim input_dims = input.dims();
int last_dim = input_dims.size() - 1;
int n_dim = input_dims[last_dim];
DDim flattened_input_dims, flattened_output_dims;
if (input_dims.size() > 2) {
flattened_input_dims =
phi::flatten_to_3d(input_dims, last_dim - 1, last_dim);
} else {
flattened_input_dims = phi::make_ddim({1, n_dim, n_dim});
}
Tensor flattened_input;
flattened_input.ShareDataWith(input);
flattened_input.Resize(flattened_input_dims);
(*output) = flattened_input.Split(1, 0);
}
using Real = typename dtype::Real<T>;
static void CheckLapackEigResult(const int info, const std::string& name) {
PADDLE_ENFORCE_LE(info,
inline void CheckLapackEigResult(const int info, const std::string& name) {
PADDLE_ENFORCE_LE(
info,
0,
platform::errors::PreconditionNotMet(
"The QR algorithm failed to compute all the "
errors::PreconditionNotMet("The QR algorithm failed to compute all the "
"eigenvalues in function %s.",
name.c_str()));
PADDLE_ENFORCE_GE(
info,
0,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The %d-th argument has an illegal value in function %s.",
-info,
name.c_str()));
}
template <typename DeviceContext, typename T>
static typename std::enable_if<std::is_floating_point<T>::value>::type
LapackEigvals(const framework::ExecutionContext& ctx,
const Tensor& input,
Tensor* output,
Tensor* work,
Tensor* rwork /*unused*/) {
Tensor a; // will be overwritten when lapackEig exit
framework::TensorCopy(input, input.place(), &a);
Tensor w;
template <typename T, typename Context>
typename std::enable_if<std::is_floating_point<T>::value>::type LapackEigvals(
const Context& ctx,
const DenseTensor& input,
DenseTensor* output,
DenseTensor* work,
DenseTensor* rwork /*unused*/) {
DenseTensor a; // will be overwritten when lapackEig exit
Copy(ctx, input, input.place(), /*blocking=*/true, &a);
DenseTensor w;
int64_t n_dim = input.dims()[1];
auto* w_data =
w.mutable_data<T>(phi::make_ddim({n_dim << 1}), ctx.GetPlace());
w.Resize(make_ddim({n_dim << 1}));
T* w_data = ctx.template Alloc<T>(&w);
int64_t work_mem = work->memory_size();
int64_t required_work_mem = 3 * n_dim * sizeof(T);
PADDLE_ENFORCE_GE(
work_mem,
3 * n_dim * sizeof(T),
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The memory size of the work tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received work\'s memory size = %" PRId64 " bytes.",
......@@ -132,30 +107,28 @@ LapackEigvals(const framework::ExecutionContext& ctx,
static_cast<T*>(NULL),
&info);
std::string name = "framework::platform::dynload::dgeev_";
if (framework::TransToProtoVarType(input.dtype()) ==
framework::proto::VarType::FP64) {
name = "framework::platform::dynload::sgeev_";
std::string name = "phi::backend::dynload::dgeev_";
if (input.dtype() == DataType::FLOAT64) {
name = "phi::backend::dynload::sgeev_";
}
CheckLapackEigResult(info, name);
platform::ForRange<DeviceContext> for_range(
ctx.template device_context<DeviceContext>(), n_dim);
phi::funcs::RealImagToComplexFunctor<PaddleCType<T>> functor(
funcs::ForRange<Context> for_range(ctx, n_dim);
funcs::RealImagToComplexFunctor<PaddleCType<T>> functor(
w_data, w_data + n_dim, output->template data<PaddleCType<T>>(), n_dim);
for_range(functor);
}
template <typename DeviceContext, typename T>
typename std::enable_if<std::is_same<T, platform::complex<float>>::value ||
std::is_same<T, platform::complex<double>>::value>::type
LapackEigvals(const framework::ExecutionContext& ctx,
const Tensor& input,
Tensor* output,
Tensor* work,
Tensor* rwork) {
Tensor a; // will be overwritten when lapackEig exit
framework::TensorCopy(input, input.place(), &a);
template <typename T, typename Context>
typename std::enable_if<std::is_same<T, dtype::complex<float>>::value ||
std::is_same<T, dtype::complex<double>>::value>::type
LapackEigvals(const Context& ctx,
const DenseTensor& input,
DenseTensor* output,
DenseTensor* work,
DenseTensor* rwork) {
DenseTensor a; // will be overwritten when lapackEig exit
Copy(ctx, input, input.place(), /*blocking=*/true, &a);
int64_t work_mem = work->memory_size();
int64_t n_dim = input.dims()[1];
......@@ -163,7 +136,7 @@ LapackEigvals(const framework::ExecutionContext& ctx,
PADDLE_ENFORCE_GE(
work_mem,
3 * n_dim * sizeof(T),
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The memory size of the work tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received work\'s memory size = %" PRId64 " bytes.",
......@@ -171,11 +144,11 @@ LapackEigvals(const framework::ExecutionContext& ctx,
work_mem));
int64_t rwork_mem = rwork->memory_size();
int64_t required_rwork_mem = (n_dim << 1) * sizeof(phi::dtype::Real<T>);
int64_t required_rwork_mem = (n_dim << 1) * sizeof(dtype::Real<T>);
PADDLE_ENFORCE_GE(
rwork_mem,
required_rwork_mem,
platform::errors::InvalidArgument(
errors::InvalidArgument(
"The memory size of the rwork tensor in LapackEigvals function "
"should be at least %" PRId64 " bytes, "
"but received rwork\'s memory size = %" PRId64 " bytes.",
......@@ -183,7 +156,7 @@ LapackEigvals(const framework::ExecutionContext& ctx,
rwork_mem));
int info = 0;
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
phi::funcs::lapackEig<T, dtype::Real<T>>(
'N',
'N',
static_cast<int>(n_dim),
......@@ -196,42 +169,56 @@ LapackEigvals(const framework::ExecutionContext& ctx,
1,
work->template data<T>(),
static_cast<int>(work_mem / sizeof(T)),
rwork->template data<phi::dtype::Real<T>>(),
rwork->template data<dtype::Real<T>>(),
&info);
std::string name = "framework::platform::dynload::cgeev_";
if (framework::TransToProtoVarType(input.dtype()) ==
framework::proto::VarType::COMPLEX64) {
name = "framework::platform::dynload::zgeev_";
std::string name = "phi::backend::dynload::cgeev_";
if (input.dtype() == DataType::COMPLEX128) {
name = "phi::backend::dynload::zgeev_";
}
CheckLapackEigResult(info, name);
}
template <typename DeviceContext, typename T>
class EigvalsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out");
output->mutable_data<PaddleCType<T>>(ctx.GetPlace());
void SpiltBatchSquareMatrix(const DenseTensor& input,
std::vector<DenseTensor>* output) {
DDim input_dims = input.dims();
int last_dim = input_dims.size() - 1;
int n_dim = input_dims[last_dim];
std::vector<Tensor> input_matrices;
SpiltBatchSquareMatrix(*input, /*->*/ &input_matrices);
DDim flattened_input_dims, flattened_output_dims;
if (input_dims.size() > 2) {
flattened_input_dims =
phi::flatten_to_3d(input_dims, last_dim - 1, last_dim);
} else {
flattened_input_dims = phi::make_ddim({1, n_dim, n_dim});
}
DenseTensor flattened_input;
flattened_input.ShareDataWith(input);
flattened_input.Resize(flattened_input_dims);
(*output) = flattened_input.Split(1, 0);
}
int64_t n_dim = input_matrices[0].dims()[1];
int64_t n_batch = input_matrices.size();
DDim output_dims = output->dims();
output->Resize(phi::make_ddim({n_batch, n_dim}));
std::vector<Tensor> output_vectors = output->Split(1, 0);
template <typename T, typename Context>
void EigvalsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<PaddleCType<T>>(out);
std::vector<DenseTensor> x_matrices;
SpiltBatchSquareMatrix(x, /*->*/ &x_matrices);
int64_t n_dim = x_matrices[0].dims()[1];
int64_t n_batch = x_matrices.size();
DDim out_dims = out->dims();
out->Resize(make_ddim({n_batch, n_dim}));
std::vector<DenseTensor> out_vectors = out->Split(1, 0);
// query workspace size
T qwork;
int info;
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(
'N',
funcs::lapackEig<T, dtype::Real<T>>('N',
'N',
static_cast<int>(n_dim),
input_matrices[0].template data<T>(),
x_matrices[0].template data<T>(),
static_cast<int>(n_dim),
NULL,
NULL,
......@@ -240,34 +227,34 @@ class EigvalsKernel : public framework::OpKernel<T> {
1,
&qwork,
-1,
static_cast<phi::dtype::Real<T>*>(NULL),
static_cast<dtype::Real<T>*>(NULL),
&info);
int64_t lwork = static_cast<int64_t>(qwork);
Tensor work, rwork;
try {
work.mutable_data<T>(phi::make_ddim({lwork}), ctx.GetPlace());
} catch (memory::allocation::BadAlloc&) {
LOG(WARNING) << "Failed to allocate Lapack workspace with the optimal "
<< "memory size = " << lwork * sizeof(T) << " bytes, "
<< "try reallocating a smaller workspace with the minimum "
<< "required size = " << 3 * n_dim * sizeof(T) << " bytes, "
<< "this may lead to bad performance.";
lwork = 3 * n_dim;
work.mutable_data<T>(phi::make_ddim({lwork}), ctx.GetPlace());
}
if (framework::IsComplexType(
framework::TransToProtoVarType(input->dtype()))) {
rwork.mutable_data<phi::dtype::Real<T>>(phi::make_ddim({n_dim << 1}),
ctx.GetPlace());
DenseTensor work, rwork;
work.Resize(make_ddim({lwork}));
ctx.template Alloc<T>(&work);
if (IsComplexType(x.dtype())) {
rwork.Resize(make_ddim({n_dim << 1}));
ctx.template Alloc<dtype::Real<T>>(&rwork);
}
for (int64_t i = 0; i < n_batch; ++i) {
LapackEigvals<DeviceContext, T>(
ctx, input_matrices[i], &output_vectors[i], &work, &rwork);
LapackEigvals<T, Context>(
ctx, x_matrices[i], &out_vectors[i], &work, &rwork);
}
output->Resize(output_dims);
}
};
} // namespace operators
} // namespace paddle
out->Resize(out_dims);
}
} // namespace phi
PD_REGISTER_KERNEL(eigvals,
CPU,
ALL_LAYOUT,
phi::EigvalsKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void EigvalsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EigvalsOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("eigvals", {"X"}, {}, {"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(eigvals, phi::EigvalsOpArgumentMapping);
......@@ -2339,7 +2339,9 @@ def eigvals(x, name=None):
"The last two dimensions of Input(x) should be equal, but received x's shape = {}"
.format(x_shape))
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_eigvals(x)
elif paddle.in_dynamic_mode():
return _C_ops.eigvals(x)
helper = LayerHelper('eigvals', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册