From 7024ade70597962aad8e7f7cf77b174fa821ee13 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 8 Mar 2022 15:54:32 +0800 Subject: [PATCH] [Phi] Move matrix inverse into phi (#40237) * move matrix inverse into phi * change license year --- paddle/fluid/operators/determinant_op.h | 6 +- paddle/fluid/operators/inverse_op.h | 4 +- paddle/fluid/operators/math/CMakeLists.txt | 1 - paddle/fluid/operators/math/matrix_inverse.cc | 38 ----- .../fluid/operators/math/matrix_inverse.cu.cc | 124 --------------- paddle/fluid/operators/matrix_power_op.h | 6 +- paddle/phi/kernels/funcs/CMakeLists.txt | 1 + paddle/phi/kernels/funcs/matrix_inverse.cc | 37 +++++ paddle/phi/kernels/funcs/matrix_inverse.cu.cc | 141 ++++++++++++++++++ .../kernels/funcs}/matrix_inverse.h | 41 ++--- 10 files changed, 208 insertions(+), 191 deletions(-) delete mode 100644 paddle/fluid/operators/math/matrix_inverse.cc delete mode 100644 paddle/fluid/operators/math/matrix_inverse.cu.cc create mode 100644 paddle/phi/kernels/funcs/matrix_inverse.cc create mode 100644 paddle/phi/kernels/funcs/matrix_inverse.cu.cc rename paddle/{fluid/operators/math => phi/kernels/funcs}/matrix_inverse.h (61%) diff --git a/paddle/fluid/operators/determinant_op.h b/paddle/fluid/operators/determinant_op.h index 375ef4344f4..463a707ccf1 100644 --- a/paddle/fluid/operators/determinant_op.h +++ b/paddle/fluid/operators/determinant_op.h @@ -19,11 +19,11 @@ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/matrix_inverse.h" #include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/matrix_inverse.h" namespace paddle { namespace operators { @@ -226,7 +226,7 @@ class DeterminantGradKernel : public framework::OpKernel { inverse_A.Resize(input->dims()); inverse_A.mutable_data(context.GetPlace()); - math::MatrixInverseFunctor mat_inv; + phi::funcs::MatrixInverseFunctor mat_inv; mat_inv(dev_ctx, *input, &inverse_A); VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); @@ -381,7 +381,7 @@ class SlogDeterminantGradKernel : public framework::OpKernel { inverse_A.Resize(input->dims()); inverse_A.mutable_data(context.GetPlace()); - math::MatrixInverseFunctor mat_inv; + phi::funcs::MatrixInverseFunctor mat_inv; mat_inv(dev_ctx, *input, &inverse_A); VLOG(3) << "inverse(A) dims: " << inverse_A.dims(); diff --git a/paddle/fluid/operators/inverse_op.h b/paddle/fluid/operators/inverse_op.h index 1e061d8b50a..31c22915ec5 100644 --- a/paddle/fluid/operators/inverse_op.h +++ b/paddle/fluid/operators/inverse_op.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/matrix_inverse.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/matrix_inverse.h" namespace paddle { namespace operators { @@ -30,7 +30,7 @@ class InverseKernel : public framework::OpKernel { output->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); - math::MatrixInverseFunctor mat_inv; + phi::funcs::MatrixInverseFunctor mat_inv; mat_inv(dev_ctx, *input, output); } }; diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index bce927c32dd..d5a86d62b41 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -46,7 +46,6 @@ math_library(vol2col) math_library(prelu) math_library(bert_encoder_functor) math_library(tree2col DEPS math_function) -math_library(matrix_inverse) math_library(segment_pooling) math_library(matrix_solve) diff --git a/paddle/fluid/operators/math/matrix_inverse.cc b/paddle/fluid/operators/math/matrix_inverse.cc deleted file mode 100644 index 1b36e615c68..00000000000 --- a/paddle/fluid/operators/math/matrix_inverse.cc +++ /dev/null @@ -1,38 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/matrix_inverse.h" -#include "Eigen/Core" -#include "Eigen/LU" -#include "paddle/phi/kernels/funcs/blas/blas.h" - -namespace paddle { -namespace operators { -namespace math { - -template -class MatrixInverseFunctor { - public: - void operator()(const platform::CPUDeviceContext& context, - const framework::Tensor& a, framework::Tensor* a_inv) { - compute_inverse_eigen(context, a, a_inv); - } -}; - -template class MatrixInverseFunctor; -template class MatrixInverseFunctor; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_inverse.cu.cc b/paddle/fluid/operators/math/matrix_inverse.cu.cc deleted file mode 100644 index 41335a69417..00000000000 --- a/paddle/fluid/operators/math/matrix_inverse.cu.cc +++ /dev/null @@ -1,124 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/matrix_inverse.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" - -namespace paddle { -namespace platform { -class CUDADeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -namespace math { - -template -class MatrixInverseFunctor; - -template -class MatrixInverseFunctor { - public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& a, framework::Tensor* a_inv) { -#ifndef PADDLE_WITH_HIP - const auto& mat_dims = a.dims(); - const int rank = mat_dims.size(); - int n = mat_dims[rank - 1]; - int batch_size = rank > 2 ? a.numel() / (n * n) : 1; - - memory::allocation::AllocationPtr tmp_gpu_mat_data; - const T* gpu_mat = a.data(); - if (n >= 32) { - // Copy all elements of input matrix A to a temporary memory space to - // avoid being overriden by getrf. - tmp_gpu_mat_data = memory::Alloc(context, a.numel() * sizeof(T)); - memory::Copy(context.GetPlace(), tmp_gpu_mat_data->ptr(), - context.GetPlace(), a.data(), a.numel() * sizeof(T), - context.stream()); - gpu_mat = reinterpret_cast(tmp_gpu_mat_data->ptr()); - } - - std::vector cpu_ptrs(batch_size * 2); - for (int i = 0; i < batch_size; ++i) { - cpu_ptrs[i] = gpu_mat + i * n * n; - cpu_ptrs[i + batch_size] = a_inv->data() + i * n * n; - } - - // Copy the addresses of A and A_inv from host to device. - memory::allocation::AllocationPtr tmp_gpu_ptrs_data = - memory::Alloc(context, cpu_ptrs.size() * sizeof(T*)); - memory::Copy(context.GetPlace(), tmp_gpu_ptrs_data->ptr(), - platform::CPUPlace(), static_cast(cpu_ptrs.data()), - cpu_ptrs.size() * sizeof(T*), context.stream()); - T** gpu_inv_ptrs = - reinterpret_cast(tmp_gpu_ptrs_data->ptr()) + batch_size; - - // Allocate device memory for info and pivots. - int num_ints = n < 32 ? batch_size : batch_size * (n + 1); - memory::allocation::AllocationPtr tmp_gpu_info_data = - memory::Alloc(context, num_ints * sizeof(int)); - int* gpu_info_ptr = reinterpret_cast(tmp_gpu_info_data->ptr()); - - auto blas = phi::funcs::GetBlas(context); - - std::vector info; // only for singular checking - info.resize(batch_size); - // This functions in cuBLAS is intended to be used for matrices of small - // sizes where the launch overhead is a significant factor. - // TODO(Xreki): call function in cusolver for large matrices. - if (n < 32) { - // cublasmatinvBatched is a short cut of cublasgetrfBatched - // plus cublasgetriBatched. - // However it only works if N is less than 32. If not, we need to - // go through cublasgetrfBatched and cublasgetriBatched. - blas.BatchedMatInv(n, - reinterpret_cast(tmp_gpu_ptrs_data->ptr()), - gpu_inv_ptrs, gpu_info_ptr, batch_size); - } else { - // This function performs the LU factorization of each matrix A by the - // equation P * A = L * U. L and U are written back to original matrix A, - // and diagonal elements of L are discarded. - int* gpu_pivot_ptr = - reinterpret_cast(tmp_gpu_info_data->ptr()) + batch_size; - blas.BatchedGETRF(n, reinterpret_cast(tmp_gpu_ptrs_data->ptr()), - gpu_pivot_ptr, gpu_info_ptr, batch_size); - - blas.BatchedGETRI(n, - reinterpret_cast(tmp_gpu_ptrs_data->ptr()), - gpu_pivot_ptr, gpu_inv_ptrs, gpu_info_ptr, batch_size); - } - memory::Copy(platform::CPUPlace(), info.data(), context.GetPlace(), - gpu_info_ptr, sizeof(int) * batch_size, context.stream()); - for (int i = 0; i < batch_size; ++i) { - PADDLE_ENFORCE_EQ(info[i], 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: U(%d, %d) is zero, singular U. " - "Please check the matrix value and change it to a " - "non-singular matrix", - i, info[i], info[i])); - } -#else - compute_inverse_eigen(context, a, a_inv); -#endif - } -}; - -template class MatrixInverseFunctor; -template class MatrixInverseFunctor; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/matrix_power_op.h b/paddle/fluid/operators/matrix_power_op.h index d2c67d80b4f..8eb9c58513d 100644 --- a/paddle/fluid/operators/matrix_power_op.h +++ b/paddle/fluid/operators/matrix_power_op.h @@ -18,9 +18,9 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/math/matrix_inverse.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/matrix_inverse.h" namespace paddle { namespace operators { @@ -67,7 +67,7 @@ void MatrixPowerFunction(const Tensor* X, const int n, Tensor* Out, framework::TensorCopy(*X, ctx.GetPlace(), dev_ctx, &new_x); } else { // newX = X^{-1}, n = -n - math::MatrixInverseFunctor mat_inv; + phi::funcs::MatrixInverseFunctor mat_inv; mat_inv(dev_ctx, *X, &new_x); new_n = -n; } @@ -200,7 +200,7 @@ void MatrixPowerGradFunction(const Tensor* X, const Tensor* Out, framework::TensorCopy(*X, ctx.GetPlace(), dev_ctx, &new_x); } else { // newX = X^{-1}, n = -n - math::MatrixInverseFunctor mat_inv; + phi::funcs::MatrixInverseFunctor mat_inv; mat_inv(dev_ctx, *X, &new_x); new_n = -n; } diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index 02cba6009c4..f0fbb7bf084 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -9,3 +9,4 @@ math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) math_library(concat_and_split_functor DEPS dense_tensor) math_library(matrix_reduce DEPS dense_tensor) +math_library(matrix_inverse DEPS dense_tensor eigen3 blas) diff --git a/paddle/phi/kernels/funcs/matrix_inverse.cc b/paddle/phi/kernels/funcs/matrix_inverse.cc new file mode 100644 index 00000000000..c95e97f8ea8 --- /dev/null +++ b/paddle/phi/kernels/funcs/matrix_inverse.cc @@ -0,0 +1,37 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/funcs/matrix_inverse.h" + +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { +namespace funcs { + +template +void MatrixInverseFunctor::operator()(const Context& dev_ctx, + const DenseTensor& a, + DenseTensor* a_inv) { + ComputeInverseEigen(dev_ctx, a, a_inv); +} + +template class MatrixInverseFunctor; +template class MatrixInverseFunctor; + +// TODO(chenweihang): remove these instantiations later +template class MatrixInverseFunctor; +template class MatrixInverseFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/matrix_inverse.cu.cc b/paddle/phi/kernels/funcs/matrix_inverse.cu.cc new file mode 100644 index 00000000000..686b8405bf7 --- /dev/null +++ b/paddle/phi/kernels/funcs/matrix_inverse.cu.cc @@ -0,0 +1,141 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/funcs/matrix_inverse.h" + +#include "paddle/phi/kernels/funcs/blas/blas.h" + +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/memory/memcpy.h" + +namespace phi { +namespace funcs { + +template +void MatrixInverseFunctor::operator()(const Context& dev_ctx, + const DenseTensor& a, + DenseTensor* a_inv) { +#ifndef PADDLE_WITH_HIP + const auto& mat_dims = a.dims(); + const int rank = mat_dims.size(); + int n = mat_dims[rank - 1]; + int batch_size = rank > 2 ? a.numel() / (n * n) : 1; + + paddle::memory::allocation::AllocationPtr tmp_gpu_mat_data; + const T* gpu_mat = a.data(); + if (n >= 32) { + // Copy all elements of input matrix A to a temporary memory space to + // avoid being overriden by getrf. + tmp_gpu_mat_data = paddle::memory::Alloc(dev_ctx, a.numel() * sizeof(T)); + paddle::memory::Copy(dev_ctx.GetPlace(), + tmp_gpu_mat_data->ptr(), + dev_ctx.GetPlace(), + a.data(), + a.numel() * sizeof(T), + dev_ctx.stream()); + gpu_mat = reinterpret_cast(tmp_gpu_mat_data->ptr()); + } + + std::vector cpu_ptrs(batch_size * 2); + for (int i = 0; i < batch_size; ++i) { + cpu_ptrs[i] = gpu_mat + i * n * n; + cpu_ptrs[i + batch_size] = a_inv->data() + i * n * n; + } + + // Copy the addresses of A and A_inv from host to device. + paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + paddle::memory::Alloc(dev_ctx, cpu_ptrs.size() * sizeof(T*)); + paddle::memory::Copy(dev_ctx.GetPlace(), + tmp_gpu_ptrs_data->ptr(), + phi::CPUPlace(), + static_cast(cpu_ptrs.data()), + cpu_ptrs.size() * sizeof(T*), + dev_ctx.stream()); + T** gpu_inv_ptrs = + reinterpret_cast(tmp_gpu_ptrs_data->ptr()) + batch_size; + + // Allocate device memory for info and pivots. + int num_ints = n < 32 ? batch_size : batch_size * (n + 1); + paddle::memory::allocation::AllocationPtr tmp_gpu_info_data = + paddle::memory::Alloc(dev_ctx, num_ints * sizeof(int)); + int* gpu_info_ptr = reinterpret_cast(tmp_gpu_info_data->ptr()); + + auto blas = phi::funcs::GetBlas(dev_ctx); + + std::vector info; // only for singular checking + info.resize(batch_size); + // This functions in cuBLAS is intended to be used for matrices of small + // sizes where the launch overhead is a significant factor. + // TODO(Xreki): call function in cusolver for large matrices. + if (n < 32) { + // cublasmatinvBatched is a short cut of cublasgetrfBatched + // plus cublasgetriBatched. + // However it only works if N is less than 32. If not, we need to + // go through cublasgetrfBatched and cublasgetriBatched. + blas.BatchedMatInv(n, + reinterpret_cast(tmp_gpu_ptrs_data->ptr()), + gpu_inv_ptrs, + gpu_info_ptr, + batch_size); + } else { + // This function performs the LU factorization of each matrix A by the + // equation P * A = L * U. L and U are written back to original matrix A, + // and diagonal elements of L are discarded. + int* gpu_pivot_ptr = + reinterpret_cast(tmp_gpu_info_data->ptr()) + batch_size; + blas.BatchedGETRF(n, + reinterpret_cast(tmp_gpu_ptrs_data->ptr()), + gpu_pivot_ptr, + gpu_info_ptr, + batch_size); + + blas.BatchedGETRI(n, + reinterpret_cast(tmp_gpu_ptrs_data->ptr()), + gpu_pivot_ptr, + gpu_inv_ptrs, + gpu_info_ptr, + batch_size); + } + paddle::memory::Copy(phi::CPUPlace(), + info.data(), + dev_ctx.GetPlace(), + gpu_info_ptr, + sizeof(int) * batch_size, + dev_ctx.stream()); + for (int i = 0; i < batch_size; ++i) { + PADDLE_ENFORCE_EQ(info[i], + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: U(%d, %d) is zero, singular U. " + "Please check the matrix value and change it to a " + "non-singular matrix", + i, + info[i], + info[i])); + } +#else + ComputeInverseEigen(dev_ctx, a, a_inv); +#endif +} + +template class MatrixInverseFunctor; +template class MatrixInverseFunctor; + +// TODO(chenweihang): remove these instantiations later +template class MatrixInverseFunctor; +template class MatrixInverseFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/matrix_inverse.h b/paddle/phi/kernels/funcs/matrix_inverse.h similarity index 61% rename from paddle/fluid/operators/math/matrix_inverse.h rename to paddle/phi/kernels/funcs/matrix_inverse.h index fb58b483666..c5b04a81065 100644 --- a/paddle/fluid/operators/math/matrix_inverse.h +++ b/paddle/phi/kernels/funcs/matrix_inverse.h @@ -1,4 +1,4 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,17 +17,18 @@ limitations under the License. */ #include #include "Eigen/Core" #include "Eigen/LU" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" -namespace paddle { -namespace operators { -namespace math { +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" -template -void compute_inverse_eigen(const DeviceContext& context, - const framework::Tensor& a, - framework::Tensor* a_inv) { +namespace phi { +namespace funcs { + +template +void ComputeInverseEigen(const Context& dev_ctx, + const DenseTensor& a, + DenseTensor* a_inv) { using Matrix = Eigen::Matrix; using EigenMatrixMap = Eigen::Map; @@ -38,7 +39,7 @@ void compute_inverse_eigen(const DeviceContext& context, int batch_size = rank > 2 ? a.numel() / (n * n) : 1; const T* a_ptr = a.data(); - T* a_inv_ptr = a_inv->mutable_data(context.GetPlace()); + T* a_inv_ptr = a_inv->mutable_data(dev_ctx.GetPlace()); for (int i = 0; i < batch_size; ++i) { ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n); @@ -47,20 +48,20 @@ void compute_inverse_eigen(const DeviceContext& context, lu.compute(mat); const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff(); - PADDLE_ENFORCE_GT( - min_abs_pivot, static_cast(0), - platform::errors::InvalidArgument("Input is not invertible.")); + PADDLE_ENFORCE_GT(min_abs_pivot, + static_cast(0), + errors::InvalidArgument("Input is not invertible.")); mat_inv.noalias() = lu.inverse(); } } -template +template class MatrixInverseFunctor { public: - void operator()(const DeviceContext& context, const framework::Tensor& a, - framework::Tensor* a_inv); + void operator()(const Context& dev_ctx, + const DenseTensor& a, + DenseTensor* a_inv); }; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi -- GitLab