未验证 提交 3788f5e5 编写于 作者: F freeliuzc 提交者: GitHub

move eig operator from fluid to phi (#44398)

* move eig operator from fluid to phi

* add eig_grad unitest, upgrade IsComplexType() from fluid to phi
上级 9e307229
......@@ -17,7 +17,11 @@
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -25,37 +29,6 @@ namespace operators {
class EigOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eig");
OP_INOUT_CHECK(
ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues", "Eig");
OP_INOUT_CHECK(
ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors", "Eig");
auto x_dims = ctx->GetInputDim("X");
int rank = x_dims.size();
PADDLE_ENFORCE_GE(rank,
2,
platform::errors::InvalidArgument(
"Expects input tensor x to be not less than "
"2 dimentions, but got dimention %d",
rank));
PADDLE_ENFORCE_EQ(x_dims[rank - 2],
x_dims[rank - 1],
platform::errors::InvalidArgument(
"The input matrix must be a square matrix, "
"but receive a matrix with %d rows and %d colums",
x_dims[rank - 2],
x_dims[rank - 1]));
std::vector<int> batch_dims_vec{};
for (int i = 0; i < rank - 1; ++i) {
batch_dims_vec.emplace_back(x_dims[i]);
}
ctx->SetOutputDim("Eigenvectors", x_dims);
ctx->SetOutputDim("Eigenvalues", phi::make_ddim(batch_dims_vec));
}
protected:
// The output of eig is always complex-valued even for real-valued inputs
......@@ -100,26 +73,6 @@ This API processes eigen decomposition for general square matrices.
class EigGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(
ctx->HasInput("Eigenvalues"), "Input", "Eigenvalues", "EigGrad");
OP_INOUT_CHECK(
ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors", "EigGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")),
"Input",
"Eigenvalues@GRAD",
"EigGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvectors")),
"Input",
"Eigenvectors@GRAD",
"EigGrad");
auto dims = ctx->GetInputDim("Eigenvectors");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
......@@ -152,27 +105,20 @@ class EigGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators
} // namespace paddle
using complex64 = paddle::platform::complex<float>;
using complex128 = paddle::platform::complex<double>;
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(eig,
EigInferShapeFunctor,
PD_INFER_META(phi::EigInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(eig_grad,
EigGradInferShapeFunctor,
PD_INFER_META(phi::EigGradInferMeta));
REGISTER_OPERATOR(eig,
ops::EigOp,
ops::EigOpMaker,
ops::EigGradOpMaker<paddle::framework::OpDesc>,
ops::EigGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(eig_grad, ops::EigGradOp);
REGISTER_OP_CPU_KERNEL(eig,
ops::EigKernel<phi::CPUContext, float, complex64>,
ops::EigKernel<phi::CPUContext, double, complex128>,
ops::EigKernel<phi::CPUContext, complex64, complex64>,
ops::EigKernel<phi::CPUContext, complex128, complex128>);
REGISTER_OP_CPU_KERNEL(
eig_grad,
ops::EigGradKernel<phi::CPUContext, float, complex64>,
ops::EigGradKernel<phi::CPUContext, double, complex128>,
ops::EigGradKernel<phi::CPUContext, complex64, complex64>,
ops::EigGradKernel<phi::CPUContext, complex128, complex128>);
ops::EigGradOpMaker<paddle::imperative::OpBase>,
EigInferShapeFunctor);
REGISTER_OPERATOR(eig_grad, ops::EigGradOp, EigGradInferShapeFunctor);
......@@ -57,341 +57,5 @@ inline int MatrixStride(const Tensor& matrix) {
return dims_list[num_dims - 1] * dims_list[num_dims - 2];
}
// Transpose two axis of a Tensor
template <typename DeviceContext, typename T>
void TransposeTwoAxis(const Tensor& input,
Tensor* transposed_input,
const int axis1,
const int axis2,
const framework::ExecutionContext& context) {
std::vector<int> permute(input.dims().size());
std::iota(permute.begin(), permute.end(), 0);
permute[axis1] = axis2;
permute[axis2] = axis1;
transposed_input->mutable_data<T>(input.dims(), context.GetPlace());
auto& dev_ctx = context.template device_context<phi::CPUContext>();
TransCompute<DeviceContext, T>(
input.dims().size(), dev_ctx, input, transposed_input, permute);
}
// Apply eig to a batch of matrices, values, vectors and (intermidiate
// tensor) info are overritten
template <typename T>
void LapackEig(Tensor* input,
Tensor* values,
Tensor* vectors,
int info,
const framework::ExecutionContext& context) {
char jobvl = 'N';
char jobvr = 'V'; // only right eigenvectors are computed
int num_dims = input->dims().size();
int order = input->dims()[num_dims - 1];
T* input_data = input->data<T>();
int lda = std::max<int>(1, order);
T* values_data = values->mutable_data<T>(context.GetPlace());
T* lvector_data = nullptr;
int ldvl = 1;
T* rvector_data = vectors->mutable_data<T>(context.GetPlace());
int ldvr = lda;
int lwork = -1;
int batch_count = BatchCount(*input);
int matrix_stride = MatrixStride(*input);
int values_stride = values->dims()[values->dims().size() - 1];
Tensor rwork;
phi::dtype::Real<T>* rwork_data = nullptr;
rwork.Resize(phi::make_ddim({lda * 2}));
rwork_data = rwork.mutable_data<phi::dtype::Real<T>>(context.GetPlace());
// call lapackEig once to compute the size of work;
T computed_work_size;
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(jobvl,
jobvr,
order,
input_data,
lda,
values_data,
lvector_data,
ldvl,
rvector_data,
ldvr,
&computed_work_size,
lwork,
rwork_data,
&info);
lwork = std::max<int>(
1, static_cast<int>(phi::dtype::Real<T>(computed_work_size)));
Tensor work;
work.Resize(phi::make_ddim({lwork}));
T* work_data = work.mutable_data<T>(context.GetPlace());
for (auto i = 0; i < batch_count; ++i) {
T* current_matrix = &input_data[i * matrix_stride];
T* current_values = &values_data[i * values_stride];
T* current_rvectors = &rvector_data[i * matrix_stride];
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(jobvl,
jobvr,
order,
current_matrix,
lda,
current_values,
lvector_data,
ldvl,
current_rvectors,
ldvr,
work_data,
lwork,
rwork_data,
&info);
PADDLE_ENFORCE_EQ(
info,
0,
platform::errors::PreconditionNotMet(
"current info is not 0, computation failed. "
"= 0: successful exit."
"< 0: if INFO = -i, the i-th argument had an illegal value."
"> 0: if INFO = i, the QR algorithm failed to compute all the "
"eigenvalues, and no eigenvectors have been computed; "
"elements i+1:N of WR and WI contain eigenvalues which "
"have converged."));
}
}
template <typename DeviceContext, typename T>
void ApplyEigKernel(const Tensor& input,
Tensor* values,
Tensor* vectors,
const framework::ExecutionContext& context) {
Tensor input_column_major;
Tensor vectors_row_major;
int num_dims = input.dims().size();
// transfer to column-major memory layout i.e. make_ddim from tranposed_input:
// [batch,row,col]->[batch,col,row]
TransposeTwoAxis<DeviceContext, T>(
input, &input_column_major, num_dims - 1, num_dims - 2, context);
// make sure 'vectors_row_major' holds memory before passed to LapackEig()
vectors_row_major.Resize(input.dims());
int info = 0;
LapackEig<T>(&input_column_major, values, &vectors_row_major, info, context);
// transfer column-major layout back
// vectors_row_major: column-major layout
// vector: original layout
TransposeTwoAxis<DeviceContext, T>(
vectors_row_major, vectors, num_dims - 1, num_dims - 2, context);
}
template <typename T, typename Tout>
void ConstructComplexVectors(Tensor* c_vectors,
const Tensor& c_values,
const Tensor& r_vectors,
const framework::ExecutionContext& ctx,
int batch_count,
int order) {
int matrix_stride = MatrixStride(r_vectors);
auto* c_vectors_data = c_vectors->mutable_data<Tout>(ctx.GetPlace());
auto* c_values_data = c_values.data<Tout>();
auto* r_v_data = r_vectors.data<T>();
for (int b = 0; b < batch_count; b++) {
auto* vecs = &r_v_data[b * matrix_stride];
auto* res = &c_vectors_data[b * matrix_stride];
auto* vals = &c_values_data[b * order];
for (int j = 0; j < order; j++) {
if (vals[j].imag < EPSILON) {
for (int i = 0; i < order; i++) {
res[j * order + i] = platform::complex<T>(vecs[j * order + i], 0);
}
} else {
for (int i = 0; i < order; i++) {
res[j * order + i] = platform::complex<T>(vecs[j * order + i],
vecs[(j + 1) * order + i]);
res[(j + 1) * order + i] = platform::complex<T>(
vecs[j * order + i], -vecs[(j + 1) * order + i]);
}
j++;
}
}
}
}
template <typename DeviceContext, typename T, typename Tout>
class EigKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out_values = context.Output<Tensor>("Eigenvalues");
auto* out_vectors = context.Output<Tensor>("Eigenvectors");
if (!framework::IsComplexType(framework::TransToProtoVarType(x->dtype()))) {
out_values->mutable_data<Tout>(context.GetPlace());
out_vectors->mutable_data<Tout>(context.GetPlace());
int batch_count = BatchCount(*x);
int order = x->dims()[x->dims().size() - 1];
Tensor real_values;
Tensor real_vectors;
// double the size of real_values, the first half stores the real part,
// the next half stores the imag part
std::vector<int> origin_dim = phi::vectorize<int>(out_values->dims());
int last_item = origin_dim.back();
origin_dim.pop_back();
origin_dim.push_back(last_item * 2);
framework::DDim big_dim = phi::make_ddim(origin_dim);
real_values.mutable_data<phi::dtype::Real<T>>(big_dim,
context.GetPlace());
real_vectors.mutable_data<phi::dtype::Real<T>>(x->dims(),
context.GetPlace());
ApplyEigKernel<DeviceContext, phi::dtype::Real<T>>(
*x, &real_values, &real_vectors, context);
auto& orig_dev_ctx = context.template device_context<DeviceContext>();
auto& dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
orig_dev_ctx);
// 1. extract real part & imag part from real_values
Tensor real_part =
phi::funcs::Slice<T>(dev_ctx, real_values, {-1}, {0}, {order});
Tensor imag_part = phi::funcs::Slice<T>(
dev_ctx, real_values, {-1}, {order}, {order * 2});
// 2. construct complex values
auto* real_part_data = real_part.data<phi::dtype::Real<T>>();
auto* imag_part_data = imag_part.data<phi::dtype::Real<T>>();
int out_values_numel = out_values->numel();
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(), out_values_numel);
phi::funcs::RealImagToComplexFunctor<Tout> functor(
real_part_data,
imag_part_data,
out_values->mutable_data<Tout>(context.GetPlace()),
out_values_numel);
for_range(functor);
// 3. construct complex vectors
Tensor real_vector_trans =
phi::TransposeLast2Dim<T>(dev_ctx, real_vectors);
Tensor out_vectors_trans;
out_vectors_trans.mutable_data<Tout>(x->dims(), context.GetPlace());
ConstructComplexVectors<phi::dtype::Real<T>, Tout>(&out_vectors_trans,
*out_values,
real_vector_trans,
context,
batch_count,
order);
TransposeTwoAxis<DeviceContext, Tout>(out_vectors_trans,
out_vectors,
x->dims().size() - 1,
x->dims().size() - 2,
context);
} else {
out_values->mutable_data<T>(context.GetPlace());
out_vectors->mutable_data<T>(context.GetPlace());
ApplyEigKernel<DeviceContext, T>(*x, out_values, out_vectors, context);
}
}
};
template <typename DeviceContext, typename T>
void ComputeBackwardForComplexInput(
const Tensor& V,
const Tensor& L,
const Tensor& gL,
const Tensor& gV,
T* x_grad_data,
int batch_count,
int order,
const framework::ExecutionContext& context) {
auto& orig_dev_ctx = context.template device_context<DeviceContext>();
auto& dev_ctx = static_cast<
const typename framework::ConvertToPhiContext<DeviceContext>::TYPE&>(
orig_dev_ctx);
Tensor trans_v = phi::TransposeLast2Dim<T>(dev_ctx, V);
Tensor Vh = phi::Conj<T>(dev_ctx, trans_v);
Tensor Lconj = phi::Conj<T>(dev_ctx, L);
Tensor Econj = phi::Subtract<T>(dev_ctx,
phi::funcs::Unsqueeze(Lconj, -2),
phi::funcs::Unsqueeze(Lconj, -1));
Tensor VhgV = phi::Matmul<T>(dev_ctx, Vh, gV);
Tensor diag_real = phi::Real<T>(dev_ctx, VhgV);
Tensor diag_res = phi::funcs::BatchDiag<T>(dev_ctx, diag_real, batch_count);
Tensor diag_unsqueezed = phi::funcs::Unsqueeze(diag_res, -2);
// turn diag_unsqueezed into complex
auto numel = diag_unsqueezed.numel();
Tensor diag_unsqueezed_complex;
auto* data_diag_un = diag_unsqueezed.data<phi::dtype::Real<T>>();
auto* data_diag_un_com = diag_unsqueezed_complex.mutable_data<T>(
diag_unsqueezed.dims(),
context.GetPlace(),
static_cast<size_t>(numel * sizeof(T)));
platform::ForRange<DeviceContext> for_range(orig_dev_ctx, numel);
phi::funcs::RealToComplexFunctor<T> functor(
data_diag_un, data_diag_un_com, numel);
for_range(functor);
// real tensor multiply complex tensor in broadcast manner
Tensor res1 = phi::Multiply<T>(dev_ctx, V, diag_unsqueezed_complex);
Tensor res2 = phi::Matmul<T>(dev_ctx, Vh, res1);
Tensor result = phi::Subtract<T>(dev_ctx, VhgV, res2);
result.mutable_data<T>(V.dims(), context.GetPlace());
result = phi::Divide<T>(dev_ctx, result, Econj);
result =
phi::funcs::DiagFill<T, T>(dev_ctx, order, order, order, 0, gL, result);
Tensor rhs = phi::Matmul<T>(dev_ctx, result, Vh);
// solve linear system
// solve(Vh, rhs, out, m, k)
// Vh: matrix with shape [m,m]
// rhs: rhs with shape [m,k]
// x_grad: out
int m = Vh.dims()[Vh.dims().size() - 1];
int k = rhs.dims()[rhs.dims().size() - 1];
auto* matrix_data = Vh.data<T>();
auto* rhs_data = rhs.data<T>();
phi::funcs::SolveLinearSystem<T>(
matrix_data, rhs_data, x_grad_data, m, k, batch_count);
}
template <typename DeviceContext, typename T, typename Tout>
class EigGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& L = *context.Input<Tensor>("Eigenvalues");
auto& V = *context.Input<Tensor>("Eigenvectors");
auto& gL = *context.Input<Tensor>(framework::GradVarName("Eigenvalues"));
auto& gV = *context.Input<Tensor>(framework::GradVarName("Eigenvectors"));
auto& x_grad = *context.Output<Tensor>(framework::GradVarName("X"));
auto* x_grad_data = x_grad.mutable_data<Tout>(context.GetPlace());
auto& dims = V.dims();
framework::DDim dim_origin = dims;
int num_dims = dim_origin.size();
int batch_count = BatchCount(V);
const int order = dim_origin[num_dims - 1];
ComputeBackwardForComplexInput<DeviceContext, Tout>(
V, L, gL, gV, x_grad_data, batch_count, order, context);
}
};
} // namespace operators
} // namespace paddle
......@@ -2285,3 +2285,13 @@
args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place = {})
output : Tensor
invoke : full_like(x, 0, dtype, place)
# eig
- api: eig
args: (Tensor x)
output: Tensor(out_w), Tensor(out_v)
infer_meta:
func: EigInferMeta
kernel:
func: eig
backward: eig_grad
......@@ -557,6 +557,19 @@
kernel :
func : dropout_grad
- backward_api : eig_grad
forward : eig (Tensor x) -> Tensor(out_w), Tensor(out_v)
args : (Tensor out_w, Tensor out_v, Tensor out_w_grad, Tensor out_v_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_v]
kernel :
func : eig_grad
data_type : out_v
data_transform:
skip_transform : out_w, out_w_grad
- backward_api : eigh_grad
forward : eigh (Tensor x, str uplo) -> Tensor(out_w), Tensor(out_v)
args : (Tensor out_w, Tensor out_v, Tensor out_w_grad, Tensor out_v_grad)
......
......@@ -223,6 +223,18 @@ void DeformableConvGradInferMeta(const MetaTensor& x,
}
}
void EigGradInferMeta(const MetaTensor& out_w,
const MetaTensor& out_v,
const MetaTensor& dout_w,
const MetaTensor& dout_v,
MetaTensor* dx) {
auto dims = out_v.dims();
if (dx) {
dx->set_dims(dims);
}
}
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -106,6 +106,12 @@ void DeformableConvGradInferMeta(const MetaTensor& x,
MetaTensor* filter_grad,
MetaTensor* mask_grad);
void EigGradInferMeta(const MetaTensor& out_w,
const MetaTensor& out_v,
const MetaTensor& dout_w,
const MetaTensor& dout_v,
MetaTensor* dx);
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -375,6 +375,32 @@ void DiagonalInferMeta(const MetaTensor& input,
out->set_dims(phi::make_ddim(out_dims));
}
void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v) {
auto x_dims = x.dims();
int rank = x_dims.size();
PADDLE_ENFORCE_GE(
rank,
2,
phi::errors::InvalidArgument("Expects input tensor x to be not less than "
"2 dimentions, but got dimention %d",
rank));
PADDLE_ENFORCE_EQ(x_dims[rank - 2],
x_dims[rank - 1],
phi::errors::InvalidArgument(
"The input matrix must be a square matrix, "
"but receive a matrix with %d rows and %d colums",
x_dims[rank - 2],
x_dims[rank - 1]));
std::vector<int> batch_dims_vec{};
for (int i = 0; i < rank - 1; ++i) {
batch_dims_vec.emplace_back(x_dims[i]);
}
out_w->set_dims(phi::make_ddim(batch_dims_vec));
out_v->set_dims(x_dims);
}
void EighInferMeta(const MetaTensor& x,
const std::string& uplo,
MetaTensor* out_w,
......
......@@ -77,6 +77,8 @@ void DiagInferMeta(const MetaTensor& x,
void DiagonalInferMeta(
const MetaTensor& input, int offset, int axis1, int axis2, MetaTensor* out);
void EigInferMeta(const MetaTensor& x, MetaTensor* out_w, MetaTensor* out_v);
void EighInferMeta(const MetaTensor& x,
const std::string& uplo,
MetaTensor* out_w,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <math.h>
#include <algorithm>
#include <complex>
#include "Eigen/Core"
#include "Eigen/LU"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/elementwise_divide_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#define EPSILON 1e-6
namespace phi {
inline int BatchCount(const DenseTensor& matrix) {
int count = 1;
int num_dims = matrix.dims().size();
for (int i = 0; i < num_dims - 2; ++i) {
count *= matrix.dims()[i];
}
return count;
}
inline int MatrixStride(const DenseTensor& matrix) {
phi::DDim dims_list = matrix.dims();
int num_dims = dims_list.size();
return dims_list[num_dims - 1] * dims_list[num_dims - 2];
}
// only used for complex input
template <typename T>
void SolveLinearSystem(T* matrix_data,
T* rhs_data,
T* out_data,
int order,
int rhs_cols,
int batch) {
using Treal = typename Eigen::NumTraits<T>::Real;
// cast paddle::complex into std::complex
std::complex<Treal>* matrix_data_ =
reinterpret_cast<std::complex<Treal>*>(matrix_data);
std::complex<Treal>* rhs_data_ =
reinterpret_cast<std::complex<Treal>*>(rhs_data);
std::complex<Treal>* out_data_ =
reinterpret_cast<std::complex<Treal>*>(out_data);
using Matrix = Eigen::Matrix<std::complex<Treal>,
Eigen::Dynamic,
Eigen::Dynamic,
Eigen::RowMajor>;
using InputMatrixMap = Eigen::Map<Matrix>;
using OutputMatrixMap = Eigen::Map<Matrix>;
for (int i = 0; i < batch; ++i) {
auto input_matrix =
InputMatrixMap(matrix_data_ + i * order * order, order, order);
auto input_rhs =
InputMatrixMap(rhs_data_ + i * order * rhs_cols, order, rhs_cols);
auto output =
OutputMatrixMap(out_data_ + i * order * rhs_cols, order, rhs_cols);
Eigen::PartialPivLU<Matrix> lu_decomposition(order);
lu_decomposition.compute(input_matrix);
const Treal min_abs_piv =
lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_GT(
min_abs_piv,
Treal(0),
errors::InvalidArgument("Something's wrong with SolveLinearSystem. "));
output = lu_decomposition.solve(input_rhs);
}
}
template <typename T, typename Context>
void TransposeTwoAxis(const DenseTensor& input,
DenseTensor* transposed_input,
const int axis1,
const int axis2,
const Context& dev_ctx) {
std::vector<int> permute(input.dims().size());
std::iota(permute.begin(), permute.end(), 0);
permute[axis1] = axis2;
permute[axis2] = axis1;
transposed_input->Resize(input.dims());
dev_ctx.template Alloc<T>(transposed_input);
funcs::TransCompute<Context, T>(
input.dims().size(), dev_ctx, input, transposed_input, permute);
}
// Apply eig to a batch of matrices, values, vectors and (intermidiate
// DenseTensor) info are overritten
template <typename T, typename Context>
void LapackEig(DenseTensor* input,
DenseTensor* values,
DenseTensor* vectors,
int info,
const Context& dev_ctx) {
char jobvl = 'N';
char jobvr = 'V'; // only right eigenvectors are computed
int num_dims = input->dims().size();
int order = input->dims()[num_dims - 1];
T* input_data = input->data<T>();
int lda = std::max<int>(1, order);
T* values_data = dev_ctx.template Alloc<T>(values);
T* lvector_data = nullptr;
int ldvl = 1;
T* rvector_data = dev_ctx.template Alloc<T>(vectors);
int ldvr = lda;
int lwork = -1;
int batch_count = BatchCount(*input);
int matrix_stride = MatrixStride(*input);
int values_stride = values->dims()[values->dims().size() - 1];
DenseTensor rwork;
phi::dtype::Real<T>* rwork_data = nullptr;
rwork.Resize(phi::make_ddim({lda * 2}));
rwork_data = dev_ctx.template Alloc<phi::dtype::Real<T>>(&rwork);
// call lapackEig once to compute the size of work;
T computed_work_size;
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(jobvl,
jobvr,
order,
input_data,
lda,
values_data,
lvector_data,
ldvl,
rvector_data,
ldvr,
&computed_work_size,
lwork,
rwork_data,
&info);
lwork = std::max<int>(
1, static_cast<int>(phi::dtype::Real<T>(computed_work_size)));
DenseTensor work;
work.Resize(phi::make_ddim({lwork}));
T* work_data = dev_ctx.template Alloc<T>(&work);
for (auto i = 0; i < batch_count; ++i) {
T* current_matrix = &input_data[i * matrix_stride];
T* current_values = &values_data[i * values_stride];
T* current_rvectors = &rvector_data[i * matrix_stride];
phi::funcs::lapackEig<T, phi::dtype::Real<T>>(jobvl,
jobvr,
order,
current_matrix,
lda,
current_values,
lvector_data,
ldvl,
current_rvectors,
ldvr,
work_data,
lwork,
rwork_data,
&info);
PADDLE_ENFORCE_EQ(
info,
0,
errors::PreconditionNotMet(
"current info is not 0, computation failed. "
"= 0: successful exit."
"< 0: if INFO = -i, the i-th argument had an illegal value."
"> 0: if INFO = i, the QR algorithm failed to compute all the "
"eigenvalues, and no eigenvectors have been computed; "
"elements i+1:N of WR and WI contain eigenvalues which "
"have converged."));
}
}
template <typename T, typename Context>
void ApplyEigKernel(const DenseTensor& input,
DenseTensor* values,
DenseTensor* vectors,
const Context& dev_ctx) {
DenseTensor input_column_major;
DenseTensor vectors_row_major;
int num_dims = input.dims().size();
// transfer to column-major memory layout i.e. make_ddim from tranposed_input:
// [batch,row,col]->[batch,col,row]
TransposeTwoAxis<T, Context>(
input, &input_column_major, num_dims - 1, num_dims - 2, dev_ctx);
// make sure 'vectors_row_major' holds memory before passed to LapackEig()
vectors_row_major.Resize(input.dims());
int info = 0;
LapackEig<T, Context>(
&input_column_major, values, &vectors_row_major, info, dev_ctx);
// transfer column-major layout back
// vectors_row_major: column-major layout
// vector: original layout
TransposeTwoAxis<T, Context>(
vectors_row_major, vectors, num_dims - 1, num_dims - 2, dev_ctx);
}
// template <typename T, typename Tout>
template <typename T, typename Tout, typename Context>
void ConstructComplexVectors(DenseTensor* c_vectors,
const DenseTensor& c_values,
const DenseTensor& r_vectors,
const Context& dev_ctx,
int batch_count,
int order) {
int matrix_stride = MatrixStride(r_vectors);
auto* c_vectors_data = dev_ctx.template Alloc<Tout>(c_vectors);
auto* c_values_data = c_values.data<Tout>();
auto* r_v_data = r_vectors.data<T>();
for (int b = 0; b < batch_count; b++) {
auto* vecs = &r_v_data[b * matrix_stride];
auto* res = &c_vectors_data[b * matrix_stride];
auto* vals = &c_values_data[b * order];
for (int j = 0; j < order; j++) {
if (vals[j].imag < EPSILON) {
for (int i = 0; i < order; i++) {
res[j * order + i] = dtype::complex<T>(vecs[j * order + i], 0);
}
} else {
for (int i = 0; i < order; i++) {
res[j * order + i] =
dtype::complex<T>(vecs[j * order + i], vecs[(j + 1) * order + i]);
res[(j + 1) * order + i] = dtype::complex<T>(
vecs[j * order + i], -vecs[(j + 1) * order + i]);
}
j++;
}
}
}
}
template <typename T, typename Context>
void ComputeBackwardForComplexInput(const DenseTensor& L,
const DenseTensor& V,
const DenseTensor& gL,
const DenseTensor& gV,
T* x_grad_data,
int batch_count,
int order,
const Context& dev_ctx) {
DenseTensor trans_v = phi::TransposeLast2Dim<T>(dev_ctx, V);
DenseTensor Vh = phi::Conj<T>(dev_ctx, trans_v);
DenseTensor Lconj = phi::Conj<T>(dev_ctx, L);
DenseTensor Econj = phi::Subtract<T>(dev_ctx,
phi::funcs::Unsqueeze(Lconj, -2),
phi::funcs::Unsqueeze(Lconj, -1));
DenseTensor VhgV = phi::Matmul<T>(dev_ctx, Vh, gV);
DenseTensor diag_real = phi::Real<T>(dev_ctx, VhgV);
DenseTensor diag_res =
phi::funcs::BatchDiag<T>(dev_ctx, diag_real, batch_count);
DenseTensor diag_unsqueezed = phi::funcs::Unsqueeze(diag_res, -2);
// turn diag_unsqueezed into complex
auto numel = diag_unsqueezed.numel();
DenseTensor diag_unsqueezed_complex;
auto* data_diag_un = diag_unsqueezed.data<phi::dtype::Real<T>>();
diag_unsqueezed_complex.Resize(diag_unsqueezed.dims());
auto* data_diag_un_com = dev_ctx.template Alloc<T>(
&diag_unsqueezed_complex, static_cast<size_t>(numel * sizeof(T)));
phi::funcs::ForRange<Context> for_range(dev_ctx, numel);
phi::funcs::RealToComplexFunctor<T> functor(
data_diag_un, data_diag_un_com, numel);
for_range(functor);
// real tensor multiply complex tensor in broadcast manner
DenseTensor res1 = phi::Multiply<T>(dev_ctx, V, diag_unsqueezed_complex);
DenseTensor res2 = phi::Matmul<T>(dev_ctx, Vh, res1);
DenseTensor result = phi::Subtract<T>(dev_ctx, VhgV, res2);
result.Resize(V.dims());
dev_ctx.template Alloc<T>(&result);
result = phi::Divide<T>(dev_ctx, result, Econj);
result =
phi::funcs::DiagFill<T, T>(dev_ctx, order, order, order, 0, gL, result);
DenseTensor rhs = phi::Matmul<T>(dev_ctx, result, Vh);
// solve linear system
// solve(Vh, rhs, out, m, k)
// Vh: matrix with shape [m,m]
// rhs: rhs with shape [m,k]
// x_grad: out
int m = Vh.dims()[Vh.dims().size() - 1];
int k = rhs.dims()[rhs.dims().size() - 1];
auto* matrix_data = Vh.data<T>();
auto* rhs_data = rhs.data<T>();
SolveLinearSystem<T>(matrix_data, rhs_data, x_grad_data, m, k, batch_count);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/eig_grad_kernel.h"
#include "paddle/phi/kernels/cpu/eig.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void EigGradKernel(const Context& dev_ctx,
const DenseTensor& out_w,
const DenseTensor& out_v,
const DenseTensor& dout_w,
const DenseTensor& dout_v,
DenseTensor* dx) {
auto* dx_data = dev_ctx.template Alloc<phi::dtype::Complex<T>>(dx);
auto& dims = out_v.dims();
phi::DDim dim_origin = dims;
int num_dims = dim_origin.size();
int batch_count = BatchCount(out_v);
const int order = dim_origin[num_dims - 1];
ComputeBackwardForComplexInput<phi::dtype::Complex<T>, Context>(
out_w, out_v, dout_w, dout_v, dx_data, batch_count, order, dev_ctx);
}
} // namespace phi
PD_REGISTER_KERNEL(eig_grad,
CPU,
ALL_LAYOUT,
phi::EigGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/eig_kernel.h"
#include "paddle/phi/kernels/cpu/eig.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void EigKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out_w,
DenseTensor* out_v) {
if (!IsComplexType(x.dtype())) {
dev_ctx.template Alloc<phi::dtype::Complex<T>>(out_w);
dev_ctx.template Alloc<phi::dtype::Complex<T>>(out_v);
int batch_count = BatchCount(x);
int order = x.dims()[x.dims().size() - 1];
DenseTensor real_w;
DenseTensor real_v;
// double the size of real_w, the first half stores the real part,
// the next half stores the imag part
std::vector<int> origin_dim = phi::vectorize<int>(out_w->dims());
int last_item = origin_dim.back();
origin_dim.pop_back();
origin_dim.push_back(last_item * 2);
phi::DDim big_dim = phi::make_ddim(origin_dim);
real_w.Resize(big_dim);
dev_ctx.template Alloc<phi::dtype::Real<T>>(&real_w);
real_v.Resize(x.dims());
dev_ctx.template Alloc<phi::dtype::Real<T>>(&real_v);
phi::ApplyEigKernel<phi::dtype::Real<T>, Context>(
x, &real_w, &real_v, dev_ctx);
// 1. extract real part & imag part from real_w
DenseTensor real_part =
phi::funcs::Slice<T>(dev_ctx, real_w, {-1}, {0}, {order});
DenseTensor imag_part =
phi::funcs::Slice<T>(dev_ctx, real_w, {-1}, {order}, {order * 2});
// 2. construct complex values
auto* real_part_data = real_part.data<phi::dtype::Real<T>>();
auto* imag_part_data = imag_part.data<phi::dtype::Real<T>>();
int out_w_numel = out_w->numel();
phi::funcs::ForRange<Context> for_range(dev_ctx, out_w_numel);
phi::funcs::RealImagToComplexFunctor<phi::dtype::Complex<T>> functor(
real_part_data,
imag_part_data,
dev_ctx.template Alloc<phi::dtype::Complex<T>>(out_w),
out_w_numel);
for_range(functor);
// 3. construct complex vectors
DenseTensor real_vector_trans = phi::TransposeLast2Dim<T>(dev_ctx, real_v);
DenseTensor out_v_trans;
out_v_trans.Resize(x.dims());
dev_ctx.template Alloc<phi::dtype::Complex<T>>(&out_v_trans);
phi::ConstructComplexVectors<phi::dtype::Real<T>,
phi::dtype::Complex<T>,
Context>(
&out_v_trans, *out_w, real_vector_trans, dev_ctx, batch_count, order);
TransposeTwoAxis<phi::dtype::Complex<T>, Context>(
out_v_trans, out_v, x.dims().size() - 1, x.dims().size() - 2, dev_ctx);
} else {
dev_ctx.template Alloc<T>(out_w);
dev_ctx.template Alloc<T>(out_v);
phi::ApplyEigKernel<T, Context>(x, out_w, out_v, dev_ctx);
}
}
} // namespace phi
PD_REGISTER_KERNEL(eig,
CPU,
ALL_LAYOUT,
phi::EigKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void EigGardKernel(const Context& dev_ctx,
const DenseTensor& out_w,
const DenseTensor& out_v,
const DenseTensor& dout_w,
const DenseTensor& dout_v,
DenseTensor* dx);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void EigKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out_w,
DenseTensor* out_v);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EigGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"eig_grad",
{"Eigenvalues", "Eigenvectors", "Eigenvalues@GRAD", "Eigenvectors@GRAD"},
{},
{"X@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(eig_grad, phi::EigGradOpArgumentMapping);
......@@ -227,6 +227,52 @@ class TestEigStatic(TestEigOp):
str(np.abs(fetch_vec)))
class TestEigDyGraph(unittest.TestCase):
def test_check_output_with_place(self):
input_np = np.random.random([3, 3]).astype('complex')
expect_val, expect_vec = np.linalg.eig(input_np)
paddle.set_device("cpu")
paddle.disable_static()
input_tensor = paddle.to_tensor(input_np)
fetch_val, fetch_vec = paddle.linalg.eig(input_tensor)
self.assertTrue(
np.allclose(expect_val, fetch_val.numpy(), 1e-6,
1e-6), "The eigen values have diff: \nExpected " +
str(expect_val) + "\n" + "But got: " + str(fetch_val))
self.assertTrue(
np.allclose(np.abs(expect_vec), np.abs(fetch_vec.numpy()), 1e-6,
1e-6), "The eigen vectors have diff: \nExpected " +
str(np.abs(expect_vec)) + "\n" + "But got: " +
str(np.abs(fetch_vec.numpy())))
def test_check_grad(self):
test_shape = [3, 3]
test_type = 'float64'
paddle.set_device("cpu")
input_np = np.random.random(test_shape).astype(test_type)
real_w, real_v = np.linalg.eig(input_np)
grad_w = np.ones(real_w.shape, test_type)
grad_v = np.ones(real_v.shape, test_type)
grad_x = eig_backward(real_w, real_v, grad_w, grad_v)
with fluid.dygraph.guard():
x = fluid.dygraph.to_variable(input_np)
x.stop_gradient = False
w, v = paddle.linalg.eig(x)
(w.sum() + v.sum()).backward()
self.assertTrue(
np.allclose(np.abs(x.grad.numpy()), np.abs(grad_x), 1e-5, 1e-5),
"The grad x have diff: \nExpected " + str(np.abs(grad_x)) + "\n" +
"But got: " + str(np.abs(x.grad.numpy())))
class TestEigWrongDimsError(unittest.TestCase):
def test_error(self):
......@@ -254,7 +300,7 @@ class TestEigUnsupportedDtypeError(unittest.TestCase):
paddle.disable_static()
a = (np.random.random((3, 3)) * 10).astype('int64')
x = paddle.to_tensor(a)
self.assertRaises(ValueError, paddle.linalg.eig, x)
self.assertRaises(RuntimeError, paddle.linalg.eig, x)
if __name__ == "__main__":
......
......@@ -2269,7 +2269,9 @@ def eig(x, name=None):
# [ (16.50471283351188+0j) , (-5.5034820550763515+0j) ,
# (-0.21026087843552282+0j)])
"""
if paddle.in_dynamic_mode():
if in_dygraph_mode():
return _C_ops.final_state_eig(x)
elif paddle.in_dynamic_mode():
w, v = _C_ops.eig(x)
return w, v
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册