未验证 提交 995b5f2c 编写于 作者: Z zhulei 提交者: GitHub

fix matrix_inverse_op with rocm (#32128)

* fix matrix_inverse_op with rocm

* fix matrix_inverse_op with rocm

* fix matrix_inverse_op with rocm

* fix matrix_inverse_op with rocm
上级 279b653c
...@@ -23,34 +23,10 @@ namespace math { ...@@ -23,34 +23,10 @@ namespace math {
template <typename T> template <typename T>
class MatrixInverseFunctor<platform::CPUDeviceContext, T> { class MatrixInverseFunctor<platform::CPUDeviceContext, T> {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
public: public:
void operator()(const platform::CPUDeviceContext& context, void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& a, framework::Tensor* a_inv) { const framework::Tensor& a, framework::Tensor* a_inv) {
const auto& mat_dims = a.dims(); compute_inverse_eigen<platform::CPUDeviceContext, T>(context, a, a_inv);
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;
const T* a_ptr = a.data<T>();
T* a_inv_ptr = a_inv->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);
const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_GT(
min_abs_pivot, static_cast<T>(0),
platform::errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
}
} }
}; };
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/matrix_inverse.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
...@@ -32,6 +33,7 @@ class MatrixInverseFunctor<platform::CUDADeviceContext, T> { ...@@ -32,6 +33,7 @@ class MatrixInverseFunctor<platform::CUDADeviceContext, T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& a, framework::Tensor* a_inv) { const framework::Tensor& a, framework::Tensor* a_inv) {
#ifndef PADDLE_WITH_HIP
const auto& mat_dims = a.dims(); const auto& mat_dims = a.dims();
const int rank = mat_dims.size(); const int rank = mat_dims.size();
int n = mat_dims[rank - 1]; int n = mat_dims[rank - 1];
...@@ -111,6 +113,9 @@ class MatrixInverseFunctor<platform::CUDADeviceContext, T> { ...@@ -111,6 +113,9 @@ class MatrixInverseFunctor<platform::CUDADeviceContext, T> {
"non-singular matrix", "non-singular matrix",
i, info[i], info[i])); i, info[i], info[i]));
} }
#else
compute_inverse_eigen<platform::CUDADeviceContext, T>(context, a, a_inv);
#endif
} }
}; };
......
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include "Eigen/Core"
#include "Eigen/LU"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
...@@ -22,6 +24,36 @@ namespace paddle { ...@@ -22,6 +24,36 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
template <typename DeviceContext, typename T>
void compute_inverse_eigen(const DeviceContext& context,
const framework::Tensor& a,
framework::Tensor* a_inv) {
using Matrix =
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
using EigenMatrixMap = Eigen::Map<Matrix>;
using ConstEigenMatrixMap = Eigen::Map<const Matrix>;
const auto& mat_dims = a.dims();
const int rank = mat_dims.size();
int n = mat_dims[rank - 1];
int batch_size = rank > 2 ? a.numel() / (n * n) : 1;
const T* a_ptr = a.data<T>();
T* a_inv_ptr = a_inv->mutable_data<T>(context.GetPlace());
for (int i = 0; i < batch_size; ++i) {
ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n);
EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n);
Eigen::PartialPivLU<Matrix> lu;
lu.compute(mat);
const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff();
PADDLE_ENFORCE_GT(
min_abs_pivot, static_cast<T>(0),
platform::errors::InvalidArgument("Input is not invertible."));
mat_inv.noalias() = lu.inverse();
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatrixInverseFunctor { class MatrixInverseFunctor {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册