未验证 提交 b9d4285b 编写于 作者: C crystal 提交者: GitHub

【phi】migrate matrix_rank to phi (#40074)

* migrate matrix_rank to phi

* migrate eigh and matrix_rank to phi

* fix matrix_rank

* optimize code

* move matrix_rank to phi

* add max functor

* migrate matrix_rank to phi

* optimize code
上级 edd97f94
......@@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/matrix_rank_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
......@@ -70,9 +69,9 @@ class MatrixRankeOp : public framework::OperatorWithKernel {
std::vector<int> x_batch_dims_array(max_dim);
std::vector<int> tol_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(dim_x_batch, dim_tol, x_batch_dims_array.data(),
tol_dims_array.data(), out_dims_array.data(),
max_dim, axis);
phi::funcs::GetBroadcastDimsArrays(
dim_x_batch, dim_tol, x_batch_dims_array.data(),
tol_dims_array.data(), out_dims_array.data(), max_dim, axis);
ctx->SetOutputDim("Out", phi::make_ddim(out_dims_array));
}
} else {
......@@ -115,141 +114,9 @@ class MatrixRankeOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
template <typename T>
void BatchEigenvalues(const T* x_data, T* eigenvalues_data, int batches,
int rows, int cols, int k) {
// Eigen::Matrix API need non-const pointer.
T* input = const_cast<T*>(x_data);
int stride = rows * cols;
for (int i = 0; i < batches; i++) {
auto m = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
input + i * stride, rows, rows);
Eigen::SelfAdjointEigenSolver<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
eigen_solver(m);
auto eigenvalues = eigen_solver.eigenvalues().cwiseAbs();
for (int j = 0; j < k; j++) {
*(eigenvalues_data + i * k + j) = eigenvalues[j];
}
}
}
template <typename T>
void BatchSVD(const T* x_data, T* eigenvalues_data, int batches, int rows,
int cols, int k) {
// Eigen::Matrix API need non-const pointer.
T* input = const_cast<T*>(x_data);
int stride = rows * cols;
Eigen::BDCSVD<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
svd;
for (int i = 0; i < batches; i++) {
auto m = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
input + i * stride, rows, cols);
svd.compute(m);
auto res_s = svd.singularValues();
for (int j = 0; j < k; j++) {
eigenvalues_data[i * k + j] = res_s[j];
}
}
}
template <typename T>
class MatrixRankCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
auto* x_data = x->data<T>();
auto* out = context.Output<Tensor>("Out");
out->mutable_data<int64_t>(context.GetPlace());
bool hermitian = context.Attr<bool>("hermitian");
auto dim_x = x->dims();
auto dim_out = out->dims();
int rows = dim_x[dim_x.size() - 2];
int cols = dim_x[dim_x.size() - 1];
int k = std::min(rows, cols);
auto numel = x->numel();
int batches = numel / (rows * cols);
bool use_default_tol = context.Attr<bool>("use_default_tol");
const Tensor* atol_tensor = nullptr;
Tensor temp_tensor;
T rtol_T = 0;
if (use_default_tol) {
framework::TensorFromVector<T>(std::vector<T>{0},
context.device_context(), &temp_tensor);
atol_tensor = &temp_tensor;
rtol_T = std::numeric_limits<T>::epsilon() * std::max(rows, cols);
} else if (context.HasInput("TolTensor")) {
atol_tensor = context.Input<Tensor>("TolTensor");
} else {
framework::TensorFromVector<T>(std::vector<T>{context.Attr<float>("tol")},
context.device_context(), &temp_tensor);
atol_tensor = &temp_tensor;
}
Tensor eigenvalue_tensor;
auto* eigenvalue_data = eigenvalue_tensor.mutable_data<T>(
detail::GetEigenvalueDim(dim_x, k), context.GetPlace());
if (hermitian) {
BatchEigenvalues<T>(x_data, eigenvalue_data, batches, rows, cols, k);
} else {
BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols, k);
}
auto dito_T =
math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext, T>(
context);
std::vector<int> max_eigenvalue_shape =
phi::vectorize<int>(detail::RemoveLastDim(eigenvalue_tensor.dims()));
Tensor max_eigenvalue_tensor =
dito_T.ReduceMax(eigenvalue_tensor, max_eigenvalue_shape);
Tensor temp_rtol_tensor;
framework::TensorFromVector<T>(std::vector<T>{rtol_T}, &temp_rtol_tensor);
Tensor rtol_tensor = dito_T.Mul(temp_rtol_tensor, max_eigenvalue_tensor);
Tensor tol_tensor;
tol_tensor.mutable_data<T>(dim_out, context.GetPlace());
ElementwiseComputeEx<GreaterElementFunctor<T>, platform::CPUDeviceContext,
T, T>(context, atol_tensor, &rtol_tensor, -1,
GreaterElementFunctor<T>(), &tol_tensor);
tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1));
Tensor compare_result;
compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
context.GetPlace());
int axis = -1;
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) {
ElementwiseComputeEx<phi::funcs::GreaterThanFunctor<T, int64_t>,
platform::CPUDeviceContext, T, int>(
context, &eigenvalue_tensor, &tol_tensor, axis,
phi::funcs::GreaterThanFunctor<T, int64_t>(), &compare_result);
} else {
ElementwiseComputeEx<phi::funcs::LessThanFunctor<T, int64_t>,
platform::CPUDeviceContext, T, int>(
context, &eigenvalue_tensor, &tol_tensor, axis,
phi::funcs::LessThanFunctor<T, int64_t>(), &compare_result);
}
auto dito_int =
math::DeviceIndependenceTensorOperations<platform::CPUDeviceContext,
int64_t>(context);
std::vector<int> result_shape = phi::vectorize<int>(dim_out);
Tensor result = dito_int.ReduceSum(compare_result, result_shape);
out->ShareDataWith(result);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(matrix_rank, ops::MatrixRankeOp, ops::MatrixRankeOpMaker);
REGISTER_OP_CPU_KERNEL(matrix_rank, ops::MatrixRankCPUKernel<float>,
ops::MatrixRankCPUKernel<double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/matrix_rank_op.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle {
namespace operators {
namespace detail {
DDim GetUDDim(const DDim& x_dim, int k) {
auto x_vec = phi::vectorize(x_dim);
x_vec[x_vec.size() - 1] = k;
return phi::make_ddim(x_vec);
}
DDim GetVHDDim(const DDim& x_dim, int k) {
auto x_vec = phi::vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
return phi::make_ddim(x_vec);
}
} // namespace detail
template <typename T>
class MatrixRankGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx =
context.template device_context<platform::CUDADeviceContext>();
const Tensor* x = context.Input<Tensor>("X");
auto* x_data = x->data<T>();
auto* out = context.Output<Tensor>("Out");
out->mutable_data<int64_t>(context.GetPlace());
bool hermitian = context.Attr<bool>("hermitian");
auto dim_x = x->dims();
auto dim_out = out->dims();
int rows = dim_x[dim_x.size() - 2];
int cols = dim_x[dim_x.size() - 1];
int k = std::min(rows, cols);
auto numel = x->numel();
int batches = numel / (rows * cols);
bool use_default_tol = context.Attr<bool>("use_default_tol");
const Tensor* atol_tensor = nullptr;
Tensor temp_tensor;
T rtol_T = 0;
if (use_default_tol) {
framework::TensorFromVector<T>(std::vector<T>{0},
context.device_context(), &temp_tensor);
atol_tensor = &temp_tensor;
rtol_T = std::numeric_limits<T>::epsilon() * std::max(rows, cols);
} else if (context.HasInput("TolTensor")) {
atol_tensor = context.Input<Tensor>("TolTensor");
} else {
framework::TensorFromVector<T>(std::vector<T>{context.Attr<float>("tol")},
context.device_context(), &temp_tensor);
atol_tensor = &temp_tensor;
}
// Must Copy X once, because the gesvdj will destory the content when exit.
Tensor x_tmp;
paddle::framework::TensorCopy(*x, context.GetPlace(), &x_tmp);
auto info = memory::Alloc(dev_ctx, sizeof(int) * batches);
int* info_ptr = reinterpret_cast<int*>(info->ptr());
Tensor eigenvalue_tensor;
auto* eigenvalue_data = eigenvalue_tensor.mutable_data<T>(
detail::GetEigenvalueDim(dim_x, k), context.GetPlace());
if (hermitian) {
SyevjBatched(dev_ctx, batches, rows, x_tmp.data<T>(), eigenvalue_data,
info_ptr);
platform::ForRange<platform::CUDADeviceContext> for_range(
dev_ctx, eigenvalue_tensor.numel());
phi::funcs::AbsFunctor<T> functor(eigenvalue_data, eigenvalue_data,
eigenvalue_tensor.numel());
for_range(functor);
} else {
Tensor U, VH;
auto* u_data =
U.mutable_data<T>(detail::GetUDDim(dim_x, k), context.GetPlace());
auto* vh_data =
VH.mutable_data<T>(detail::GetVHDDim(dim_x, k), context.GetPlace());
GesvdjBatched(dev_ctx, batches, cols, rows, k, x_tmp.data<T>(), vh_data,
u_data, eigenvalue_data, info_ptr, 1);
}
auto dito_T =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
T>(context);
std::vector<int> max_eigenvalue_shape =
phi::vectorize<int>(detail::RemoveLastDim(eigenvalue_tensor.dims()));
Tensor max_eigenvalue_tensor =
dito_T.ReduceMax(eigenvalue_tensor, max_eigenvalue_shape);
Tensor temp_rtol_tensor;
framework::TensorFromVector<T>(std::vector<T>{rtol_T},
context.device_context(), &temp_rtol_tensor);
Tensor rtol_tensor = dito_T.Mul(temp_rtol_tensor, max_eigenvalue_tensor);
Tensor tol_tensor;
tol_tensor.mutable_data<T>(dim_out, context.GetPlace());
ElementwiseComputeEx<GreaterElementFunctor<T>, platform::CUDADeviceContext,
T, T>(context, atol_tensor, &rtol_tensor, -1,
GreaterElementFunctor<T>(), &tol_tensor);
tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1));
Tensor compare_result;
compare_result.mutable_data<int64_t>(detail::NewAxisDim(dim_out, k),
context.GetPlace());
int axis = -1;
ElementwiseComputeEx<phi::funcs::GreaterThanFunctor<T, int64_t>,
platform::CUDADeviceContext, T, int64_t>(
context, &eigenvalue_tensor, &tol_tensor, axis,
phi::funcs::GreaterThanFunctor<T, int64_t>(), &compare_result);
auto dito_int =
math::DeviceIndependenceTensorOperations<platform::CUDADeviceContext,
int64_t>(context);
std::vector<int> result_shape = phi::vectorize<int>(dim_out);
Tensor result = dito_int.ReduceSum(compare_result, result_shape);
out->ShareDataWith(result);
}
void GesvdjBatched(const platform::CUDADeviceContext& dev_ctx, int batchSize,
int m, int n, int k, T* A, T* U, T* V, T* S, int* info,
int thin_UV = 1) const;
void SyevjBatched(const platform::CUDADeviceContext& dev_ctx, int batchSize,
int n, T* A, T* W, int* info) const;
};
template <>
void MatrixRankGPUKernel<float>::GesvdjBatched(
const platform::CUDADeviceContext& dev_ctx, int batchSize, int m, int n,
int k, float* A, float* U, float* V, float* S, int* info,
int thin_UV) const {
// do not compute singular vectors
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR;
gesvdjInfo_t gesvdj_params = NULL;
int lda = m;
int ldu = m;
int ldt = n;
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgesvdj_bufferSize(
handle, jobz, thin_UV, m, n, A, lda, S, U, ldu, V, ldt, &lwork,
gesvdj_params));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
int stride_A = lda * n;
int stride_U = ldu * (thin_UV ? k : m);
int stride_V = ldt * (thin_UV ? k : n);
for (int i = 0; i < batchSize; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSgesvdj(
handle, jobz, thin_UV, m, n, A + stride_A * i, lda, S + k * i,
U + stride_U * i, ldu, V + stride_V * i, ldt, workspace_ptr, lwork,
info, gesvdj_params));
int error_info;
memory::Copy(platform::CPUPlace(), &error_info, dev_ctx.GetPlace(), info,
sizeof(int), dev_ctx.stream());
PADDLE_ENFORCE_EQ(
error_info, 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info));
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params));
}
template <>
void MatrixRankGPUKernel<double>::GesvdjBatched(
const platform::CUDADeviceContext& dev_ctx, int batchSize, int m, int n,
int k, double* A, double* U, double* V, double* S, int* info,
int thin_UV) const {
// do not compute singular vectors
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR;
gesvdjInfo_t gesvdj_params = NULL;
int lda = m;
int ldu = m;
int ldt = n;
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgesvdj_bufferSize(
handle, jobz, thin_UV, m, n, A, lda, S, U, ldu, V, ldt, &lwork,
gesvdj_params));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
int stride_A = lda * n;
int stride_U = ldu * (thin_UV ? k : m);
int stride_V = ldt * (thin_UV ? k : n);
for (int i = 0; i < batchSize; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDgesvdj(
handle, jobz, thin_UV, m, n, A + stride_A * i, lda, S + k * i,
U + stride_U * i, ldu, V + stride_V * i, ldt, workspace_ptr, lwork,
info, gesvdj_params));
// check the error info
int error_info;
memory::Copy(platform::CPUPlace(), &error_info, dev_ctx.GetPlace(), info,
sizeof(int), dev_ctx.stream());
PADDLE_ENFORCE_EQ(
error_info, 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info));
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params));
}
template <>
void MatrixRankGPUKernel<float>::SyevjBatched(
const platform::CUDADeviceContext& dev_ctx, int batchSize, int n, float* A,
float* W, int* info) const {
auto handle = dev_ctx.cusolver_dn_handle();
// Compute eigenvalues only
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR;
// matrix is saved as column-major in cusolver.
// numpy and torch use lower triangle to compute eigenvalues, so here use
// upper triangle
cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
int lda = n;
int stride_A = lda * n;
int lwork = 0;
syevjInfo_t params = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnCreateSyevjInfo(&params));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSsyevj_bufferSize(
handle, jobz, uplo, n, A, lda, W, &lwork, params));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
for (int i = 0; i < batchSize; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSsyevj(
handle, jobz, uplo, n, A + stride_A * i, lda, W + n * i, workspace_ptr,
lwork, info, params));
int error_info;
memory::Copy(platform::CPUPlace(), &error_info, dev_ctx.GetPlace(), info,
sizeof(int), dev_ctx.stream());
PADDLE_ENFORCE_EQ(
error_info, 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver eigenvalues is not zero. [%d]", i,
error_info));
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnDestroySyevjInfo(params));
}
template <>
void MatrixRankGPUKernel<double>::SyevjBatched(
const platform::CUDADeviceContext& dev_ctx, int batchSize, int n, double* A,
double* W, int* info) const {
auto handle = dev_ctx.cusolver_dn_handle();
// Compute eigenvalues only
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR;
// upper triangle of A is stored
cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
int lda = n;
int stride_A = lda * n;
int lwork = 0;
syevjInfo_t params = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnCreateSyevjInfo(&params));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDsyevj_bufferSize(
handle, jobz, uplo, n, A, lda, W, &lwork, params));
auto workspace = memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
for (int i = 0; i < batchSize; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDsyevj(
handle, jobz, uplo, n, A + stride_A * i, lda, W + n * i, workspace_ptr,
lwork, info, params));
int error_info;
memory::Copy(platform::CPUPlace(), &error_info, dev_ctx.GetPlace(), info,
sizeof(int), dev_ctx.stream());
PADDLE_ENFORCE_EQ(
error_info, 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: CUSolver eigenvalues is not zero. [%d]", i,
error_info));
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusolverDnDestroySyevjInfo(params));
}
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(matrix_rank, ops::MatrixRankGPUKernel<float>,
ops::MatrixRankGPUKernel<double>);
#endif // not PADDLE_WITH_HIP
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/matrix_rank_kernel.h"
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
namespace phi {
template <typename T, typename Context>
void MatrixRankKernel(const Context& dev_ctx,
const DenseTensor& x,
float tol,
bool use_default_tol,
bool hermitian,
DenseTensor* out) {
DenseTensor atol_tensor;
if (use_default_tol) {
atol_tensor = phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(0));
} else {
atol_tensor = phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(tol));
}
MatrixRankTolKernel<T, Context>(
dev_ctx, x, atol_tensor, use_default_tol, hermitian, out);
}
} // namespace phi
PD_REGISTER_KERNEL(
matrix_rank, CPU, ALL_LAYOUT, phi::MatrixRankKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
#include <Eigen/Dense>
#include <Eigen/SVD>
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/reduce.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/reduce_functor.h"
#include "paddle/phi/kernels/impl/matrix_rank_kernel_impl.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/reduce_max_kernel.h"
namespace phi {
template <typename T>
void BatchEigenvalues(const T* x_data,
T* eigenvalues_data,
int batches,
int rows,
int cols,
int k) {
// Eigen::Matrix API need non-const pointer.
T* input = const_cast<T*>(x_data);
int stride = rows * cols;
for (int i = 0; i < batches; i++) {
auto m = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
input + i * stride, rows, rows);
Eigen::SelfAdjointEigenSolver<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
eigen_solver(m);
auto eigenvalues = eigen_solver.eigenvalues().cwiseAbs();
for (int j = 0; j < k; j++) {
*(eigenvalues_data + i * k + j) = eigenvalues[j];
}
}
}
template <typename T>
void BatchSVD(const T* x_data,
T* eigenvalues_data,
int batches,
int rows,
int cols,
int k) {
// Eigen::Matrix API need non-const pointer.
T* input = const_cast<T*>(x_data);
int stride = rows * cols;
Eigen::BDCSVD<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
svd;
for (int i = 0; i < batches; i++) {
auto m = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>(
input + i * stride, rows, cols);
svd.compute(m);
auto res_s = svd.singularValues();
for (int j = 0; j < k; j++) {
eigenvalues_data[i * k + j] = res_s[j];
}
}
}
template <typename T, typename Context>
void MatrixRankTolKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& atol_tensor,
bool use_default_tol,
bool hermitian,
DenseTensor* out) {
auto* x_data = x.data<T>();
dev_ctx.template Alloc<int64_t>(out);
auto dim_x = x.dims();
auto dim_out = out->dims();
int rows = dim_x[dim_x.size() - 2];
int cols = dim_x[dim_x.size() - 1];
int k = std::min(rows, cols);
auto numel = x.numel();
int batches = numel / (rows * cols);
T rtol_T = 0;
if (use_default_tol) {
rtol_T = std::numeric_limits<T>::epsilon() * std::max(rows, cols);
}
DenseTensor eigenvalue_tensor;
eigenvalue_tensor.Resize(detail::GetEigenvalueDim(dim_x, k));
auto* eigenvalue_data = dev_ctx.template Alloc<T>(&eigenvalue_tensor);
if (hermitian) {
BatchEigenvalues<T>(x_data, eigenvalue_data, batches, rows, cols, k);
} else {
BatchSVD<T>(x_data, eigenvalue_data, batches, rows, cols, k);
}
DenseTensor max_eigenvalue_tensor;
max_eigenvalue_tensor.Resize(detail::RemoveLastDim(eigenvalue_tensor.dims()));
dev_ctx.template Alloc<T>(&max_eigenvalue_tensor);
phi::MaxKernel<T, Context>(dev_ctx,
eigenvalue_tensor,
std::vector<int64_t>{-1},
false,
&max_eigenvalue_tensor);
DenseTensor temp_rtol_tensor;
temp_rtol_tensor =
phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(rtol_T));
DenseTensor rtol_tensor =
phi::Multiply<T>(dev_ctx, temp_rtol_tensor, max_eigenvalue_tensor);
DenseTensor tol_tensor;
tol_tensor.Resize(dim_out);
dev_ctx.template Alloc<T>(&tol_tensor);
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T, T>(
dev_ctx,
atol_tensor,
rtol_tensor,
-1,
GreaterElementFunctor<T>(),
&tol_tensor);
tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1));
DenseTensor compare_result;
compare_result.Resize(detail::NewAxisDim(dim_out, k));
dev_ctx.template Alloc<int64_t>(&compare_result);
int axis = -1;
if (eigenvalue_tensor.dims().size() >= tol_tensor.dims().size()) {
funcs::ElementwiseCompute<funcs::GreaterThanFunctor<T, int64_t>, T, int>(
dev_ctx,
eigenvalue_tensor,
tol_tensor,
axis,
funcs::GreaterThanFunctor<T, int64_t>(),
&compare_result);
} else {
funcs::ElementwiseCompute<funcs::LessThanFunctor<T, int64_t>, T, int>(
dev_ctx,
eigenvalue_tensor,
tol_tensor,
axis,
funcs::LessThanFunctor<T, int64_t>(),
&compare_result);
}
phi::SumKernel<int64_t>(dev_ctx,
compare_result,
std::vector<int64_t>{-1},
compare_result.dtype(),
false,
out);
}
} // namespace phi
PD_REGISTER_KERNEL(
matrix_rank_tol, CPU, ALL_LAYOUT, phi::MatrixRankTolKernel, float, double) {
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include "paddle/phi/kernels/matrix_rank_kernel.h"
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/full_kernel.h"
namespace phi {
template <typename T, typename Context>
void MatrixRankKernel(const Context& dev_ctx,
const DenseTensor& x,
float tol,
bool use_default_tol,
bool hermitian,
DenseTensor* out) {
DenseTensor atol_tensor;
if (use_default_tol) {
atol_tensor = phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(0));
} else {
atol_tensor = phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(tol));
}
MatrixRankTolKernel<T, Context>(
dev_ctx, x, atol_tensor, use_default_tol, hermitian, out);
}
} // namespace phi
PD_REGISTER_KERNEL(matrix_rank, // cuda_only
GPU,
ALL_LAYOUT,
phi::MatrixRankKernel,
float,
double) {}
#endif // not PADDLE_WITH_HIP
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
// HIP not support cusolver
#include "paddle/phi/kernels/matrix_rank_tol_kernel.h"
#include <algorithm>
#include <vector>
#include "paddle/fluid/memory/memory.h"
#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/abs_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/compare_functors.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/for_range.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/impl/matrix_rank_kernel_impl.h"
#include "paddle/phi/kernels/math_kernel.h"
#include "paddle/phi/kernels/reduce_max_kernel.h"
namespace phi {
template <typename T>
void GesvdjBatched(const phi::GPUContext& dev_ctx,
int batchSize,
int m,
int n,
int k,
T* A,
T* U,
T* V,
T* S,
int* info,
int thin_UV = 1);
template <typename T>
void SyevjBatched(const phi::GPUContext& dev_ctx,
int batchSize,
int n,
T* A,
T* W,
int* info);
template <>
void GesvdjBatched<float>(const phi::GPUContext& dev_ctx,
int batchSize,
int m,
int n,
int k,
float* A,
float* U,
float* V,
float* S,
int* info,
int thin_UV) {
// do not compute singular vectors
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR;
gesvdjInfo_t gesvdj_params = NULL;
int lda = m;
int ldu = m;
int ldt = n;
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params));
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnSgesvdj_bufferSize(handle,
jobz,
thin_UV,
m,
n,
A,
lda,
S,
U,
ldu,
V,
ldt,
&lwork,
gesvdj_params));
auto workspace = paddle::memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
int stride_A = lda * n;
int stride_U = ldu * (thin_UV ? k : m);
int stride_V = ldt * (thin_UV ? k : n);
for (int i = 0; i < batchSize; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSgesvdj(handle,
jobz,
thin_UV,
m,
n,
A + stride_A * i,
lda,
S + k * i,
U + stride_U * i,
ldu,
V + stride_V * i,
ldt,
workspace_ptr,
lwork,
info,
gesvdj_params));
int error_info;
paddle::memory::Copy(phi::CPUPlace(),
&error_info,
dev_ctx.GetPlace(),
info,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
error_info,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info));
}
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params));
}
template <>
void GesvdjBatched<double>(const phi::GPUContext& dev_ctx,
int batchSize,
int m,
int n,
int k,
double* A,
double* U,
double* V,
double* S,
int* info,
int thin_UV) {
// do not compute singular vectors
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR;
gesvdjInfo_t gesvdj_params = NULL;
int lda = m;
int ldu = m;
int ldt = n;
int lwork = 0;
auto handle = dev_ctx.cusolver_dn_handle();
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnCreateGesvdjInfo(&gesvdj_params));
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnDgesvdj_bufferSize(handle,
jobz,
thin_UV,
m,
n,
A,
lda,
S,
U,
ldu,
V,
ldt,
&lwork,
gesvdj_params));
auto workspace = paddle::memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
int stride_A = lda * n;
int stride_U = ldu * (thin_UV ? k : m);
int stride_V = ldt * (thin_UV ? k : n);
for (int i = 0; i < batchSize; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDgesvdj(handle,
jobz,
thin_UV,
m,
n,
A + stride_A * i,
lda,
S + k * i,
U + stride_U * i,
ldu,
V + stride_V * i,
ldt,
workspace_ptr,
lwork,
info,
gesvdj_params));
// check the error info
int error_info;
paddle::memory::Copy(phi::CPUPlace(),
&error_info,
dev_ctx.GetPlace(),
info,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
error_info,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver SVD is not zero. [%d]", i, error_info));
}
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cusolverDnDestroyGesvdjInfo(gesvdj_params));
}
template <>
void SyevjBatched<float>(const phi::GPUContext& dev_ctx,
int batchSize,
int n,
float* A,
float* W,
int* info) {
auto handle = dev_ctx.cusolver_dn_handle();
// Compute eigenvalues only
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR;
// matrix is saved as column-major in cusolver.
// numpy and torch use lower triangle to compute eigenvalues, so here use
// upper triangle
cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
int lda = n;
int stride_A = lda * n;
int lwork = 0;
syevjInfo_t params = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnCreateSyevjInfo(&params));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj_bufferSize(
handle, jobz, uplo, n, A, lda, W, &lwork, params));
auto workspace = paddle::memory::Alloc(dev_ctx, lwork * sizeof(float));
float* workspace_ptr = reinterpret_cast<float*>(workspace->ptr());
for (int i = 0; i < batchSize; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSsyevj(handle,
jobz,
uplo,
n,
A + stride_A * i,
lda,
W + n * i,
workspace_ptr,
lwork,
info,
params));
int error_info;
paddle::memory::Copy(phi::CPUPlace(),
&error_info,
dev_ctx.GetPlace(),
info,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
error_info,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver eigenvalues is not zero. [%d]",
i,
error_info));
}
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDestroySyevjInfo(params));
}
template <>
void SyevjBatched<double>(const phi::GPUContext& dev_ctx,
int batchSize,
int n,
double* A,
double* W,
int* info) {
auto handle = dev_ctx.cusolver_dn_handle();
// Compute eigenvalues only
const cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_NOVECTOR;
// upper triangle of A is stored
cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER;
int lda = n;
int stride_A = lda * n;
int lwork = 0;
syevjInfo_t params = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnCreateSyevjInfo(&params));
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDsyevj_bufferSize(
handle, jobz, uplo, n, A, lda, W, &lwork, params));
auto workspace = paddle::memory::Alloc(dev_ctx, lwork * sizeof(double));
double* workspace_ptr = reinterpret_cast<double*>(workspace->ptr());
for (int i = 0; i < batchSize; i++) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDsyevj(handle,
jobz,
uplo,
n,
A + stride_A * i,
lda,
W + n * i,
workspace_ptr,
lwork,
info,
params));
int error_info;
paddle::memory::Copy(phi::CPUPlace(),
&error_info,
dev_ctx.GetPlace(),
info,
sizeof(int),
dev_ctx.stream());
PADDLE_ENFORCE_EQ(
error_info,
0,
phi::errors::PreconditionNotMet(
"For batch [%d]: CUSolver eigenvalues is not zero. [%d]",
i,
error_info));
}
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDestroySyevjInfo(params));
}
template <typename T, typename Context>
void MatrixRankTolKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& atol_tensor,
bool use_default_tol,
bool hermitian,
DenseTensor* out) {
auto* x_data = x.data<T>();
dev_ctx.template Alloc<int64_t>(out);
auto dim_x = x.dims();
auto dim_out = out->dims();
int rows = dim_x[dim_x.size() - 2];
int cols = dim_x[dim_x.size() - 1];
int k = std::min(rows, cols);
auto numel = x.numel();
int batches = numel / (rows * cols);
T rtol_T = 0;
if (use_default_tol) {
rtol_T = std::numeric_limits<T>::epsilon() * std::max(rows, cols);
}
// Must Copy X once, because the gesvdj will destory the content when exit.
DenseTensor x_tmp;
paddle::framework::TensorCopy(x, dev_ctx.GetPlace(), &x_tmp);
auto info = paddle::memory::Alloc(dev_ctx, sizeof(int) * batches);
int* info_ptr = reinterpret_cast<int*>(info->ptr());
DenseTensor eigenvalue_tensor;
eigenvalue_tensor.Resize(detail::GetEigenvalueDim(dim_x, k));
auto* eigenvalue_data = dev_ctx.template Alloc<T>(&eigenvalue_tensor);
if (hermitian) {
SyevjBatched<T>(
dev_ctx, batches, rows, x_tmp.data<T>(), eigenvalue_data, info_ptr);
phi::AbsKernel<T, Context>(dev_ctx, eigenvalue_tensor, &eigenvalue_tensor);
} else {
DenseTensor U, VH;
U.Resize(detail::GetUDDim(dim_x, k));
VH.Resize(detail::GetVHDDim(dim_x, k));
auto* u_data = dev_ctx.template Alloc<T>(&U);
auto* vh_data = dev_ctx.template Alloc<T>(&VH);
GesvdjBatched<T>(dev_ctx,
batches,
cols,
rows,
k,
x_tmp.data<T>(),
vh_data,
u_data,
eigenvalue_data,
info_ptr,
1);
}
DenseTensor max_eigenvalue_tensor;
dev_ctx.template Alloc<T>(&max_eigenvalue_tensor);
max_eigenvalue_tensor.Resize(detail::RemoveLastDim(eigenvalue_tensor.dims()));
phi::MaxKernel<T, Context>(dev_ctx,
eigenvalue_tensor,
std::vector<int64_t>{-1},
false,
&max_eigenvalue_tensor);
DenseTensor temp_rtol_tensor;
temp_rtol_tensor =
phi::Full<T, Context>(dev_ctx, {1}, static_cast<T>(rtol_T));
DenseTensor rtol_tensor =
phi::Multiply<T>(dev_ctx, temp_rtol_tensor, max_eigenvalue_tensor);
DenseTensor tol_tensor;
tol_tensor.Resize(dim_out);
dev_ctx.template Alloc<T>(&tol_tensor);
funcs::ElementwiseCompute<GreaterElementFunctor<T>, T, T>(
dev_ctx,
atol_tensor,
rtol_tensor,
-1,
GreaterElementFunctor<T>(),
&tol_tensor);
tol_tensor.Resize(detail::NewAxisDim(tol_tensor.dims(), 1));
DenseTensor compare_result;
compare_result.Resize(detail::NewAxisDim(dim_out, k));
dev_ctx.template Alloc<int64_t>(&compare_result);
int axis = -1;
funcs::ElementwiseCompute<funcs::GreaterThanFunctor<T, int64_t>, T, int64_t>(
dev_ctx,
eigenvalue_tensor,
tol_tensor,
axis,
funcs::GreaterThanFunctor<T, int64_t>(),
&compare_result);
phi::SumKernel<int64_t>(dev_ctx,
compare_result,
std::vector<int64_t>{-1},
compare_result.dtype(),
false,
out);
}
} // namespace phi
PD_REGISTER_KERNEL(matrix_rank_tol, // cuda_only
GPU,
ALL_LAYOUT,
phi::MatrixRankTolKernel,
float,
double) {}
#endif // not PADDLE_WITH_HIP
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -13,14 +13,11 @@
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/matrix_rank_kernel.h"
namespace phi {
namespace detail {
static DDim GetEigenvalueDim(const DDim& dim, int k) {
......@@ -44,6 +41,18 @@ static DDim RemoveLastDim(const DDim& dim) {
vec.erase(vec.end() - 1, vec.end());
return phi::make_ddim(vec);
}
static DDim GetUDDim(const DDim& x_dim, int k) {
auto x_vec = phi::vectorize(x_dim);
x_vec[x_vec.size() - 1] = k;
return phi::make_ddim(x_vec);
}
static DDim GetVHDDim(const DDim& x_dim, int k) {
auto x_vec = phi::vectorize(x_dim);
x_vec[x_vec.size() - 2] = k;
return phi::make_ddim(x_vec);
}
} // namespace detail
template <typename T>
......@@ -57,5 +66,4 @@ struct GreaterElementFunctor {
}
};
} // namespace operators
} // namespace paddle
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MatrixRankKernel(const Context& dev_ctx,
const DenseTensor& x,
float tol,
bool use_default_tol,
bool hermitian,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void MatrixRankTolKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& atol_tensor,
bool use_default_tol,
bool hermitian,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
// we have to return every specific KernelSignature for infrt now
KernelSignature MatrixRankOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("TolTensor")) {
return KernelSignature("matrix_rank_tol",
{"X", "TolTensor"},
{"use_default_tol", "hermitian"},
{"Out"});
} else {
return KernelSignature("matrix_rank",
{"X"},
{
"tol", "use_default_tol", "hermitian",
},
{"Out"});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(matrix_rank, phi::MatrixRankOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册