未验证 提交 2aca8d90 编写于 作者: C crystal 提交者: GitHub

【phi】migrate eigh op to phi (#40213)

* migrate eigh to phi

* optimize code

* modify code according to comment

* conflict resolution
上级 ec582895
......@@ -12,7 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigh_op.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -22,42 +25,9 @@ using framework::Tensor;
class EighOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Eigh");
OP_INOUT_CHECK(ctx->HasOutput("Eigenvalues"), "Output", "Eigenvalues",
"Eigh");
OP_INOUT_CHECK(ctx->HasOutput("Eigenvectors"), "Output", "Eigenvectors",
"Eigh");
auto input_dim = ctx->GetInputDim("X");
auto rank = input_dim.size();
PADDLE_ENFORCE_GE(rank, 2,
platform::errors::InvalidArgument(
"The Input(X) should have at least 2 dimensions."
"But received a %d dimension tensor.",
rank));
PADDLE_ENFORCE_EQ(
input_dim[rank - 2], input_dim[rank - 1],
platform::errors::InvalidArgument(
"Eigh op is designed for square matrix, consequently"
"inner-most 2 dimensions of Input(X) should be symmetric."
"But received X's shape[-2] = %d and shape[-1] = %d.",
input_dim[rank - 2], input_dim[rank - 1]));
std::vector<int64_t> values_dim;
for (auto i = 0; i < rank - 1; i++) {
values_dim.emplace_back(input_dim[i]);
}
ctx->SetOutputDim("Eigenvalues", phi::make_ddim(values_dim));
ctx->SetOutputDim("Eigenvectors", input_dim);
}
};
class EignOpMaker : public framework::OpProtoAndCheckerMaker {
class EighOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
......@@ -140,24 +110,11 @@ class EighGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle
namespace ops = paddle::operators;
DECLARE_INFER_SHAPE_FUNCTOR(eigh, EighInferShapeFunctor,
PD_INFER_META(phi::EighInferMeta));
REGISTER_OPERATOR(eigh, ops::EighOp, ops::EignOpMaker,
REGISTER_OPERATOR(eigh, ops::EighOp, ops::EighOpMaker,
ops::EighGradOpMaker<paddle::framework::OpDesc>,
ops::EighGradOpMaker<paddle::imperative::OpBase>);
ops::EighGradOpMaker<paddle::imperative::OpBase>,
EighInferShapeFunctor);
REGISTER_OPERATOR(eigh_grad, ops::EighGradOp);
REGISTER_OP_CPU_KERNEL(
eigh, ops::EighKernel<paddle::platform::CPUDeviceContext, float>,
ops::EighKernel<paddle::platform::CPUDeviceContext, double>,
ops::EighKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::EighKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CPU_KERNEL(
eigh_grad, ops::EighGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<float>>,
ops::EighGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex<double>>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/eigh_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
eigh, ops::EighKernel<paddle::platform::CUDADeviceContext, float>,
ops::EighKernel<paddle::platform::CUDADeviceContext, double>,
ops::EighKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::EighKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
REGISTER_OP_CUDA_KERNEL(
eigh_grad, ops::EighGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<float>>,
ops::EighGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex<double>>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/eigen_values_vectors.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class EighKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto input = ctx.Input<Tensor>("X");
auto output_w = ctx.Output<Tensor>("Eigenvalues");
auto output_v = ctx.Output<Tensor>("Eigenvectors");
std::string lower = ctx.Attr<std::string>("UPLO");
bool is_lower = (lower == "L");
math::MatrixEighFunctor<DeviceContext, T> functor;
functor(ctx, *input, output_w, output_v, is_lower, true);
}
};
template <typename DeviceContext, typename T>
class EighGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using ValueType = phi::dtype::Real<T>;
auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X"));
x_grad.mutable_data<T>(ctx.GetPlace());
auto& output_w = *ctx.Input<Tensor>("Eigenvalues");
auto& output_v = *ctx.Input<Tensor>("Eigenvectors");
auto& output_w_grad =
*ctx.Input<Tensor>(framework::GradVarName("Eigenvalues"));
auto& output_v_grad =
*ctx.Input<Tensor>(framework::GradVarName("Eigenvectors"));
auto& dims = output_v.dims();
const int m = dims[dims.size() - 1];
auto dito =
math::DeviceIndependenceTensorOperations<DeviceContext, T, ValueType>(
ctx);
auto tV = dito.Transpose(dito.Conj(output_v));
auto W = dito.template Sub<ValueType>(dito.Unsqueeze(output_w, -2),
dito.Unsqueeze(output_w, -1));
Tensor result = dito.Matmul(tV, output_v_grad);
result.mutable_data<T>(dims, ctx.GetPlace());
std::vector<int> out_shape = phi::vectorize<int>(dims);
auto constant = dito.Fill(out_shape, 0.5);
result = dito.Sub(result, dito.Conj(dito.Transpose(result)));
result = dito.Mul(result, constant);
result = dito.Div(result, W);
result = dito.DiagFill(m, m, m, 0, output_w_grad, result);
x_grad = dito.Matmul(output_v, dito.Matmul(result, tV));
}
};
} // namespace operators
} // namespace paddle
......@@ -1123,6 +1123,38 @@ void TransposeInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}
void EighInferMeta(const MetaTensor& x,
const std::string& uplo,
MetaTensor* out_w,
MetaTensor* out_v) {
auto input_dim = x.dims();
auto rank = input_dim.size();
PADDLE_ENFORCE_GE(rank,
2,
phi::errors::InvalidArgument(
"The Input(X) should have at least 2 dimensions."
"But received a %d dimension tensor.",
rank));
PADDLE_ENFORCE_EQ(
input_dim[rank - 2],
input_dim[rank - 1],
phi::errors::InvalidArgument(
"Eigh op is designed for square matrix, consequently"
"inner-most 2 dimensions of Input(X) should be symmetric."
"But received X's shape[-2] = %d and shape[-1] = %d.",
input_dim[rank - 2],
input_dim[rank - 1]));
std::vector<int64_t> values_dim;
for (auto i = 0; i < rank - 1; i++) {
values_dim.emplace_back(input_dim[i]);
}
out_w->set_dims(phi::make_ddim(values_dim));
out_v->set_dims(input_dim);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
......
......@@ -163,4 +163,9 @@ void TransposeInferMeta(const MetaTensor& x,
const std::vector<int>& axis,
MetaTensor* out);
void EighInferMeta(const MetaTensor& x,
const std::string& uplo,
MetaTensor* out_w,
MetaTensor* out_v);
} // namespace phi
......@@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel)
set(MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel)
kernel_library(math_kernel DEPS ${COMMON_KERNEL_DEPS} cast_kernel copy_kernel)
kernel_library(softmax_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
kernel_library(softmax_grad_kernel DEPS ${COMMON_KERNEL_DEPS} softmax)
......@@ -38,6 +38,7 @@ kernel_library(put_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_k
kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(take_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scatter_kernel)
kernel_library(eigh_kernel DEPS ${COMMON_KERNEL_DEPS} lapack_function)
# 4. auto parse and build kernel targets by cmake
register_kernels(EXCLUDES ${COMMON_BAISC_KERNELS} ${MANUAL_BUILD_KERNELS} DEPS ${COMMON_KERNEL_DEPS} ${COMMON_BAISC_KERNELS} )
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/eigh_grad_kernel.h"
#include "paddle/phi/kernels/impl/eigh_grad_kernel_impl.h"
#include "paddle/phi/common/complex.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(eigh_grad,
CPU,
ALL_LAYOUT,
phi::EighGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/eigh_kernel.h"
#include "paddle/phi/kernels/funcs/values_vectors_functor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace phi {
template <typename T, typename Context>
void EighKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& uplo,
DenseTensor* out_w,
DenseTensor* out_v) {
bool is_lower = (uplo == "L");
phi::funcs::MatrixEighFunctor<Context, T> functor;
functor(dev_ctx, x, out_w, out_v, is_lower, true);
}
} // namespace phi
PD_REGISTER_KERNEL(eigh,
CPU,
ALL_LAYOUT,
phi::EighKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void EighGardKernel(const Context& dev_ctx,
const DenseTensor& out_w,
const DenseTensor& out_v,
const DenseTensor& dout_w,
const DenseTensor& dout_v,
DenseTensor* dx);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
namespace phi {
template <typename T, typename Context>
void EighKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& uplo,
DenseTensor* out_w,
DenseTensor* out_v);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/memory/memory.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/phi/backends/dynload/cusolver.h"
#endif // PADDLE_WITH_CUDA
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
namespace funcs {
inline int64_t GetBatchSize(phi::DDim dims) {
int64_t batch_size = 1;
auto dim_size = dims.size();
for (int i = 0; i < dim_size - 2; i++) {
batch_size *= dims[i];
}
return batch_size;
}
static void CheckEighResult(const int batch, const int info) {
PADDLE_ENFORCE_LE(
info,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: the [%d] off-diagonal elements of an intermediate"
"tridiagonal form did not converge to zero",
batch,
info));
PADDLE_ENFORCE_GE(
info,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: the [%d] argument had an illegal value",
batch,
info));
}
template <typename DeviceContext, typename T>
struct MatrixEighFunctor {
void operator()(const DeviceContext &dev_ctx,
const DenseTensor &input,
DenseTensor *eigen_values,
DenseTensor *eigen_vectors,
bool is_lower,
bool has_vectors);
};
// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real
// symmetric matrices, and uses the variable has_vectors to
// control whether to return the eigenvectors.
template <typename T>
struct MatrixEighFunctor<CPUContext, T> {
public:
void operator()(const CPUContext &dev_ctx,
const DenseTensor &input,
DenseTensor *eigen_values,
DenseTensor *eigen_vectors,
bool is_lower,
bool has_vectors) {
using ValueType = phi::dtype::Real<T>;
ValueType *out_value = dev_ctx.template Alloc<ValueType>(eigen_values);
DenseTensor input_trans;
// lapack is a column-major storge, transpose make the input to
// have a continuous memory layout
input_trans = phi::TransposeLast2Dim<T>(dev_ctx, input);
T *input_vector = input_trans.data<T>();
auto dims = input.dims();
int dim_size = dims.size();
int64_t batch_size = GetBatchSize(dims);
int vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
int values_stride = dims[dim_size - 1];
char uplo = is_lower ? 'L' : 'U';
char jobz = has_vectors ? 'V' : 'N';
int n = dims[dim_size - 1];
int64_t lda = std::max<int64_t>(1, n);
// if work = -1, it means that you need to use the lapack function to query
// the optimal value
int lwork = -1; // The length of the array work
int lrwork = -1; // The dimension of the array rwork,rwork is REAL array
int liwork = -1; // The dimension of the array iwork
int iwork_opt = -1; // The optimal length of the array liwork
T lwork_opt = static_cast<T>(-1); // The optimal length of the array work
ValueType rwork_opt =
static_cast<ValueType>(-1); // The optimal length of the array rwork
int info = 0;
// Call lapackEigh to get the optimal size of work data
phi::funcs::lapackEigh<T, ValueType>(jobz,
uplo,
n,
input_vector,
lda,
out_value,
&lwork_opt,
lwork,
&rwork_opt,
lrwork,
&iwork_opt,
liwork,
&info);
lwork = std::max<int>(1, static_cast<int>(lwork_opt));
liwork = std::max<int>(1, iwork_opt);
DenseTensor rwork_tensor;
ValueType *rwork_data = nullptr;
// complex type
if (input.type() == phi::DataType::COMPLEX64 ||
input.type() == phi::DataType::COMPLEX128) {
lrwork = std::max<int>(1, static_cast<int>(rwork_opt));
rwork_tensor.Resize(phi::make_ddim({lrwork}));
rwork_data = dev_ctx.template Alloc<ValueType>(&rwork_tensor);
}
DenseTensor iwork_tensor, work_tensor;
iwork_tensor.Resize(phi::make_ddim({liwork}));
int *iwork_data = dev_ctx.template Alloc<int>(&iwork_tensor);
work_tensor.Resize(phi::make_ddim({lwork}));
T *work_data = dev_ctx.template Alloc<T>(&work_tensor);
for (auto i = 0; i < batch_size; i++) {
auto *value_data = out_value + i * values_stride;
auto *input_data = input_vector + i * vector_stride;
phi::funcs::lapackEigh<T, ValueType>(jobz,
uplo,
n,
input_data,
lda,
value_data,
work_data,
lwork,
rwork_data,
lrwork,
iwork_data,
liwork,
&info);
CheckEighResult(i, info);
}
if (has_vectors) {
PADDLE_ENFORCE_NOT_NULL(eigen_vectors,
phi::errors::InvalidArgument(
"When has_vectors is true,"
"the eigenvectors needs to be calculated, "
"so the eigenvectors must be provided."));
input_trans = phi::TransposeLast2Dim<T>(dev_ctx, input_trans);
eigen_vectors->ShareDataWith(input_trans);
}
}
};
#ifdef PADDLE_WITH_CUDA
// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real
// symmetric matrices on GPU, and uses the variable has_vectors
// to control whether to return the eigenvectors.
template <typename T>
struct MatrixEighFunctor<GPUContext, T> {
public:
void operator()(const GPUContext &dev_ctx,
const DenseTensor &input,
DenseTensor *eigen_values,
DenseTensor *eigen_vectors,
bool is_lower,
bool has_vectors) {
using ValueType = phi::dtype::Real<T>;
ValueType *out_value = dev_ctx.template Alloc<ValueType>(eigen_values);
DenseTensor input_trans;
input_trans = phi::TransposeLast2Dim<T>(dev_ctx, input);
T *input_vector = input_trans.data<T>();
auto &dims = input.dims();
int dim_size = dims.size();
int64_t batch_size = GetBatchSize(dims);
cublasFillMode_t uplo =
is_lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER;
cusolverEigMode_t jobz =
has_vectors ? CUSOLVER_EIG_MODE_VECTOR : CUSOLVER_EIG_MODE_NOVECTOR;
int n = dims[dim_size - 1];
int lda = std::max<int>(1, n);
auto vector_stride = dims[dim_size - 1] * dims[dim_size - 2];
auto values_stride = dims[dim_size - 1];
int lwork = 0;
auto info = paddle::memory::Alloc(dev_ctx, sizeof(int) * batch_size);
auto *info_ptr = reinterpret_cast<int *>(info->ptr());
// When the input type is float32, and the feature value input dimension
// is greater than or equal to [*,32,32] and less than or equal to
// [*,512,512], Syevj has better performance.
bool use_syevj = (input.dtype() == phi::DataType::FLOAT32 &&
values_stride >= 32 && values_stride <= 512);
syevjInfo_t syevj_params;
if (use_syevj) {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnCreateSyevjInfo(&syevj_params));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj_bufferSize(
dev_ctx.cusolver_dn_handle(),
jobz,
uplo,
n,
reinterpret_cast<const float *>(input_vector),
lda,
reinterpret_cast<const float *>(out_value),
&lwork,
syevj_params));
} else {
EvdBuffer(dev_ctx.cusolver_dn_handle(),
jobz,
uplo,
n,
input_vector,
lda,
out_value,
&lwork);
}
auto work = paddle::memory::Alloc(dev_ctx, sizeof(T) * lwork);
auto *work_ptr = reinterpret_cast<T *>(work->ptr());
for (auto i = 0; i < batch_size; i++) {
auto *input_data = input_vector + i * vector_stride;
auto *value_data = out_value + i * values_stride;
auto handle = dev_ctx.cusolver_dn_handle();
if (use_syevj) {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnSsyevj(handle,
jobz,
uplo,
n,
reinterpret_cast<float *>(input_data),
lda,
reinterpret_cast<float *>(value_data),
reinterpret_cast<float *>(work_ptr),
lwork,
info_ptr,
syevj_params));
} else {
Evd(handle,
jobz,
uplo,
n,
input_data,
lda,
value_data,
work_ptr,
lwork,
info_ptr);
}
int error_info = 0;
paddle::memory::Copy(phi::CPUPlace(),
&error_info,
dev_ctx.GetPlace(),
info_ptr,
sizeof(int),
dev_ctx.stream());
CheckEighResult(i, error_info);
}
if (use_syevj) {
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnDestroySyevjInfo(syevj_params));
}
if (has_vectors) {
PADDLE_ENFORCE_NOT_NULL(eigen_vectors,
phi::errors::InvalidArgument(
"When has_vectors is true,"
"the eigenvectors needs to be calculated,"
"so the eigenvectors must be provided."));
// input_trans = dito.Transpose(input_trans);
input_trans = phi::TransposeLast2Dim<T>(dev_ctx, input_trans);
eigen_vectors->ShareDataWith(input_trans);
}
}
using ValueType = phi::dtype::Real<T>;
inline void EvdBuffer(cusolverDnHandle_t handle,
cusolverEigMode_t jobz,
cublasFillMode_t uplo,
int n,
const T *A,
int lda,
const ValueType *W,
int *lwork) const;
inline void Evd(cusolverDnHandle_t handle,
cusolverEigMode_t jobz,
cublasFillMode_t uplo,
int n,
T *A,
int lda,
ValueType *W,
T *work,
int lwork,
int *devInfo) const;
};
using phi::dtype::complex;
#define FUNC_WITH_TYPES(m) \
m(float, Ssy, float) m(double, Dsy, double) m( \
complex<float>, Che, cuComplex) m(complex<double>, Zhe, cuDoubleComplex)
#define EVDBUFFER_INSTANCE(T, C, CastType) \
template <> \
inline void MatrixEighFunctor<GPUContext, T>::EvdBuffer( \
cusolverDnHandle_t handle, \
cusolverEigMode_t jobz, \
cublasFillMode_t uplo, \
int n, \
const T *A, \
int lda, \
const ValueType *W, \
int *lwork) const { \
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDn##C##evd_bufferSize( \
handle, \
jobz, \
uplo, \
n, \
reinterpret_cast<const CastType *>(A), \
lda, \
W, \
lwork)); \
}
FUNC_WITH_TYPES(EVDBUFFER_INSTANCE);
#define EVD_INSTANCE(T, C, CastType) \
template <> \
inline void MatrixEighFunctor<GPUContext, T>::Evd(cusolverDnHandle_t handle, \
cusolverEigMode_t jobz, \
cublasFillMode_t uplo, \
int n, \
T *A, \
int lda, \
ValueType *W, \
T *work, \
int lwork, \
int *devInfo) const { \
PADDLE_ENFORCE_GPU_SUCCESS( \
dynload::cusolverDn##C##evd(handle, \
jobz, \
uplo, \
n, \
reinterpret_cast<CastType *>(A), \
lda, \
W, \
reinterpret_cast<CastType *>(work), \
lwork, \
devInfo)); \
}
FUNC_WITH_TYPES(EVD_INSTANCE);
#undef FUNC_WITH_TYPES
#undef EVDBUFFER_INSTANCE
#undef EVD_INSTANCE
#endif // PADDLE_WITH_CUDA
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/eigh_grad_kernel.h"
#include "paddle/phi/kernels/impl/eigh_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
PD_REGISTER_KERNEL(eigh_grad,
GPU,
ALL_LAYOUT,
phi::EighGradKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include "paddle/phi/kernels/eigh_kernel.h"
#include "paddle/phi/kernels/funcs/values_vectors_functor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
namespace phi {
template <typename T, typename Context>
void EighKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::string& uplo,
DenseTensor* out_w,
DenseTensor* out_v) {
bool is_lower = (uplo == "L");
phi::funcs::MatrixEighFunctor<Context, T> functor;
functor(dev_ctx, x, out_w, out_v, is_lower, true);
}
} // namespace phi
PD_REGISTER_KERNEL(eigh, // cuda_only
GPU,
ALL_LAYOUT,
phi::EighKernel,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
#endif // not PADDLE_WITH_HIP
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/device_context.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/diag_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/unsqueeze.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T, typename Context>
void EighGradKernel(const Context& dev_ctx,
const DenseTensor& out_w,
const DenseTensor& out_v,
const DenseTensor& dout_w,
const DenseTensor& dout_v,
DenseTensor* dx) {
dev_ctx.template Alloc<T>(dx);
auto& dims = out_v.dims();
const int m = dims[dims.size() - 1];
DenseTensor tV =
phi::TransposeLast2Dim<T>(dev_ctx, phi::Conj<T>(dev_ctx, out_v));
DenseTensor W =
phi::Subtract<phi::dtype::Real<T>>(dev_ctx,
phi::funcs::Unsqueeze(out_w, -2),
phi::funcs::Unsqueeze(out_w, -1));
DenseTensor result = phi::Matmul<T>(dev_ctx, tV, dout_v);
result.Resize(dims);
dev_ctx.template Alloc<T>(&result);
std::vector<int> out_shape = phi::vectorize<int>(dims);
DenseTensor constant;
constant.Resize(phi::make_ddim(out_shape));
dev_ctx.template Alloc<T>(&constant);
phi::funcs::SetConstant<Context, T>()(dev_ctx, &constant, T(0.5));
result = phi::Subtract<T>(
dev_ctx,
result,
phi::Conj<T>(dev_ctx, phi::TransposeLast2Dim<T>(dev_ctx, result)));
result = phi::Multiply<T>(dev_ctx, result, constant);
if (result.type() != W.type()) {
auto x_vector = EigenVector<T>::Flatten(result);
auto y_vector = EigenVector<phi::dtype::Real<T>>::Flatten(W);
auto out_vector = EigenVector<T>::Flatten(result);
auto& place = *dev_ctx.eigen_device();
out_vector.device(place) = x_vector / y_vector;
} else {
result = phi::Divide<T>(dev_ctx, result, W);
}
result = phi::funcs::DiagFill<T, phi::dtype::Real<T>>(
dev_ctx, m, m, m, 0, dout_w, result);
*dx = phi::Matmul<T>(dev_ctx, out_v, phi::Matmul<T>(dev_ctx, result, tV));
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EighGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("eigh_grad",
{"Eigenvalues",
"Eigenvectors",
GradVarName("Eigenvalues"),
GradVarName("Eigenvectors")},
{},
{GradVarName("X")});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(eigh_grad, phi::EighGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册