From a8c0fb4e866d58b176c7367242bd307e0d07b023 Mon Sep 17 00:00:00 2001 From: Guo Sheng Date: Fri, 24 Apr 2020 20:55:23 +0800 Subject: [PATCH] Add cholesky_op (#23543) * Add cholesky_op forward part. test=develop * Complete cholesky_op forward part. test=develop * Add cholesky_op backward part. test=develop * Complete cholesky_op backward part. test=develop * Refine cholesky_op error check and docs. test=develop * Add grad_check unit test for cholesky_op. test=develop * Fix sample code in cholesky doc. test=develop * Refine some error messages of cholesky_op. test=develop * Refine some error messages of cholesky_op. test=develop * Remove unused input in cholesky_grad. test=develop * Remove unused input in cholesky_grad. test=develop * Fix stream for cusolverDnSetStream. test=develop * Update PADDLE_ENFORCE_CUDA_SUCCESS from cholesky_op to adapt to latest code. test=develop * Add CUSOLVER ERROR in enforce.h test=develop * Fix the missing return value in cholesky. test=develop --- paddle/fluid/operators/cholesky_op.cc | 121 ++++++ paddle/fluid/operators/cholesky_op.cu | 153 +++++++ paddle/fluid/operators/cholesky_op.h | 372 ++++++++++++++++++ paddle/fluid/operators/math/blas.h | 10 + paddle/fluid/operators/math/blas_impl.cu.h | 35 ++ paddle/fluid/operators/math/blas_impl.h | 31 ++ paddle/fluid/operators/nll_loss_op.cu | 1 + paddle/fluid/platform/device_context.cc | 6 + paddle/fluid/platform/device_context.h | 22 ++ paddle/fluid/platform/dynload/CMakeLists.txt | 2 +- paddle/fluid/platform/dynload/cublas.h | 2 + paddle/fluid/platform/dynload/cusolver.cc | 30 ++ paddle/fluid/platform/dynload/cusolver.h | 75 ++++ .../fluid/platform/dynload/dynamic_loader.cc | 10 + .../fluid/platform/dynload/dynamic_loader.h | 1 + paddle/fluid/platform/dynload/mklml.h | 2 + paddle/fluid/platform/enforce.h | 40 ++ python/paddle/__init__.py | 2 +- .../fluid/tests/unittests/test_cholesky_op.py | 94 +++++ python/paddle/tensor/linalg.py | 55 ++- 20 files changed, 1061 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/operators/cholesky_op.cc create mode 100644 paddle/fluid/operators/cholesky_op.cu create mode 100644 paddle/fluid/operators/cholesky_op.h create mode 100644 paddle/fluid/platform/dynload/cusolver.cc create mode 100644 paddle/fluid/platform/dynload/cusolver.h create mode 100644 python/paddle/fluid/tests/unittests/test_cholesky_op.py diff --git a/paddle/fluid/operators/cholesky_op.cc b/paddle/fluid/operators/cholesky_op.cc new file mode 100644 index 00000000000..0902f5b6bc9 --- /dev/null +++ b/paddle/fluid/operators/cholesky_op.cc @@ -0,0 +1,121 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cholesky_op.h" + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class CholeskyOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Cholesky"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Cholesky"); + auto dims = ctx->GetInputDim("X"); + auto rank = dims.size(); + PADDLE_ENFORCE_GE(rank, 2, + platform::errors::InvalidArgument( + "The Input(X) should have at least 2 dimensions. But " + "received a %d dimension tensor.", + rank)); + PADDLE_ENFORCE_EQ( + dims[rank - 2], dims[rank - 1], + platform::errors::InvalidArgument( + "The inner-most 2 dimensions of Input(X) all should be symmetric " + "positive-definite matrices and have the same size. But received " + "X's shape[-2] = %d and shape[-1] = %d.", + dims[rank - 2], dims[rank - 1])); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + } +}; + +class CholeskyOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), The input tensor of cholesky op. Its shape should be " + "[*, M, M] where * is zero or more batch dimensions, and matrices " + "on the inner-most 2 dimensions all should be symmetric " + "positive-definite."); + AddOutput("Out", + "(Tensor), The output tensor of cholesky op. It has the same " + "shape as the input, and it is composed of upper-triangular or " + "lower-triangular Cholesky factors of each of the individual " + "matrices."); + AddAttr("upper", + "(bool, default false), flag indicating whether to return " + "upper or lower triangular matrices. Default: False") + .SetDefault(false); + AddComment(R"DOC( +Cholesky Operator. + +Computes the Cholesky decomposition of one symmetric positive-definite matrix +or batches of symmetric positive-definite matrices. + +)DOC"); + } +}; + +class CholeskyGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "CholeskyGrad"); + OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Out")), "Input", + "Out@GRAD", "CholeskyGrad"); + auto dims = ctx->GetInputDim("Out"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, dims); + } + } +}; + +template +class CholeskyGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType(this->ForwardOpType() + "_grad"); + op->SetInput("Out", this->Output("Out")); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(cholesky, ops::CholeskyOp, ops::CholeskyOpMaker, + ops::CholeskyGradOpMaker, + ops::CholeskyGradOpMaker); +REGISTER_OPERATOR(cholesky_grad, ops::CholeskyGradOp); + +REGISTER_OP_CPU_KERNEL(cholesky, ops::CholeskyCPUKernel, + ops::CholeskyCPUKernel); + +REGISTER_OP_CPU_KERNEL( + cholesky_grad, + ops::CholeskyGradKernel, + ops::CholeskyGradKernel); diff --git a/paddle/fluid/operators/cholesky_op.cu b/paddle/fluid/operators/cholesky_op.cu new file mode 100644 index 00000000000..dabb2d2567e --- /dev/null +++ b/paddle/fluid/operators/cholesky_op.cu @@ -0,0 +1,153 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/cholesky_op.h" +#include "paddle/fluid/platform/dynload/cusolver.h" + +namespace paddle { +namespace operators { + +template +class CholeskyGPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto& dev_ctx = + context.template device_context(); + + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + + bool upper = context.Attr("upper"); + auto& dims = x->dims(); + int batch_count = 1; + for (int i = 0; i < dims.size() - 2; i++) { + batch_count *= dims[i]; + } + int m = dims[dims.size() - 1]; + int tensor_size = batch_count * m * m; + + const auto* x_data = x->data(); + auto* out_data = out->mutable_data(context.GetPlace()); + + // matrices are assumed to be stored in column-major order in cusolver + cublasFillMode_t uplo = + upper ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; + // portf is inplace, thus copy the triangular part of the input matrices to + // the output and set the other triangular part to 0 firstly + platform::ForRange for_range(dev_ctx, + tensor_size); + if (upper) { + MatrixBandPartFunctor matrix_band_part_functor( + m, m, /* num_lower_diags */ 0, /* num_upper_diags */ m, x_data, + out_data); + for_range(matrix_band_part_functor); + } else { + MatrixBandPartFunctor matrix_band_part_functor( + m, m, /* num_lower_diags */ m, /* num_upper_diags */ 0, x_data, + out_data); + for_range(matrix_band_part_functor); + } + + // TODO(guosheng): Add callback to check info + auto info = memory::Alloc(dev_ctx, sizeof(int) * batch_count); + auto* info_ptr = reinterpret_cast(info->ptr()); + +#if CUDA_VERSION >= 9020 + if (batch_count > 1) { + std::vector output_ptrs; + for (int i = 0; i < batch_count; i++) { + output_ptrs.emplace_back(out_data + i * m * m); + } + thrust::device_vector dev_output_ptrs(output_ptrs.begin(), + output_ptrs.end()); + PotrfBatched(dev_ctx, uplo, m, + thrust::raw_pointer_cast(dev_output_ptrs.data()), m, + info_ptr, batch_count); + // TODO(guosheng): There seems to a bug in cusolver potrfBatched and need + // to clear the upper triangle of the output. Remove this workaround once + // the bug is fixed. + if (!upper) { + MatrixBandPartFunctor matrix_band_part_functor( + m, m, /* num_lower_diags */ m, /* num_upper_diags */ 0, out_data, + out_data); + for_range(matrix_band_part_functor); + } + } else { +#endif + for (int i = 0; i < batch_count; i++) { + Potrf(dev_ctx, uplo, m, out_data + i * m * m, m, info_ptr + i); + } + +#if CUDA_VERSION >= 9020 + } +#endif + } + + void Potrf(const platform::CUDADeviceContext& dev_ctx, cublasFillMode_t uplo, + int n, T* A, int lda, int* info) const; + + void PotrfBatched(const platform::CUDADeviceContext& dev_ctx, + cublasFillMode_t uplo, int n, T* Aarray[], int lda, + int* info_array, int batch_size) const; +}; + +#define FUNC_WITH_TYPES(m) m(float, S) m(double, D) + +#define POTRF_INSTANCE(T, C) \ + template <> \ + void CholeskyGPUKernel::Potrf(const platform::CUDADeviceContext& dev_ctx, \ + cublasFillMode_t uplo, int n, T* A, \ + int lda, int* info) const { \ + auto handle = dev_ctx.cusolver_dn_handle(); \ + int workspace_size = 0; \ + PADDLE_ENFORCE_CUDA_SUCCESS( \ + platform::dynload::cusolverDn##C##potrf_bufferSize( \ + handle, uplo, n, A, lda, &workspace_size)); \ + auto workspace = memory::Alloc(dev_ctx, workspace_size); \ + T* workspace_ptr = reinterpret_cast(workspace->ptr()); \ + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cusolverDn##C##potrf( \ + handle, uplo, n, A, lda, workspace_ptr, workspace_size, info)); \ + } + +FUNC_WITH_TYPES(POTRF_INSTANCE); + +#if CUDA_VERSION >= 9020 +#define POTRF_BATCH_INSTANCE(T, C) \ + template <> \ + void CholeskyGPUKernel::PotrfBatched( \ + const platform::CUDADeviceContext& dev_ctx, cublasFillMode_t uplo, \ + int n, T* Aarray[], int lda, int* info_array, int batch_size) const { \ + auto handle = dev_ctx.cusolver_dn_handle(); \ + PADDLE_ENFORCE_CUDA_SUCCESS( \ + platform::dynload::cusolverDn##C##potrfBatched( \ + handle, uplo, n, Aarray, lda, info_array, batch_size)); \ + } + +FUNC_WITH_TYPES(POTRF_BATCH_INSTANCE); +#endif + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(cholesky, ops::CholeskyGPUKernel, + ops::CholeskyGPUKernel); +REGISTER_OP_CUDA_KERNEL( + cholesky_grad, + ops::CholeskyGradKernel, + ops::CholeskyGradKernel); diff --git a/paddle/fluid/operators/cholesky_op.h b/paddle/fluid/operators/cholesky_op.h new file mode 100644 index 00000000000..b0280b00ecf --- /dev/null +++ b/paddle/fluid/operators/cholesky_op.h @@ -0,0 +1,372 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "Eigen/Cholesky" +#include "Eigen/Core" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class CholeskyCPUKernel : public framework::OpKernel { + public: + // different with EigenMatrix in framework/eigen.h + using EigenMatrix = + Eigen::Matrix; + using InputMatrixMap = Eigen::Map; + using OutputMatrixMap = Eigen::Map; + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* x = context.Input("X"); + Tensor* out = context.Output("Out"); + + bool upper = context.Attr("upper"); + auto& dims = x->dims(); + int batch_count = 1; + for (int i = 0; i < dims.size() - 2; i++) { + batch_count *= dims[i]; + } + auto m = dims[dims.size() - 1]; + + const auto* x_data = x->data(); + auto* out_data = out->mutable_data(context.GetPlace()); + // Cholesky decomposition for each matrix, maybe can use multi threads + for (int i = 0; i < batch_count; i++) { + auto input = InputMatrixMap(x_data + i * m * m, m, m); + auto output = OutputMatrixMap(out_data + i * m * m, m, m); + if (upper) { + Eigen::LLT< + Eigen::Matrix, + Eigen::UpLoType::Upper> + llt_decomposition(input); + PADDLE_ENFORCE_EQ( + llt_decomposition.info(), Eigen::Success, + platform::errors::InvalidArgument( + "Cholesky decomposition was not successful. The input matrice " + "might not be not be positive definite.")); + output = llt_decomposition.matrixU(); + } else { + Eigen::LLT< + Eigen::Matrix, + Eigen::UpLoType::Lower> + llt_decomposition(input); + PADDLE_ENFORCE_EQ( + llt_decomposition.info(), Eigen::Success, + platform::errors::InvalidArgument( + "Cholesky decomposition was not successful. The input matrice " + "might not be not be positive definite.")); + output = llt_decomposition.matrixL(); + } + } + } +}; + +/*! Use these functors to implement tril, triu, diagonal and other operators */ +template +struct EyeFunctor { + EyeFunctor(const int m, const int n, T* output) + : m_(m), n_(n), output_(output) {} + + HOSTDEVICE void operator()(size_t index) const { + const int global_row = index / n_; + const int col = index - global_row * n_; + const int batch = global_row / m_; + const int row = global_row - batch * m_; + output_[index] = col == row ? static_cast(1) : static_cast(0); + } + + const int m_, n_; + T* output_; +}; + +template +struct MatrixBandPartFunctor { + /*! Set output as input value outside a central band and 0 inside that band. + * That is: output[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n] + * where: in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) && (num_upper + * < 0 || (n-m) <= num_upper) + */ + MatrixBandPartFunctor(const int m, const int n, const int num_lower_diags, + const int num_upper_diags, const T* input, T* output) + : m_(m), + n_(n), + num_lower_diags_(num_lower_diags), + num_upper_diags_(num_upper_diags), + input_(input), + output_(output) {} + + HOSTDEVICE void operator()(size_t index) const { + const int col = index % n_; + const int row = (index / n_) % m_; + const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_); + const int band_end = + (num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1); + if (col < band_start || col >= band_end) { + output_[index] = static_cast(0); + } else { + output_[index] = input_[index]; + } + } + + const int m_, n_, num_lower_diags_, num_upper_diags_; + const T* input_; + T* output_; +}; + +template +struct MatrixSetDiagFunctor { + /*! Overwrite specified diagonals of output by the values in diagonal. + * diagonals can be a central band specified by num_diags and + * upper_diag_index, where upper_diag_index=0 refers to the main diagonal, + * positive value means superdiagonal and negative value means subdiagonal. + * When it is a band, `diag` has a shape [i, j, ..., num_diags, max_diag_len] + * and the num_diags diagonals has a up to down layout. Otherwise it has a + * shape [i, j, ..., max_diag_len]. + */ + MatrixSetDiagFunctor(const int m, const int n, const int num_diags, + const int max_diag_len, const int upper_diag_index, + const T* diag, T* output) + : m_(m), + n_(n), + num_diags_(num_diags), + max_diag_len_(max_diag_len), + upper_diag_index_(upper_diag_index), + diag_(diag), + output_(output) {} + + HOSTDEVICE void operator()(size_t index) const { + const int batch_and_diag_index = index / max_diag_len_; + const int index_in_the_diagonal = + index - batch_and_diag_index * max_diag_len_; + const int batch = batch_and_diag_index / num_diags_; + const int diag_index_in_input = batch_and_diag_index - batch * num_diags_; + // diag_index=0 refers to the main diagonal + const int diag_index = upper_diag_index_ - diag_index_in_input; + // shift down for subdiagonal if diag_index < 0 + const int y_index = + index_in_the_diagonal + (0 > -diag_index ? 0 : -diag_index); + // shift right for superdiagonal if diag_index > 0 + const int x_index = + index_in_the_diagonal + (0 > diag_index ? 0 : diag_index); + + // Upper-bound checks for diagonals shorter than max_diag_len. + // y_index and x_index are nonnegative by construction. + if (y_index < m_ && x_index < n_) { + const int out_index = batch * m_ * n_ + y_index * n_ + x_index; + output_[out_index] = diag_[index]; + } + } + + const int m_, n_, num_diags_, max_diag_len_, upper_diag_index_; + const T* diag_; + T* output_; +}; + +template +struct MatrixDiagPartFunctor { + /*! Similar to MatrixSetDiagFunctor but return the diagonals. diag_index=0 + * refers to the main diagonal, positive value means superdiagonal and + * negative value means subdiagonal */ + MatrixDiagPartFunctor(const int m, const int n, const int num_diags, + const int max_diag_len, const int upper_diag_index, + const T padding, const T* input, T* output) + : m_(m), + n_(n), + num_diags_(num_diags), + max_diag_len_(max_diag_len), + upper_diag_index_(upper_diag_index), + input_(input), + output_(output) {} + + HOSTDEVICE void operator()(size_t index) const { + const int batch_and_mapped_diag_index = index / max_diag_len_; + const int index_in_the_diagonal = + index - batch_and_mapped_diag_index * max_diag_len_; + const int batch = batch_and_mapped_diag_index / num_diags_; + const int mapped_diag_index = + batch_and_mapped_diag_index - batch * num_diags_; + // diag_index=0 refers to the main diagonal + const int diag_index = upper_diag_index_ - mapped_diag_index; + // shift down for subdiagonal if diag_index < 0 + const int y_index = + index_in_the_diagonal + (0 > -diag_index ? 0 : -diag_index); + // shift right for superdiagonal if diag_index > 0 + const int x_index = + index_in_the_diagonal + (0 > diag_index ? 0 : diag_index); + if (y_index < m_ && x_index < n_) { + output_[index] = input_[batch * m_ * n_ + y_index * m_ + x_index]; + } else { + output_[index] = padding_; + } + } + + const int m_, n_, num_diags_, max_diag_len_, upper_diag_index_; + const T padding_; + const T* input_; + T* output_; +}; + +template +struct MatrixBandPartScaleEndFunctor { + /*! Compared with MatrixBandPartFunctor, it scale up values at the end of + * band. It can be used to fuse the following operations, which actually + * output triangular with diagonal scaled up: + * 1. dig = matrix_diag_part(middle) + * 2. middle = matrix_set_diag(middle, diag * scalar) + * 3. middle = matrix_band_part(middle, -1, 0) + */ + MatrixBandPartScaleEndFunctor(const int m, const int n, + const int num_lower_diags, + const int num_upper_diags, const T scale, + const T* input, T* output) + : m_(m), + n_(n), + num_lower_diags_(num_lower_diags), + num_upper_diags_(num_upper_diags), + scale_(scale), + input_(input), + output_(output) {} + + HOSTDEVICE void operator()(size_t index) const { + const int col = index % n_; + const int row = (index / n_) % m_; + const int band_start = (num_lower_diags_ < 0 ? 0 : row - num_lower_diags_); + const int band_end = + (num_upper_diags_ < 0 ? n_ : row + num_upper_diags_ + 1); + if (col < band_start || col >= band_end) { + output_[index] = 0; + } else if (col == band_end - 1) { + output_[index] = scale_ * input_[index]; + } else { + output_[index] = input_[index]; + } + } + + const int m_, n_, num_lower_diags_, num_upper_diags_; + const T scale_; + const T* input_; + T* output_; +}; + +template +struct AddtoScaleFunctor { + AddtoScaleFunctor(const T scale, const T* input, T* output) + : scale_(scale), input_(input), output_(output) {} + HOSTDEVICE void operator()(size_t index) const { + output_[index] += input_[index]; + output_[index] *= scale_; + } + const T scale_; + const T* input_; + T* output_; +}; + +template +class CholeskyGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* out = context.Input("Out"); + auto* out_grad = context.Input(framework::GradVarName("Out")); + auto* x_grad = context.Output(framework::GradVarName("X")); + auto* x_grad_data = x_grad->mutable_data(context.GetPlace()); + + bool upper = context.Attr("upper"); + auto& dims = out->dims(); + int batch_count = 1; + for (int i = 0; i < dims.size() - 2; i++) { + batch_count *= dims[i]; + } + auto m = dims[dims.size() - 1]; + int tensor_size = batch_count * m * m; + + auto& dev_ctx = context.template device_context(); + + std::vector axis(dims.size() - 2); + std::iota(axis.begin(), axis.end(), 0); + axis.insert(axis.end(), {dims.size() - 1, dims.size() - 2}); + Tensor l, l_grad; + if (upper) { + l.mutable_data(dims, context.GetPlace()); + l_grad.mutable_data(dims, context.GetPlace()); + TransCompute(dims.size(), dev_ctx, *out, &l, axis); + TransCompute(dims.size(), dev_ctx, *out_grad, &l_grad, + axis); + } else { + l = *out; + l_grad = *out_grad; + } + auto* l_data = l.data(); + + /*! refer to Iain Murray (2016); arXiv 1602.07527 */ + /*! phi = matmul(L.transpose(-1, -2), grad) */ + Tensor middle; + auto* middle_data = middle.mutable_data(dims, context.GetPlace()); + auto trans_desc = math::CreateMatrixDescriptor(dims, 0, true); + auto no_trans_desc = math::CreateMatrixDescriptor(dims, 0, false); + auto blas = math::GetBlas(context); + blas.MatMul(l, trans_desc, l_grad, no_trans_desc, T(1), &middle, T(0)); + + /*! phi.tril_().diagonal(0, -2, -1).mul_(0.5) */ + platform::ForRange for_range(dev_ctx, tensor_size); + MatrixBandPartScaleEndFunctor matrix_band_part_scale_end_functor( + m, m, /* num_lower_diags */ m, /* num_upper_diags */ 0, + /* scale */ 0.5, middle_data, middle_data); + for_range(matrix_band_part_scale_end_functor); + + // Compute inverse by solving the triangular linear system AX = B, where B + // is the identity matrix. The matrix X would be overwritten on B + Tensor identity; + auto* identity_data = identity.mutable_data(dims, context.GetPlace()); + EyeFunctor eye_functor(m, m, identity_data); + for_range(eye_functor); + // TODO(guosheng): use trsmBatched for GPU + for (int i = 0; i < batch_count; i++) { + blas.TRSM(/*side*/ CblasLeft, /*uplo*/ CblasLower, + /*trans*/ CblasNoTrans, /*diag*/ CblasNonUnit, /*m*/ m, /*n*/ m, + /*alpha*/ T(1), l_data + i * m * m, /*lda*/ m, + identity_data + i * m * m, /*ldb*/ m); + } + Tensor& l_inverse = identity; + + /*! x_grad = matmul(matmul(L_inverse.transpose(-1, -2), phi), L_inverse) */ + Tensor middle1; + middle1.mutable_data(dims, context.GetPlace()); + blas.MatMul(l_inverse, trans_desc, middle, no_trans_desc, T(1), &middle1, + T(0)); + blas.MatMul(middle1, no_trans_desc, l_inverse, no_trans_desc, T(1), x_grad, + T(0)); + + /*! x_grad.add(x_grad.transpose(-1, -2)).mul_(0.5) */ + Tensor x_grad_trans; + auto* x_grad_trans_data = + x_grad_trans.mutable_data(dims, context.GetPlace()); + TransCompute(dims.size(), dev_ctx, *x_grad, &x_grad_trans, + axis); + AddtoScaleFunctor addto_scale_functor(0.5, x_grad_trans_data, + x_grad_data); + for_range(addto_scale_functor); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index b0148a70554..5a96e6bb4a1 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -218,6 +218,11 @@ class Blas { template void VMERF(int n, const T* a, T* y, int64_t mode) const; + template + void TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, int M, int N, T alpha, const T* A, int lda, T* B, + int ldb) const; + private: const DeviceContext& context_; }; @@ -351,6 +356,11 @@ class BlasT : private Blas { Base()->template VMERF(args...); } + template + void TRSM(ARGS... args) const { + Base()->template TRSM(args...); + } + private: const Blas* Base() const { return static_cast*>(this); diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 8e903a4eccc..e7720a97699 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -88,6 +88,11 @@ struct CUBlas { PADDLE_THROW("cublasSgemmEx is supported on cuda >= 8.0"); #endif } + + template + static void TRSM(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasStrsm(args...)); + } }; template <> @@ -131,6 +136,11 @@ struct CUBlas { static void GEMM_EX(ARGS... args) { PADDLE_THROW("Currently there are not cublasDgemmEx."); } + + template + static void TRSM(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDtrsm(args...)); + } }; template <> @@ -411,6 +421,31 @@ void Blas::BatchedGEMM( #endif // CUDA_VERSION >= 9010 } +template <> +template +void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, int M, int N, + T alpha, const T *A, int lda, T *B, + int ldb) const { + // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` + // where ' stands for transpose + cublasSideMode_t cuSide = + (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; + cublasFillMode_t cuUplo = + (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + // use CUBLAS_OP_C (conjugate transpose) for complex + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasDiagType_t cuDiag = + (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::TRSM(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, &alpha, A, + lda, B, ldb); + }); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 011c4191a4e..cdaf53fea30 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -152,6 +152,11 @@ struct CBlas { platform::dynload::mkl_scsrmm(args...); } #endif + + template + static void TRSM(ARGS... args) { + platform::dynload::cblas_strsm(args...); + } }; template <> @@ -273,6 +278,11 @@ struct CBlas { platform::dynload::mkl_dcsrmm(args...); } #endif + + template + static void TRSM(ARGS... args) { + platform::dynload::cblas_dtrsm(args...); + } }; #else @@ -298,6 +308,11 @@ struct CBlas { static void GEMV(ARGS... args) { cblas_sgemv(args...); } + + template + static void TRSM(ARGS... args) { + cblas_strsm(args...); + } }; template <> @@ -321,6 +336,11 @@ struct CBlas { static void GEMV(ARGS... args) { cblas_dgemv(args...); } + + template + static void TRSM(ARGS... args) { + cblas_dtrsm(args...); + } }; #endif @@ -899,6 +919,17 @@ void Blas::CSRMM( } #endif +template <> +template +void Blas::TRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, int M, int N, + T alpha, const T *A, int lda, T *B, + int ldb) const { + CBlas::TRSM(CblasRowMajor, side, uplo, transA, diag, M, N, alpha, A, lda, + B, ldb); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/nll_loss_op.cu b/paddle/fluid/operators/nll_loss_op.cu index ff7ac17a238..7b37239a339 100644 --- a/paddle/fluid/operators/nll_loss_op.cu +++ b/paddle/fluid/operators/nll_loss_op.cu @@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include #include "cub/cub.cuh" #include "paddle/fluid/operators/math.h" diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 46013f57677..3a1405e95c4 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -241,12 +241,14 @@ CUDAContext::CUDAContext(const CUDAPlace& place, InitEigenContext(); InitCuBlasContext(); InitCuDNNContext(); + InitCuSolverContext(); } CUDAContext::~CUDAContext() { CUDADeviceGuard guard(place_.device); DestoryCuDNNContext(); DestoryCuBlasContext(); + DestoryCuSolverContext(); } CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { @@ -340,6 +342,10 @@ CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); } +cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const { + return context()->CusolverDnHandle(); +} + cudaStream_t CUDADeviceContext::stream() const { return context()->RawStream(); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 7615a0b7ea0..76fa9ee09b8 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -22,6 +22,7 @@ limitations under the License. */ #include "paddle/fluid/platform/cuda_helper.h" #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" +#include "paddle/fluid/platform/dynload/cusolver.h" #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) #include "paddle/fluid/platform/dynload/nccl.h" #endif @@ -105,6 +106,10 @@ class CUDAContext { const cudnnHandle_t& CudnnHandle() const { return cudnn_handle_; } + const cusolverDnHandle_t& CusolverDnHandle() const { + return cusolver_dn_handle_; + } + const std::unique_ptr& CublasHandle() const { return cublas_handle_; } @@ -170,6 +175,13 @@ class CUDAContext { } } + void InitCuSolverContext() { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cusolverDnCreate(&cusolver_dn_handle_)); + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cusolverDnSetStream(cusolver_dn_handle_, RawStream())); + } + void DestoryCuDNNContext() { if (cudnn_handle_) { PADDLE_ENFORCE_CUDA_SUCCESS(dynload::cudnnDestroy(cudnn_handle_)); @@ -182,6 +194,13 @@ class CUDAContext { cublas_tensor_core_handle_.reset(); } + void DestoryCuSolverContext() { + if (cusolver_dn_handle_) { + PADDLE_ENFORCE_CUDA_SUCCESS( + dynload::cusolverDnDestroy(cusolver_dn_handle_)); + } + } + CUDAPlace place_; std::unique_ptr eigen_device_; std::unique_ptr eigen_stream_; @@ -189,6 +208,7 @@ class CUDAContext { cudnnHandle_t cudnn_handle_; std::unique_ptr cublas_handle_; std::unique_ptr cublas_tensor_core_handle_; + cusolverDnHandle_t cusolver_dn_handle_; DISABLE_COPY_AND_ASSIGN(CUDAContext); }; @@ -249,6 +269,8 @@ class CUDADeviceContext : public DeviceContext { * sequential cudnn function calls. */ CudnnWorkspaceHandle cudnn_workspace_handle() const; + cusolverDnHandle_t cusolver_dn_handle() const; + /*! \brief Return cuda stream in the device context. */ cudaStream_t stream() const; diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 04597830a84..9ea218907a4 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -1,6 +1,6 @@ cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce) -list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc) +list(APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc) # There is no macOS version of NCCL. # Disable nvrtc and cuda_driver api on MacOS and Windows, and only do a early test on Linux. diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index 439a51dd695..141de2881d3 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -76,6 +76,8 @@ extern void *cublas_dso_handle; __macro(cublasSgemmEx); \ __macro(cublasSgeam); \ __macro(cublasDgeam); \ + __macro(cublasStrsm_v2); \ + __macro(cublasDtrsm_v2); \ __macro(cublasCreate_v2); \ __macro(cublasDestroy_v2); \ __macro(cublasSetStream_v2); \ diff --git a/paddle/fluid/platform/dynload/cusolver.cc b/paddle/fluid/platform/dynload/cusolver.cc new file mode 100644 index 00000000000..84aecabf6e2 --- /dev/null +++ b/paddle/fluid/platform/dynload/cusolver.cc @@ -0,0 +1,30 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/dynload/cusolver.h" + +namespace paddle { +namespace platform { +namespace dynload { + +std::once_flag cusolver_dso_flag; +void *cusolver_dso_handle; + +#define DEFINE_WRAP(__name) DynLoad__##__name __name + +CUSOLVER_ROUTINE_EACH(DEFINE_WRAP); + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h new file mode 100644 index 00000000000..226e53369e8 --- /dev/null +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include + +#include // NOLINT +#include "paddle/fluid/platform/port.h" + +#include "paddle/fluid/platform/dynload/dynamic_loader.h" + +namespace paddle { +namespace platform { +namespace dynload { +extern std::once_flag cusolver_dso_flag; +extern void *cusolver_dso_handle; +#ifdef PADDLE_USE_DSO +#define DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + cusolverStatus_t operator()(Args... args) { \ + using cusolverFunc = decltype(&::__name); \ + std::call_once(cusolver_dso_flag, []() { \ + cusolver_dso_handle = \ + paddle::platform::dynload::GetCusolverDsoHandle(); \ + }); \ + static void *p_##__name = dlsym(cusolver_dso_handle, #__name); \ + return reinterpret_cast(p_##__name)(args...); \ + } \ + }; \ + extern DynLoad__##__name __name +#else +#define DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + cusolverStatus_t operator()(Args... args) { \ + return ::__name(args...); \ + } \ + }; \ + extern DynLoad__##__name __name +#endif + +#define CUSOLVER_ROUTINE_EACH(__macro) \ + __macro(cusolverDnCreate); \ + __macro(cusolverDnDestroy); \ + __macro(cusolverDnSetStream); \ + __macro(cusolverDnSpotrf_bufferSize); \ + __macro(cusolverDnDpotrf_bufferSize); \ + __macro(cusolverDnSpotrf); \ + __macro(cusolverDnDpotrf); + +CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); + +#if CUDA_VERSION >= 9020 +#define CUSOLVER_ROUTINE_EACH_R1(__macro) \ + __macro(cusolverDnSpotrfBatched); \ + __macro(cusolverDnDpotrfBatched); + +CUSOLVER_ROUTINE_EACH_R1(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP) +#endif + +} // namespace dynload +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/dynamic_loader.cc b/paddle/fluid/platform/dynload/dynamic_loader.cc index 48848bd84fe..3eb2b21fcc7 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.cc +++ b/paddle/fluid/platform/dynload/dynamic_loader.cc @@ -224,6 +224,16 @@ void* GetCurandDsoHandle() { #endif } +void* GetCusolverDsoHandle() { +#if defined(__APPLE__) || defined(__OSX__) + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusolver.dylib"); +#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA) + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, win_cusolver_lib); +#else + return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libcusolver.so"); +#endif +} + void* GetNVRTCDsoHandle() { #if defined(__APPLE__) || defined(__OSX__) return GetDsoHandleFromSearchPath(FLAGS_cuda_dir, "libnvrtc.dylib"); diff --git a/paddle/fluid/platform/dynload/dynamic_loader.h b/paddle/fluid/platform/dynload/dynamic_loader.h index 4940411ccf7..1136184ce1f 100644 --- a/paddle/fluid/platform/dynload/dynamic_loader.h +++ b/paddle/fluid/platform/dynload/dynamic_loader.h @@ -29,6 +29,7 @@ void* GetCublasDsoHandle(); void* GetCUDNNDsoHandle(); void* GetCUPTIDsoHandle(); void* GetCurandDsoHandle(); +void* GetCusolverDsoHandle(); void* GetNVRTCDsoHandle(); void* GetCUDADsoHandle(); void* GetWarpCTCDsoHandle(); diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index e5e818f5fba..914d04e0486 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -53,6 +53,8 @@ extern void* mklml_dso_handle; __macro(cblas_dcopy); \ __macro(cblas_sgemv); \ __macro(cblas_dgemv); \ + __macro(cblas_strsm); \ + __macro(cblas_dtrsm); \ __macro(cblas_sgemm_alloc); \ __macro(cblas_dgemm_alloc); \ __macro(cblas_sgemm_pack); \ diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index f2e0c52170b..a70f30b4e21 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -56,6 +56,7 @@ limitations under the License. */ #include "paddle/fluid/platform/dynload/cublas.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/curand.h" +#include "paddle/fluid/platform/dynload/cusolver.h" #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) #include "paddle/fluid/platform/dynload/nccl.h" #endif // __APPLE__ @@ -680,6 +681,44 @@ inline void throw_on_error(cublasStatus_t stat, const std::string& msg) { #endif } +/***** CUSOLVER ERROR *****/ +inline bool is_error(cusolverStatus_t stat) { + return stat != CUSOLVER_STATUS_SUCCESS; +} + +inline const char* cusolverGetErrorString(cusolverStatus_t stat) { + switch (stat) { + case CUSOLVER_STATUS_NOT_INITIALIZED: + return "CUSOLVER_STATUS_NOT_INITIALIZED"; + case CUSOLVER_STATUS_ALLOC_FAILED: + return "CUSOLVER_STATUS_ALLOC_FAILED"; + case CUSOLVER_STATUS_INVALID_VALUE: + return "CUSOLVER_STATUS_INVALID_VALUE"; + case CUSOLVER_STATUS_ARCH_MISMATCH: + return "CUSOLVER_STATUS_ARCH_MISMATCH"; + case CUSOLVER_STATUS_EXECUTION_FAILED: + return "CUSOLVER_STATUS_EXECUTION_FAILED"; + case CUSOLVER_STATUS_INTERNAL_ERROR: + return "CUSOLVER_STATUS_INTERNAL_ERROR"; + case CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED: + return "CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED"; + default: + return "Unknown cusolver status"; + } +} +inline std::string build_nvidia_error_msg(cusolverStatus_t stat) { + std::string msg(" Cublas error, "); + return msg + cusolverGetErrorString(stat) + " "; +} + +inline void throw_on_error(cusolverStatus_t stat, const std::string& msg) { +#ifndef REPLACE_ENFORCE_GLOG + throw std::runtime_error(msg); +#else + LOG(FATAL) << msg; +#endif +} + /****** NCCL ERROR ******/ #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) inline bool is_error(ncclResult_t nccl_result) { @@ -716,6 +755,7 @@ DEFINE_CUDA_STATUS_TYPE(cudaError_t, cudaSuccess); DEFINE_CUDA_STATUS_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS); DEFINE_CUDA_STATUS_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS); DEFINE_CUDA_STATUS_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS); +DEFINE_CUDA_STATUS_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS); #if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL) DEFINE_CUDA_STATUS_TYPE(ncclResult_t, ncclSuccess); diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 7d22df8452f..6020ed68b57 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -160,7 +160,7 @@ from .tensor.linalg import norm #DEFINE_ALIAS from .tensor.linalg import dist #DEFINE_ALIAS from .tensor.linalg import t #DEFINE_ALIAS from .tensor.linalg import cross #DEFINE_ALIAS -# from .tensor.linalg import cholesky #DEFINE_ALIAS +from .tensor.linalg import cholesky #DEFINE_ALIAS # from .tensor.linalg import .tensordot #DEFINE_ALIAS # from .tensor.manipulation import cast #DEFINE_ALIAS # from .tensor.manipulation import concat #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_cholesky_op.py b/python/paddle/fluid/tests/unittests/test_cholesky_op.py new file mode 100644 index 00000000000..4e2280c0118 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cholesky_op.py @@ -0,0 +1,94 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import paddle.fluid.core as core +from op_test import OpTest, skip_check_grad_ci +from gradient_checker import grad_check +from decorator_helper import prog_scope + + +@skip_check_grad_ci( + reason="The input of cholesky_op should always be symmetric positive-definite. " + "However, OpTest calculates the numeric gradient of each element in input " + "via small finite difference, which makes the input no longer symmetric " + "positive-definite thus can not compute the Cholesky decomposition. " + "While we can use the gradient_checker.grad_check to perform gradient " + "check of cholesky_op, since it supports check gradient with a program " + "and we can construct symmetric positive-definite matrices in the program") +class TestCholeskyOp(OpTest): + def setUp(self): + self.op_type = "cholesky" + self._input_shape = (2, 32, 32) + self._upper = True + self.init_config() + self.trans_dims = list(range(len(self._input_shape) - 2)) + [ + len(self._input_shape) - 1, len(self._input_shape) - 2 + ] + self.root_data = np.random.random(self._input_shape).astype("float64") + # construct symmetric positive-definite matrice + input_data = np.matmul( + self.root_data, self.root_data.transpose(self.trans_dims)) + 1e-05 + output_data = np.linalg.cholesky(input_data).astype("float64") + if self._upper: + output_data = output_data.transpose(self.trans_dims) + self.inputs = {"X": input_data} + self.attrs = {"upper": self._upper} + self.outputs = {"Out": output_data} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + @prog_scope() + def func(self, place): + # use small size since Jacobian gradients is time consuming + root_data = self.root_data[..., :3, :3] + prog = fluid.Program() + with fluid.program_guard(prog): + root = layers.create_parameter( + dtype=root_data.dtype, shape=root_data.shape) + root_t = layers.transpose(root, self.trans_dims) + x = layers.matmul(x=root, y=root_t) + 1e-05 + out = paddle.cholesky(x, upper=self.attrs["upper"]) + grad_check(root, out, x_init=root_data, place=place) + + def init_config(self): + self._upper = True + + +class TestCholeskyOpLower(TestCholeskyOp): + def init_config(self): + self._upper = False + + +class TestCholeskyOp2D(TestCholeskyOp): + def init_config(self): + self._input_shape = (64, 64) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 256c5be92ab..d3264d68c60 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -26,7 +26,7 @@ __all__ = [ 'dist', 't', 'cross', - # 'cholesky', + 'cholesky', # 'tensordot', 'bmm' ] @@ -627,6 +627,59 @@ def cross(input, other, dim=None): return out +def cholesky(x, upper=False): + """ + Computes the Cholesky decomposition of one symmetric positive-definite + matrix or batches of symmetric positive-definite matrice. + + If `upper` is `True`, the decomposition has the form :math:`A = U^{T}U` , + and the returned matrix :math:`U` is upper-triangular. Otherwise, the + decomposition has the form :math:`A = LL^{T}` , and the returned matrix + :math:`L` is lower-triangular. + + Args: + x (Variable): The input tensor. Its shape should be `[*, M, M]`, + where * is zero or more batch dimensions, and matrices on the + inner-most 2 dimensions all should be symmetric positive-definite. + Its data type should be float32 or float64. + upper (bool): The flag indicating whether to return upper or lower + triangular matrices. Default: False. + + Returns: + Variable: A Tensor with same shape and data type as `x`. It represents \ + triangular matrices generated by Cholesky decomposition. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + + with fluid.dygraph.guard(): + a = np.random.rand(3, 3) + a_t = np.transpose(a, [1, 0]) + x = np.matmul(a, a_t) + 1e-03 + x = fluid.dygraph.to_variable(x) + out = paddle.cholesky(x, upper=False) + print(out.numpy()) + # [[1.190523 0. 0. ] + # [0.9906703 0.27676893 0. ] + # [1.25450498 0.05600871 0.06400121]] + + """ + check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'cholesky') + check_type(upper, 'upper', bool, 'cholesky') + helper = LayerHelper('cholesky', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='cholesky', + inputs={'X': [x]}, + outputs={'Out': out}, + attrs={'upper': upper}) + return out + + def bmm(x, y, name=None): """ Applies batched matrix multiplication to two tensors. -- GitLab