From 995b5f2c0f0135ea00119f4a247c9bd7385913c0 Mon Sep 17 00:00:00 2001 From: zhulei <563755780@qq.com> Date: Wed, 14 Apr 2021 11:14:54 +0800 Subject: [PATCH] fix matrix_inverse_op with rocm (#32128) * fix matrix_inverse_op with rocm * fix matrix_inverse_op with rocm * fix matrix_inverse_op with rocm * fix matrix_inverse_op with rocm --- paddle/fluid/operators/math/matrix_inverse.cc | 26 +-------------- .../fluid/operators/math/matrix_inverse.cu.cc | 5 +++ paddle/fluid/operators/math/matrix_inverse.h | 32 +++++++++++++++++++ 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/operators/math/matrix_inverse.cc b/paddle/fluid/operators/math/matrix_inverse.cc index 25bc5d725e1..60481491cb4 100644 --- a/paddle/fluid/operators/math/matrix_inverse.cc +++ b/paddle/fluid/operators/math/matrix_inverse.cc @@ -23,34 +23,10 @@ namespace math { template class MatrixInverseFunctor { - using Matrix = - Eigen::Matrix; - using EigenMatrixMap = Eigen::Map; - using ConstEigenMatrixMap = Eigen::Map; - public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& a, framework::Tensor* a_inv) { - const auto& mat_dims = a.dims(); - const int rank = mat_dims.size(); - int n = mat_dims[rank - 1]; - int batch_size = rank > 2 ? a.numel() / (n * n) : 1; - - const T* a_ptr = a.data(); - T* a_inv_ptr = a_inv->mutable_data(context.GetPlace()); - - for (int i = 0; i < batch_size; ++i) { - ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n); - EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n); - Eigen::PartialPivLU lu; - lu.compute(mat); - - const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff(); - PADDLE_ENFORCE_GT( - min_abs_pivot, static_cast(0), - platform::errors::InvalidArgument("Input is not invertible.")); - mat_inv.noalias() = lu.inverse(); - } + compute_inverse_eigen(context, a, a_inv); } }; diff --git a/paddle/fluid/operators/math/matrix_inverse.cu.cc b/paddle/fluid/operators/math/matrix_inverse.cu.cc index 7f5df114680..5deedf084c6 100644 --- a/paddle/fluid/operators/math/matrix_inverse.cu.cc +++ b/paddle/fluid/operators/math/matrix_inverse.cu.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/math/matrix_inverse.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { @@ -32,6 +33,7 @@ class MatrixInverseFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& a, framework::Tensor* a_inv) { +#ifndef PADDLE_WITH_HIP const auto& mat_dims = a.dims(); const int rank = mat_dims.size(); int n = mat_dims[rank - 1]; @@ -111,6 +113,9 @@ class MatrixInverseFunctor { "non-singular matrix", i, info[i], info[i])); } +#else + compute_inverse_eigen(context, a, a_inv); +#endif } }; diff --git a/paddle/fluid/operators/math/matrix_inverse.h b/paddle/fluid/operators/math/matrix_inverse.h index f0baf0b250e..fb58b483666 100644 --- a/paddle/fluid/operators/math/matrix_inverse.h +++ b/paddle/fluid/operators/math/matrix_inverse.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include +#include "Eigen/Core" +#include "Eigen/LU" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -22,6 +24,36 @@ namespace paddle { namespace operators { namespace math { +template +void compute_inverse_eigen(const DeviceContext& context, + const framework::Tensor& a, + framework::Tensor* a_inv) { + using Matrix = + Eigen::Matrix; + using EigenMatrixMap = Eigen::Map; + using ConstEigenMatrixMap = Eigen::Map; + const auto& mat_dims = a.dims(); + const int rank = mat_dims.size(); + int n = mat_dims[rank - 1]; + int batch_size = rank > 2 ? a.numel() / (n * n) : 1; + + const T* a_ptr = a.data(); + T* a_inv_ptr = a_inv->mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; ++i) { + ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n); + EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n); + Eigen::PartialPivLU lu; + lu.compute(mat); + + const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff(); + PADDLE_ENFORCE_GT( + min_abs_pivot, static_cast(0), + platform::errors::InvalidArgument("Input is not invertible.")); + mat_inv.noalias() = lu.inverse(); + } +} + template class MatrixInverseFunctor { public: -- GitLab