未验证 提交 07d0b834 编写于 作者: C crystal 提交者: GitHub

Add CPU and GPU eigh op implementation (#34990)

上级 7d9ca164
......@@ -185,6 +185,7 @@ function(op_library TARGET)
list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu")
list(REMOVE_ITEM hip_srcs "svd_op.cu")
list(REMOVE_ITEM hip_srcs "eigh_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigh_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class EighOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigh");
OP_INOUT_CHECK(ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues",
"Eigh");
OP_INOUT_CHECK(ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors",
"Eigh");
auto input_dim = ctx->GetInputDim("X");
auto rank = input_dim.size();
PADDLE_ENFORCE_GE(rank, 2,
platform::errors::InvalidArgument(
"The Input(X) should have at least 2 dimensions."
"But received a %d dimension tensor.",
rank));
PADDLE_ENFORCE_EQ(
input_dim[rank - 2], input_dim[rank - 1],
platform::errors::InvalidArgument(
"Eigh op is designed for square matrix, consequently"
"inner-most 2 dimensions of Input(X) should be symmetric."
"But received X's shape[-2] = %d and shape[-1] = %d.",
input_dim[rank - 2], input_dim[rank - 1]));
std::vector<int64_t> values_dim;
if (rank > 2) {
for (auto i = 0; i < rank - 1; i++) {
values_dim.emplace_back(input_dim[i]);
}
} else {
values_dim = {input_dim[1]};
}
ctx->SetOutputDim("Eigenvalues", framework::make_ddim(values_dim));
ctx->SetOutputDim("Eigenvectors", input_dim);
}
};
class EignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), Hermitian or real symmetric matrices."
"Its shape should be [*, N, N] where * is zero or"
"more batch dimensions. The data type is float32 ,"
"float64, complex64, complex128.");
AddOutput("Eigenvalues",
"(Tensor), The eigenvalues in ascending order."
"The data type is float32 or float64.");
AddOutput(
"Eigenvectors",
"(Tensor), The column is the normalized eigenvector "
"corresponding to the eigenvalue. The data type is the same as ``X``.");
AddAttr<std::string>(
"UPLO",
"(string, default 'L'), 'L' represents the lower triangular matrix,"
"'U' represents the upper triangular matrix.")
.SetDefault("L");
AddComment(R"DOC(
Eigh Operator.
Computes the eigenvalues and eigenvectors of a complex Hermitian
(conjugate symmetric) or a real symmetric matrix.
)DOC");
}
};
class EighGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Eigenvalues"), "Input", "Eigenvalues",
"EighGrad");
OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors",
"EighGrad");
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvalues")),
"Input", "Eigenvalues@GRAD", "EighGrad");
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvectors")),
"Input", "Eigenvectors@GRAD", "EighGrad");
auto dims = ctx->GetInputDim("Eigenvectors");
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, dims);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Eigenvectors")),
ctx.device_context());
}
};
template <typename T>
class EighGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType(this->ForwardOpType() + "_grad");
op->SetInput("Eigenvalues", this->Output("Eigenvalues"));
op->SetInput("Eigenvectors", this->Output("Eigenvectors"));
op->SetInput(framework::GradVarName("Eigenvalues"),
this->OutputGrad("Eigenvalues"));
op->SetInput(framework::GradVarName("Eigenvectors"),
this->OutputGrad("Eigenvectors"));
op->SetAttrMap(this->Attrs());
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker,
ops::EighGradOpMaker<paddle::framework::OpDesc>,
ops::EighGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(eigh_grad, ops::EighGradOp);
REGISTER_OP_CPU_KERNEL(
eigh, ops::EighKernel<paddle::platform::CPUDeviceContext, float, float>,
ops::EighKernel<paddle::platform::CPUDeviceContext, double, double>,
ops::EighKernel<paddle::platform::CPUDeviceContext, float,
paddle::platform::complex<float>>,
ops::EighKernel<paddle::platform::CPUDeviceContext, double,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
eigh_grad,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, float, float>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double, double>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, float,
paddle::platform::complex<float>>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double,
paddle::platform::complex<double>>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigh_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename ValueType, typename T>
class EighGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto input_var = ctx.Input<Tensor>("X");
auto output_w_var = ctx.Output<Tensor>("Eigenvalues");
auto output_v_var = ctx.Output<Tensor>("Eigenvectors");
std::string lower = ctx.Attr<std::string>("UPLO");
bool is_lower = (lower == "L");
math::MatrixEighFunctor<ValueType, T> functor;
functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
eigh, ops::EighGPUKernel<float, float>, ops::EighGPUKernel<double, double>,
ops::EighGPUKernel<float, paddle::platform::complex<float>>,
ops::EighGPUKernel<double, paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
eigh_grad,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, float, float>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double, double>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, float,
paddle::platform::complex<float>>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/eigen_values_vectors.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename ValueType, typename T>
class EighKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto input_var = ctx.Input<Tensor>("X");
auto output_w_var = ctx.Output<Tensor>("Eigenvalues");
auto output_v_var = ctx.Output<Tensor>("Eigenvectors");
std::string lower = ctx.Attr<std::string>("UPLO");
bool is_lower = (lower == "L");
math::MatrixEighFunctorCPU<DeviceContext, ValueType, T> functor;
functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true);
}
};
template <typename DeviceContext, typename ValueType, typename T>
class EighGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X"));
x_grad.mutable_data<T>(ctx.GetPlace());
auto& output_w_var = *ctx.Input<Tensor>("Eigenvalues");
auto& output_v_var = *ctx.Input<Tensor>("Eigenvectors");
auto& output_w_grad =
*ctx.Input<Tensor>(framework::GradVarName("Eigenvalues"));
auto& output_v_grad =
*ctx.Input<Tensor>(framework::GradVarName("Eigenvectors"));
auto& dims = output_v_var.dims();
const int m = dims[dims.size() - 1];
auto dito =
math::DeviceIndependenceTensorOperations<DeviceContext, T, ValueType>(
ctx);
auto tV = dito.Transpose(dito.Conj(output_v_var));
auto W = dito.Sub_(dito.Unsqueeze(output_w_var, -2),
dito.Unsqueeze(output_w_var, -1));
Tensor result = dito.Matmul(tV, output_v_grad);
result.mutable_data<T>(dims, ctx.GetPlace());
std::vector<int> out_shape = framework::vectorize<int>(dims);
auto constant = dito.Fill(out_shape, 0.5);
result = dito.Sub(result, dito.Conj(dito.Transpose(result)));
result = dito.Mul(result, constant);
result = dito.Div_(result, W);
result = dito.DiagFill(m, m, m, 0, output_w_grad, result);
x_grad = dito.Matmul(output_v_var, dito.Matmul(result, tV));
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "Eigen/Core"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/svd_helper.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cusolver.h"
#endif // PADDLE_WITH_CUDA
namespace paddle {
namespace operators {
namespace math {
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using InputMatrixMap = Eigen::Map<
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using OutputMatrixMap = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
template <typename ValueType>
inline void ComputeFloatEigenvaluesAndVectors(ValueType *x_data,
ValueType *eigenvalues_data,
ValueType *eigenvectors_data,
int batches, int rows, int cols,
bool has_vectors) {
int stride = rows * cols;
for (int i = 0; i < batches; i++) {
auto m = InputMatrixMap<ValueType>(x_data + i * stride, rows, cols);
auto eigenvalues =
OutputMatrixMap<ValueType>(eigenvalues_data + i * rows, 1, rows);
auto eigenvectors =
OutputMatrixMap<ValueType>(eigenvectors_data + i * stride, rows, cols);
Eigen::SelfAdjointEigenSolver<Eigen::Matrix<
ValueType, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
eigen_solver(m, has_vectors ? Eigen::ComputeEigenvectors
: Eigen::EigenvaluesOnly);
PADDLE_ENFORCE_EQ(
eigen_solver.info(), Eigen::Success,
platform::errors::InvalidArgument(
"Self Adjoint Eigen decomposition is not successful. "
"The %d-th input matrice might not be not be positive definite.",
i));
eigenvalues = eigen_solver.eigenvalues().transpose();
if (has_vectors) {
eigenvectors = eigen_solver.eigenvectors().transpose();
}
}
}
template <typename T, typename ValueType>
inline void ComputeComplexEigenvaluesAndVectors(T *x_data,
ValueType *eigenvalues_data,
T *eigenvectors_data,
int batches, int rows, int cols,
bool has_vectors) {
using Complex = std::complex<ValueType>;
Complex *input = reinterpret_cast<Complex *>(x_data);
Complex *eigenvectors_data_ = reinterpret_cast<Complex *>(eigenvectors_data);
int stride = rows * cols;
for (int i = 0; i < batches; i++) {
auto m = InputMatrixMap<Complex>(input + i * stride, rows, cols);
auto eigenvalues =
OutputMatrixMap<ValueType>(eigenvalues_data + i * rows, 1, rows);
auto eigenvectors =
OutputMatrixMap<Complex>(eigenvectors_data_ + i * stride, rows, cols);
Eigen::SelfAdjointEigenSolver<
Eigen::Matrix<Complex, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
eigen_solver(m, has_vectors ? Eigen::ComputeEigenvectors
: Eigen::EigenvaluesOnly);
PADDLE_ENFORCE_EQ(
eigen_solver.info(), Eigen::Success,
platform::errors::InvalidArgument(
"Self Adjoint Eigen decomposition is not successful. "
"The %d-th input matrice might not be not be positive definite.",
i));
eigenvalues = eigen_solver.eigenvalues().transpose();
if (has_vectors) {
eigenvectors = eigen_solver.eigenvectors().transpose();
}
}
}
inline int64_t GetBatchSize(framework::DDim dims) {
int64_t batch_size = 1;
auto dim_size = dims.size();
for (int i = 0; i < dim_size - 2; i++) {
batch_size *= dims[i];
}
return batch_size;
}
// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real
// symmetric matrices, and uses the variable has_vectors to
// control whether to return the eigenvectors.
template <typename DeviceContext, typename ValueType, typename T>
struct MatrixEighFunctorCPU {
public:
void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
bool has_vectors) {
auto dims = input.dims();
auto output_value_dim = eigen_values->dims();
int64_t batch_size = 1;
int dim_size = dims.size();
for (int64_t i = 0; i < dim_size - 2; i++) {
batch_size *= dims[i];
}
auto dito = DeviceIndependenceTensorOperations<DeviceContext, T>(ctx);
Tensor input_tensor;
TensorCopy(input, ctx.GetPlace(), &input_tensor);
if (!is_lower) {
input_tensor = dito.Transpose(input);
}
int rows = dims[dims.size() - 2];
auto *value_data =
eigen_values->mutable_data<ValueType>(output_value_dim, ctx.GetPlace());
if (framework::IsComplexType(input_tensor.type())) {
auto *x_data = input_tensor.data<T>();
auto *vector_data = eigen_vectors->mutable_data<T>(dims, ctx.GetPlace());
ComputeComplexEigenvaluesAndVectors<T, ValueType>(
x_data, value_data, vector_data, batch_size, rows, rows, has_vectors);
} else {
auto *x_data = input_tensor.data<ValueType>();
auto *vector_data =
eigen_vectors->mutable_data<ValueType>(dims, ctx.GetPlace());
ComputeFloatEigenvaluesAndVectors<ValueType>(
x_data, value_data, vector_data, batch_size, rows, rows, has_vectors);
}
if (has_vectors) {
*eigen_vectors = dito.Transpose(*eigen_vectors);
}
}
};
#ifdef PADDLE_WITH_CUDA
// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real
// symmetric matrices on GPU, and uses the variable has_vectors
// to control whether to return the eigenvectors.
template <typename ValueType, typename T>
struct MatrixEighFunctor {
public:
void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
bool has_vectors) {
auto *out_value = eigen_values->mutable_data<ValueType>(ctx.GetPlace());
auto *out_vector = eigen_vectors->mutable_data<T>(ctx.GetPlace());
auto &dims = input.dims();
int dim_size = dims.size();
int64_t batch_size = GetBatchSize(dims);
cublasFillMode_t uplo =
is_lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
cusolverEigMode_t jobz =
has_vectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
int n = dims[dim_size - 1];
int lda = std::max<int>(1, n);
auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
auto values_stride = dims[dim_size - 1];
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto dito =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
T>(ctx);
Tensor output_v_var_trans = dito.Transpose(input);
TensorCopy(output_v_var_trans, ctx.GetPlace(), eigen_vectors);
int lwork = 0;
auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_size);
auto *info_ptr = reinterpret_cast<int *>(info->ptr());
// When the input type is float32, and the feature value input dimension is
// greater than or equal to [*,32,32] and less than or equal to
// [*,512,512], Syevj has better performance.
bool use_syevj =
(eigen_vectors->type() == framework::proto::VarType::FP32 &&
values_stride >= 32 && values_stride <= 512);
syevjInfo_t syevj_params;
if (use_syevj) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cusolverDnCreateSyevjInfo(&syevj_params));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cusolverDnSsyevj_bufferSize(
dev_ctx.cusolver_dn_handle(), jobz, uplo, n,
reinterpret_cast<const float *>(out_vector), lda,
reinterpret_cast<const float *>(out_value), &lwork,
syevj_params));
} else {
EvdBuffer(dev_ctx.cusolver_dn_handle(), jobz, uplo, n, out_vector, lda,
out_value, &lwork);
}
auto work = memory::Alloc(dev_ctx, sizeof(T) * lwork);
auto *work_ptr = reinterpret_cast<T *>(work->ptr());
for (auto i = 0; i < batch_size; i++) {
auto vector_data = out_vector + i * vector_stride;
auto value_data = out_value + i * values_stride;
auto handle = dev_ctx.cusolver_dn_handle();
if (use_syevj) {
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDnSsyevj(
handle, jobz, uplo, n, reinterpret_cast<float *>(vector_data), lda,
reinterpret_cast<float *>(value_data),
reinterpret_cast<float *>(work_ptr), lwork, info_ptr,
syevj_params));
} else {
Evd(handle, jobz, uplo, n, vector_data, lda, value_data, work_ptr,
lwork, info_ptr);
}
int error_info;
memory::Copy(platform::CPUPlace(), &error_info,
BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()),
info_ptr, sizeof(int), dev_ctx.stream());
PADDLE_ENFORCE_EQ(
error_info, 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: the [%d] argument had an illegal value", i,
error_info));
}
if (use_syevj) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cusolverDnDestroySyevjInfo(syevj_params));
}
if (has_vectors) {
*eigen_vectors = dito.Transpose(*eigen_vectors);
}
}
inline void EvdBuffer(cusolverDnHandle_t handle, cusolverEigMode_t jobz,
cublasFillMode_t uplo, int n, const T *A, int lda,
const ValueType *W, int *lwork) const;
inline void Evd(cusolverDnHandle_t handle, cusolverEigMode_t jobz,
cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W,
T *work, int lwork, int *devInfo) const;
};
#define FUNC_WITH_TYPES(m) \
m(float, float, Ssy, float) m(double, double, Dsy, double) \
m(float, paddle::platform::complex<float>, Che, cuComplex) \
m(double, paddle::platform::complex<double>, Zhe, cuDoubleComplex)
#define EVDBUFFER_INSTANCE(ValueType, T, C, CastType) \
template <> \
inline void MatrixEighFunctor<ValueType, T>::EvdBuffer( \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \
int *lwork) const { \
PADDLE_ENFORCE_CUDA_SUCCESS( \
platform::dynload::cusolverDn##C##evd_bufferSize( \
handle, jobz, uplo, n, reinterpret_cast<const CastType *>(A), lda, \
W, lwork)); \
}
FUNC_WITH_TYPES(EVDBUFFER_INSTANCE);
#define EVD_INSTANCE(ValueType, T, C, CastType) \
template <> \
inline void MatrixEighFunctor<ValueType, T>::Evd( \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \
int lwork, int *devInfo) const { \
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDn##C##evd( \
handle, jobz, uplo, n, reinterpret_cast<CastType *>(A), lda, W, \
reinterpret_cast<CastType *>(work), lwork, devInfo)); \
}
FUNC_WITH_TYPES(EVD_INSTANCE);
#undef FUNC_WITH_TYPES
#undef EVDBUFFER_INSTANCE
#undef EVD_INSTANCE
#endif // PADDLE_WITH_CUDA
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -25,6 +25,8 @@
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -36,6 +38,9 @@ using Tensor = framework::Tensor;
using InTensors = std::vector<const Tensor*>;
using OutTensors = std::vector<Tensor*>;
using OpName = std::string;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T>
void EigenSvd(const T* X, T* U, T* VH, T* S, int rows, int cols,
......@@ -140,7 +145,42 @@ static std::vector<int> GetBroadcastShape(InTensors ins) {
break; \
}
template <typename DeviceContext, typename T>
template <typename T, typename ValueType>
struct DiagAndFillFunctor {
DiagAndFillFunctor(const int m, const int n, const int num_lower_diags,
const int num_upper_diags, const ValueType* scale,
const T* input, T* output)
: m_(m),
n_(n),
num_lower_diags_(num_lower_diags),
num_upper_diags_(num_upper_diags),
scale_(scale),
input_(input),
output_(output) {}
HOSTDEVICE void operator()(size_t index) const {
const int col = index % n_;
const int row = (index / n_) % m_;
const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_);
const int band_end =
(num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1);
if (col < band_start || col >= band_end) {
output_[index] = input_[index];
} else if (col == band_end - 1) {
output_[index] = static_cast<T>(scale_[index % m_]);
} else {
output_[index] = input_[index];
}
}
private:
const int m_, n_, num_lower_diags_, num_upper_diags_;
const ValueType* scale_;
const T* input_;
T* output_;
};
template <typename DeviceContext, typename T, typename ValueType = T>
struct DeviceIndependenceTensorOperations {
// 1. Device indenpendence, for kernel reuse.
// 2. Input and output is always tensor type.
......@@ -398,6 +438,60 @@ struct DeviceIndependenceTensorOperations {
return ret;
}
Tensor Conj(const Tensor& x) {
Tensor out;
auto* out_data = out.mutable_data<T>(x.dims(), context.GetPlace());
auto* x_data = x.data<T>();
auto for_range = GetForRange(x.numel());
math::ConjFunctor<T> functor(x_data, x.numel(), out_data);
for_range(functor);
return out;
}
Tensor DiagFill(const int m, const int n, const int num_lower_diags,
const int num_upper_diags, const Tensor& scale,
const Tensor& input) {
Tensor out;
auto& dev_ctx = context.template device_context<DeviceContext>();
platform::ForRange<DeviceContext> for_range(dev_ctx, input.numel());
DiagAndFillFunctor<T, ValueType> diag_and_copy_functor(
m, n, num_lower_diags, num_upper_diags, scale.data<ValueType>(),
input.data<T>(), out.mutable_data<T>(input.dims(), input.place()));
for_range(diag_and_copy_functor);
return out;
}
// Support x and y are different data types
Tensor Div_(const Tensor& x, const Tensor& y) {
Tensor out;
out.mutable_data<T>(x.dims(), context.GetPlace());
auto x_vector = EigenVector<T>::Flatten(x);
auto y_vector = EigenVector<ValueType>::Flatten(y);
auto out_vector = EigenVector<T>::Flatten(out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_vector.device(place) = x_vector / y_vector;
return out;
}
framework::Tensor Sub_(const framework::Tensor& x,
const framework::Tensor& y) {
framework::Tensor ret;
std::vector<int> out_shape = GetBroadcastShape({&x, &y});
ret.Resize(framework::make_ddim(out_shape));
if (x.dims().size() >= y.dims().size()) {
ElementwiseComputeEx<SubFunctor<ValueType>, DeviceContext, ValueType>(
context, &x, &y, -1, SubFunctor<ValueType>(), &ret);
} else {
ElementwiseComputeEx<InverseSubFunctor<ValueType>, DeviceContext,
ValueType>(
// This is copyed from elementwise_sub, which means we
// need reverse will xrank < yrank
context, &x, &y, -1, InverseSubFunctor<ValueType>(), &ret);
}
return ret;
}
private:
const framework::ExecutionContext& context;
BlasT<DeviceContext, T> GetBlas() {
......
......@@ -48,7 +48,15 @@ extern void *cusolver_dso_handle;
__macro(cusolverDnSpotrf_bufferSize); \
__macro(cusolverDnDpotrf_bufferSize); \
__macro(cusolverDnSpotrf); \
__macro(cusolverDnDpotrf);
__macro(cusolverDnDpotrf); \
__macro(cusolverDnSsyevd_bufferSize); \
__macro(cusolverDnDsyevd_bufferSize); \
__macro(cusolverDnCheevd_bufferSize); \
__macro(cusolverDnZheevd_bufferSize); \
__macro(cusolverDnSsyevd); \
__macro(cusolverDnDsyevd); \
__macro(cusolverDnCheevd); \
__macro(cusolverDnZheevd);
CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP);
......
......@@ -101,6 +101,8 @@ from .tensor.linalg import histogram # noqa: F401
from .tensor.linalg import mv # noqa: F401
from .tensor.linalg import multi_dot # noqa: F401
from .tensor.linalg import matrix_power # noqa: F401
from .tensor.linalg import svd # noqa: F401
from .tensor.linalg import eigh # noqa: F401
from .tensor.logic import equal # noqa: F401
from .tensor.logic import greater_equal # noqa: F401
from .tensor.logic import greater_than # noqa: F401
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from op_test import OpTest
from gradient_checker import grad_check
class TestEighOp(OpTest):
def setUp(self):
paddle.enable_static()
self.op_type = "eigh"
self.init_input()
self.init_config()
np.random.seed(123)
out_w, out_v = np.linalg.eigh(self.x_np, self.UPLO)
self.inputs = {"X": self.x_np}
self.attrs = {"UPLO": self.UPLO}
self.outputs = {'Eigenvalues': out_w, "Eigenvectors": out_v}
def init_config(self):
self.UPLO = 'L'
def init_input(self):
self.x_shape = (10, 10)
self.x_type = np.float64
self.x_np = np.random.random(self.x_shape).astype(self.x_type)
def test_check_output(self):
self.check_output(no_check_set=['Eigenvectors'])
def test_grad(self):
self.check_grad(["X"], ["Eigenvalues"])
class TestEighUPLOCase(TestEighOp):
def init_config(self):
self.UPLO = 'U'
class TestEighGPUCase(unittest.TestCase):
def setUp(self):
self.x_shape = [32, 32]
self.dtype = "float32"
np.random.seed(123)
self.x_np = np.random.random(self.x_shape).astype(self.dtype)
self.rtol = 1e-5
self.atol = 1e-5
def test_check_output_gpu(self):
if paddle.is_compiled_with_cuda():
paddle.disable_static(place=paddle.CUDAPlace(0))
input_real_data = paddle.to_tensor(self.x_np)
expected_w, expected_v = np.linalg.eigh(self.x_np)
actual_w, actual_v = paddle.linalg.eigh(input_real_data)
np.testing.assert_allclose(
actual_w, expected_w, rtol=self.rtol, atol=self.atol)
np.testing.assert_allclose(
abs(actual_v.numpy()),
abs(expected_v),
rtol=self.rtol,
atol=self.atol)
class TestEighAPI(unittest.TestCase):
def setUp(self):
self.init_input_shape()
self.dtype = "float32"
self.UPLO = 'L'
self.rtol = 1e-6
self.atol = 1e-6
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
np.random.seed(123)
self.real_data = np.random.random(self.x_shape).astype(self.dtype)
self.complex_data = np.random.random(self.x_shape).astype(
self.dtype) + 1J * np.random.random(self.x_shape).astype(self.dtype)
self.trans_dims = list(range(len(self.x_shape) - 2)) + [
len(self.x_shape) - 1, len(self.x_shape) - 2
]
def init_input_shape(self):
self.x_shape = [5, 5]
def compare_result(self, actual_w, actual_v, expected_w, expected_v):
np.testing.assert_allclose(
actual_w, expected_w, rtol=self.rtol, atol=self.atol)
np.testing.assert_allclose(
abs(actual_v), abs(expected_v), rtol=self.rtol, atol=self.atol)
def check_static_float_result(self):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
input_x = paddle.static.data(
'input_x', shape=self.x_shape, dtype=self.dtype)
output_w, output_v = paddle.linalg.eigh(input_x)
exe = paddle.static.Executor(self.place)
expected_w, expected_v = exe.run(main_prog,
feed={"input_x": self.real_data},
fetch_list=[output_w, output_v])
actual_w, actual_v = np.linalg.eigh(self.real_data)
self.compare_result(actual_w, actual_v, expected_w, expected_v)
def check_static_complex_result(self):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
x_dtype = np.complex64 if self.dtype == "float32" else np.complex128
input_x = paddle.static.data(
'input_x', shape=self.x_shape, dtype=x_dtype)
output_w, output_v = paddle.linalg.eigh(input_x)
exe = paddle.static.Executor(self.place)
expected_w, expected_v = exe.run(
main_prog,
feed={"input_x": self.complex_data},
fetch_list=[output_w, output_v])
actual_w, actual_v = np.linalg.eigh(self.complex_data)
self.compare_result(actual_w, actual_v, expected_w, expected_v)
def test_in_static_mode(self):
paddle.enable_static()
self.check_static_float_result()
self.check_static_complex_result()
def test_in_dynamic_mode(self):
paddle.disable_static(self.place)
input_real_data = paddle.to_tensor(self.real_data)
expected_w, expected_v = np.linalg.eigh(self.real_data)
actual_w, actual_v = paddle.linalg.eigh(input_real_data)
self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v)
input_complex_data = paddle.to_tensor(self.complex_data)
expected_w, expected_v = np.linalg.eigh(self.complex_data)
actual_w, actual_v = paddle.linalg.eigh(input_complex_data)
self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v)
def test_eigh_grad(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.complex_data, stop_gradient=False)
w, v = paddle.linalg.eigh(x)
(w.sum() + paddle.abs(v).sum()).backward()
np.testing.assert_allclose(
abs(x.grad.numpy()),
abs(x.grad.numpy().conj().transpose(self.trans_dims)),
rtol=self.rtol,
atol=self.atol)
class TestEighBatchAPI(TestEighAPI):
def init_input_shape(self):
self.x_shape = [2, 5, 5]
class TestEighAPIError(unittest.TestCase):
def test_error(self):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
with paddle.static.program_guard(main_prog, startup_prog):
#input maxtrix must greater than 2 dimensions
input_x = paddle.static.data(
name='x_1', shape=[12], dtype='float32')
self.assertRaises(ValueError, paddle.linalg.eigh, input_x)
#input matrix must be square matrix
input_x = paddle.static.data(
name='x_2', shape=[12, 32], dtype='float32')
self.assertRaises(ValueError, paddle.linalg.eigh, input_x)
#uplo must be in 'L' or 'U'
input_x = paddle.static.data(
name='x_3', shape=[4, 4], dtype="float32")
uplo = 'R'
self.assertRaises(ValueError, paddle.linalg.eigh, input_x, uplo)
#x_data cannot be integer
input_x = paddle.static.data(
name='x_4', shape=[4, 4], dtype="int32")
self.assertRaises(TypeError, paddle.linalg.eigh, input_x)
if __name__ == "__main__":
unittest.main()
......@@ -32,5 +32,6 @@ no_check_set_white_list = [
'fusion_lstm',
'softmax_with_cross_entropy',
'svd',
'eigh',
'class_center_sample',
]
......@@ -19,6 +19,7 @@ from .tensor import inverse as inv # noqa: F401
from .tensor.linalg import multi_dot # noqa: F401
from .tensor.linalg import matrix_rank
from .tensor.linalg import svd
from .tensor.linalg import eigh # noqa: F401
__all__ = [
'cholesky', #noqa
......@@ -27,5 +28,6 @@ __all__ = [
'multi_dot',
'matrix_rank',
'svd',
'matrix_power'
'matrix_power',
'eigh'
]
......@@ -47,6 +47,7 @@ from .linalg import mv # noqa: F401
from .linalg import matrix_power # noqa: F401
from .linalg import multi_dot # noqa: F401
from .linalg import svd # noqa: F401
from .linalg import eigh # noqa: F401
from .logic import equal # noqa: F401
from .logic import greater_equal # noqa: F401
from .logic import greater_than # noqa: F401
......
......@@ -1106,7 +1106,7 @@ def svd(x, full_matrices=False, name=None):
def matrix_power(x, n, name=None):
r"""
Computes the n-th power of a square matrix or a batch of square matrices.
Let :math:`X` be a sqaure matrix or a batch of square matrices, :math:`n` be
an exponent, the equation should be:
......@@ -1251,3 +1251,72 @@ def multi_dot(x, name=None):
out = helper.create_variable_for_type_inference(dtype)
helper.append_op(type='multi_dot', inputs={"X": x}, outputs={"Out": out})
return out
def eigh(x, UPLO='L', name=None):
"""
Compute the eigenvalues and eigenvectors of a
complex Hermitian (conjugate symmetric) or a real symmetric matrix.
Args:
x (Tensor): A tensor with shape :math:`[*, N, N]` , The data type of the input Tensor x
should be one of float32, float64, complex64, complex128.
UPLO(str, optional): (string, default 'L'), 'L' represents the lower triangular matrix,
"'U' represents the upper triangular matrix.".
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
out_value(Tensor): A Tensor with shape [*, N] and data type of float32 and float64. The eigenvalues of eigh op.
out_vector(Tensor): A Tensor with shape [*, N, N] and data type of float32,float64,complex64 and complex128. The eigenvectors of eigh op.
Examples:
.. code-block:: python
import numpy as np
import paddle
x_data = np.array([[1, -2j], [2j, 5]])
x = paddle.to_tensor(x_data)
out_value, out_vector = paddle.eigh(x, UPLO='L')
print(out_value)
#[0.17157288, 5.82842712]
print(out_vector)
#[(-0.9238795325112867+0j), (-0.3826834323650898+0j)],
#[ 0.3826834323650898j , -0.9238795325112867j ]]
"""
if in_dygraph_mode():
return _C_ops.eigh(x, 'UPLO', UPLO)
def __check_input(x, UPLO):
x_shape = list(x.shape)
if len(x.shape) < 2:
raise ValueError(
"Input(input) only support >=2 tensor, but received "
"length of Input(input) is %s." % len(x.shape))
if x_shape[-1] != x_shape[-2]:
raise ValueError(
"The input matrix must be batches of square matrices. But received x's dimention: {}".
format(x_shape))
if UPLO is not 'L' and UPLO is not 'U':
raise ValueError(
"UPLO must be L or U. But received UPLO is: {}".format(UPLO))
__check_input(x, UPLO)
helper = LayerHelper('eigh', **locals())
check_variable_and_dtype(
x, 'dtype', ['float32', 'float64', 'complex64', 'complex128'], 'eigh')
out_value = helper.create_variable_for_type_inference(dtype=x.dtype)
out_vector = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='eigh',
inputs={'X': x},
outputs={'Eigenvalues': out_value,
'Eigenvectors': out_vector},
attrs={'UPLO': UPLO})
return out_value, out_vector
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册