From e24ca55ef58dac35c3d1e9a3dd0c950724bc22a8 Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Fri, 11 Mar 2022 17:01:55 +0800 Subject: [PATCH] [Phi]migrate cholesky_solve op to phi (#40387) --- paddle/fluid/operators/cholesky_solve_op.cc | 68 +---- paddle/fluid/operators/cholesky_solve_op.cu | 136 --------- paddle/fluid/operators/cholesky_solve_op.h | 252 ---------------- paddle/fluid/operators/triangular_solve_op.h | 40 --- paddle/fluid/platform/dynload/CMakeLists.txt | 2 - paddle/fluid/platform/dynload/lapack.cc | 27 -- paddle/fluid/platform/dynload/lapack.h | 68 ----- paddle/phi/backends/dynload/lapack.h | 4 +- paddle/phi/infermeta/binary.cc | 54 ++++ paddle/phi/infermeta/binary.h | 5 + .../phi/kernels/cholesky_solve_grad_kernel.h | 31 ++ paddle/phi/kernels/cholesky_solve_kernel.h | 28 ++ .../kernels/cpu/cholesky_solve_grad_kernel.cc | 25 ++ .../phi/kernels/cpu/cholesky_solve_kernel.cc | 42 +++ .../phi/kernels/funcs/lapack/CMakeLists.txt | 2 +- .../kernels/funcs/lapack/lapack_function.cc | 285 +++++++++--------- .../kernels/gpu/cholesky_solve_grad_kernel.cu | 30 ++ .../phi/kernels/gpu/cholesky_solve_kernel.cu | 141 +++++++++ .../impl/cholesky_solve_grad_kernel_impl.h | 134 ++++++++ .../kernels/impl/cholesky_solve_kernel_impl.h | 104 +++++++ paddle/phi/ops/compat/cholesky_solve_sig.cc | 30 ++ 21 files changed, 775 insertions(+), 733 deletions(-) delete mode 100644 paddle/fluid/operators/cholesky_solve_op.cu delete mode 100644 paddle/fluid/operators/cholesky_solve_op.h delete mode 100644 paddle/fluid/platform/dynload/lapack.cc delete mode 100644 paddle/fluid/platform/dynload/lapack.h create mode 100644 paddle/phi/kernels/cholesky_solve_grad_kernel.h create mode 100644 paddle/phi/kernels/cholesky_solve_kernel.h create mode 100644 paddle/phi/kernels/cpu/cholesky_solve_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/cholesky_solve_kernel.cc create mode 100644 paddle/phi/kernels/gpu/cholesky_solve_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/cholesky_solve_kernel.cu create mode 100644 paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h create mode 100644 paddle/phi/ops/compat/cholesky_solve_sig.cc diff --git a/paddle/fluid/operators/cholesky_solve_op.cc b/paddle/fluid/operators/cholesky_solve_op.cc index 6b5bae8fc7..5403e2440e 100644 --- a/paddle/fluid/operators/cholesky_solve_op.cc +++ b/paddle/fluid/operators/cholesky_solve_op.cc @@ -12,8 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/cholesky_solve_op.h" -#include "paddle/fluid/operators/solve_op.h" +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/infermeta/binary.h" namespace paddle { namespace operators { @@ -39,50 +40,6 @@ class CholeskySolveOpMaker : public framework::OpProtoAndCheckerMaker { class CholeskySolveOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *context) const override { - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "CholeskySolve"); - OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "CholeskySolve"); - OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "CholeskySolve"); - auto u_dims = context->GetInputDim("Y"); - auto b_dims = context->GetInputDim("X"); - int u_rank = u_dims.size(); - int b_rank = b_dims.size(); - PADDLE_ENFORCE_GE(u_rank, 2, - platform::errors::InvalidArgument( - "the rank of input Y must greater or equal to 2")); - PADDLE_ENFORCE_GE(b_rank, 2, - platform::errors::InvalidArgument( - "the rank of input X must greater or equal to 2")); - PADDLE_ENFORCE_EQ(u_dims[u_rank - 1], u_dims[u_rank - 2], - platform::errors::InvalidArgument( - "input Matrix Y should be square matrix," - "But Got last shape of %ld x %ld", - u_dims[u_rank - 1], u_dims[u_rank - 2])); - PADDLE_ENFORCE_EQ( - b_dims[b_rank - 2], u_dims[u_rank - 2], - platform::errors::InvalidArgument( - "the first dim of input X must equal to the dim of input Y," - "But Got %ld and %ld", - b_dims[b_rank - 2], u_dims[u_rank - 2])); - - std::vector u_dims_vec = phi::vectorize(u_dims); - std::vector b_dims_vec = phi::vectorize(b_dims); - - std::vector u_dims_vec_cut(u_dims_vec.begin(), - u_dims_vec.end() - 2); - std::vector b_dims_vec_cut(b_dims_vec.begin(), - b_dims_vec.end() - 2); - - std::vector expand_batch_portion = - get_broadcast_batch_portion(u_dims_vec_cut, b_dims_vec_cut); - - std::vector b_broadcast_dims({expand_batch_portion}); - b_broadcast_dims.insert(b_broadcast_dims.end(), - {b_dims_vec[b_rank - 2], b_dims_vec[b_rank - 1]}); - - // dim of 'Out' is the same with 'Y' after broadcast - context->SetOutputDim("Out", phi::make_ddim(b_broadcast_dims)); - } protected: framework::OpKernelType GetExpectedKernelType( @@ -151,22 +108,15 @@ class CholeskySolveGradOp : public framework::OperatorWithKernel { } // namespace operators } // namespace paddle namespace ops = paddle::operators; + +DECLARE_INFER_SHAPE_FUNCTOR(cholesky_solve, CholeskySolveInferShapeFunctor, + PD_INFER_META(phi::CholeskySolveInferMeta)); + REGISTER_OPERATOR(cholesky_solve, ops::CholeskySolveOp, ops::CholeskySolveOpMaker, ops::CholeskySolveOpVarTypeInference, ops::CholeskySolveOpGradMaker, - ops::CholeskySolveOpGradMaker); + ops::CholeskySolveOpGradMaker, + CholeskySolveInferShapeFunctor); REGISTER_OPERATOR(cholesky_solve_grad, ops::CholeskySolveGradOp); - -REGISTER_OP_CPU_KERNEL( - cholesky_solve, - ops::CholeskySolveKernel, - ops::CholeskySolveKernel); - -REGISTER_OP_CPU_KERNEL( - cholesky_solve_grad, - ops::CholeskySolveGradKernel, - ops::CholeskySolveGradKernel); -// Complex<> is not supported because of TensorExpand, which used to boardcast -// input Tensor diff --git a/paddle/fluid/operators/cholesky_solve_op.cu b/paddle/fluid/operators/cholesky_solve_op.cu deleted file mode 100644 index 1b551a7cd0..0000000000 --- a/paddle/fluid/operators/cholesky_solve_op.cu +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef PADDLE_WITH_HIP -// HIP not support cusolver - -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/operators/cholesky_solve_op.h" -#include "paddle/fluid/platform/dynload/cusolver.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using CUDADeviceContext = paddle::platform::CUDADeviceContext; - -template -void cusolver_potrs(const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, - int n, int nrhs, T *Adata, int lda, T *Bdata, int ldb, - int *devInfo); - -template <> -void cusolver_potrs(const cusolverDnHandle_t &cusolverH, - cublasFillMode_t uplo, int n, int nrhs, float *Adata, - int lda, float *Bdata, int ldb, int *devInfo) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnSpotrs( - cusolverH, uplo, n, nrhs, Adata, lda, Bdata, ldb, devInfo)); -} - -template <> -void cusolver_potrs(const cusolverDnHandle_t &cusolverH, - cublasFillMode_t uplo, int n, int nrhs, - double *Adata, int lda, double *Bdata, int ldb, - int *devInfo) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnDpotrs( - cusolverH, uplo, n, nrhs, Adata, lda, Bdata, ldb, devInfo)); -} - -template <> -void cusolver_potrs>( - const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, int n, int nrhs, - platform::complex *Adata, int lda, platform::complex *Bdata, - int ldb, int *devInfo) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnCpotrs( - cusolverH, uplo, n, nrhs, reinterpret_cast(Adata), lda, - reinterpret_cast(Bdata), ldb, devInfo)); -} - -template <> -void cusolver_potrs>( - const cusolverDnHandle_t &cusolverH, cublasFillMode_t uplo, int n, int nrhs, - platform::complex *Adata, int lda, platform::complex *Bdata, - int ldb, int *devInfo) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusolverDnZpotrs( - cusolverH, uplo, n, nrhs, - reinterpret_cast(Adata), lda, - reinterpret_cast(Bdata), ldb, devInfo)); -} - -template -class CholeskySolveFunctor { - public: - void operator()(const platform::CUDADeviceContext &dev_ctx, bool upper, int n, - int nrhs, T *Adata, int lda, T *Bdata, int *devInfo) { - cublasFillMode_t uplo = - upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; - - /* step 1: get cusolver handle*/ - auto cusolverH = dev_ctx.cusolver_dn_handle(); - - /* step 2: solve A0*X0 = B0 */ - cusolver_potrs(cusolverH, uplo, n, nrhs, Adata, lda, Bdata, lda, - devInfo); - - PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); - } -}; - -template -class MatrixReduceSumFunctor { - public: - void operator()(const Tensor &in, Tensor *out, - const framework::ExecutionContext &ctx) { - // For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] - // out_reduce_dim should be [0, 2] - const std::vector in_dims = phi::vectorize(in.dims()); - auto in_size = in_dims.size(); - const std::vector out_dims = phi::vectorize(out->dims()); - auto out_size = out_dims.size(); - - std::vector out_bst_dims(in_size); - - std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1); - std::copy(out_dims.data(), out_dims.data() + out_size, - out_bst_dims.data() + in_size - out_size); - - std::vector out_reduce_dims; - for (size_t idx = 0; idx <= in_size - 3; idx++) { - if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) { - out_reduce_dims.push_back(idx); - } - } - gpuStream_t stream = ctx.cuda_device_context().stream(); - TensorReduceImpl>( - ctx.cuda_device_context(), in, out, kps::IdentityFunctor(), - out_reduce_dims, stream); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - cholesky_solve, - ops::CholeskySolveKernel, - ops::CholeskySolveKernel); - -REGISTER_OP_CUDA_KERNEL( - cholesky_solve_grad, - ops::CholeskySolveGradKernel, - ops::CholeskySolveGradKernel); - -#endif // not PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/cholesky_solve_op.h b/paddle/fluid/operators/cholesky_solve_op.h deleted file mode 100644 index 74b961d4e5..0000000000 --- a/paddle/fluid/operators/cholesky_solve_op.h +++ /dev/null @@ -1,252 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/solve_op.h" -#include "paddle/fluid/operators/triangular_solve_op.h" -#include "paddle/fluid/platform/complex.h" -#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" -#include "paddle/phi/kernels/math_kernel.h" -#include "paddle/phi/kernels/transpose_kernel.h" - -namespace paddle { -namespace operators { // namespace operators - -template -class CholeskySolveFunctor { - public: - void operator()(const platform::DeviceContext &dev_ctx, bool upper, int n, - int nrhs, T *Adata, int lda, T *Bdata, int *devInfo); -}; - -template -class CholeskySolveFunctor { - public: - void operator()(const platform::CPUDeviceContext &dev_ctx, bool upper, int n, - int nrhs, T *Adata, int lda, T *Bdata, int *devInfo) { - char uplo = upper ? 'U' : 'L'; - phi::funcs::lapackCholeskySolve(uplo, n, nrhs, Adata, lda, Bdata, lda, - devInfo); - } -}; - -template -void cholesky_solve_fn(const paddle::framework::ExecutionContext &ctx, - const framework::Tensor &uin, - const framework::Tensor &bin, framework::Tensor *out, - bool upper) { - const auto &dev_ctx = ctx.template device_context(); - // framework::Tensor broadcast - std::vector u_bst_dims_vec; - std::vector b_bst_dims_vec; - std::tie(u_bst_dims_vec, b_bst_dims_vec) = get_broadcast_dims(uin, bin); - framework::Tensor u_bst(uin.type()); - TensorExpand(dev_ctx, uin, &u_bst, u_bst_dims_vec); - - framework::Tensor b_bst(bin.type()); - TensorExpand(dev_ctx, bin, &b_bst, b_bst_dims_vec); - - auto &phi_dev_ctx = static_cast< - const typename framework::ConvertToPhiContext::TYPE &>( - dev_ctx); - - // calculate u's conjugate for complex - framework::Tensor u_conj(u_bst.type()); - platform::ForRange u_for_range(dev_ctx, u_bst.numel()); - phi::funcs::ConjFunctor u_functor( - u_bst.data(), u_bst.numel(), - u_conj.mutable_data(u_bst.dims(), dev_ctx.GetPlace())); - u_for_range(u_functor); - u_conj = phi::TransposeLast2Dim(phi_dev_ctx, u_conj); - - // calculate b's conjugate for complex - framework::Tensor b_conj(b_bst.type()); - platform::ForRange b_for_range(dev_ctx, b_bst.numel()); - phi::funcs::ConjFunctor b_functor( - b_bst.data(), b_bst.numel(), - b_conj.mutable_data(b_bst.dims(), dev_ctx.GetPlace())); - b_for_range(b_functor); - b_conj = phi::TransposeLast2Dim(phi_dev_ctx, b_conj); - - auto ut_data = u_conj.mutable_data(dev_ctx.GetPlace()); - auto uindims = u_bst.dims(); - auto bindims = b_bst.dims(); - int uinrank = uindims.size(); - int binrank = bindims.size(); - - int n = uindims[uinrank - 2]; - int nrhs = bindims[binrank - 1]; - int ldab = std::max(1, n); - - // framework::Tensor out_copy(b_conj.type()); - // out_copy.Resize(b_conj.dims()); - framework::TensorCopy(b_conj, dev_ctx.GetPlace(), out); - T *out_data = out->mutable_data(dev_ctx.GetPlace()); - - auto info_dims = phi::slice_ddim(bindims, 0, binrank - 2); - auto batchsize = product(info_dims); - - framework::Tensor tmp; - std::vector tmpdim(1, batchsize); - tmp.Resize(phi::make_ddim(tmpdim)); - int *info = tmp.mutable_data(dev_ctx.GetPlace()); - - CholeskySolveFunctor functor; - for (int b = 0; b < batchsize; b++) { - auto uin_data_item = &ut_data[b * n * n]; - auto out_data_item = &out_data[b * n * nrhs]; - auto info_item = &info[b]; - functor(dev_ctx, upper, n, nrhs, uin_data_item, ldab, out_data_item, - info_item); - } - - // calculate out's conjugate for complex - platform::ForRange out_for_range(dev_ctx, out->numel()); - phi::funcs::ConjFunctor out_functor( - out->data(), out->numel(), - out->mutable_data(out->dims(), dev_ctx.GetPlace())); - out_for_range(out_functor); - *out = phi::TransposeLast2Dim(phi_dev_ctx, *out); -} - -template -class CholeskySolveKernel : public framework::OpKernel { - public: - void Compute(const paddle::framework::ExecutionContext &ctx) const override { - auto *uin = ctx.Input("Y"); - auto *bin = ctx.Input("X"); - auto *out = ctx.Output("Out"); - auto upper = ctx.Attr("upper"); - cholesky_solve_fn(ctx, *uin, *bin, out, upper); - } -}; - -template -class CholeskySolveGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *bin = ctx.Input("X"); - auto *uin = ctx.Input("Y"); - auto *out = ctx.Input("Out"); - auto *dout = ctx.Input(framework::GradVarName("Out")); - auto *db = ctx.Output(framework::GradVarName("X")); - auto *du = ctx.Output(framework::GradVarName("Y")); - auto upper = ctx.Attr("upper"); - - const auto &dev_ctx = ctx.template device_context(); - auto &phi_dev_ctx = static_cast< - const typename framework::ConvertToPhiContext::TYPE &>( - dev_ctx); - - std::vector u_bst_dims_vec; - std::vector b_bst_dims_vec; - std::tie(u_bst_dims_vec, b_bst_dims_vec) = get_broadcast_dims(*uin, *bin); - framework::Tensor u_bst(uin->type()); - TensorExpand(dev_ctx, *uin, &u_bst, u_bst_dims_vec); - - framework::Tensor db_bst(bin->type()); - TensorExpand(dev_ctx, *bin, &db_bst, b_bst_dims_vec); - - if (dout) { - db->mutable_data(dev_ctx.GetPlace()); - cholesky_solve_fn(ctx, u_bst, *dout, &db_bst, upper); - - if (db_bst.dims() == db->dims()) { - framework::TensorCopy(db_bst, dev_ctx.GetPlace(), dev_ctx, db); - } else { - MatrixReduceSumFunctor functor; - functor(db_bst, db, ctx); - db->Resize(bin->dims()); - } - - auto blas = phi::funcs::GetBlas(ctx); - - // calculate out's conjugate for complex - framework::Tensor out_conj(out->type()); - platform::ForRange out_for_range(dev_ctx, out->numel()); - phi::funcs::ConjFunctor out_functor( - out->data(), out->numel(), - out_conj.mutable_data(out->dims(), dev_ctx.GetPlace())); - out_for_range(out_functor); - out_conj = phi::TransposeLast2Dim(phi_dev_ctx, out_conj); - - framework::Tensor commonterm(out->type()); - auto outdims = out_conj.dims(); - auto dbdims = db_bst.dims(); - auto mat_dim_a = phi::funcs::CreateMatrixDescriptor(outdims, 0, false); - auto mat_dim_b = phi::funcs::CreateMatrixDescriptor(dbdims, 0, false); - auto cmtdim = outdims; - cmtdim[cmtdim.size() - 2] = dbdims[dbdims.size() - 2]; - commonterm.Resize(cmtdim); - commonterm.mutable_data(dev_ctx.GetPlace()); - blas.MatMul(db_bst, mat_dim_b, out_conj, mat_dim_a, static_cast(1), - &commonterm, static_cast(0)); - - // calculate commonterm's conjugate for complex - framework::Tensor commonterm_conj(commonterm.type()); - platform::ForRange commonterm_for_range( - dev_ctx, commonterm.numel()); - phi::funcs::ConjFunctor commonterm_functor( - commonterm.data(), commonterm.numel(), - commonterm_conj.mutable_data(commonterm.dims(), - dev_ctx.GetPlace())); - commonterm_for_range(commonterm_functor); - commonterm_conj = phi::TransposeLast2Dim(phi_dev_ctx, commonterm_conj); - - phi::AddRawKernel( - static_cast::TYPE &>(dev_ctx), - commonterm, commonterm_conj, -1, &commonterm); - - auto mat_dim_u = - phi::funcs::CreateMatrixDescriptor(u_bst.dims(), 0, false); - auto mat_dim_c = - phi::funcs::CreateMatrixDescriptor(commonterm.dims(), 0, false); - - Tensor du_bst(uin->type()); - // get upper or lower triangular - du_bst.Resize(u_bst.dims()); - du_bst.mutable_data(dev_ctx.GetPlace()); - if (upper) { - blas.MatMul(u_bst, mat_dim_u, commonterm, mat_dim_c, static_cast(-1), - &du_bst, static_cast(0)); - } else { - blas.MatMul(commonterm, mat_dim_c, u_bst, mat_dim_u, static_cast(-1), - &du_bst, static_cast(0)); - } - - const auto &udims = u_bst.dims(); - const auto H = udims[udims.size() - 2]; - const auto W = udims[udims.size() - 1]; - platform::ForRange x_for_range(dev_ctx, u_bst.numel()); - TrilTriuCompute tril_triu_computer(du_bst.data(), 0, !upper, H, W, - u_bst.data()); - x_for_range(tril_triu_computer); - - du->mutable_data(dev_ctx.GetPlace()); - if (u_bst.dims() == du->dims()) { - framework::TensorCopy(u_bst, dev_ctx.GetPlace(), dev_ctx, du); - } else { - MatrixReduceSumFunctor functor; - functor(u_bst, du, ctx); - du->Resize(uin->dims()); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/triangular_solve_op.h b/paddle/fluid/operators/triangular_solve_op.h index 315847b4d8..fd46aca456 100644 --- a/paddle/fluid/operators/triangular_solve_op.h +++ b/paddle/fluid/operators/triangular_solve_op.h @@ -60,45 +60,5 @@ static void triangular_solve(const DeviceContext &context, const Tensor &x, unitriangular); } -template -class MatrixReduceSumFunctor { - public: - void operator()(const Tensor &input, Tensor *output, - const framework::ExecutionContext &ctx); -}; - -template -class MatrixReduceSumFunctor { - public: - void operator()(const Tensor &in, Tensor *out, - const framework::ExecutionContext &ctx) { - // For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] - // out_reduce_dim should be [0, 2] - const std::vector in_dims = phi::vectorize(in.dims()); - auto in_size = in_dims.size(); - const std::vector out_dims = phi::vectorize(out->dims()); - auto out_size = out_dims.size(); - - std::vector out_bst_dims(in_size); - - std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1); - std::copy(out_dims.data(), out_dims.data() + out_size, - out_bst_dims.data() + in_size - out_size); - out->Resize(phi::make_ddim(out_bst_dims)); - - std::vector out_reduce_dims; - for (size_t idx = 0; idx <= in_size - 3; idx++) { - if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) { - out_reduce_dims.push_back(idx); - } - } - - ReduceKernelFunctor( - &in, out, out_reduce_dims, true, false, ctx) - .template apply(); - out->Resize(phi::make_ddim(out_dims)); - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/fluid/platform/dynload/CMakeLists.txt b/paddle/fluid/platform/dynload/CMakeLists.txt index 87aa5dcde6..1f95e12127 100644 --- a/paddle/fluid/platform/dynload/CMakeLists.txt +++ b/paddle/fluid/platform/dynload/CMakeLists.txt @@ -46,8 +46,6 @@ if (WITH_MKLML) cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml phi_dynload_mklml) endif() -cc_library(dynload_lapack SRCS lapack.cc DEPS dynamic_loader phi_dynload_lapack) -add_dependencies(dynload_lapack extern_lapack) # TODO(TJ): add iomp, mkldnn? if (MKL_FOUND AND WITH_ONEMKL) diff --git a/paddle/fluid/platform/dynload/lapack.cc b/paddle/fluid/platform/dynload/lapack.cc deleted file mode 100644 index 5a21bb4d04..0000000000 --- a/paddle/fluid/platform/dynload/lapack.cc +++ /dev/null @@ -1,27 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/platform/dynload/lapack.h" - -namespace paddle { -namespace platform { -namespace dynload { - -#define DEFINE_WRAP(__name) DynLoad__##__name __name - -LAPACK_ROUTINE_EACH(DEFINE_WRAP); - -} // namespace dynload -} // namespace platform -} // namespace paddle diff --git a/paddle/fluid/platform/dynload/lapack.h b/paddle/fluid/platform/dynload/lapack.h deleted file mode 100644 index 59e04dbd2a..0000000000 --- a/paddle/fluid/platform/dynload/lapack.h +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include "paddle/phi/backends/dynload/lapack.h" -#include "paddle/phi/common/complex.h" - -namespace paddle { -namespace platform { -namespace dynload { - -/** - * The following macro definition can generate structs - * (for each function) to dynamic load lapack routine - * via operator overloading. - */ -#define DYNAMIC_LOAD_LAPACK_WRAP(__name) \ - using DynLoad__##__name = phi::dynload::DynLoad__##__name; \ - extern DynLoad__##__name __name - -#define DECLARE_DYNAMIC_LOAD_LAPACK_WRAP(__name) \ - DYNAMIC_LOAD_LAPACK_WRAP(__name) - -#define LAPACK_ROUTINE_EACH(__macro) \ - __macro(dgetrf_); \ - __macro(sgetrf_); \ - __macro(zheevd_); \ - __macro(cheevd_); \ - __macro(dsyevd_); \ - __macro(ssyevd_); \ - __macro(dgeev_); \ - __macro(sgeev_); \ - __macro(zgeev_); \ - __macro(cgeev_); \ - __macro(dgels_); \ - __macro(sgels_); \ - __macro(dgelsd_); \ - __macro(sgelsd_); \ - __macro(dgelsy_); \ - __macro(sgelsy_); \ - __macro(dgelss_); \ - __macro(sgelss_); \ - __macro(zpotrs_); \ - __macro(cpotrs_); \ - __macro(dpotrs_); \ - __macro(spotrs_); - -LAPACK_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_LAPACK_WRAP); - -#undef DYNAMIC_LOAD_LAPACK_WRAP - -} // namespace dynload -} // namespace platform -} // namespace paddle diff --git a/paddle/phi/backends/dynload/lapack.h b/paddle/phi/backends/dynload/lapack.h index 75fc8fd9a3..c81c66c692 100644 --- a/paddle/phi/backends/dynload/lapack.h +++ b/paddle/phi/backends/dynload/lapack.h @@ -20,8 +20,8 @@ limitations under the License. */ #include "paddle/phi/backends/dynload/dynamic_loader.h" #include "paddle/phi/backends/dynload/port.h" -// Note(zhouwei): because lapack doesn't provide appropriate header file. -// should expose API statement yourself. +// Because lapack doesn't provide appropriate header file, +// we should expose API statement yourself. // getrf_(For example) extern "C" void dgetrf_( diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index ff73829c47..641956c4d9 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -274,6 +274,60 @@ void HuberLossInferMeta(const MetaTensor& input, out->share_lod(input); } +void CholeskySolveInferMeta(const MetaTensor& x, + const MetaTensor& y, + bool upper, + MetaTensor* out) { + auto x_dims = x.dims(); + auto y_dims = y.dims(); + + auto x_dims_n = x_dims.size(); + auto y_dims_n = y_dims.size(); + + PADDLE_ENFORCE_GE(x_dims_n, + 2, + phi::errors::InvalidArgument( + "the rank of input Y must greater or equal to 2")); + PADDLE_ENFORCE_GE(y_dims_n, + 2, + phi::errors::InvalidArgument( + "the rank of input X must greater or equal to 2")); + PADDLE_ENFORCE_EQ( + y_dims[y_dims_n - 1], + y_dims[y_dims_n - 2], + phi::errors::InvalidArgument("input Matrix Y should be square matrix," + "But Got last shape of %ld x %ld", + y_dims[y_dims_n - 1], + y_dims[y_dims_n - 2])); + PADDLE_ENFORCE_EQ( + x_dims[x_dims_n - 2], + y_dims[y_dims_n - 2], + phi::errors::InvalidArgument("the first dim of Matrix X must be equal to " + "the fisrt dim of Matrix Y," + "But Got %ld and %ld", + x_dims[x_dims_n - 2], + y_dims[y_dims_n - 2])); + + std::vector x_dims_vec = phi::vectorize(x_dims); + std::vector y_dims_vec = phi::vectorize(y_dims); + + std::vector x_dims_vec_cut(x_dims_vec.begin(), x_dims_vec.end() - 2); + std::vector y_dims_vec_cut(y_dims_vec.begin(), y_dims_vec.end() - 2); + + std::vector expand_batch_portion = + funcs::MatrixGetBroadcastBatchPortion(x_dims_vec_cut, y_dims_vec_cut); + + std::vector x_broadcast_dims({expand_batch_portion}); + x_broadcast_dims.insert(x_broadcast_dims.end(), + {x_dims_vec[x_dims_n - 2], x_dims_vec[x_dims_n - 1]}); + + // dim of 'out' is the same with 'X' after broadcast + out->set_dims(phi::make_ddim(x_broadcast_dims)); + out->set_dtype(x.dtype()); + out->set_layout(x.layout()); + out->share_lod(x); +} + void TriangularSolveInferMeta(const MetaTensor& x, const MetaTensor& y, bool upper, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index bc5cb887f2..d2b16e557b 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -62,6 +62,11 @@ void HuberLossInferMeta(const MetaTensor& input_meta, MetaTensor* residual, MetaConfig config = MetaConfig()); +void CholeskySolveInferMeta(const MetaTensor& x, + const MetaTensor& y, + bool upper, + MetaTensor* out); + void TriangularSolveInferMeta(const MetaTensor& x, const MetaTensor& y, bool upper, diff --git a/paddle/phi/kernels/cholesky_solve_grad_kernel.h b/paddle/phi/kernels/cholesky_solve_grad_kernel.h new file mode 100644 index 0000000000..e2ce67abae --- /dev/null +++ b/paddle/phi/kernels/cholesky_solve_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CholeskySolveGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + bool upper, + DenseTensor* dx, + DenseTensor* dy); + +} // namespace phi diff --git a/paddle/phi/kernels/cholesky_solve_kernel.h b/paddle/phi/kernels/cholesky_solve_kernel.h new file mode 100644 index 0000000000..b304a20e61 --- /dev/null +++ b/paddle/phi/kernels/cholesky_solve_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void CholeskySolveKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool upper, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/cholesky_solve_grad_kernel.cc b/paddle/phi/kernels/cpu/cholesky_solve_grad_kernel.cc new file mode 100644 index 0000000000..b6f5dd29ba --- /dev/null +++ b/paddle/phi/kernels/cpu/cholesky_solve_grad_kernel.cc @@ -0,0 +1,25 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(cholesky_solve_grad, + CPU, + ALL_LAYOUT, + phi::CholeskySolveGradKernel, + float, + double) {} diff --git a/paddle/phi/kernels/cpu/cholesky_solve_kernel.cc b/paddle/phi/kernels/cpu/cholesky_solve_kernel.cc new file mode 100644 index 0000000000..02597560a7 --- /dev/null +++ b/paddle/phi/kernels/cpu/cholesky_solve_kernel.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" + +namespace phi { + +template +class CholeskySolveFunctor { + public: + void operator()(const CPUContext &dev_ctx, + bool upper, + int M, + int N, + T *Adata, + int lda, + T *Bdata, + int *devInfo) { + char uplo = upper ? 'U' : 'L'; + funcs::lapackCholeskySolve(uplo, M, N, Adata, lda, Bdata, lda, devInfo); + } +}; + +} // namespace phi + +PD_REGISTER_KERNEL( + cholesky_solve, CPU, ALL_LAYOUT, phi::CholeskySolveKernel, float, double) {} diff --git a/paddle/phi/kernels/funcs/lapack/CMakeLists.txt b/paddle/phi/kernels/funcs/lapack/CMakeLists.txt index ffff5ae8ab..1a53470b2e 100644 --- a/paddle/phi/kernels/funcs/lapack/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/lapack/CMakeLists.txt @@ -1 +1 @@ -math_library(lapack_function DEPS dynload_lapack) +math_library(lapack_function DEPS phi_dynload_lapack) diff --git a/paddle/phi/kernels/funcs/lapack/lapack_function.cc b/paddle/phi/kernels/funcs/lapack/lapack_function.cc index 0407b8fd48..0f887dce4b 100644 --- a/paddle/phi/kernels/funcs/lapack/lapack_function.cc +++ b/paddle/phi/kernels/funcs/lapack/lapack_function.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/funcs/lapack/lapack_function.h" -#include "paddle/fluid/platform/dynload/lapack.h" +#include "paddle/phi/backends/dynload/lapack.h" #include "paddle/phi/common/complex.h" namespace phi { @@ -22,12 +22,12 @@ namespace funcs { // LU (for example) template <> void lapackLu(int m, int n, double *a, int lda, int *ipiv, int *info) { - paddle::platform::dynload::dgetrf_(&m, &n, a, &lda, ipiv, info); + dynload::dgetrf_(&m, &n, a, &lda, ipiv, info); } template <> void lapackLu(int m, int n, float *a, int lda, int *ipiv, int *info) { - paddle::platform::dynload::sgetrf_(&m, &n, a, &lda, ipiv, info); + dynload::sgetrf_(&m, &n, a, &lda, ipiv, info); } // eigh @@ -47,7 +47,7 @@ void lapackEigh(char jobz, int *info) { (void)rwork; // unused (void)lrwork; // unused - paddle::platform::dynload::ssyevd_( + dynload::ssyevd_( &jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); } @@ -67,7 +67,7 @@ void lapackEigh(char jobz, int *info) { (void)rwork; // unused (void)lrwork; // unused - paddle::platform::dynload::dsyevd_( + dynload::dsyevd_( &jobz, &uplo, &n, a, &lda, w, work, &lwork, iwork, &liwork, info); } @@ -86,20 +86,19 @@ void lapackEigh, float>( int *iwork, int liwork, int *info) { - paddle::platform::dynload::cheevd_( - &jobz, - &uplo, - &n, - reinterpret_cast *>(a), - &lda, - w, - reinterpret_cast *>(work), - &lwork, - rwork, - &lrwork, - iwork, - &liwork, - info); + dynload::cheevd_(&jobz, + &uplo, + &n, + reinterpret_cast *>(a), + &lda, + w, + reinterpret_cast *>(work), + &lwork, + rwork, + &lrwork, + iwork, + &liwork, + info); } template <> @@ -117,20 +116,19 @@ void lapackEigh, double>( int *iwork, int liwork, int *info) { - paddle::platform::dynload::zheevd_( - &jobz, - &uplo, - &n, - reinterpret_cast *>(a), - &lda, - w, - reinterpret_cast *>(work), - &lwork, - rwork, - &lrwork, - iwork, - &liwork, - info); + dynload::zheevd_(&jobz, + &uplo, + &n, + reinterpret_cast *>(a), + &lda, + w, + reinterpret_cast *>(work), + &lwork, + rwork, + &lrwork, + iwork, + &liwork, + info); } // Eig @@ -152,20 +150,20 @@ void lapackEig(char jobvl, double *wr = w; double *wi = w + n; (void)rwork; // unused - paddle::platform::dynload::dgeev_(&jobvl, - &jobvr, - &n, - a, - &lda, - wr, - wi, - vl, - &ldvl, - vr, - &ldvr, - work, - &lwork, - info); + dynload::dgeev_(&jobvl, + &jobvr, + &n, + a, + &lda, + wr, + wi, + vl, + &ldvl, + vr, + &ldvr, + work, + &lwork, + info); } template <> @@ -186,20 +184,20 @@ void lapackEig(char jobvl, float *wr = w; float *wi = w + n; (void)rwork; // unused - paddle::platform::dynload::sgeev_(&jobvl, - &jobvr, - &n, - a, - &lda, - wr, - wi, - vl, - &ldvl, - vr, - &ldvr, - work, - &lwork, - info); + dynload::sgeev_(&jobvl, + &jobvr, + &n, + a, + &lda, + wr, + wi, + vl, + &ldvl, + vr, + &ldvr, + work, + &lwork, + info); } template <> @@ -218,21 +216,20 @@ void lapackEig, double>( int lwork, double *rwork, int *info) { - paddle::platform::dynload::zgeev_( - &jobvl, - &jobvr, - &n, - reinterpret_cast *>(a), - &lda, - reinterpret_cast *>(w), - reinterpret_cast *>(vl), - &ldvl, - reinterpret_cast *>(vr), - &ldvr, - reinterpret_cast *>(work), - &lwork, - rwork, - info); + dynload::zgeev_(&jobvl, + &jobvr, + &n, + reinterpret_cast *>(a), + &lda, + reinterpret_cast *>(w), + reinterpret_cast *>(vl), + &ldvl, + reinterpret_cast *>(vr), + &ldvr, + reinterpret_cast *>(work), + &lwork, + rwork, + info); } template <> @@ -251,21 +248,20 @@ void lapackEig, float>( int lwork, float *rwork, int *info) { - paddle::platform::dynload::cgeev_( - &jobvl, - &jobvr, - &n, - reinterpret_cast *>(a), - &lda, - reinterpret_cast *>(w), - reinterpret_cast *>(vl), - &ldvl, - reinterpret_cast *>(vr), - &ldvr, - reinterpret_cast *>(work), - &lwork, - rwork, - info); + dynload::cgeev_(&jobvl, + &jobvr, + &n, + reinterpret_cast *>(a), + &lda, + reinterpret_cast *>(w), + reinterpret_cast *>(vl), + &ldvl, + reinterpret_cast *>(vr), + &ldvr, + reinterpret_cast *>(work), + &lwork, + rwork, + info); } template <> @@ -280,8 +276,7 @@ void lapackGels(char trans, double *work, int lwork, int *info) { - paddle::platform::dynload::dgels_( - &trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); + dynload::dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); } template <> @@ -296,8 +291,7 @@ void lapackGels(char trans, float *work, int lwork, int *info) { - paddle::platform::dynload::sgels_( - &trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); + dynload::sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info); } template <> @@ -316,20 +310,20 @@ void lapackGelsd(int m, double *rwork, int *iwork, int *info) { - paddle::platform::dynload::dgelsd_(&m, - &n, - &nrhs, - a, - &lda, - b, - &ldb, - s, - &rcond, - rank, - work, - &lwork, - iwork, - info); + dynload::dgelsd_(&m, + &n, + &nrhs, + a, + &lda, + b, + &ldb, + s, + &rcond, + rank, + work, + &lwork, + iwork, + info); } template <> @@ -348,20 +342,20 @@ void lapackGelsd(int m, float *rwork, int *iwork, int *info) { - paddle::platform::dynload::sgelsd_(&m, - &n, - &nrhs, - a, - &lda, - b, - &ldb, - s, - &rcond, - rank, - work, - &lwork, - iwork, - info); + dynload::sgelsd_(&m, + &n, + &nrhs, + a, + &lda, + b, + &ldb, + s, + &rcond, + rank, + work, + &lwork, + iwork, + info); } template <> @@ -379,7 +373,7 @@ void lapackGelsy(int m, int lwork, double *rwork, int *info) { - paddle::platform::dynload::dgelsy_( + dynload::dgelsy_( &m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, rank, work, &lwork, info); } @@ -398,7 +392,7 @@ void lapackGelsy(int m, int lwork, float *rwork, int *info) { - paddle::platform::dynload::sgelsy_( + dynload::sgelsy_( &m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, rank, work, &lwork, info); } @@ -417,7 +411,7 @@ void lapackGelss(int m, int lwork, double *rwork, int *info) { - paddle::platform::dynload::dgelss_( + dynload::dgelss_( &m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, work, &lwork, info); } @@ -436,7 +430,7 @@ void lapackGelss(int m, int lwork, float *rwork, int *info) { - paddle::platform::dynload::sgelss_( + dynload::sgelss_( &m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, work, &lwork, info); } @@ -450,15 +444,14 @@ void lapackCholeskySolve>( phi::dtype::complex *b, int ldb, int *info) { - paddle::platform::dynload::zpotrs_( - &uplo, - &n, - &nrhs, - reinterpret_cast *>(a), - &lda, - reinterpret_cast *>(b), - &ldb, - info); + dynload::zpotrs_(&uplo, + &n, + &nrhs, + reinterpret_cast *>(a), + &lda, + reinterpret_cast *>(b), + &ldb, + info); } template <> @@ -471,14 +464,14 @@ void lapackCholeskySolve>( phi::dtype::complex *b, int ldb, int *info) { - paddle::platform::dynload::cpotrs_(&uplo, - &n, - &nrhs, - reinterpret_cast *>(a), - &lda, - reinterpret_cast *>(b), - &ldb, - info); + dynload::cpotrs_(&uplo, + &n, + &nrhs, + reinterpret_cast *>(a), + &lda, + reinterpret_cast *>(b), + &ldb, + info); } template <> @@ -490,7 +483,7 @@ void lapackCholeskySolve(char uplo, double *b, int ldb, int *info) { - paddle::platform::dynload::dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); + dynload::dpotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); } template <> @@ -502,7 +495,7 @@ void lapackCholeskySolve(char uplo, float *b, int ldb, int *info) { - paddle::platform::dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); + dynload::spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info); } } // namespace funcs diff --git a/paddle/phi/kernels/gpu/cholesky_solve_grad_kernel.cu b/paddle/phi/kernels/gpu/cholesky_solve_grad_kernel.cu new file mode 100644 index 0000000000..82b1282cc3 --- /dev/null +++ b/paddle/phi/kernels/gpu/cholesky_solve_grad_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP +// backward reuse forward, HIP not support forward + +#include "paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_REGISTER_KERNEL(cholesky_solve_grad, // cuda_only + GPU, + ALL_LAYOUT, + phi::CholeskySolveGradKernel, + float, + double) {} + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/gpu/cholesky_solve_kernel.cu b/paddle/phi/kernels/gpu/cholesky_solve_kernel.cu new file mode 100644 index 0000000000..f1c91f3824 --- /dev/null +++ b/paddle/phi/kernels/gpu/cholesky_solve_kernel.cu @@ -0,0 +1,141 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PADDLE_WITH_HIP +// HIP not support cusolver + +#include "paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h" + +#include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/backends/dynload/cusolver.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/complex.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/lapack/lapack_function.h" + +namespace phi { + +template +void cusolver_potrs(const solverHandle_t &handle, + cublasFillMode_t uplo, + int M, + int N, + T *Adata, + int lda, + T *Bdata, + int ldb, + int *devInfo); + +template <> +void cusolver_potrs(const solverHandle_t &handle, + cublasFillMode_t uplo, + int M, + int N, + float *Adata, + int lda, + float *Bdata, + int ldb, + int *devInfo) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSpotrs( + handle, uplo, M, N, Adata, lda, Bdata, ldb, devInfo)); +} + +template <> +void cusolver_potrs(const solverHandle_t &handle, + cublasFillMode_t uplo, + int M, + int N, + double *Adata, + int lda, + double *Bdata, + int ldb, + int *devInfo) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDpotrs( + handle, uplo, M, N, Adata, lda, Bdata, ldb, devInfo)); +} + +template <> +void cusolver_potrs>( + const solverHandle_t &handle, + cublasFillMode_t uplo, + int M, + int N, + phi::dtype::complex *Adata, + int lda, + phi::dtype::complex *Bdata, + int ldb, + int *devInfo) { + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cusolverDnCpotrs(handle, + uplo, + M, + N, + reinterpret_cast(Adata), + lda, + reinterpret_cast(Bdata), + ldb, + devInfo)); +} + +template <> +void cusolver_potrs>( + const cusolverDnHandle_t &handle, + cublasFillMode_t uplo, + int M, + int N, + phi::dtype::complex *Adata, + int lda, + phi::dtype::complex *Bdata, + int ldb, + int *devInfo) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnZpotrs( + handle, + uplo, + M, + N, + reinterpret_cast(Adata), + lda, + reinterpret_cast(Bdata), + ldb, + devInfo)); +} + +template +class CholeskySolveFunctor { + public: + void operator()(const GPUContext &dev_ctx, + bool upper, + int M, + int N, + T *Adata, + int lda, + T *Bdata, + int *devInfo) { + cublasFillMode_t uplo = + upper ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + auto handle = dev_ctx.cusolver_dn_handle(); + cusolver_potrs(handle, uplo, M, N, Adata, lda, Bdata, lda, devInfo); + } +}; + +} // namespace phi + +PD_REGISTER_KERNEL(cholesky_solve, // cuda_only + GPU, + ALL_LAYOUT, + phi::CholeskySolveKernel, + float, + double) {} + +#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h b/paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h new file mode 100644 index 0000000000..9f557e7463 --- /dev/null +++ b/paddle/phi/kernels/impl/cholesky_solve_grad_kernel_impl.h @@ -0,0 +1,134 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/cholesky_solve_grad_kernel.h" + +#include "paddle/phi/kernels/cholesky_solve_kernel.h" +#include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/funcs/complex_functors.h" +#include "paddle/phi/kernels/funcs/matrix_reduce.h" +#include "paddle/phi/kernels/math_kernel.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/operators/tril_triu_op.h" + +namespace phi { + +template +void CholeskySolveGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const DenseTensor& out, + const DenseTensor& dout, + bool upper, + DenseTensor* dx, + DenseTensor* dy) { + // get broadcast dim + std::vector x_bst_dims_vec; + std::vector y_bst_dims_vec; + std::tie(x_bst_dims_vec, y_bst_dims_vec) = + funcs::MatrixGetBroadcastDims(x, y); + ScalarArray x_bst_dims(x_bst_dims_vec); + ScalarArray y_bst_dims(y_bst_dims_vec); + + // Tensor broadcast to temp 'y_bst' + DenseTensor y_bst = phi::Empty(dev_ctx, y_bst_dims); + ExpandKernel(dev_ctx, y, y_bst_dims, &y_bst); + + // reuse forward to calculate dx_bst, which is broad_cast of dx + DenseTensor dx_bst = phi::Empty(dev_ctx, x_bst_dims); + CholeskySolveKernel(dev_ctx, dout, y_bst, upper, &dx_bst); + + // get 'dx' according to 'dx_bst' + dx->Resize(x.dims()); + dev_ctx.template Alloc(dx); + if (dx_bst.dims() == x.dims()) { + Copy(dev_ctx, dx_bst, dev_ctx.GetPlace(), false, dx); + } else { + funcs::MatrixReduceSumFunctor functor; + functor(dev_ctx, dx_bst, dx); + dx->Resize(x.dims()); + } + + // calculate out's conjugate for complex + DenseTensor out_conj = Conj(dev_ctx, out); + out_conj = phi::TransposeLast2Dim(dev_ctx, out_conj); + + DenseTensor commonterm = phi::Empty(dev_ctx, y_bst_dims); + auto blas = phi::funcs::GetBlas(dev_ctx); + blas.MatMul(dx_bst, + phi::funcs::CreateMatrixDescriptor(dx_bst.dims(), 0, false), + out_conj, + phi::funcs::CreateMatrixDescriptor(out_conj.dims(), 0, false), + static_cast(1), + &commonterm, + static_cast(0)); + + // calculate commonterm's conjugate for complex + DenseTensor commonterm_conj = Conj(dev_ctx, commonterm); + commonterm_conj = phi::TransposeLast2Dim(dev_ctx, commonterm_conj); + + phi::AddRawKernel(dev_ctx, commonterm, commonterm_conj, -1, &commonterm); + + DenseTensor dy_bst = phi::Empty(dev_ctx, y_bst_dims); + if (upper) { + blas.MatMul(y_bst, + phi::funcs::CreateMatrixDescriptor(y_bst.dims(), 0, false), + commonterm, + phi::funcs::CreateMatrixDescriptor(commonterm.dims(), 0, false), + static_cast(-1), + &dy_bst, + static_cast(0)); + } else { + blas.MatMul(commonterm, + phi::funcs::CreateMatrixDescriptor(commonterm.dims(), 0, false), + y_bst, + phi::funcs::CreateMatrixDescriptor(y_bst.dims(), 0, false), + static_cast(-1), + &dy_bst, + static_cast(0)); + } + + // get upper or lower of 'dy_bst' + DenseTensor dy_bst_upper = phi::Empty(dev_ctx, y_bst_dims); + + int y_bst_ndim = y_bst_dims_vec.size(); + const auto H = y_bst_dims_vec[y_bst_ndim - 2]; + const auto W = y_bst_dims_vec[y_bst_ndim - 1]; + phi::funcs::ForRange y_for_range(dev_ctx, dy_bst.numel()); + paddle::operators::TrilTriuCompute tril_triu_functor( + dy_bst.data(), 0, !upper, H, W, dy_bst_upper.data()); + y_for_range(tril_triu_functor); + + // get 'dy' according to 'dy_bst' + dy->Resize(y.dims()); + dev_ctx.template Alloc(dy); + if (dy_bst_upper.dims() == y.dims()) { + Copy(dev_ctx, dy_bst_upper, dev_ctx.GetPlace(), false, dy); + } else { + funcs::MatrixReduceSumFunctor functor; + functor(dev_ctx, dy_bst_upper, dy); + dy->Resize(y.dims()); + } +} + +} // namespace phi diff --git a/paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h b/paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h new file mode 100644 index 0000000000..16ceb776f1 --- /dev/null +++ b/paddle/phi/kernels/impl/cholesky_solve_kernel_impl.h @@ -0,0 +1,104 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/cholesky_solve_kernel.h" + +#include "paddle/phi/kernels/complex_kernel.h" +#include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/empty_kernel.h" +#include "paddle/phi/kernels/expand_kernel.h" +#include "paddle/phi/kernels/funcs/common_shape.h" +#include "paddle/phi/kernels/transpose_kernel.h" + +namespace phi { + +template +class CholeskySolveFunctor { + public: + void operator()(const Context& dev_ctx, + bool upper, + int M, + int N, + T* Adata, + int lda, + T* Bdata, + int* devInfo); +}; + +template +void CholeskySolveKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + bool upper, + DenseTensor* out) { + // get broadcast dim + std::vector x_bst_dims_vec; + std::vector y_bst_dims_vec; + std::tie(x_bst_dims_vec, y_bst_dims_vec) = + funcs::MatrixGetBroadcastDims(x, y); + ScalarArray x_bst_dims(x_bst_dims_vec); + ScalarArray y_bst_dims(y_bst_dims_vec); + + DenseTensor y_bst = phi::Empty(dev_ctx, y_bst_dims); + ExpandKernel(dev_ctx, y, y_bst_dims, &y_bst); + + // Tensor broadcast to temp 'x_bst' and 'y_bst' + DenseTensor x_bst = phi::Empty(dev_ctx, x_bst_dims); + ExpandKernel(dev_ctx, x, x_bst_dims, &x_bst); + + // calculate y_bst's conjugate for complex + DenseTensor y_bst_conj = Conj(dev_ctx, y_bst); + y_bst_conj = phi::TransposeLast2Dim(dev_ctx, y_bst_conj); + T* y_bst_conj_data = y_bst_conj.data(); + + // calculate x_bst's conjugate for complex + DenseTensor x_bst_conj = Conj(dev_ctx, x_bst); + x_bst_conj = phi::TransposeLast2Dim(dev_ctx, x_bst_conj); + + // copy x_bst's conjugate to 'result' + DenseTensor result; + Copy(dev_ctx, x_bst_conj, dev_ctx.GetPlace(), false, &result); + T* res_data = result.data(); + + // CPU use lapack, GPU use cusolver + int x_bst_ndim = x_bst_dims_vec.size(); + int M = static_cast(x_bst_dims_vec[x_bst_ndim - 2]); + int N = static_cast(x_bst_dims_vec[x_bst_ndim - 1]); + int batchsize = product(phi::slice_ddim(x_bst.dims(), 0, x_bst_ndim - 2)); + + DenseTensor info = + phi::Empty(dev_ctx, ScalarArray({batchsize})); + int* info_data = info.data(); + + CholeskySolveFunctor functor; + for (int i = 0; i < batchsize; ++i) { + functor(dev_ctx, + upper, + M, + N, + y_bst_conj_data + i * M * M, + std::max(1, M), + res_data + i * M * N, + info_data + i); + } + + // calculate out's conjugate for complex + result = phi::TransposeLast2Dim(dev_ctx, result); + out->Resize(phi::make_ddim(x_bst_dims_vec)); + ConjKernel(dev_ctx, result, out); +} + +} // namespace phi diff --git a/paddle/phi/ops/compat/cholesky_solve_sig.cc b/paddle/phi/ops/compat/cholesky_solve_sig.cc new file mode 100644 index 0000000000..6a9759f835 --- /dev/null +++ b/paddle/phi/ops/compat/cholesky_solve_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature CholeskySolveGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("cholesky_solve_grad", + {"X", "Y", "Out", GradVarName("Out")}, + {"upper"}, + {GradVarName("X"), GradVarName("Y")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(cholesky_solve_grad, + phi::CholeskySolveGradOpArgumentMapping); -- GitLab