From 5f2059db05bbc0590f4971198e034a56d1aa5915 Mon Sep 17 00:00:00 2001 From: lzhao4ever Date: Thu, 3 Nov 2016 09:24:16 -0700 Subject: [PATCH] Add matrix inverse (#240) * Add matrix inverse --- cmake/cblas.cmake | 16 ++++- paddle/cuda/include/hl_cuda_cublas.h | 24 +++++-- .../cuda/include/stub/hl_cuda_cublas_stub.h | 6 ++ paddle/cuda/src/hl_cuda_cublas.cc | 55 ++++++++++++++++ paddle/math/MathFunctions.cpp | 40 ++++++++++++ paddle/math/MathFunctions.h | 15 +++++ paddle/math/Matrix.cpp | 65 +++++++++++++++++++ paddle/math/Matrix.h | 20 ++++++ paddle/math/tests/test_matrixCompare.cpp | 29 ++++++++- 9 files changed, 261 insertions(+), 9 deletions(-) diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake index 57c32a54cd..685334c658 100644 --- a/cmake/cblas.cmake +++ b/cmake/cblas.cmake @@ -1,4 +1,4 @@ -# Find the CBlas libraries +# Find the CBlas and lapack libraries # # It will search MKL, atlas, OpenBlas, reference-cblas in order. # @@ -19,6 +19,8 @@ set(MKL_ROOT $ENV{MKL_ROOT} CACHE PATH "Folder contains MKL") find_path(MKL_INCLUDE_DIR mkl.h PATHS ${MKL_ROOT}/include) +find_path(MKL_INCLUDE_DIR mkl_lapacke.h PATHS + ${MKL_ROOT}/include) find_library(MKL_CORE_LIB NAMES mkl_core PATHS ${MKL_ROOT}/lib ${MKL_ROOT}/lib/intel64) @@ -37,6 +39,7 @@ if(MKL_INCLUDE_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64) ${MKL_SEQUENTIAL_LIB} ${MKL_CORE_LIB}) add_definitions(-DPADDLE_USE_MKL) + message(STATUS "Found MKL (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBS})") return() # return file. endif() @@ -55,15 +58,19 @@ set(ATLAS_LIB_SEARCH_PATHS ) find_path(ATLAS_INC_DIR NAMES cblas.h PATHS ${ATLAS_INCLUDE_SEARCH_PATHS}) +find_path(ATLAS_CLAPACK_INC_DIR NAMES clapack.h + PATHS ${ATLAS_INCLUDE_SEARCH_PATHS}) find_library(ATLAS_CBLAS_LIB NAMES cblas libcblas.so.3 PATHS ${ATLAS_LIB_SEARCH_PATHS}) -find_library(ATLAS_LIB NAMES atlas libatlas.so.3 +find_library(ATLAS_LIB NAMES lapack_atlas liblapack_atlas.so.3 PATHS ${ATLAS_LIB_SEARCH_PATHS}) if(ATLAS_INC_DIR AND ATLAS_CBLAS_LIB AND ATLAS_LIB) set(CBLAS_PROVIDER ATLAS) - set(CBLAS_INC_DIR ${ATLAS_INC_DIR}) + set(CBLAS_INC_DIR ${ATLAS_INC_DIR} ${ATLAS_CLAPACK_INC_DIR}) set(CBLAS_LIBS ${ATLAS_LIB} ${ATLAS_CBLAS_LIB}) + add_definitions(-DPADDLE_USE_ATLAS) + message(STATUS "Found Atlas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBS})") return() endif() @@ -83,6 +90,8 @@ set(OPENBLAS_LIB_SEARCH_PATHS find_path(OPENBLAS_INC_DIR NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) +find_path(OPENBLAS_LAPACKE_INC_DIR NAMES lapacke.h + PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) find_library(OPENBLAS_LIB NAMES openblas PATHS ${OPENBLAS_LIB_SEARCH_PATHS}) @@ -90,6 +99,7 @@ if(OPENBLAS_INC_DIR AND OPENBLAS_LIB) set(CBLAS_PROVIDER OPENBLAS) set(CBLAS_INC_DIR ${OPENBLAS_INC_DIR}) set(CBLAS_LIBS ${OPENBLAS_LIB}) + message(STATUS "Found OpenBlas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBS})") return() endif() diff --git a/paddle/cuda/include/hl_cuda_cublas.h b/paddle/cuda/include/hl_cuda_cublas.h index 0ffbed18b5..d757317eb4 100644 --- a/paddle/cuda/include/hl_cuda_cublas.h +++ b/paddle/cuda/include/hl_cuda_cublas.h @@ -21,8 +21,8 @@ limitations under the License. */ /** * @brief Matrix transpose: C_d = T(A_d) * - * @param[in] A_d input matrix (M x N). - * @param[out] C_d output matrix (N x M). + * @param[in] A_d input matrix (dimM x dimN). + * @param[out] C_d output matrix (dimN x dimM). * @param[in] dimM matrix height. * @param[in] dimN matrix width. * @param[in] lda the first dimension of A_d. @@ -39,8 +39,8 @@ extern void hl_matrix_transpose(real *A_d, /* * @brief Matrix transpose, while lda = dimN, ldc = dimM. * - * @param[in] A_d input matrix (M x N). - * @param[out] C_d output matrix (N x M). + * @param[in] A_d input matrix (dimM x dimN). + * @param[out] C_d output matrix (dimN x dimM). * @param[in] dimM matrix height. * @param[in] dimN matrix width. * @@ -50,6 +50,22 @@ extern void hl_matrix_transpose(real *A_d, int dimM, int dimN); +/* + * @brief Matrix inverse + * + * @param[in] A_d input matrix (dimN x dimN). + * @param[out] C_d output matrix (dimN x dimN). + * @param[in] dimN matrix height = matrix width + * @param[in] lda the first dimension of A_d + * @param[in] ldc the first dimension of C_d + * + */ +extern void hl_matrix_inverse(real *A_d, + real *C_d, + int dimN, + int lda, + int ldc); + /** * @brief C_d = alpha*(op(A_d) * op(B_d)) + beta*C_d * diff --git a/paddle/cuda/include/stub/hl_cuda_cublas_stub.h b/paddle/cuda/include/stub/hl_cuda_cublas_stub.h index 4a5e2a25a7..903dcbe835 100644 --- a/paddle/cuda/include/stub/hl_cuda_cublas_stub.h +++ b/paddle/cuda/include/stub/hl_cuda_cublas_stub.h @@ -30,6 +30,12 @@ inline void hl_matrix_transpose(real *A_d, int dimM, int dimN) {} +inline void hl_matrix_inverse(real *A_d, + real *C_d, + int dimN, + int lda, + int ldc) {} + inline void hl_matrix_mul(real *A_d, hl_trans_op_t transa, real *B_d, hl_trans_op_t transb, real *C_d, diff --git a/paddle/cuda/src/hl_cuda_cublas.cc b/paddle/cuda/src/hl_cuda_cublas.cc index b3c9001ba3..724ea490e8 100644 --- a/paddle/cuda/src/hl_cuda_cublas.cc +++ b/paddle/cuda/src/hl_cuda_cublas.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "hl_cuda.h" #include "hl_cuda_cublas.h" #include "hl_thread.ph" #include "hl_dso_loader.h" @@ -75,6 +76,8 @@ DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched) DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched) DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched) DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched) +DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched) CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP) #undef DYNAMIC_LOAD_CUBLAS_WRAP @@ -88,10 +91,14 @@ CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP) #define CUBLAS_GEAM dynload::cublasSgeam #define CUBLAS_GEMV dynload::cublasSgemv #define CUBLAS_GEMM dynload::cublasSgemm +#define CUBLAS_GETRF dynload::cublasSgetrfBatched +#define CUBLAS_GETRI dynload::cublasSgetriBatched #else #define CUBLAS_GEAM dynload::cublasDgeam #define CUBLAS_GEMV dynload::cublasDgemv #define CUBLAS_GEMM dynload::cublasDgemm +#define CUBLAS_GETRF dynload::cublasDgetrfBatched +#define CUBLAS_GETRI dynload::cublasDgetriBatched #endif const char* hl_cublas_get_error_string(cublasStatus_t status) { @@ -162,6 +169,54 @@ void hl_matrix_transpose(real *A_d, real *C_d, int dimM, int dimN) { hl_matrix_transpose(A_d, C_d, dimM, dimN, dimN, dimM); } +void hl_matrix_inverse(real *A_d, real *C_d, int dimN, int lda, int ldc) { + /* Solve Ax = I */ + CHECK_NOTNULL(A_d); + CHECK_NOTNULL(C_d); + + /* Step 1: Compute the LU decomposition of matrix A */ + real **inout_h = &A_d; + real **inout_d = (real **)hl_malloc_device(sizeof(real *)); + hl_memcpy(inout_d, inout_h, sizeof(real *)); + + int *pivot_d = (int *)hl_malloc_device(dimN*sizeof(int)); + int *info_d = (int *)t_resource.gpu_mem; + + /* Note: cublasSgetrfBatched is used to calculate a number of + small-sized matrices. There may be a better way to reconstruct + the API for better performance. + */ + CHECK_CUBLAS(CUBLAS_GETRF(t_resource.handle, + dimN, inout_d, lda, pivot_d, + info_d, 1)); + + int info_h; + hl_memcpy(&info_h, info_d, sizeof(int)); + if (info_h != 0) { + LOG(FATAL) << "Factorization of matrix failed: matrix may be singular.\n"; + } + + /* Step 2: Compute the inverse of the matrix given its LU decomposition */ + real **out_h = &C_d; + real **out_d = (real **)hl_malloc_device(sizeof(real *)); + hl_memcpy(out_d, out_h, sizeof(real *)); + + CHECK_CUBLAS(CUBLAS_GETRI(t_resource.handle, + dimN, (const real **)inout_d, lda, pivot_d, + out_d, ldc, info_d, 1)); + + hl_memcpy(&info_h, info_d, sizeof(int)); + if (info_h != 0) { + LOG(FATAL) << "Inversion of matrix failed: matrix may be singular.\n"; + } + + hl_free_mem_device(inout_d); + hl_free_mem_device(pivot_d); + hl_free_mem_device(out_d); + + CHECK_SYNC("hl_matrix_inverse failed"); +} + void hl_matrix_mul(real *A_d, hl_trans_op_t transa, real *B_d, hl_trans_op_t transb, real *C_d, diff --git a/paddle/math/MathFunctions.cpp b/paddle/math/MathFunctions.cpp index da493379e3..f813206647 100644 --- a/paddle/math/MathFunctions.cpp +++ b/paddle/math/MathFunctions.cpp @@ -39,6 +39,46 @@ void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, beta, C, ldc); } +template<> +int getrf(const CBLAS_ORDER order, const int M, const int N, + float *A, const int lda, int *ipiv) { +#ifdef PADDLE_USE_ATLAS + return clapack_sgetrf(order, M, N, A, lda, ipiv); +#else + return LAPACKE_sgetrf(order, M, N, A, lda, ipiv); +#endif +} + +template<> +int getrf(const CBLAS_ORDER order, const int M, const int N, + double *A, const int lda, int *ipiv) { +#ifdef PADDLE_USE_ATLAS + return clapack_dgetrf(order, M, N, A, lda, ipiv); +#else + return LAPACKE_dgetrf(order, M, N, A, lda, ipiv); +#endif +} + +template<> +int getri(const CBLAS_ORDER order, const int N, float *A, + const int lda, const int *ipiv) { +#ifdef PADDLE_USE_ATLAS + return clapack_sgetri(order, N, A, lda, ipiv); +#else + return LAPACKE_sgetri(order, N, A, lda, ipiv); +#endif +} + +template<> +int getri(const CBLAS_ORDER order, const int N, double *A, + const int lda, const int *ipiv) { +#ifdef PADDLE_USE_ATLAS + return clapack_dgetri(order, N, A, lda, ipiv); +#else + return LAPACKE_dgetri(order, N, A, lda, ipiv); +#endif +} + template<> void axpy(const int n, const float alpha, const float* x, float* y) { cblas_saxpy(n, alpha, x, 1, y, 1); diff --git a/paddle/math/MathFunctions.h b/paddle/math/MathFunctions.h index 43075977dc..cad0e4740b 100644 --- a/paddle/math/MathFunctions.h +++ b/paddle/math/MathFunctions.h @@ -21,6 +21,13 @@ limitations under the License. */ extern "C" { #include } +#ifdef PADDLE_USE_ATLAS +extern "C" { +#include +} +#else +#include +#endif #endif #include @@ -34,6 +41,14 @@ void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const T* B, const int ldb, const T beta, T* C, const int ldc); +template +int getrf(const CBLAS_ORDER Order, const int M, const int N, + T *A, const int lda, int *ipiv); + +template +int getri(const CBLAS_ORDER Order, const int N, T *A, + const int lda, const int *ipiv); + template void axpy(const int n, const T alpha, const T* x, T* y); diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index aaeae98f0d..d901ba9349 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -335,6 +335,30 @@ void GpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { hl_matrix_transpose(data, dataTrans, height_, width_, lda, ldc); } + +MatrixPtr GpuMatrix::getInverse() { + MatrixPtr matInv; + inverse(matInv, true); + return matInv; +} + +void GpuMatrix::inverse(MatrixPtr matInv, bool memAlloc) { + CHECK_EQ(height_, width_); + + if (memAlloc) { + matInv = std::make_shared(height_, width_); + } else { + CHECK(matInv != NULL); + } + + real* data = getData(); + real* dataInv = matInv->getData(); + int lda = getStride(); + int ldc = matInv->getStride(); + + hl_matrix_inverse(data, dataInv, height_, lda, ldc); +} + void GpuMatrix::addBias(Matrix& b, real scale) { CHECK(b.getHeight() == 1) << "the Bias should be a vector"; BaseMatrix::addBias(b, scale); @@ -1437,6 +1461,47 @@ void CpuMatrix::transpose(MatrixPtr matTrans, bool memAlloc) { } } + +MatrixPtr CpuMatrix::getInverse() { + MatrixPtr matInv; + inverse(matInv, true); + return matInv; +} + +void CpuMatrix::inverse(MatrixPtr matInv, bool memAlloc) { + CHECK_EQ(height_, width_); + + if (memAlloc) { + matInv = std::make_shared(height_, width_); + } else { + CHECK(matInv != NULL); + } + + CHECK_EQ(height_, matInv->getHeight()); + CHECK_EQ(width_, matInv->getWidth()); + matInv->copyFrom(*this); + + real* data = getData(); + real* dataInv = matInv->getData(); + int ldc = matInv->getStride(); + + if (height_ == 1) { + CHECK_NE(*data, 0); + *dataInv = 1.0 / (*data); + return; + } + + /* Compute the LU decomposition of the matrix */ + std::vector ipiv(height_); + CBLAS_ORDER order = (matInv->isTransposed() ? CblasColMajor : CblasRowMajor); + int info = getrf(order, height_, height_, dataInv, ldc, ipiv.data()); + CHECK_EQ(info, 0); + + /* Compute the inverse of the matrix given its LU decompsotion */ + info = getri(order, height_, dataInv, ldc, ipiv.data()); + CHECK_EQ(info, 0); +} + void CpuMatrix::convExpand(Matrix& feature, int feaImgHeight, int feaImgWidth, int channels, int blockH, int blockW, int strideH, int strideW, int paddingH, int paddingW, diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 52cbed528c..293d13f4d6 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -328,6 +328,20 @@ public: LOG(FATAL) << "Not implemented"; } + virtual MatrixPtr getInverse() { + LOG(FATAL) << "Not implemented"; + } + + /** + * @brief inverse. + * + * if allocate matInv's memory outside, then set memAlloc as false; + * else set as true. + */ + virtual void inverse(MatrixPtr matInv, bool memAlloc) { + LOG(FATAL) << "Not implemented"; + } + public: /// Only set all variables to 0 or NULL but not free them. virtual void clear() { @@ -1043,6 +1057,9 @@ public: MatrixPtr getTranspose(); void transpose(MatrixPtr matTrans, bool memAlloc); + MatrixPtr getInverse(); + void inverse(MatrixPtr matInv, bool memAlloc); + /// add b to each sample of this. void addBias(Matrix& b, real scale); void addSharedBias(Matrix& b, real scale); @@ -1282,6 +1299,9 @@ public: MatrixPtr getTranspose(); void transpose(MatrixPtr matTrans, bool memAlloc); + MatrixPtr getInverse(); + void inverse(MatrixPtr matInv, bool memAlloc); + void copyFrom(const Matrix& src); void copyFrom(const Matrix& src, hl_stream_t stream); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 0ddf7e0dfc..b887cccaaa 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -641,9 +641,32 @@ void testMatrixTranspose(int height, int width) { MatrixCheckEqual(*cpuT, *outputCheck); } +void testMatrixInverse(int height) { + MatrixPtr cpu = std::make_shared(height, height); + MatrixPtr gpu = std::make_shared(height, height); + MatrixPtr cpuI = std::make_shared(height, height); + MatrixPtr gpuI = std::make_shared(height, height); + + cpu->randomizeUniform(); + gpu->copyFrom(*cpu); + cpu->inverse(cpuI, false); + gpu->inverse(gpuI, false); + + MatrixPtr outputCheck = std::make_shared(height, height); + outputCheck->copyFrom(*gpuI); + MatrixCheckErr(*cpuI, *outputCheck); + + outputCheck->mul(cpu, cpuI); + cpu->zeroMem(); + for (int i = 0; i < height; i++) { + cpu->getRowBuf(i)[i] = 1.0; + } + MatrixCheckErr(*cpu, *outputCheck); +} + TEST(Matrix, unary) { - for (auto height : {1, 11, 73, 128, 200, 330}) { - for (auto width : {1, 32, 100, 512, 1000, 3210}) { + for (auto height : {1, 3, 11, 73, 128, 200, 330}) { + for (auto width : {1, 3, 32, 100, 512, 1000, 3210}) { VLOG(3) << " height=" << height << " width=" << width; // applyUnary @@ -675,6 +698,8 @@ TEST(Matrix, unary) { // transpose testMatrixTranspose(height, width); } + // inverse + testMatrixInverse(height); } } -- GitLab