From 39f7c41f54b0dbaed761a53e6929a7c9fdbd5c7f Mon Sep 17 00:00:00 2001 From: zhiboniu <31800336+zhiboniu@users.noreply.github.com> Date: Fri, 24 Dec 2021 12:01:53 +0800 Subject: [PATCH] Add new API cholesky_solve (#38167) --- cmake/operators.cmake | 1 + paddle/fluid/operators/cholesky_solve_op.cc | 172 ++++++++++++ paddle/fluid/operators/cholesky_solve_op.cu | 136 +++++++++ paddle/fluid/operators/cholesky_solve_op.h | 247 +++++++++++++++++ .../fluid/operators/math/lapack_function.cc | 32 +++ paddle/fluid/operators/math/lapack_function.h | 4 + paddle/fluid/operators/triangular_solve_op.cu | 3 +- paddle/fluid/platform/dynload/cusolver.h | 6 + paddle/fluid/platform/dynload/lapack.h | 15 +- .../tests/unittests/test_cholesky_solve_op.py | 262 ++++++++++++++++++ .../white_list/op_threshold_white_list.py | 1 + python/paddle/linalg.py | 18 +- python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/linalg.py | 50 ++++ 14 files changed, 939 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/operators/cholesky_solve_op.cc create mode 100644 paddle/fluid/operators/cholesky_solve_op.cu create mode 100644 paddle/fluid/operators/cholesky_solve_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 673b33900d..ef25675d7d 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -196,6 +196,7 @@ function(op_library TARGET) list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc") list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc") list(REMOVE_ITEM hip_srcs "cholesky_op.cu") + list(REMOVE_ITEM hip_srcs "cholesky_solve_op.cu") list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu") list(REMOVE_ITEM hip_srcs "svd_op.cu") list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu") diff --git a/paddle/fluid/operators/cholesky_solve_op.cc b/paddle/fluid/operators/cholesky_solve_op.cc new file mode 100644 index 0000000000..577176e1ff --- /dev/null +++ b/paddle/fluid/operators/cholesky_solve_op.cc @@ -0,0 +1,172 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/cholesky_solve_op.h" +#include "paddle/fluid/operators/solve_op.h" + +namespace paddle { +namespace operators { + +class CholeskySolveOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddComment(R"DOC(Solves a linear system of equations with a positive " + "semidefinite matrix to be inverted given its Cholesky factor matrix uu." + ")DOC"); + AddInput("X", "(Tensor) The input tensor, shape of (*,m,k)"); + AddInput("Y", + "(Tensor) The input tensor, shape of (*,m,m) composed of upper or " + "lower triangular Cholesky factor"); + AddOutput("Out", "(Tensor) The output tensor, shape same to X"); + AddAttr("upper", + "whether to consider the Cholesky factor " + "as a lower or upper triangular matrix") + .SetDefault(false); + } +}; + +class CholeskySolveOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *context) const override { + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "CholeskySolve"); + OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "CholeskySolve"); + OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "CholeskySolve"); + auto u_dims = context->GetInputDim("Y"); + auto b_dims = context->GetInputDim("X"); + int u_rank = u_dims.size(); + int b_rank = b_dims.size(); + PADDLE_ENFORCE_GE(u_rank, 2, + platform::errors::InvalidArgument( + "the rank of input Y must greater or equal to 2")); + PADDLE_ENFORCE_GE(b_rank, 2, + platform::errors::InvalidArgument( + "the rank of input X must greater or equal to 2")); + PADDLE_ENFORCE_EQ(u_dims[u_rank - 1], u_dims[u_rank - 2], + platform::errors::InvalidArgument( + "input Matrix Y should be square matrix," + "But Got last shape of %ld x %ld", + u_dims[u_rank - 1], u_dims[u_rank - 2])); + PADDLE_ENFORCE_EQ( + b_dims[b_rank - 2], u_dims[u_rank - 2], + platform::errors::InvalidArgument( + "the first dim of input X must equal to the dim of input Y," + "But Got %ld and %ld", + b_dims[b_rank - 2], u_dims[u_rank - 2])); + + std::vector u_dims_vec = paddle::framework::vectorize(u_dims); + std::vector b_dims_vec = paddle::framework::vectorize(b_dims); + + std::vector u_dims_vec_cut(u_dims_vec.begin(), + u_dims_vec.end() - 2); + std::vector b_dims_vec_cut(b_dims_vec.begin(), + b_dims_vec.end() - 2); + + std::vector expand_batch_portion = + get_broadcast_batch_portion(u_dims_vec_cut, b_dims_vec_cut); + + std::vector b_broadcast_dims({expand_batch_portion}); + b_broadcast_dims.insert(b_broadcast_dims.end(), + {b_dims_vec[b_rank - 2], b_dims_vec[b_rank - 1]}); + + // dim of 'Out' is the same with 'Y' after broadcast + context->SetOutputDim("Out", framework::make_ddim(b_broadcast_dims)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace()); + } +}; + +class CholeskySolveOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto var_type = ctx->GetInputType("Y", 0); + auto data_type = ctx->GetInputDataType("Y", 0); + + ctx->SetOutputType("Out", var_type, framework::ALL_ELEMENTS); + ctx->SetOutputDataType("Out", data_type, framework::ALL_ELEMENTS); + } +}; + +template +class CholeskySolveOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("cholesky_solve_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput("Y", this->Input("Y")); + retv->SetInput("Out", this->Output("Out")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + retv->SetAttrMap(this->Attrs()); + } +}; + +class CholeskySolveGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "cholesky_solve"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "cholesky_solve"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "cholesky_solve"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "cholesky_solve"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; +REGISTER_OPERATOR(cholesky_solve, ops::CholeskySolveOp, + ops::CholeskySolveOpMaker, + ops::CholeskySolveOpVarTypeInference, + ops::CholeskySolveOpGradMaker, + ops::CholeskySolveOpGradMaker); + +REGISTER_OPERATOR(cholesky_solve_grad, ops::CholeskySolveGradOp); + +REGISTER_OP_CPU_KERNEL( + cholesky_solve, + ops::CholeskySolveKernel, + ops::CholeskySolveKernel); + +REGISTER_OP_CPU_KERNEL( + cholesky_solve_grad, + ops::CholeskySolveGradKernel, + ops::CholeskySolveGradKernel); +// Complex<> is not supported because of TensorExpand, which used to boardcast +// input Tensor diff --git a/paddle/fluid/operators/cholesky_solve_op.cu b/paddle/fluid/operators/cholesky_solve_op.cu new file mode 100644 index 0000000000..f42364c961 --- /dev/null +++ b/paddle/fluid/operators/cholesky_solve_op.cu @@ -0,0 +1,136 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/cholesky_solve_op.h" +#include "paddle/fluid/platform/dynload/cusolver.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using CUDADeviceContext = paddle::platform::CUDADeviceContext; + +template +void cusolver_potrs(const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, + int n, int nrhs, T *Adata, int lda, T *Bdata, int ldb, + int *devInfo); + +template <> +void cusolver_potrs(const cusolverDnHandle_t &cusolverH, + cublasFillMode_t uplo, int n, int nrhs, float *Adata, + int lda, float *Bdata, int ldb, int *devInfo) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSpotrs( + cusolverH, uplo, n, nrhs, Adata, lda, Bdata, ldb, devInfo)); +} + +template <> +void cusolver_potrs(const cusolverDnHandle_t &cusolverH, + cublasFillMode_t uplo, int n, int nrhs, + double *Adata, int lda, double *Bdata, int ldb, + int *devInfo) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDpotrs( + cusolverH, uplo, n, nrhs, Adata, lda, Bdata, ldb, devInfo)); +} + +template <> +void cusolver_potrs>( + const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, int n, int nrhs, + platform::complex *Adata, int lda, platform::complex *Bdata, + int ldb, int *devInfo) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnCpotrs( + cusolverH, uplo, n, nrhs, reinterpret_cast(Adata), lda, + reinterpret_cast(Bdata), ldb, devInfo)); +} + +template <> +void cusolver_potrs>( + const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, int n, int nrhs, + platform::complex *Adata, int lda, platform::complex *Bdata, + int ldb, int *devInfo) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnZpotrs( + cusolverH, uplo, n, nrhs, + reinterpret_cast(Adata), lda, + reinterpret_cast(Bdata), ldb, devInfo)); +} + +template +class CholeskySolveFunctor { + public: + void operator()(const platform::CUDADeviceContext &dev_ctx, bool upper, int n, + int nrhs, T *Adata, int lda, T *Bdata, int *devInfo) { + cublasFillMode_t uplo = + upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + + /* step 1: get cusolver handle*/ + auto cusolverH = dev_ctx.cusolver_dn_handle(); + + /* step 2: solve A0*X0 = B0 */ + cusolver_potrs(cusolverH, uplo, n, nrhs, Adata, lda, Bdata, lda, + devInfo); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); + } +}; + +template +class MatrixReduceSumFunctor { + public: + void operator()(const Tensor &in, Tensor *out, + const framework::ExecutionContext &ctx) { + // For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] + // out_reduce_dim should be [0, 2] + const std::vector in_dims = framework::vectorize(in.dims()); + auto in_size = in_dims.size(); + const std::vector out_dims = + framework::vectorize(out->dims()); + auto out_size = out_dims.size(); + + std::vector out_bst_dims(in_size); + + std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1); + std::copy(out_dims.data(), out_dims.data() + out_size, + out_bst_dims.data() + in_size - out_size); + + std::vector out_reduce_dims; + for (size_t idx = 0; idx <= in_size - 3; idx++) { + if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) { + out_reduce_dims.push_back(idx); + } + } + gpuStream_t stream = ctx.cuda_device_context().stream(); + TensorReduceFunctorImpl>( + in, out, kps::IdentityFunctor(), out_reduce_dims, stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + cholesky_solve, + ops::CholeskySolveKernel, + ops::CholeskySolveKernel); + +REGISTER_OP_CUDA_KERNEL( + cholesky_solve_grad, + ops::CholeskySolveGradKernel, + ops::CholeskySolveGradKernel); + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/cholesky_solve_op.h b/paddle/fluid/operators/cholesky_solve_op.h new file mode 100644 index 0000000000..f3b0056165 --- /dev/null +++ b/paddle/fluid/operators/cholesky_solve_op.h @@ -0,0 +1,247 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/lapack_function.h" +#include "paddle/fluid/operators/solve_op.h" +#include "paddle/fluid/operators/svd_helper.h" +#include "paddle/fluid/operators/triangular_solve_op.h" +#include "paddle/fluid/platform/complex.h" +#include "paddle/pten/include/math.h" + +namespace paddle { +namespace operators { // namespace operators + +template +class CholeskySolveFunctor { + public: + void operator()(const platform::DeviceContext &dev_ctx, bool upper, int n, + int nrhs, T *Adata, int lda, T *Bdata, int *devInfo); +}; + +template +class CholeskySolveFunctor { + public: + void operator()(const platform::CPUDeviceContext &dev_ctx, bool upper, int n, + int nrhs, T *Adata, int lda, T *Bdata, int *devInfo) { + char uplo = upper ? 'U' : 'L'; + math::lapackCholeskySolve(uplo, n, nrhs, Adata, lda, Bdata, lda, + devInfo); + } +}; + +template +void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx, + const framework::Tensor &uin, + const framework::Tensor &bin, framework::Tensor *out, + bool upper) { + const auto &dev_ctx = ctx.template device_context(); + // framework::Tensor broadcast + std::vector u_bst_dims_vec; + std::vector b_bst_dims_vec; + std::tie(u_bst_dims_vec, b_bst_dims_vec) = get_broadcast_dims(uin, bin); + framework::Tensor u_bst(uin.type()); + TensorExpand(dev_ctx, uin, &u_bst, u_bst_dims_vec); + + framework::Tensor b_bst(bin.type()); + TensorExpand(dev_ctx, bin, &b_bst, b_bst_dims_vec); + + math::DeviceIndependenceTensorOperations helper(ctx); + + // calculate u's conjugate for complex + framework::Tensor u_conj(u_bst.type()); + platform::ForRange u_for_range(dev_ctx, u_bst.numel()); + math::ConjFunctor u_functor( + u_bst.data(), u_bst.numel(), + u_conj.mutable_data(u_bst.dims(), dev_ctx.GetPlace())); + u_for_range(u_functor); + u_conj = helper.Transpose(u_conj); + + // calculate b's conjugate for complex + framework::Tensor b_conj(b_bst.type()); + platform::ForRange b_for_range(dev_ctx, b_bst.numel()); + math::ConjFunctor b_functor( + b_bst.data(), b_bst.numel(), + b_conj.mutable_data(b_bst.dims(), dev_ctx.GetPlace())); + b_for_range(b_functor); + b_conj = helper.Transpose(b_conj); + + auto ut_data = u_conj.mutable_data(dev_ctx.GetPlace()); + auto uindims = u_bst.dims(); + auto bindims = b_bst.dims(); + int uinrank = uindims.size(); + int binrank = bindims.size(); + + int n = uindims[uinrank - 2]; + int nrhs = bindims[binrank - 1]; + int ldab = std::max(1, n); + + // framework::Tensor out_copy(b_conj.type()); + // out_copy.Resize(b_conj.dims()); + framework::TensorCopy(b_conj, dev_ctx.GetPlace(), out); + T *out_data = out->mutable_data(dev_ctx.GetPlace()); + + auto info_dims = slice_ddim(bindims, 0, binrank - 2); + auto batchsize = product(info_dims); + + framework::Tensor tmp; + std::vector tmpdim(1, batchsize); + tmp.Resize(framework::make_ddim(tmpdim)); + int *info = tmp.mutable_data(dev_ctx.GetPlace()); + + CholeskySolveFunctor functor; + for (int b = 0; b < batchsize; b++) { + auto uin_data_item = &ut_data[b * n * n]; + auto out_data_item = &out_data[b * n * nrhs]; + auto info_item = &info[b]; + functor(dev_ctx, upper, n, nrhs, uin_data_item, ldab, out_data_item, + info_item); + } + + // calculate out's conjugate for complex + platform::ForRange out_for_range(dev_ctx, out->numel()); + math::ConjFunctor out_functor( + out->data(), out->numel(), + out->mutable_data(out->dims(), dev_ctx.GetPlace())); + out_for_range(out_functor); + *out = helper.Transpose(*out); +} + +template +class CholeskySolveKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext &ctx) const override { + auto *uin = ctx.Input("Y"); + auto *bin = ctx.Input("X"); + auto *out = ctx.Output("Out"); + auto upper = ctx.Attr("upper"); + cholesky_solve_fn(ctx, *uin, *bin, out, upper); + } +}; + +template +class CholeskySolveGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *bin = ctx.Input("X"); + auto *uin = ctx.Input("Y"); + auto *out = ctx.Input("Out"); + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *db = ctx.Output(framework::GradVarName("X")); + auto *du = ctx.Output(framework::GradVarName("Y")); + auto upper = ctx.Attr("upper"); + + const auto &dev_ctx = ctx.template device_context(); + math::DeviceIndependenceTensorOperations helper(ctx); + + std::vector u_bst_dims_vec; + std::vector b_bst_dims_vec; + std::tie(u_bst_dims_vec, b_bst_dims_vec) = get_broadcast_dims(*uin, *bin); + framework::Tensor u_bst(uin->type()); + TensorExpand(dev_ctx, *uin, &u_bst, u_bst_dims_vec); + + framework::Tensor db_bst(bin->type()); + TensorExpand(dev_ctx, *bin, &db_bst, b_bst_dims_vec); + + if (dout) { + db->mutable_data(dev_ctx.GetPlace()); + cholesky_solve_fn(ctx, u_bst, *dout, &db_bst, upper); + + if (db_bst.dims() == db->dims()) { + framework::TensorCopy(db_bst, dev_ctx.GetPlace(), dev_ctx, db); + } else { + MatrixReduceSumFunctor functor; + functor(db_bst, db, ctx); + db->Resize(bin->dims()); + } + + auto blas = math::GetBlas(ctx); + + // calculate out's conjugate for complex + framework::Tensor out_conj(out->type()); + platform::ForRange out_for_range(dev_ctx, out->numel()); + math::ConjFunctor out_functor( + out->data(), out->numel(), + out_conj.mutable_data(out->dims(), dev_ctx.GetPlace())); + out_for_range(out_functor); + out_conj = helper.Transpose(out_conj); + + framework::Tensor commonterm(out->type()); + auto outdims = out_conj.dims(); + auto dbdims = db_bst.dims(); + auto mat_dim_a = math::CreateMatrixDescriptor(outdims, 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(dbdims, 0, false); + auto cmtdim = outdims; + cmtdim[cmtdim.size() - 2] = dbdims[dbdims.size() - 2]; + commonterm.Resize(cmtdim); + commonterm.mutable_data(dev_ctx.GetPlace()); + blas.MatMul(db_bst, mat_dim_b, out_conj, mat_dim_a, static_cast(1), + &commonterm, static_cast(0)); + + // calculate commonterm's conjugate for complex + framework::Tensor commonterm_conj(commonterm.type()); + platform::ForRange commonterm_for_range( + dev_ctx, commonterm.numel()); + math::ConjFunctor commonterm_functor( + commonterm.data(), commonterm.numel(), + commonterm_conj.mutable_data(commonterm.dims(), + dev_ctx.GetPlace())); + commonterm_for_range(commonterm_functor); + commonterm_conj = helper.Transpose(commonterm_conj); + + auto pt_x = paddle::experimental::MakePtenDenseTensor(commonterm); + auto pt_y = paddle::experimental::MakePtenDenseTensor(commonterm_conj); + auto pt_z = paddle::experimental::MakePtenDenseTensor(commonterm); + pten::Add(dev_ctx, *pt_x.get(), *pt_y.get(), -1, pt_z.get()); + + auto mat_dim_u = math::CreateMatrixDescriptor(u_bst.dims(), 0, false); + auto mat_dim_c = + math::CreateMatrixDescriptor(commonterm.dims(), 0, false); + + Tensor du_bst(uin->type()); + // get upper or lower triangular + du_bst.Resize(u_bst.dims()); + du_bst.mutable_data(dev_ctx.GetPlace()); + if (upper) { + blas.MatMul(u_bst, mat_dim_u, commonterm, mat_dim_c, static_cast(-1), + &du_bst, static_cast(0)); + } else { + blas.MatMul(commonterm, mat_dim_c, u_bst, mat_dim_u, static_cast(-1), + &du_bst, static_cast(0)); + } + + const auto &udims = u_bst.dims(); + const auto H = udims[udims.size() - 2]; + const auto W = udims[udims.size() - 1]; + platform::ForRange x_for_range(dev_ctx, u_bst.numel()); + TrilTriuCompute tril_triu_computer(du_bst.data(), 0, !upper, H, W, + u_bst.data()); + x_for_range(tril_triu_computer); + + du->mutable_data(dev_ctx.GetPlace()); + if (u_bst.dims() == du->dims()) { + framework::TensorCopy(u_bst, dev_ctx.GetPlace(), dev_ctx, du); + } else { + MatrixReduceSumFunctor functor; + functor(u_bst, du, ctx); + du->Resize(uin->dims()); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/lapack_function.cc b/paddle/fluid/operators/math/lapack_function.cc index 3ce2225420..450400e35c 100644 --- a/paddle/fluid/operators/math/lapack_function.cc +++ b/paddle/fluid/operators/math/lapack_function.cc @@ -125,6 +125,38 @@ void lapackEig, float>( reinterpret_cast *>(work), &lwork, rwork, info); } +template <> +void lapackCholeskySolve>( + char uplo, int n, int nrhs, platform::complex *a, int lda, + platform::complex *b, int ldb, int *info) { + platform::dynload::zpotrs_( + &uplo, &n, &nrhs, reinterpret_cast *>(a), &lda, + reinterpret_cast *>(b), &ldb, info); +} + +template <> +void lapackCholeskySolve>(char uplo, int n, int nrhs, + platform::complex *a, + int lda, + platform::complex *b, + int ldb, int *info) { + platform::dynload::cpotrs_( + &uplo, &n, &nrhs, reinterpret_cast *>(a), &lda, + reinterpret_cast *>(b), &ldb, info); +} + +template <> +void lapackCholeskySolve(char uplo, int n, int nrhs, double *a, int lda, + double *b, int ldb, int *info) { + platform::dynload::dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); +} + +template <> +void lapackCholeskySolve(char uplo, int n, int nrhs, float *a, int lda, + float *b, int ldb, int *info) { + platform::dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/lapack_function.h b/paddle/fluid/operators/math/lapack_function.h index a4c2c865c8..b3275d2ced 100644 --- a/paddle/fluid/operators/math/lapack_function.h +++ b/paddle/fluid/operators/math/lapack_function.h @@ -32,6 +32,10 @@ void lapackEig(char jobvl, char jobvr, int n, T1* a, int lda, T1* w, T1* vl, int ldvl, T1* vr, int ldvr, T1* work, int lwork, T2* rwork, int* info); +template +void lapackCholeskySolve(char uplo, int n, int nrhs, T* a, int lda, T* b, + int ldb, int* info); + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/triangular_solve_op.cu b/paddle/fluid/operators/triangular_solve_op.cu index dfd48fb47e..b7ea5cd953 100644 --- a/paddle/fluid/operators/triangular_solve_op.cu +++ b/paddle/fluid/operators/triangular_solve_op.cu @@ -19,7 +19,8 @@ namespace paddle { namespace operators { template -struct MatrixReduceSumFunctor { +class MatrixReduceSumFunctor { + public: void operator()(const Tensor& in, Tensor* out, const framework::ExecutionContext& ctx) { // For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] diff --git a/paddle/fluid/platform/dynload/cusolver.h b/paddle/fluid/platform/dynload/cusolver.h index 4c018908b5..b4b6d50e55 100644 --- a/paddle/fluid/platform/dynload/cusolver.h +++ b/paddle/fluid/platform/dynload/cusolver.h @@ -49,6 +49,10 @@ extern void *cusolver_dso_handle; __macro(cusolverDnDpotrf_bufferSize); \ __macro(cusolverDnSpotrf); \ __macro(cusolverDnDpotrf); \ + __macro(cusolverDnSpotrs); \ + __macro(cusolverDnDpotrs); \ + __macro(cusolverDnCpotrs); \ + __macro(cusolverDnZpotrs); \ __macro(cusolverDnSsyevd_bufferSize); \ __macro(cusolverDnDsyevd_bufferSize); \ __macro(cusolverDnCheevd_bufferSize); \ @@ -64,6 +68,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP); #define CUSOLVER_ROUTINE_EACH_R1(__macro) \ __macro(cusolverDnSpotrfBatched); \ __macro(cusolverDnDpotrfBatched); \ + __macro(cusolverDnSpotrsBatched); \ + __macro(cusolverDnDpotrsBatched); \ __macro(cusolverDnSgesvdj_bufferSize); \ __macro(cusolverDnSgeqrf_bufferSize); \ __macro(cusolverDnDgeqrf_bufferSize); \ diff --git a/paddle/fluid/platform/dynload/lapack.h b/paddle/fluid/platform/dynload/lapack.h index 9b4dd3d9e3..32d7461f42 100644 --- a/paddle/fluid/platform/dynload/lapack.h +++ b/paddle/fluid/platform/dynload/lapack.h @@ -66,6 +66,15 @@ extern "C" void cgeev_(char *jobvl, char *jobvr, int *n, std::complex *a, std::complex *work, int *lwork, float *rwork, int *info); +extern "C" void zpotrs_(char *uplo, int *n, int *nrhs, std::complex *a, + int *lda, std::complex *b, int *ldb, int *info); +extern "C" void cpotrs_(char *uplo, int *n, int *nrhs, std::complex *a, + int *lda, std::complex *b, int *ldb, int *info); +extern "C" void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, + double *b, int *ldb, int *info); +extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, + float *b, int *ldb, int *info); + namespace paddle { namespace platform { namespace dynload { @@ -105,7 +114,11 @@ extern void *lapack_dso_handle; __macro(dgeev_); \ __macro(sgeev_); \ __macro(zgeev_); \ - __macro(cgeev_); + __macro(cgeev_); \ + __macro(zpotrs_); \ + __macro(cpotrs_); \ + __macro(dpotrs_); \ + __macro(spotrs_); LAPACK_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_LAPACK_WRAP); diff --git a/python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py b/python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py new file mode 100644 index 0000000000..c31594b75e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_cholesky_solve_op.py @@ -0,0 +1,262 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.w + +from __future__ import print_function + +import unittest +import numpy as np +import scipy +import scipy.linalg + +import sys +sys.path.append("..") +import paddle +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard, core + +paddle.enable_static() + + +def cholesky_solution(X, B, upper=True): + if upper: + A = np.triu(X) + L = A.T + U = A + else: + A = np.tril(X) + L = A + U = A.T + return scipy.linalg.solve_triangular( + U, scipy.linalg.solve_triangular( + L, B, lower=True)) + + +def scipy_cholesky_solution(X, B, upper=True): + if upper: + umat = np.triu(X) + A = umat.T @umat + else: + umat = np.tril(X) + A = umat @umat.T + K = scipy.linalg.cho_factor(A) + return scipy.linalg.cho_solve(K, B) + + +def boardcast_shape(matA, matB): + shapeA = matA.shape + shapeB = matB.shape + Boardshape = [] + for idx in range(len(shapeA) - 2): + if shapeA[idx] == shapeB[idx]: + Boardshape.append(shapeA[idx]) + continue + elif shapeA[idx] == 1 or shapeB[idx] == 1: + Boardshape.append(max(shapeA[idx], shapeB[idx])) + else: + raise Exception( + 'shapeA and shapeB should be boardcasted, but got {} and {}'. + format(shapeA, shapeB)) + bsA = Boardshape + list(shapeA[-2:]) + bsB = Boardshape + list(shapeB[-2:]) + return np.broadcast_to(matA, bsA), np.broadcast_to(matB, bsB) + + +def scipy_cholesky_solution_batch(bumat, bB, upper=True): + bumat, bB = boardcast_shape(bumat, bB) + ushape = bumat.shape + bshape = bB.shape + bumat = bumat.reshape((-1, ushape[-2], ushape[-1])) + bB = bB.reshape((-1, bshape[-2], bshape[-1])) + batch = 1 + for d in ushape[:-2]: + batch *= d + bx = [] + for b in range(batch): + # x = scipy_cholesky_solution(bumat[b], bB[b], upper) #large matrix result error + x = cholesky_solution(bumat[b], bB[b], upper) + bx.append(x) + return np.array(bx).reshape(bshape) + + +# 2D + 2D , , upper=False +class TestCholeskySolveOp(OpTest): + """ + case 1 + """ + + def config(self): + self.y_shape = [15, 15] + self.x_shape = [15, 5] + self.upper = False + self.dtype = np.float64 + + def set_output(self): + umat = self.inputs['Y'] + self.output = scipy_cholesky_solution_batch( + umat, self.inputs['X'], upper=self.upper) + + def setUp(self): + self.op_type = "cholesky_solve" + self.config() + + if self.upper: + umat = np.triu(np.random.random(self.y_shape).astype(self.dtype)) + else: + umat = np.tril(np.random.random(self.y_shape).astype(self.dtype)) + + self.inputs = { + 'X': np.random.random(self.x_shape).astype(self.dtype), + 'Y': umat + } + self.attrs = {'upper': self.upper} + self.set_output() + self.outputs = {'Out': self.output} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['Y'], 'Out', max_relative_error=0.01) + + +# 3D(broadcast) + 3D, upper=True +class TestCholeskySolveOp3(TestCholeskySolveOp): + """ + case 3 + """ + + def config(self): + self.y_shape = [1, 10, 10] + self.x_shape = [2, 10, 5] + self.upper = True + self.dtype = np.float64 + + +class TestCholeskySolveAPI(unittest.TestCase): + def setUp(self): + np.random.seed(2021) + self.place = [paddle.CPUPlace()] + # self.place = [paddle.CUDAPlace(0)] + self.dtype = "float64" + self.upper = True + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def check_static_result(self, place): + paddle.enable_static() + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[10, 2], dtype=self.dtype) + y = fluid.data(name="y", shape=[10, 10], dtype=self.dtype) + z = paddle.linalg.cholesky_solve(x, y, upper=self.upper) + + x_np = np.random.random([10, 2]).astype(self.dtype) + y_np = np.random.random([10, 10]).astype(self.dtype) + if self.upper: + umat = np.triu(y_np) + else: + umat = np.tril(y_np) + z_np = cholesky_solution(umat, x_np, upper=self.upper) + z2_np = scipy_cholesky_solution(umat, x_np, upper=self.upper) + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"x": x_np, + "y": umat}, + fetch_list=[z]) + self.assertTrue(np.allclose(fetches[0], z_np)) + + def test_static(self): + for place in self.place: + self.check_static_result(place=place) + + def test_dygraph(self): + def run(place): + paddle.disable_static(place) + x_np = np.random.random([20, 2]).astype(self.dtype) + y_np = np.random.random([20, 20]).astype(self.dtype) + z_np = scipy_cholesky_solution(y_np, x_np, upper=self.upper) + + x = paddle.to_tensor(x_np) + y = paddle.to_tensor(y_np) + z = paddle.linalg.cholesky_solve(x, y, upper=self.upper) + + self.assertTrue(np.allclose(z_np, z.numpy())) + self.assertEqual(z_np.shape, z.numpy().shape) + paddle.enable_static() + + for idx, place in enumerate(self.place): + run(place) + + def test_boardcast(self): + def run(place): + paddle.disable_static() + x_np = np.random.random([1, 30, 2]).astype(self.dtype) + y_np = np.random.random([2, 30, 30]).astype(self.dtype) + nx_np = np.concatenate((x_np, x_np), axis=0) + + z_sci = scipy_cholesky_solution_batch(y_np, nx_np, upper=self.upper) + + x = paddle.to_tensor(x_np) + y = paddle.to_tensor(y_np) + z = paddle.linalg.cholesky_solve(x, y, upper=self.upper) + self.assertEqual(z_sci.shape, z.numpy().shape) + self.assertTrue(np.allclose(z_sci, z.numpy())) + + for idx, place in enumerate(self.place): + run(place) + + +class TestCholeskySolveOpError(unittest.TestCase): + def test_errors(self): + paddle.enable_static() + with program_guard(Program(), Program()): + # The input type of solve_op must be Variable. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + y1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + self.assertRaises(TypeError, paddle.linalg.cholesky_solve, x1, y1) + + # The data type of input must be float32 or float64. + x2 = fluid.data(name="x2", shape=[30, 30], dtype="bool") + y2 = fluid.data(name="y2", shape=[30, 10], dtype="bool") + self.assertRaises(TypeError, paddle.linalg.cholesky_solve, x2, y2) + + x3 = fluid.data(name="x3", shape=[30, 30], dtype="int32") + y3 = fluid.data(name="y3", shape=[30, 10], dtype="int32") + self.assertRaises(TypeError, paddle.linalg.cholesky_solve, x3, y3) + + x4 = fluid.data(name="x4", shape=[30, 30], dtype="float16") + y4 = fluid.data(name="y4", shape=[30, 10], dtype="float16") + self.assertRaises(TypeError, paddle.linalg.cholesky_solve, x4, y4) + + # The number of dimensions of input'X must be >= 2. + x5 = fluid.data(name="x5", shape=[30], dtype="float64") + y5 = fluid.data(name="y5", shape=[30, 30], dtype="float64") + self.assertRaises(ValueError, paddle.linalg.cholesky_solve, x5, y5) + + # The number of dimensions of input'Y must be >= 2. + x6 = fluid.data(name="x6", shape=[30, 30], dtype="float64") + y6 = fluid.data(name="y6", shape=[30], dtype="float64") + self.assertRaises(ValueError, paddle.linalg.cholesky_solve, x6, y6) + + # The inner-most 2 dimensions of input'X should be equal to each other + x7 = fluid.data(name="x7", shape=[2, 3, 4], dtype="float64") + y7 = fluid.data(name="y7", shape=[2, 4, 3], dtype="float64") + self.assertRaises(ValueError, paddle.linalg.cholesky_solve, x7, y7) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index 1c8c89d13a..5deca1dc5a 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -49,6 +49,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ 'sparse_attention', \ 'svd', \ 'matrix_power', \ + 'cholesky_solve', \ 'solve', \ ] diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 119db0894f..6b83448d0b 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -18,18 +18,19 @@ from .tensor.linalg import eig # noqa: F401 from .tensor.linalg import cond # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 from .tensor.linalg import solve # noqa: F401 +from .tensor.linalg import cholesky_solve # noqa: F401 from .tensor import inverse as inv # noqa: F401 from .tensor.linalg import eigvals # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401 -from .tensor.linalg import matrix_rank -from .tensor.linalg import svd -from .tensor.linalg import eigvalsh -from .tensor.linalg import qr +from .tensor.linalg import matrix_rank # noqa: F401 +from .tensor.linalg import svd # noqa: F401 +from .tensor.linalg import eigvalsh # noqa: F401 +from .tensor.linalg import qr # noqa: F401 from .tensor.linalg import eigh # noqa: F401 -from .tensor.linalg import det -from .tensor.linalg import slogdet -from .tensor.linalg import pinv -from .tensor.linalg import triangular_solve +from .tensor.linalg import det # noqa: F401 +from .tensor.linalg import slogdet # noqa: F401 +from .tensor.linalg import pinv # noqa: F401 +from .tensor.linalg import triangular_solve # noqa: F401 __all__ = [ 'cholesky', #noqa @@ -49,5 +50,6 @@ __all__ = [ 'eigvalsh', 'pinv', 'solve', + 'cholesky_solve', 'triangular_solve', ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index a5d119a8d1..f99b0cbbcc 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -60,6 +60,7 @@ from .linalg import eigvalsh # noqa: F401 from .linalg import eigh # noqa: F401 from .linalg import pinv # noqa: F401 from .linalg import solve # noqa: F401 +from .linalg import cholesky_solve # noqa: F401 from .logic import equal # noqa: F401 from .logic import greater_equal # noqa: F401 from .logic import greater_than # noqa: F401 @@ -433,6 +434,7 @@ tensor_method_func = [ #noqa 'uniform_', 'multi_dot', 'solve', + 'cholesky_solve', 'triangular_solve', 'asinh', 'atanh', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 6757ce68b4..a8c565f336 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2388,6 +2388,56 @@ def triangular_solve(x, return out +def cholesky_solve(x, y, upper=False, name=None): + r""" + Solves a linear system of equations A @ X = B, given A's Cholesky factor matrix u and matrix B. + + Input `x` and `y` is 2D matrices or batches of 2D matrices. If the inputs are batches, the outputs + is also batches. + + Args: + x (Tensor): The input matrix which is upper or lower triangular Cholesky factor of square matrix A. Its shape should be `[*, M, M]`, where `*` is zero or + more batch dimensions. Its data type should be float32 or float64. + y (Tensor): Multiple right-hand sides of system of equations. Its shape should be `[*, M, K]`, where `*` is + zero or more batch dimensions. Its data type should be float32 or float64. + upper (bool, optional): whether to consider the Cholesky factor as a lower or upper triangular matrix. Default: False. + name(str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The solution of the system of equations. Its data type is the same as that of `x`. + + Examples: + .. code-block:: python + + import paddle + + u = paddle.to_tensor([[1, 1, 1], + [0, 2, 1], + [0, 0,-1]], dtype="float64") + b = paddle.to_tensor([[0], [-9], [5]], dtype="float64") + out = paddle.linalg.cholesky_solve(b, u, upper=True) + + print(out) + # [-2.5, -7, 9.5] + """ + if in_dygraph_mode(): + return _C_ops.cholesky_solve(x, y, 'upper', upper) + + helper = LayerHelper("cholesky_solve", **locals()) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'cholesky_solve') + check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'cholesky_solve') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='cholesky_solve', + inputs={'X': x, + 'Y': y}, + outputs={'Out': out}, + attrs={'upper': upper}) + return out + + def eigvalsh(x, UPLO='L', name=None): """ Computes the eigenvalues of a -- GitLab