未验证 提交 995b5f2c 编写于 作者: Z zhulei 提交者: GitHub

fix matrix_inverse_op with rocm (#32128)

* fix matrix_inverse_op with rocm

* fix matrix_inverse_op with rocm

* fix matrix_inverse_op with rocm

* fix matrix_inverse_op with rocm
上级 279b653c
......@@ -23,34 +23,10 @@ namespace math {
template <typename T>
class MatrixInverseFunctor<platform::CPUDeviceContext, T> {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& a, framework::Tensor* a_inv) {
const auto& mat_dims = a.dims();
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;
const T* a_ptr = a.data<T>();
T* a_inv_ptr = a_inv->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);
const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_GT(
min_abs_pivot, static_cast<T>(0),
platform::errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
}
compute_inverse_eigen<platform::CPUDeviceContext, T>(context, a, a_inv);
}
};
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/matrix_inverse.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
......@@ -32,6 +33,7 @@ class MatrixInverseFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& a, framework::Tensor* a_inv) {
#ifndef PADDLE_WITH_HIP
const auto& mat_dims = a.dims();
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
......@@ -111,6 +113,9 @@ class MatrixInverseFunctor<platform::CUDADeviceContext, T> {
"non-singular matrix",
i, info[i], info[i]));
}
#else
compute_inverse_eigen<platform::CUDADeviceContext, T>(context, a, a_inv);
#endif
}
};
......
......@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once
#include <string>
#include "Eigen/Core"
#include "Eigen/LU"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -22,6 +24,36 @@ namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext, typename T>
void compute_inverse_eigen(const DeviceContext& context,
const framework::Tensor& a,
framework::Tensor* a_inv) {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
const auto& mat_dims = a.dims();
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;
const T* a_ptr = a.data<T>();
T* a_inv_ptr = a_inv->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);
const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_GT(
min_abs_pivot, static_cast<T>(0),
platform::errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
}
}
template <typename DeviceContext, typename T>
class MatrixInverseFunctor {
public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册