From ecfddebbef8331d2503384c15f8398030a7a73e8 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Mon, 27 Apr 2020 13:21:32 +0800 Subject: [PATCH] Add the implementation of inverse (#23310) --- paddle/fluid/framework/tensor_util.cc | 19 +-- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/inverse_op.cc | 129 ++++++++++++++++ paddle/fluid/operators/inverse_op.cu.cc | 25 +++ paddle/fluid/operators/inverse_op.h | 70 +++++++++ paddle/fluid/operators/math/CMakeLists.txt | 4 + paddle/fluid/operators/math/blas.h | 30 ++++ paddle/fluid/operators/math/blas_impl.cu.h | 74 +++++++++ paddle/fluid/operators/math/matrix_inverse.cc | 62 ++++++++ .../fluid/operators/math/matrix_inverse.cu.cc | 102 +++++++++++++ paddle/fluid/operators/math/matrix_inverse.h | 34 +++++ paddle/fluid/platform/dynload/cublas.h | 4 +- paddle/fluid/platform/dynload/mklml.h | 3 + python/paddle/__init__.py | 2 +- .../paddle/fluid/tests/unittests/op_test.py | 6 +- .../fluid/tests/unittests/test_inverse_op.py | 144 ++++++++++++++++++ .../white_list/op_threshold_white_list.py | 3 +- python/paddle/tensor/__init__.py | 2 +- python/paddle/tensor/math.py | 75 ++++++++- 19 files changed, 772 insertions(+), 18 deletions(-) create mode 100644 paddle/fluid/operators/inverse_op.cc create mode 100644 paddle/fluid/operators/inverse_op.cu.cc create mode 100644 paddle/fluid/operators/inverse_op.h create mode 100644 paddle/fluid/operators/math/matrix_inverse.cc create mode 100644 paddle/fluid/operators/math/matrix_inverse.cu.cc create mode 100644 paddle/fluid/operators/math/matrix_inverse.h create mode 100644 python/paddle/fluid/tests/unittests/test_inverse_op.py diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 8b2e3fc323c..b0828653313 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -1,16 +1,17 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include #include diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 596eb99e813..7d368966eba 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -85,7 +85,7 @@ endif() set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor device_memory_aligment) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper) if (WITH_GPU) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu bert_encoder_functor) diff --git a/paddle/fluid/operators/inverse_op.cc b/paddle/fluid/operators/inverse_op.cc new file mode 100644 index 00000000000..e73e74ab856 --- /dev/null +++ b/paddle/fluid/operators/inverse_op.cc @@ -0,0 +1,129 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/inverse_op.h" +#include +#include + +namespace paddle { +namespace operators { + +class InverseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Inverse"); + OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "Inverse"); + + auto input_dims = ctx->GetInputDim("Input"); + int64_t input_rank = input_dims.size(); + PADDLE_ENFORCE_GE( + input_rank, 2, + platform::errors::InvalidArgument( + "The dimension of Input(Input) is expected to be no less than 2. " + "But recieved: Input(Input)'s dimension = %d, shape = [%s].", + input_rank, input_dims)); + if (input_dims[input_rank - 2] > 0 && input_dims[input_rank - 1] > 0) { + PADDLE_ENFORCE_EQ(input_dims[input_rank - 2], input_dims[input_rank - 1], + platform::errors::InvalidArgument( + "The last two dimensions are expected to be equal. " + "But recieved: %d and %d; " + "Input(Input)'s shape = [%s].", + input_dims[input_rank - 2], + input_dims[input_rank - 1], input_dims)); + } + + ctx->SetOutputDim("Output", input_dims); + ctx->ShareLoD("Input", /*->*/ "Output"); + } +}; + +class InverseOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{ + {"Input", /*->*/ "Output"}}; + return m; + } +}; + +class InverseGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + auto input_grad = framework::GradVarName("Input"); + auto output_grad = framework::GradVarName("Output"); + + OP_INOUT_CHECK(ctx->HasInput("Output"), "Input", "Output", "InverseGrad"); + OP_INOUT_CHECK(ctx->HasInput(output_grad), "Input", output_grad, + "InverseGrad"); + + if (ctx->HasOutput(input_grad)) { + ctx->SetOutputDim(input_grad, ctx->GetInputDim(output_grad)); + } + } +}; + +class InverseOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + "Input", + "(Tensor) A square matrix (2-D Tensor) or batches of square matrices" + " to inverse."); + AddOutput("Output", "(Tensor) The inverse of input matrix."); + AddComment(R"DOC( +Inverse Operator + +Takes the inverse of the square matrix. +)DOC"); + } +}; + +template +class InverseGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad) const override { + grad->SetType(this->ForwardOpType() + "_grad"); + grad->SetInput("Output", this->Output("Output")); + grad->SetInput(framework::GradVarName("Output"), + this->OutputGrad("Output")); + grad->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(inverse, ops::InverseOp, ops::InverseOpMaker, + ops::InverseOpInferVarType, + ops::InverseGradOpMaker, + ops::InverseGradOpMaker); + +REGISTER_OPERATOR(inverse_grad, ops::InverseGradOp); + +REGISTER_OP_CPU_KERNEL( + inverse, ops::InverseKernel, + ops::InverseKernel); +REGISTER_OP_CPU_KERNEL( + inverse_grad, + ops::InverseGradKernel, + ops::InverseGradKernel); diff --git a/paddle/fluid/operators/inverse_op.cu.cc b/paddle/fluid/operators/inverse_op.cu.cc new file mode 100644 index 00000000000..2ca61159f3a --- /dev/null +++ b/paddle/fluid/operators/inverse_op.cu.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/inverse_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + inverse, ops::InverseKernel, + ops::InverseKernel); +REGISTER_OP_CUDA_KERNEL( + inverse_grad, + ops::InverseGradKernel, + ops::InverseGradKernel); diff --git a/paddle/fluid/operators/inverse_op.h b/paddle/fluid/operators/inverse_op.h new file mode 100644 index 00000000000..c1859a26f36 --- /dev/null +++ b/paddle/fluid/operators/inverse_op.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/matrix_inverse.h" + +namespace paddle { +namespace operators { + +template +class InverseKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = context.Input("Input"); + auto* output = context.Output("Output"); + output->mutable_data(context.GetPlace()); + + auto& dev_ctx = context.template device_context(); + math::MatrixInverseFunctor mat_inv; + mat_inv(dev_ctx, *input, output); + } +}; + +template +class InverseGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* a_inv = context.Input("Output"); + auto* a_inv_grad = + context.Input(framework::GradVarName("Output")); + auto* a_grad = + context.Output(framework::GradVarName("Input")); + + if (a_grad) { + a_grad->mutable_data(context.GetPlace()); + + auto blas = math::GetBlas(context); + auto& dev_ctx = context.template device_context(); + framework::Tensor tmp_out = + context.AllocateTmpTensor(a_inv->dims(), dev_ctx); + + auto mat_dim_a0 = + math::CreateMatrixDescriptor(a_inv_grad->dims(), 0, false); + auto mat_dim_b0 = math::CreateMatrixDescriptor(a_inv->dims(), 0, true); + blas.MatMul(*a_inv_grad, mat_dim_a0, *a_inv, mat_dim_b0, T(1), &tmp_out, + T(0)); + + auto mat_dim_a1 = math::CreateMatrixDescriptor(a_inv->dims(), 0, true); + auto mat_dim_b1 = math::CreateMatrixDescriptor(tmp_out.dims(), 0, false); + blas.MatMul(*a_inv, mat_dim_a1, tmp_out, mat_dim_b1, T(-1), a_grad, T(0)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 77a8aa5d9c7..3a19c7edff3 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -21,6 +21,9 @@ function(math_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) list(APPEND cu_srcs ${TARGET}.cu) endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc) + list(APPEND cu_srcs ${TARGET}.cu.cc) + endif() if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu) list(APPEND hip_srcs ${TARGET}.hip.cu) endif() @@ -68,6 +71,7 @@ math_library(vol2col) math_library(prelu) math_library(bert_encoder_functor) math_library(tree2col DEPS math_function) +math_library(matrix_inverse) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 5a96e6bb4a1..f8c971954fc 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -223,6 +223,19 @@ class Blas { CBLAS_DIAG diag, int M, int N, T alpha, const T* A, int lda, T* B, int ldb) const; +#ifdef PADDLE_WITH_CUDA + template + void BatchedGETRF(int n, T** a, int* ipiv, int* info, int batch_size) const; + + template + void BatchedGETRI(int n, const T** a, const int* ipiv, T** a_inv, int* info, + int batch_size) const; + + template + void BatchedMatInv(int n, const T** a, T** a_inv, int* info, + int batch_size) const; +#endif + private: const DeviceContext& context_; }; @@ -361,6 +374,23 @@ class BlasT : private Blas { Base()->template TRSM(args...); } +#ifdef PADDLE_WITH_CUDA + template + void BatchedGETRF(ARGS... args) const { + Base()->template BatchedGETRF(args...); + } + + template + void BatchedGETRI(ARGS... args) const { + Base()->template BatchedGETRI(args...); + } + + template + void BatchedMatInv(ARGS... args) const { + Base()->template BatchedMatInv(args...); + } +#endif + private: const Blas* Base() const { return static_cast*>(this); diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index e7720a97699..39bddda6caa 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -93,6 +93,24 @@ struct CUBlas { static void TRSM(ARGS... args) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasStrsm(args...)); } + + template + static void GETRF_BATCH(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasSgetrfBatched(args...)); + } + + template + static void GETRI_BATCH(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasSgetriBatched(args...)); + } + + template + static void MATINV_BATCH(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasSmatinvBatched(args...)); + } }; template <> @@ -141,6 +159,24 @@ struct CUBlas { static void TRSM(ARGS... args) { PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDtrsm(args...)); } + + template + static void GETRF_BATCH(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasDgetrfBatched(args...)); + } + + template + static void GETRI_BATCH(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasDgetriBatched(args...)); + } + + template + static void MATINV_BATCH(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasDmatinvBatched(args...)); + } }; template <> @@ -446,6 +482,44 @@ void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, }); } +template <> +template +void Blas::BatchedGETRF(int n, T **a, int *ipiv, + int *info, + int batch_size) const { + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRF_BATCH(handle, n, a, n, ipiv, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedGETRI(int n, const T **a, + const int *ipiv, T **a_inv, + int *info, + int batch_size) const { + PADDLE_ENFORCE_NE( + a_inv, a, + platform::errors::InvalidArgument( + "cuBLAS fuction 'cublasgetrfBatched' cannot be executed " + "in-place. The memory space of output matrix (address: %p) cannot " + "overlap memory space of input matrix (address: %p).", + a_inv, a)); + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRI_BATCH(handle, n, a, n, ipiv, a_inv, n, info, batch_size); + }); +} + +template <> +template +void Blas::BatchedMatInv(int n, const T **a, + T **a_inv, int *info, + int batch_size) const { + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::MATINV_BATCH(handle, n, a, n, a_inv, n, info, batch_size); + }); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_inverse.cc b/paddle/fluid/operators/math/matrix_inverse.cc new file mode 100644 index 00000000000..25bc5d725e1 --- /dev/null +++ b/paddle/fluid/operators/math/matrix_inverse.cc @@ -0,0 +1,62 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/matrix_inverse.h" +#include "Eigen/Core" +#include "Eigen/LU" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class MatrixInverseFunctor { + using Matrix = + Eigen::Matrix; + using EigenMatrixMap = Eigen::Map; + using ConstEigenMatrixMap = Eigen::Map; + + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor& a, framework::Tensor* a_inv) { + const auto& mat_dims = a.dims(); + const int rank = mat_dims.size(); + int n = mat_dims[rank - 1]; + int batch_size = rank > 2 ? a.numel() / (n * n) : 1; + + const T* a_ptr = a.data(); + T* a_inv_ptr = a_inv->mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; ++i) { + ConstEigenMatrixMap mat(a_ptr + i * n * n, n, n); + EigenMatrixMap mat_inv(a_inv_ptr + i * n * n, n, n); + Eigen::PartialPivLU lu; + lu.compute(mat); + + const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff(); + PADDLE_ENFORCE_GT( + min_abs_pivot, static_cast(0), + platform::errors::InvalidArgument("Input is not invertible.")); + mat_inv.noalias() = lu.inverse(); + } + } +}; + +template class MatrixInverseFunctor; +template class MatrixInverseFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_inverse.cu.cc b/paddle/fluid/operators/math/matrix_inverse.cu.cc new file mode 100644 index 00000000000..8ea4e582ad1 --- /dev/null +++ b/paddle/fluid/operators/math/matrix_inverse.cu.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/matrix_inverse.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class MatrixInverseFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& a, framework::Tensor* a_inv) { + const auto& mat_dims = a.dims(); + const int rank = mat_dims.size(); + int n = mat_dims[rank - 1]; + int batch_size = rank > 2 ? a.numel() / (n * n) : 1; + + memory::allocation::AllocationPtr tmp_gpu_mat_data; + const T* gpu_mat = a.data(); + if (n >= 32) { + // Copy all elements of input matrix A to a temporary memory space to + // avoid being overriden by getrf. + tmp_gpu_mat_data = memory::Alloc(context, a.numel() * sizeof(T)); + memory::Copy(boost::get(context.GetPlace()), + tmp_gpu_mat_data->ptr(), + boost::get(context.GetPlace()), + a.data(), a.numel() * sizeof(T), context.stream()); + gpu_mat = reinterpret_cast(tmp_gpu_mat_data->ptr()); + } + + std::vector cpu_ptrs(batch_size * 2); + for (int i = 0; i < batch_size; ++i) { + cpu_ptrs[i] = gpu_mat + i * n * n; + cpu_ptrs[i + batch_size] = a_inv->data() + i * n * n; + } + + // Copy the addresses of A and A_inv from host to device. + memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + memory::Alloc(context, cpu_ptrs.size() * sizeof(T*)); + memory::Copy(boost::get(context.GetPlace()), + tmp_gpu_ptrs_data->ptr(), platform::CPUPlace(), + static_cast(cpu_ptrs.data()), + cpu_ptrs.size() * sizeof(T*), context.stream()); + T** gpu_inv_ptrs = + reinterpret_cast(tmp_gpu_ptrs_data->ptr()) + batch_size; + + // Allocate device memory for info and pivots. + int num_ints = n < 32 ? batch_size : batch_size * (n + 1); + memory::allocation::AllocationPtr tmp_gpu_info_data = + memory::Alloc(context, num_ints * sizeof(int)); + int* gpu_info_ptr = reinterpret_cast(tmp_gpu_info_data->ptr()); + + auto blas = math::GetBlas(context); + + // This functions in cuBLAS is intended to be used for matrices of small + // sizes where the launch overhead is a significant factor. + // TODO(Xreki): call function in cusolver for large matrices. + if (n < 32) { + // cublasmatinvBatched is a short cut of cublasgetrfBatched + // plus cublasgetriBatched. + // However it only works if N is less than 32. If not, we need to + // go through cublasgetrfBatched and cublasgetriBatched. + blas.BatchedMatInv(n, + reinterpret_cast(tmp_gpu_ptrs_data->ptr()), + gpu_inv_ptrs, gpu_info_ptr, batch_size); + } else { + // This function performs the LU factorization of each matrix A by the + // equation P * A = L * U. L and U are written back to original matrix A, + // and diagonal elements of L are discarded. + int* gpu_pivot_ptr = + reinterpret_cast(tmp_gpu_info_data->ptr()) + batch_size; + blas.BatchedGETRF(n, reinterpret_cast(tmp_gpu_ptrs_data->ptr()), + gpu_pivot_ptr, gpu_info_ptr, batch_size); + + blas.BatchedGETRI(n, + reinterpret_cast(tmp_gpu_ptrs_data->ptr()), + gpu_pivot_ptr, gpu_inv_ptrs, gpu_info_ptr, batch_size); + } + } +}; + +template class MatrixInverseFunctor; +template class MatrixInverseFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_inverse.h b/paddle/fluid/operators/math/matrix_inverse.h new file mode 100644 index 00000000000..f0baf0b250e --- /dev/null +++ b/paddle/fluid/operators/math/matrix_inverse.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class MatrixInverseFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& a, + framework::Tensor* a_inv); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index 141de2881d3..937e200924b 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -90,7 +90,9 @@ extern void *cublas_dso_handle; __macro(cublasSgetrfBatched); \ __macro(cublasSgetriBatched); \ __macro(cublasDgetrfBatched); \ - __macro(cublasDgetriBatched); + __macro(cublasDgetriBatched); \ + __macro(cublasSmatinvBatched); \ + __macro(cublasDmatinvBatched); CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 914d04e0486..2be95b113b2 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -1,8 +1,11 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 65984c600d3..897967b0c14 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -167,7 +167,7 @@ from .tensor.math import div #DEFINE_ALIAS from .tensor.math import add #DEFINE_ALIAS from .tensor.math import atan #DEFINE_ALIAS from .tensor.math import logsumexp #DEFINE_ALIAS -# from .tensor.math import inverse #DEFINE_ALIAS +from .tensor.math import inverse #DEFINE_ALIAS from .tensor.math import log1p #DEFINE_ALIAS from .tensor.math import erf #DEFINE_ALIAS from .tensor.math import addcmul #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index f13a4ceee69..1ef63a28bc6 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -1221,9 +1221,9 @@ class OpTest(unittest.TestCase): def err_msg(): offset = np.argmax(diff_mat > max_relative_error) - return ("%s error, %s variable %s max gradient diff %f over limit %f, " - "the first error element is %d, expected %f, but got %f.") \ - % (self.op_type, msg_prefix, name, max_diff, max_relative_error, + return ("Operator %s error, %s variable %s (shape: %s, dtype: %s) max gradient diff %e over limit %e, " + "the first error element is %d, expected %e, but got %e.") \ + % (self.op_type, msg_prefix, name, str(a.shape), self.dtype, max_diff, max_relative_error, offset, a.flatten()[offset], b.flatten()[offset]) self.assertLessEqual(max_diff, max_relative_error, err_msg()) diff --git a/python/paddle/fluid/tests/unittests/test_inverse_op.py b/python/paddle/fluid/tests/unittests/test_inverse_op.py new file mode 100644 index 00000000000..13cb2b1f8b1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_inverse_op.py @@ -0,0 +1,144 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.core as core +import paddle +from op_test import OpTest + + +class TestInverseOp(OpTest): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float64" + + def setUp(self): + self.op_type = "inverse" + self.config() + + np.random.seed(123) + mat = np.random.random(self.matrix_shape).astype(self.dtype) + inverse = np.linalg.inv(mat) + + self.inputs = {'Input': mat} + self.outputs = {'Output': inverse} + + def test_check_output(self): + self.check_output() + + def test_grad(self): + self.check_grad(['Input'], 'Output') + + +class TestInverseOpBatched(TestInverseOp): + def config(self): + self.matrix_shape = [8, 4, 4] + self.dtype = "float64" + + +class TestInverseOpLarge(TestInverseOp): + def config(self): + self.matrix_shape = [32, 32] + self.dtype = "float64" + + def test_grad(self): + self.check_grad(['Input'], 'Output', max_relative_error=1e-6) + + +class TestInverseOpFP32(TestInverseOp): + def config(self): + self.matrix_shape = [10, 10] + self.dtype = "float32" + + def test_grad(self): + self.check_grad(['Input'], 'Output', max_relative_error=1e-2) + + +class TestInverseOpBatchedFP32(TestInverseOpFP32): + def config(self): + self.matrix_shape = [8, 4, 4] + self.dtype = "float32" + + +class TestInverseOpLargeFP32(TestInverseOpFP32): + def config(self): + self.matrix_shape = [32, 32] + self.dtype = "float32" + + +class TestInverseAPI(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place, with_out=False): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input = fluid.data(name="input", shape=[4, 4], dtype="float64") + if with_out: + out = fluid.data(name="output", shape=[4, 4], dtype="float64") + else: + out = None + result = paddle.inverse(input=input, out=out) + + input_np = np.random.random([4, 4]).astype("float64") + result_np = np.linalg.inv(input_np) + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input": input_np}, + fetch_list=[result]) + self.assertTrue(np.allclose(fetches[0], np.linalg.inv(input_np))) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_np = np.random.random([4, 4]).astype("float64") + input = fluid.dygraph.to_variable(input_np) + result = paddle.inverse(input) + self.assertTrue( + np.allclose(result.numpy(), np.linalg.inv(input_np))) + + +class TestInverseAPIError(unittest.TestCase): + def test_errors(self): + input_np = np.random.random([4, 4]).astype("float64") + + # input must be Variable. + self.assertRaises(TypeError, paddle.inverse, input_np) + + # The data type of input must be float32 or float64. + for dtype in ["bool", "int32", "int64", "float16"]: + input = fluid.data(name='input_' + dtype, shape=[4, 4], dtype=dtype) + self.assertRaises(TypeError, paddle.inverse, input) + + # When out is set, the data type must be the same as input. + input = fluid.data(name='input_1', shape=[4, 4], dtype="float32") + out = fluid.data(name='output', shape=[4, 4], dtype="float64") + self.assertRaises(TypeError, paddle.inverse, input, out) + + # The number of dimensions of input must be >= 2. + input = fluid.data(name='input_2', shape=[4], dtype="float32") + self.assertRaises(ValueError, paddle.inverse, input) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index b83f21e26fd..fd3d5f3104f 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -39,7 +39,8 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ 'spp', \ 'teacher_student_sigmoid_loss', \ 'unpool', \ - 'yolov3_loss' + 'yolov3_loss', \ + 'inverse' ] NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp'] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 2232f252bcf..a3bbc4879e4 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -145,7 +145,7 @@ from .math import div #DEFINE_ALIAS from .math import add #DEFINE_ALIAS from .math import atan #DEFINE_ALIAS from .math import logsumexp #DEFINE_ALIAS -# from .math import inverse #DEFINE_ALIAS +from .math import inverse #DEFINE_ALIAS from .math import log1p #DEFINE_ALIAS from .math import erf #DEFINE_ALIAS from .math import addcmul #DEFINE_ALIAS diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 4c10063237d..b099c7f8653 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -105,7 +105,7 @@ __all__ = [ 'add', 'atan', 'logsumexp', -# 'inverse', + 'inverse', 'log1p', 'erf', 'addcmul', @@ -986,6 +986,7 @@ def mm(input, mat2, out=None, name=None): 'Y': mat2}, outputs={'Out': out}) return out + def addmm(input, x, y, alpha=1.0, beta=1.0, name=None): """ **addmm** @@ -1120,6 +1121,78 @@ def logsumexp(x, dim=None, keepdim=False, out=None, name=None): return layers.log(sum_out, name) +def inverse(input, out=None, name=None): + """ + Takes the inverse of the square matrix. A square matrix is a matrix with + the same number of rows and columns. The input can be a square matrix + (2-D Tensor) or batches of square matrices. + + Args: + input (Variable): The input Variable which holds a Tensor. The last two + dimensions should be equal. When the number of dimensions is + greater than 2, it is treated as batches of square matrix. The data + type can be float32 and float64. + out (Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + If out is None, a new Varibale will be create to store the result. + name (str, optional): The default value is None. Normally there is no need for + user to set this property. For more information, + please refer to :ref:`api_guide_Name` + + Returns: + Variable: A Tensor holds the inverse of input. The shape and data type + is the same as input. + + Examples: + .. code-block:: python + + import numpy as np + import paddle + import paddle.fluid as fluid + + mat_np = np.array([[2, 0], [0, 2]]).astype("float32") + + # example for static graph + input = fluid.data("input", shape=[2, 2], dtype="float32") + out = paddle.inverse(input) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + results = exe.run(feed={"input": mat_np }, + fetch_list=[out.name]) + print(results[0]) # [[0.5, 0], [0, 0.5]] + + # example for dynamic graph + with fluid.dygraph.guard(): + mat = fluid.dygraph.to_variable(mat_np) + inv = paddle.inverse(mat) + print(inv) # [[0.5, 0], [0, 0.5]] + """ + if in_dygraph_mode(): + return core.ops.inverse(input) + + def _check_input(input): + check_variable_and_dtype(input, 'input', + ['float32', 'float64'], 'inverse') + if len(input.shape) < 2: + raise ValueError( + "The input of inverse is expected to be a Tensor whose number " + "of dimensions is no less than 2. But reviced: %d, " + "input's shape: %s." % (len(input.shape), input.shape)) + + if out is not None: + check_variable_and_dtype(out, 'out', input.dtype, 'inverse') + + _check_input(input) + + helper = LayerHelper('inverse', **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=input.dtype) + helper.append_op( + type='inverse', inputs={'Input': [input] }, outputs={'Output': [out]}) + return out + + def max(input, dim=None, keep_dim=False, out=None, name=None): """ Computes the maximum of tensor elements over the given dimension. -- GitLab