diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index d641c5c135499f836b4b50ab1eb494cbfc767fa2..0d7d0a5e13bf3d0aab33a2a4af6925b926bb8af5 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -123,7 +123,7 @@ lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling segment_pooling executor device_memory_aligment generator) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col) -set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} box_wrapper boost ps_gpu_wrapper) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} common_infer_shape_functions) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} eigen_function) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index d8d79bc4086bb7c998aa7071d4ac4703403035d3..6177ec749ac031a0749f3d5c3b090015c48d0bf7 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -89,6 +89,7 @@ math_library(bert_encoder_functor) math_library(tree2col DEPS math_function) math_library(matrix_inverse) math_library(segment_pooling) +math_library(matrix_solve) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function) cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index bbf7516c538fc3504777af62709221f167f3e0e8..6546f854df0f4ca7f1e08f3f178ac5c836633312 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -247,6 +247,12 @@ class Blas { template void BatchedMatInv(int n, const T** a, T** a_inv, int* info, int batch_size) const; + + // cuBlas solve + template + void BatchedGETRS(CBLAS_TRANSPOSE trans, int n, int nrhs, const T** a, + int lda, int* ipiv, T** b, int ldb, int* info, + int batch_size) const; #endif private: @@ -402,6 +408,12 @@ class BlasT : private Blas { void BatchedMatInv(ARGS... args) const { Base()->template BatchedMatInv(args...); } + + // solve + template + void BatchedGETRS(ARGS... args) const { + Base()->template BatchedGETRS(args...); + } #endif private: diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 477f3e0f6a2dc5cfd6fcc0b0624f8f0c2563fe8b..6f83faf1e40d865c6435dcd1fe7dfaab7693dc02 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -114,6 +114,12 @@ struct CUBlas { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cublasSmatinvBatched(args...)); } + + template + static void GETRS_BATCH(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasSgetrsBatched(args...)); + } }; template <> @@ -182,6 +188,12 @@ struct CUBlas { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cublasDmatinvBatched(args...)); } + + template + static void GETRS_BATCH(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::cublasDgetrsBatched(args...)); + } }; template <> @@ -871,6 +883,20 @@ void Blas::BatchedMatInv(int n, const T **a, }); } +template <> +template +void Blas::BatchedGETRS( + CBLAS_TRANSPOSE trans, int n, int nrhs, const T **a, int lda, int *ipiv, + T **b, int ldb, int *info, int batch_size) const { + // use CUBLAS_OP_C (conjugate transpose) for complex + cublasOperation_t cuTrans = + (trans == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, + batch_size); + }); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/blas_impl.hip.h b/paddle/fluid/operators/math/blas_impl.hip.h index 788ebc6ad985c5fb6e6667220713783f014d2a62..1ce5bac5242ab872cb3ef423c9ae7940ad38db38 100644 --- a/paddle/fluid/operators/math/blas_impl.hip.h +++ b/paddle/fluid/operators/math/blas_impl.hip.h @@ -717,6 +717,19 @@ void Blas::BatchedMatInv(int n, const T **a, }); } +template <> +template +void Blas::BatchedGETRS( + CBLAS_TRANSPOSE trans, int n, int nrhs, const T **a, int lda, int *ipiv, + T **b, int ldb, int *info, int batch_size) const { + rocblas_operation cuTrans = (trans == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::GETRS_BATCH(handle, cuTrans, n, nrhs, a, lda, ipiv, b, ldb, info, + batch_size); + }); +} } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_solve.cc b/paddle/fluid/operators/math/matrix_solve.cc new file mode 100644 index 0000000000000000000000000000000000000000..7f13b5c8a70eef7b33e2776d3314bcf18c972dad --- /dev/null +++ b/paddle/fluid/operators/math/matrix_solve.cc @@ -0,0 +1,39 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/matrix_solve.h" +#include "Eigen/Core" +#include "Eigen/LU" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace operators { +namespace math { + +template +class MatrixSolveFunctor { + public: + void operator()(const platform::CPUDeviceContext& dev_ctx, + const framework::Tensor& a, const framework::Tensor& b, + framework::Tensor* out) { + compute_solve_eigen(dev_ctx, a, b, out); + } +}; + +template class MatrixSolveFunctor; +template class MatrixSolveFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_solve.cu.cc b/paddle/fluid/operators/math/matrix_solve.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..efb3a07e4c1b47d649642f630c1e9adc49a9598c --- /dev/null +++ b/paddle/fluid/operators/math/matrix_solve.cu.cc @@ -0,0 +1,168 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/matrix_solve.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/solve_op.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace platform { +class CUDADeviceContext; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace operators { +namespace math { + +template +class MatrixSolveFunctor; + +template +class MatrixSolveFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& a, const framework::Tensor& b, + framework::Tensor* out) { +#ifndef PADDLE_WITH_HIP + + // solve the equation: Ax = B, + // use cuBlas cublasgetrfBatched funcion to performs the LU + // factorization of each matrix A, + // and then use cuBlas cublasgetriBatched function to solve the + // equation after LU factorization. + // ref: + // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched + const auto& a_dims = a.dims(); + const int a_rank = a_dims.size(); + int n = a_dims[a_rank - 1]; + int lda = n; + int batch_size = a_rank > 2 ? a.numel() / (n * n) : 1; + + const auto& b_dims = b.dims(); + const int b_rank = b_dims.size(); + int nrhs = b_dims[b_rank - 1]; + int ldb = b_dims[b_rank - 2]; + + // make sure the out dims is right + out->Resize(b_dims); + out->mutable_data(context.GetPlace()); + + // copy input A to a temporary tensor tmp_a, + // LU factorization, written back to original matrix A, so in the beginning, + // it's necessary to create a temporary tensor tmp_a. + Tensor tmp_a(a.type()); + tmp_a.Resize(a.dims()); + tmp_a.mutable_data(context.GetPlace()); + TensorCopy(a, context.GetPlace(), &tmp_a); + + // copy input B to a temporary tensor tmp_b, and transpose tmp_b, + // because cuBlas assumes column-major while Paddle uses row-majar. + Tensor tmp_b(b.type()); + const auto& new_dims_vec = getNewDimsVec(b_dims); + tmp_b.Resize(framework::make_ddim(new_dims_vec)); + tmp_b.mutable_data(context.GetPlace()); + math::TransposeNormal trans; + std::vector new_axis = getNewAxis(b_rank); + trans(context, b, &tmp_b, new_axis); + + const T* a_data_in_gpu = tmp_a.data(); + const T* b_data_in_gpu = tmp_b.data(); + + std::vector cpu_ptrs(batch_size * 2); + for (int i = 0; i < batch_size; ++i) { + cpu_ptrs[i] = a_data_in_gpu + i * n * n; + cpu_ptrs[i + batch_size] = b_data_in_gpu + i * n * nrhs; + } + + // Copy the addresses of A and tmp_b from host to device. + memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + memory::Alloc(context, cpu_ptrs.size() * sizeof(T*)); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), + tmp_gpu_ptrs_data->ptr(), platform::CPUPlace(), + static_cast(cpu_ptrs.data()), + cpu_ptrs.size() * sizeof(T*), context.stream()); + + T** gpu_tmp_b_ptrs = + reinterpret_cast(tmp_gpu_ptrs_data->ptr()) + batch_size; + + // Allocate device memory for BatchedGETRF's info and pivots. + int num_ints = n < 32 ? batch_size : batch_size * (n + 1); + memory::allocation::AllocationPtr tmp_gpu_info_data = + memory::Alloc(context, num_ints * sizeof(int)); + int* gpu_info_ptr = reinterpret_cast(tmp_gpu_info_data->ptr()); + + auto blas = math::GetBlas(context); + + // only for singular checking + std::vector info; + info.resize(batch_size); + + int* gpu_pivot_ptr = + reinterpret_cast(tmp_gpu_info_data->ptr()) + batch_size; + + // This function performs the LU factorization of each matrix A by the + // equation A = L * U. L and U are written back to original matrix A, + // and diagonal elements of L are discarded. + blas.BatchedGETRF(n, reinterpret_cast(tmp_gpu_ptrs_data->ptr()), + gpu_pivot_ptr, gpu_info_ptr, batch_size); + + // check whether BatchedGETRF is executed successfully or not + memory::Copy(platform::CPUPlace(), info.data(), + BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), + gpu_info_ptr, sizeof(int) * batch_size, context.stream()); + for (int i = 0; i < batch_size; ++i) { + PADDLE_ENFORCE_EQ(info[i], 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: U(%d, %d) is zero, singular U. " + "Please check the matrix value and change it to a " + "non-singular matrix", + i, info[i], info[i])); + } + + // hold the result code from BatchedGETRS + int host_info = 0; + + // to solve the equation after LU factorization + CBLAS_TRANSPOSE transA = CblasTrans; + blas.BatchedGETRS( + transA, n, nrhs, reinterpret_cast(tmp_gpu_ptrs_data->ptr()), + lda, gpu_pivot_ptr, gpu_tmp_b_ptrs, ldb, &host_info, batch_size); + + // check whether BatchedGETRS is executed successfully or not + PADDLE_ENFORCE_EQ(host_info, 0, + platform::errors::InvalidArgument( + "The [%d]'th argument to cublas*getrsBatched had " + "an illegal value.", + -host_info)); + + // transpose tmp_b to get the final result in row-major form. + math::TransposeNormal trans2; + trans2(context, tmp_b, out, new_axis); + +#else + compute_solve_eigen(context, a, b, out); +#endif + } +}; + +template class MatrixSolveFunctor; +template class MatrixSolveFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_solve.h b/paddle/fluid/operators/math/matrix_solve.h new file mode 100644 index 0000000000000000000000000000000000000000..93c37ae425640f2b2c4c4c147e9e593438b7c8ed --- /dev/null +++ b/paddle/fluid/operators/math/matrix_solve.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "Eigen/Core" +#include "Eigen/LU" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +template +void compute_solve_eigen(const DeviceContext& context, + const framework::Tensor& a, const framework::Tensor& b, + framework::Tensor* out) { + using Matrix = + Eigen::Matrix; + using EigenMatrixMap = Eigen::Map; + using ConstEigenMatrixMap = Eigen::Map; + // prepare for a + const auto& a_mat_dims = a.dims(); + const int a_rank = a_mat_dims.size(); + int n = a_mat_dims[a_rank - 1]; + int a_batch_size = a_rank > 2 ? a.numel() / (n * n) : 1; + + // prepare for b + const auto& b_mat_dims = b.dims(); + const int b_rank = b_mat_dims.size(); + int b_h = n; + int b_w = b_mat_dims[b_rank - 1]; + int b_batch_size = b_rank > 2 ? b.numel() / (b_h * b_w) : 1; + + const T* a_ptr = a.data(); + const T* b_ptr = b.data(); + out->Resize(b_mat_dims); // make sure the out dims is right + + T* out_ptr = out->mutable_data(context.GetPlace()); + if (a_batch_size == b_batch_size) { + for (int i = 0; i < a_batch_size; ++i) { + ConstEigenMatrixMap a_mat(a_ptr + i * n * n, n, n); + ConstEigenMatrixMap b_mat(b_ptr + i * b_h * b_w, b_h, b_w); + EigenMatrixMap out_mat(out_ptr + i * b_h * b_w, b_h, b_w); + Eigen::PartialPivLU lu; + lu.compute(a_mat); + const T min_abs_pivot = lu.matrixLU().diagonal().cwiseAbs().minCoeff(); + PADDLE_ENFORCE_GT( + min_abs_pivot, static_cast(0), + platform::errors::InvalidArgument("Input is not invertible.")); + out_mat.noalias() = lu.solve(b_mat); + } + } else { + PADDLE_ENFORCE_EQ(a_batch_size, b_batch_size, + platform::errors::InvalidArgument( + "All input tensors must have the same rank.")); + } +} + +template +class MatrixSolveFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor& a, + const framework::Tensor& b, framework::Tensor* out); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/solve_op.cc b/paddle/fluid/operators/solve_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6e89eec7493dacaa1152f3ae530a49bebdded376 --- /dev/null +++ b/paddle/fluid/operators/solve_op.cc @@ -0,0 +1,216 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/solve_op.h" +#include +#include +#include +#include +#include "paddle/fluid/framework/ddim.h" + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class SolveOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Solve"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "Solve"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Solve"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + std::vector x_dims_vec = + paddle::framework::vectorize(ctx->GetInputDim("X")); + std::vector y_dims_vec = + paddle::framework::vectorize(ctx->GetInputDim("Y")); + + auto x_dims_n = x_dims_vec.size(); + auto y_dims_n = y_dims_vec.size(); + + PADDLE_ENFORCE_GT(x_dims_n, 1, + platform::errors::InvalidArgument( + "The input tensor X's dimensions of SolveOp " + "should be larger than 1. But received X's " + "dimensions = %d, X's shape = [%s]", + x_dims_n, x_dims)); + + PADDLE_ENFORCE_GE(y_dims_n, 1, + platform::errors::InvalidArgument( + "The input tensor Y's dimensions of SolveOp " + "should be larger than or equal 1. But received Y's " + "dimensions = %d, Y's shape = [%s]", + y_dims_n, y_dims)); + + PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2], x_dims[x_dims_n - 1], + platform::errors::InvalidArgument( + "The inner-most 2 dimensions of Input(X) all should " + "be square matrices " + "But received X's shape[-2] = %d and shape[-1] = %d.", + x_dims[x_dims_n - 2], x_dims[x_dims_n - 1])); + + bool x_broadcasted = false, y_broadcasted = false; + bool trans_x = false, trans_y = false; + if (x_dims_n == 1) { + x_dims_vec.insert(x_dims_vec.begin(), 1); + x_dims_n = 2; + x_broadcasted = true; + } + + if (y_dims_n == 1) { + y_dims_vec.push_back(1); + y_dims_n = 2; + y_broadcasted = true; + } + + size_t M, N; + if (trans_x) { + M = x_dims_vec[x_dims_n - 1]; + } else { + M = x_dims_vec[x_dims_n - 2]; + } + if (trans_y) { + N = y_dims_vec[y_dims_n - 2]; + } else { + N = y_dims_vec[y_dims_n - 1]; + } + + std::vector new_dims; + if (x_dims_n >= y_dims_n) { + new_dims.assign(x_dims_vec.begin(), x_dims_vec.end() - 2); + } else { + new_dims.assign(y_dims_vec.begin(), y_dims_vec.end() - 2); + } + if (!x_broadcasted) { + new_dims.push_back(M); + } + if (!y_broadcasted) { + new_dims.push_back(N); + } + if (x_broadcasted && y_broadcasted) { + new_dims.push_back(1); + } + + auto out_dims = framework::make_ddim(new_dims); + ctx->SetOutputDim("Out", out_dims); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library, customized_type_value); + } +}; + +class SolveOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The first input tensor of solve op."); + AddInput("Y", "(Tensor), The second input tensor of solve op."); + AddOutput("Out", "(Tensor), The output tensor of solve op."); + AddComment(R"DOC( + Solve Operator. + This operator is used to computes the solution of a square system of + linear equations with a unique solution for input $X$ and $Y$. + + The equation is: + $$Out = X^-1 * Y$$ +)DOC"); + } +}; + +class SolveOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; + } +}; + +class SolveGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "solve"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "solve"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "solve"); + // reuse the linalg.solve forward output + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "solve"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +template +class SolveOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("solve_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput("Y", this->Input("Y")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + // reuse the linalg.solve forward output + retv->SetInput("Out", this->Output("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + retv->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(solve, ops::SolveOp, ops::SolveOpMaker, + ops::SolveOpInferVarType, + ops::SolveOpGradMaker, + ops::SolveOpGradMaker); + +REGISTER_OPERATOR(solve_grad, ops::SolveGradOp); + +REGISTER_OP_CPU_KERNEL( + solve, ops::SolveKernel, + ops::SolveKernel); +REGISTER_OP_CPU_KERNEL( + solve_grad, ops::SolveGradKernel, + ops::SolveGradKernel); diff --git a/paddle/fluid/operators/solve_op.cu b/paddle/fluid/operators/solve_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..2ca0bcdd7f68b485308bce908ab9d02faebd2d84 --- /dev/null +++ b/paddle/fluid/operators/solve_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/solve_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(solve, ops::SolveKernel, + ops::SolveKernel); + +REGISTER_OP_CUDA_KERNEL(solve_grad, + ops::SolveGradKernel, + ops::SolveGradKernel); diff --git a/paddle/fluid/operators/solve_op.h b/paddle/fluid/operators/solve_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d55c2647c1f3ad59143e4e92dbf002fa860d324e --- /dev/null +++ b/paddle/fluid/operators/solve_op.h @@ -0,0 +1,732 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "Eigen/Core" +#include "Eigen/LU" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/matrix_solve.h" +#include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" +#include "paddle/fluid/operators/squeeze_op.h" +#if defined(__NVCC__) || defined(__HIPCC__) +#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" +#endif + +#define MAX_RANK_SUPPORTED 6 + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using framework::To32BitIndex; + +constexpr int kMULMKLDNNINT8 = 1; + +struct IdentityFunctor { + HOSTDEVICE explicit inline IdentityFunctor() {} + + template + HOSTDEVICE inline U operator()(const U& x) const { + return x; + } +}; + +template +void ReduceSumForSolveGrad(const Tensor* input, Tensor* output, + const std::vector& reduce_dims, bool keep_dim, + const paddle::framework::ExecutionContext& ctx) { +#if defined(__NVCC__) || defined(__HIPCC__) + auto stream = ctx.cuda_device_context().stream(); + TensorReduce(*input, output, reduce_dims, + static_cast(0), cub::Sum(), + IdentityFunctor(), stream); +#else + ReduceKernelFunctor( + input, output, reduce_dims, keep_dim, false, ctx) + .template apply(); +#endif +} + +// check the input other is vector_case or not +static inline bool is_vector_rhs(const Tensor& input, const Tensor& other) { + auto x_dim = input.dims(); + auto y_dim = other.dims(); + auto x_dim_size = x_dim.size(); + auto y_dim_size = y_dim.size(); + std::vector x_dims_vec = paddle::framework::vectorize(x_dim); + std::vector y_dims_vec = paddle::framework::vectorize(y_dim); + + std::vector::const_iterator f = x_dims_vec.begin(); + std::vector::const_iterator l = x_dims_vec.end() - 1; + std::vector x_dims_vec_cut(f, l); // input.shape[:-1] + + std::vector expected_batched_rhs_shape(x_dims_vec_cut); + bool vector_case = + y_dim_size == 1 || (x_dim_size - 1 == y_dim_size && + y_dims_vec == (expected_batched_rhs_shape)); + + return vector_case; +} + +// unsqueeze operation helper +static framework::DDim GetOutputShapeUnsqueeze( + const std::vector unsqz_dims, const framework::DDim& in_dims) { + int output_size = in_dims.size() + static_cast(unsqz_dims.size()); + int cur_output_size = in_dims.size(); + std::vector output_shape(output_size, 0); + + // Validity Check: rank range. + PADDLE_ENFORCE_LE(output_size, 6, + platform::errors::InvalidArgument( + "The output " + "tensor's rank should be less than 6.")); + + for (int axis : unsqz_dims) { + int cur = axis < 0 ? axis + cur_output_size + 1 : axis; + // Vaildity Check: the axis bound + PADDLE_ENFORCE_GE(cur, 0, platform::errors::InvalidArgument( + "The insert dimension value should " + "not be less than 0")); + PADDLE_ENFORCE_LE(cur, cur_output_size, + platform::errors::InvalidArgument( + "The insert dimension value shoule not be larger " + "than the dimension size of input tensor")); + // Move old axis, and insert new axis + for (int i = cur_output_size; i >= cur; --i) { + if (output_shape[i] == 1) { + // Move axis + output_shape[i + 1] = 1; + output_shape[i] = 0; + } + } + output_shape[cur] = 1; + // Add the output size. + cur_output_size++; + } + + // Make output shape + for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) { + if (output_shape[out_idx] == 0) { + output_shape[out_idx] = in_dims[in_idx++]; + } + } + + return framework::make_ddim(output_shape); +} + +// operation like squeeze(-1) +static void to_squeeze(const framework::ExecutionContext& context, + const framework::Tensor& in, framework::Tensor* out) { + auto x_dims = in.dims(); + std::vector sqz_dims = {-1}; + auto out_dims = GetOutputShape(sqz_dims, x_dims, true); + out->mutable_data(context.GetPlace(), in.type()); + framework::TensorCopy( + in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); +} + +// vector_case, need to operate like unsqueeze(-1) +static void to_unsqueeze(const framework::ExecutionContext& context, + const framework::Tensor& in, framework::Tensor* out) { + auto x_dims = in.dims(); + std::vector unsqz_dims = {-1}; + framework::DDim out_dims = out->dims(); + out_dims = GetOutputShapeUnsqueeze(unsqz_dims, x_dims); + framework::TensorCopy( + in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); +} + +template +Container infer_size_impl(std::vector a, std::vector b) { + size_t dimsA = a.size(); + size_t dimsB = b.size(); + size_t ndim = dimsA > dimsB ? dimsA : dimsB; + Container expandedSizes(ndim); + + for (ptrdiff_t i = (ptrdiff_t)ndim - 1; i >= 0; --i) { + ptrdiff_t offset = ndim - 1 - i; + ptrdiff_t dimA = dimsA - 1 - offset; + ptrdiff_t dimB = dimsB - 1 - offset; + int64_t sizeA = (dimA >= 0) ? a[dimA] : 1; + int64_t sizeB = (dimB >= 0) ? b[dimB] : 1; + + PADDLE_ENFORCE_EQ( + (sizeA == sizeB || sizeA == 1 || sizeB == 1), true, + platform::errors::PreconditionNotMet( + "The size of tensor a (%d) must match the size of tensor b " + "(%d) at non-singleton dimension %d.", + sizeA, sizeB, i)); + + expandedSizes[i] = sizeA == 1 ? sizeB : sizeA; + } + return expandedSizes; +} + +// infer size for broadcast operation +static std::vector infer_size(std::vector a, + std::vector b) { + return infer_size_impl>(a, b); +} + +// necessary check before expand operation +static void expand_check(const Tensor& arg1, + std::vector expand_shape) { + auto rank = arg1.dims().size(); + PADDLE_ENFORCE_GE( + rank, 1, platform::errors::InvalidArgument( + "The rank of the input 'X' for expand must be positive, " + "but the value received is %d.", + rank)); + PADDLE_ENFORCE_LE( + rank, MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The rank of the input 'X' for expand must be less than " + "or equal to %d, but the value received is %d.", + MAX_RANK_SUPPORTED, rank)); + auto shape_size = static_cast(expand_shape.size()); + PADDLE_ENFORCE_GE( + shape_size, rank, + platform::errors::InvalidArgument( + "The number (%d) of elements of 'shape' for expand must be " + "greater than or equal to the rank (%d) of the input 'X'.", + shape_size, rank)); + PADDLE_ENFORCE_LE( + shape_size, MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The number (%d) of elements of 'shape' for expand must be " + "less than or equal to %d.", + shape_size, MAX_RANK_SUPPORTED)); +} + +// broadcast the batch dimensions of arg1 and arg2. +static inline std::tuple, std::vector> +_broadcast_batch_dims(const Tensor& arg1, const Tensor& arg2) { + std::vector arg1_dims_vec = + paddle::framework::vectorize(arg1.dims()); + std::vector arg2_dims_vec = + paddle::framework::vectorize(arg2.dims()); + + std::vector::const_iterator f1 = arg1_dims_vec.begin(); + std::vector::const_iterator l1 = arg1_dims_vec.end() - 2; + std::vector arg1_dims_vec_cut(f1, l1); + + std::vector::const_iterator f2 = arg2_dims_vec.begin(); + std::vector::const_iterator l2 = arg2_dims_vec.end() - 2; + std::vector arg2_dims_vec_cut(f2, l2); + + std::vector expand_batch_portion = + infer_size(arg1_dims_vec_cut, arg2_dims_vec_cut); + + std::vector arg1_expand_size({expand_batch_portion}); + arg1_expand_size.insert( + arg1_expand_size.end(), + {arg1_dims_vec[static_cast(arg1_dims_vec.size()) - 2], + arg1_dims_vec[static_cast(arg1_dims_vec.size()) - 1]}); + + std::vector arg2_expand_size({expand_batch_portion}); + arg2_expand_size.insert( + arg2_expand_size.end(), + {arg2_dims_vec[static_cast(arg2_dims_vec.size()) - 2], + arg2_dims_vec[static_cast(arg2_dims_vec.size()) - 1]}); + + return std::make_tuple(arg1_expand_size, arg2_expand_size); +} + +template +void tensor_expand(const framework::ExecutionContext& context, + const Tensor& arg1, Tensor* out0, + std::vector expand_size) { + auto in_dims = arg1.dims(); + auto expand_shape = expand_size; + auto vec_in_dims = framework::vectorize(in_dims); + auto diff = expand_shape.size() - vec_in_dims.size(); + vec_in_dims.insert(vec_in_dims.begin(), diff, 1); + std::vector repeat_times(vec_in_dims.size()); + for (size_t i = 0; i < vec_in_dims.size(); ++i) { + PADDLE_ENFORCE_NE( + expand_shape[i], 0, + platform::errors::InvalidArgument("The expanded size cannot be zero.")); + if (i < diff) { + PADDLE_ENFORCE_GT( + expand_shape[i], 0, + platform::errors::InvalidArgument( + "The expanded size (%d) for non-existing dimensions must be " + "positive for expand operation.", + expand_shape[i])); + repeat_times[i] = expand_shape[i]; + } else if (expand_shape[i] > 0) { + if (vec_in_dims[i] != 1) { + PADDLE_ENFORCE_EQ( + vec_in_dims[i], expand_shape[i], + platform::errors::InvalidArgument( + "The value (%d) of the non-singleton dimension does not match" + " the corresponding value (%d) in shape for expand operation.", + vec_in_dims[i], expand_shape[i])); + repeat_times[i] = 1; + } else { + repeat_times[i] = expand_shape[i]; + } + } else { + PADDLE_ENFORCE_EQ( + expand_shape[i], -1, + platform::errors::InvalidArgument( + "When the value in shape is negative for expand_v2 op, " + "only -1 is supported, but the value received is %d.", + expand_shape[i])); + repeat_times[i] = 1; + } + } + + Eigen::DSizes bcast_dims; + for (size_t i = 0; i < repeat_times.size(); ++i) { + bcast_dims[i] = repeat_times[i]; + } + + framework::DDim new_in_dims = framework::make_ddim(vec_in_dims); + framework::DDim out_dims(new_in_dims); + for (size_t i = 0; i < repeat_times.size(); ++i) { + out_dims[i] *= repeat_times[i]; + } + + out0->Resize(out_dims); + auto x = EigenTensor::From(arg1, new_in_dims); + out0->mutable_data(context.GetPlace()); + auto y = EigenTensor::From(*out0, out_dims); + auto& place = + *context.template device_context().eigen_device(); + // use 32-bit index to speed up + bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); + if (use_32bit_index) { + EigenBroadcast, T, Rank>::Eval( + place, To32BitIndex(y), To32BitIndex(x), bcast_dims); + } else { + EigenBroadcast, T, Rank>::Eval(place, y, x, + bcast_dims); + } +} + +template +static void linalg_solve(const framework::ExecutionContext& context, + const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* out) { + out->mutable_data(context.GetPlace()); + + auto& dev_ctx = context.template device_context(); + math::MatrixSolveFunctor mat_solve; + + // input y can be vector or matrix + // but need to be unsqueezed if y is a vector + bool is_vector = false; + is_vector = is_vector_rhs(*x, *y); + + Tensor tmp_y; + if (is_vector) { + tmp_y.mutable_data(context.GetPlace(), y->type()); + to_unsqueeze(context, *y, &tmp_y); + } else { + tmp_y.Resize(y->dims()); + tmp_y.mutable_data(context.GetPlace(), y->type()); + framework::TensorCopy( + *y, context.GetPlace(), + context.template device_context(), &tmp_y); + } + + Tensor tmp_x; + tmp_x.Resize(x->dims()); + tmp_x.mutable_data(context.GetPlace(), x->type()); + framework::TensorCopy( + *x, context.GetPlace(), + context.template device_context(), &tmp_x); + + std::vector x_broadcast_dims; + std::vector y_broadcast_dims; + std::tie(x_broadcast_dims, y_broadcast_dims) = + _broadcast_batch_dims(tmp_x, tmp_y); + + expand_check(tmp_x, x_broadcast_dims); + expand_check(tmp_y, y_broadcast_dims); + + Tensor tmp_x_bc; + Tensor tmp_y_bc; + auto tmp_x_rank = tmp_x.dims().size(); + auto tmp_y_rank = tmp_y.dims().size(); + + auto rank_0 = std::max(tmp_x_rank, static_cast(x_broadcast_dims.size())); + switch (rank_0) { + case 1: + tensor_expand<1, T, DeviceContext>(context, tmp_x, &tmp_x_bc, + x_broadcast_dims); + break; + case 2: + tensor_expand<2, T, DeviceContext>(context, tmp_x, &tmp_x_bc, + x_broadcast_dims); + break; + case 3: + tensor_expand<3, T, DeviceContext>(context, tmp_x, &tmp_x_bc, + x_broadcast_dims); + break; + case 4: + tensor_expand<4, T, DeviceContext>(context, tmp_x, &tmp_x_bc, + x_broadcast_dims); + break; + case 5: + tensor_expand<5, T, DeviceContext>(context, tmp_x, &tmp_x_bc, + x_broadcast_dims); + break; + case 6: + tensor_expand<6, T, DeviceContext>(context, tmp_x, &tmp_x_bc, + x_broadcast_dims); + break; + } + + auto rank_1 = std::max(tmp_y_rank, static_cast(y_broadcast_dims.size())); + switch (rank_1) { + case 1: + tensor_expand<1, T, DeviceContext>(context, tmp_y, &tmp_y_bc, + y_broadcast_dims); + break; + case 2: + tensor_expand<2, T, DeviceContext>(context, tmp_y, &tmp_y_bc, + y_broadcast_dims); + break; + case 3: + tensor_expand<3, T, DeviceContext>(context, tmp_y, &tmp_y_bc, + y_broadcast_dims); + break; + case 4: + tensor_expand<4, T, DeviceContext>(context, tmp_y, &tmp_y_bc, + y_broadcast_dims); + break; + case 5: + tensor_expand<5, T, DeviceContext>(context, tmp_y, &tmp_y_bc, + y_broadcast_dims); + break; + case 6: + tensor_expand<6, T, DeviceContext>(context, tmp_y, &tmp_y_bc, + y_broadcast_dims); + break; + } + + auto x_dim = x->dims(); + auto y_dim = y->dims(); + auto x_dim_size = x_dim.size(); + auto y_dim_size = y_dim.size(); + + if (is_vector) { // vector case + out->Resize(tmp_y_bc.dims()); // out.unsqueeze(-1) + mat_solve(dev_ctx, tmp_x_bc, tmp_y_bc, out); + + Tensor out_tmp; + out_tmp.Resize(out->dims()); + out_tmp = *out; + to_squeeze(context, out_tmp, out); // out.squeeze(-1) + } else { + PADDLE_ENFORCE_EQ( + x_dim[x_dim_size - 1], y_dim[y_dim_size - 2], + platform::errors::InvalidArgument( + "Matrix X1 with dimension greater than 2 and any matrix Y1," + "the matrix X1's width must be equal with matrix Y1's " + "height. But received X's shape = [%s], X1's shape = [%s], X1's " + "width = %s; Y's shape = [%s], Y1's shape = [%s], Y1's height = " + "%s.", + x_dim, x_dim, x_dim[x_dim_size - 1], y_dim, y_dim, + y_dim[y_dim_size - 2])); + mat_solve(dev_ctx, tmp_x_bc, tmp_y_bc, out); + } +} + +// for TransposeNormal +static std::vector getNewAxis(const int b_rank) { + std::vector axis_1 = {0}; + std::vector axis_2 = {1, 0}; + std::vector axis_3 = {0, 2, 1}; + std::vector axis_4 = {0, 1, 3, 2}; + std::vector axis_5 = {0, 1, 2, 4, 3}; + std::vector axis_6 = {0, 1, 2, 3, 5, 4}; + std::vector axis_7 = {0, 1, 2, 3, 4, 6, 5}; + std::vector axis_8 = {0, 1, 2, 3, 4, 5, 7, 6}; + std::vector axis_9 = {0, 1, 2, 3, 4, 5, 6, 8, 7}; + switch (b_rank) { + case 1: + return axis_1; + break; + case 2: + return axis_2; + break; + case 3: + return axis_3; + break; + case 4: + return axis_4; + break; + case 5: + return axis_5; + break; + case 6: + return axis_6; + break; + case 7: + return axis_7; + break; + case 8: + return axis_8; + break; + default: + return axis_9; + } +} + +// for Resize +static std::vector getNewDimsVec(const DDim& b_dims) { + std::vector b_dims_vec = paddle::framework::vectorize(b_dims); + int size = b_dims_vec.size(); + if (size >= 2) { + // swap the last 2 elements in b_dims_vec + int64_t temp = b_dims_vec[size - 1]; + b_dims_vec[size - 1] = b_dims_vec[size - 2]; + b_dims_vec[size - 2] = temp; + return b_dims_vec; + } + PADDLE_ENFORCE_NE( + b_dims_vec.empty(), true, + platform::errors::PreconditionNotMet( + "The size of tensor b must not be %d after getting new dims", 0)); + // if b_dims_vec.size() == 1, just retun original vec + return b_dims_vec; +} + +template +class SolveKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const auto* x = context.Input("X"); + const auto* y = context.Input("Y"); + Tensor* out = context.Output("Out"); + linalg_solve(context, x, y, out); + } +}; + +template +class SolveGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + + // reuse the linalg.solve forward output + auto* out = ctx.Input("Out"); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + bool is_vector = false; + is_vector = is_vector_rhs(*input, *y); + + Tensor tmp_y; + if (is_vector) { + tmp_y.mutable_data(ctx.GetPlace(), y->type()); + to_unsqueeze(ctx, *y, &tmp_y); + } else { + tmp_y.Resize(y->dims()); + tmp_y.mutable_data(ctx.GetPlace(), y->type()); + framework::TensorCopy( + *y, ctx.GetPlace(), + ctx.template device_context(), &tmp_y); + } + + Tensor tmp_x; + tmp_x.Resize(input->dims()); + tmp_x.mutable_data(ctx.GetPlace(), input->type()); + framework::TensorCopy( + *input, ctx.GetPlace(), + ctx.template device_context(), &tmp_x); + + std::vector x_broadcast_dims; + std::vector y_broadcast_dims; + std::tie(x_broadcast_dims, y_broadcast_dims) = + _broadcast_batch_dims(tmp_x, tmp_y); + + // tmp_dx + Tensor tmp_dx; + tmp_dx.Resize(framework::make_ddim(x_broadcast_dims)); + tmp_dx.mutable_data(ctx.GetPlace()); + + // tmp_dy + Tensor tmp_dy; + tmp_dy.Resize(framework::make_ddim(y_broadcast_dims)); + tmp_dy.mutable_data(ctx.GetPlace()); + + Tensor tmp_input(input->type()); + const auto& new_dims_vec = getNewDimsVec(input->dims()); + tmp_input.Resize(framework::make_ddim(new_dims_vec)); + tmp_input.mutable_data(ctx.GetPlace()); + math::TransposeNormal trans; + std::vector new_axis = getNewAxis(input->dims().size()); + auto& dev_ctx = ctx.template device_context(); + trans(dev_ctx, *input, &tmp_input, new_axis); + + if (dy) { + dy->mutable_data(ctx.GetPlace()); + // reuse linalg_solve forward logics to get tmp_dy + linalg_solve(ctx, &tmp_input, dout, &tmp_dy); + } + + if (dx) { + dx->mutable_data(ctx.GetPlace()); + // to get dx + auto blas = math::GetBlas(ctx); + if (input->dims().size() == 2 && y->dims().size() == 2) { + auto mat_dim_a1 = math::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); + auto mat_dim_b1 = math::CreateMatrixDescriptor(out->dims(), 0, true); + blas.MatMul(tmp_dy, mat_dim_a1, *out, mat_dim_b1, T(-1), &tmp_dx, T(0)); + } else if (is_vector_rhs(*input, *y)) { + Tensor tmp_dy_; + tmp_dy_.mutable_data(ctx.GetPlace(), y->type()); + to_unsqueeze(ctx, tmp_dy, &tmp_dy_); + + Tensor tmp_out_; + tmp_out_.mutable_data(ctx.GetPlace(), out->type()); + to_unsqueeze(ctx, *out, &tmp_out_); + + auto mat_dim_a1 = + math::CreateMatrixDescriptor(tmp_dy_.dims(), 0, false); + auto mat_dim_b1 = + math::CreateMatrixDescriptor(tmp_out_.dims(), 0, true); + blas.MatMul(tmp_dy_, mat_dim_a1, tmp_out_, mat_dim_b1, T(-1), &tmp_dx, + T(0)); + } else { + auto mat_dim_a1 = math::CreateMatrixDescriptor(tmp_dy.dims(), 0, false); + auto mat_dim_b1 = math::CreateMatrixDescriptor(out->dims(), 0, true); + blas.MatMul(tmp_dy, mat_dim_a1, *out, mat_dim_b1, T(-1), &tmp_dx, T(0)); + } + } + + if (y->dims() != tmp_dy.dims()) { + Tensor dy_help; + dy_help.Resize(tmp_dy.dims()); + dy_help.mutable_data(ctx.GetPlace(), tmp_dy.type()); + framework::TensorCopy( + tmp_dy, ctx.GetPlace(), + ctx.template device_context(), &dy_help); + + // get dims + std::vector x_dims = vectorize(input->dims()); + std::vector y_dims = vectorize(y->dims()); + std::vector dout_dims = vectorize(dout->dims()); + + if (is_vector_rhs(*input, *y)) { + dout_dims.push_back(1); + } + + int y_ndim = y_dims.size(); + int ndim = dout_dims.size(); + + const std::vector dy_help_dims = vectorize(dy_help.dims()); + std::vector dy_broadcast_dims(ndim); + + std::fill(dy_broadcast_dims.data(), + dy_broadcast_dims.data() + ndim - y_ndim, 1); + std::copy(y_dims.data(), y_dims.data() + y_ndim, + dy_broadcast_dims.data() + ndim - y_ndim); + + std::vector dy_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dy_help_dims[idx] != 1 && dy_broadcast_dims[idx] == 1) { + dy_reduce_dims.push_back(idx); + } + } + // reduce sum to get grad by ReduceSum + if (dy) { + if (dy_reduce_dims.empty()) { + *dy = std::move(dy_help); + } else { + bool keep_dim = true; + if (dy_help.dims().size() != dy->dims().size()) { + keep_dim = false; + } + ReduceSumForSolveGrad(&dy_help, dy, dy_reduce_dims, + keep_dim, ctx); + } + dy->Resize(y->dims()); + } + } else { + framework::TensorCopy( + tmp_dy, ctx.GetPlace(), + ctx.template device_context(), dy); + } + + if (input->dims() != tmp_dx.dims()) { + Tensor dx_help; + dx_help.Resize(tmp_dx.dims()); + dx_help.mutable_data(ctx.GetPlace(), tmp_dx.type()); + framework::TensorCopy( + tmp_dx, ctx.GetPlace(), + ctx.template device_context(), &dx_help); + + // get dims + std::vector x_dims = vectorize(input->dims()); + std::vector y_dims = vectorize(y->dims()); + + int x_ndim = x_dims.size(); + int ndim = x_broadcast_dims.size(); + + const std::vector dx_help_dims = vectorize(dx_help.dims()); + std::vector dx_broadcast_dims(ndim); + + std::fill(dx_broadcast_dims.data(), + dx_broadcast_dims.data() + ndim - x_ndim, 1); + std::copy(x_dims.data(), x_dims.data() + x_ndim, + dx_broadcast_dims.data() + ndim - x_ndim); + + std::vector dx_reduce_dims; + for (int idx = 0; idx <= ndim - 3; idx++) { + if (dx_help_dims[idx] != 1 && dx_broadcast_dims[idx] == 1) { + dx_reduce_dims.push_back(idx); + } + } + // reduce sum to get grad by ReduceSum + if (dx) { + dx->mutable_data(ctx.GetPlace()); + if (dx_reduce_dims.empty()) { + *dx = std::move(dx_help); + } else { + bool keep_dim = true; + if (dx_help.dims().size() != dx->dims().size()) { + keep_dim = false; + } + ReduceSumForSolveGrad(&dx_help, dx, dx_reduce_dims, + keep_dim, ctx); + } + dx->Resize(input->dims()); + } + } else { + framework::TensorCopy( + tmp_dx, ctx.GetPlace(), + ctx.template device_context(), dx); + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index 96e16894c78c659a6173e6c0a5f57bdfa4e80827..ab30ab307a9c7cb688aefc072dbf2a639d1a2531 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -89,7 +89,9 @@ extern void *cublas_dso_handle; __macro(cublasDgetrfBatched); \ __macro(cublasDgetriBatched); \ __macro(cublasSmatinvBatched); \ - __macro(cublasDmatinvBatched); + __macro(cublasDmatinvBatched); \ + __macro(cublasSgetrsBatched); \ + __macro(cublasDgetrsBatched); CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 60e6c954c161e195066989b93e75156fbaff3651..e09138ef09409906ef31b09c72db404549ac371e 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -108,6 +108,7 @@ from .tensor.linalg import matrix_power # noqa: F401 from .tensor.linalg import svd # noqa: F401 from .tensor.linalg import eigh # noqa: F401 from .tensor.linalg import pinv # noqa: F401 +from .tensor.linalg import solve # noqa: F401 from .tensor.logic import equal # noqa: F401 from .tensor.logic import greater_equal # noqa: F401 from .tensor.logic import greater_than # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 3496021892f34203928e88ddb1b828a5ddb60eb1..4b887da8382576c95ec7af65578a92707037c726 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -967,6 +967,7 @@ set_tests_properties(test_dataloader_unkeep_order PROPERTIES TIMEOUT 120) set_tests_properties(test_reader_reset PROPERTIES TIMEOUT 120) set_tests_properties(test_pool3d_api PROPERTIES TIMEOUT 120) set_tests_properties(test_cumprod_op PROPERTIES TIMEOUT 120) +set_tests_properties(test_solve_op PROPERTIES TIMEOUT 120) if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL) set_tests_properties(test_parallel_dygraph_dataparallel PROPERTIES TIMEOUT 120) set_tests_properties(test_parallel_dygraph_unused_variables PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/test_solve_op.py b/python/paddle/fluid/tests/unittests/test_solve_op.py new file mode 100644 index 0000000000000000000000000000000000000000..fd527ec90f217110a56786b1ee900f78c2781ecc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_solve_op.py @@ -0,0 +1,563 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.w + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import sys +sys.path.append("..") +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + + +# 2D normal case +class TestSolveOp(OpTest): + def config(self): + self.input_x_matrix_shape = [15, 15] + self.input_y_matrix_shape = [15, 10] + self.dtype = "float64" + + def setUp(self): + paddle.enable_static() + self.config() + self.op_type = "solve" + + np.random.seed(2021) + self.inputs = { + 'X': np.random.random(self.input_x_matrix_shape).astype(self.dtype), + 'Y': np.random.random(self.input_y_matrix_shape).astype(self.dtype) + } + self.outputs = { + 'Out': np.linalg.solve(self.inputs['X'], self.inputs['Y']) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + +# x broadcast + 3D batch case +class TestSolveOpBatched_case0(OpTest): + def setUp(self): + self.op_type = "solve" + self.dtype = "float64" + np.random.seed(2021) + self.inputs = { + 'X': np.random.random((11, 11)).astype(self.dtype), + 'Y': np.random.random((2, 11, 7)).astype(self.dtype) + } + result = np.linalg.solve(self.inputs['X'], self.inputs['Y']) + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-1) + + +# 3D batch + y vector case +class TestSolveOpBatched_case1(OpTest): + def setUp(self): + self.op_type = "solve" + self.dtype = "float64" + np.random.seed(2021) + self.inputs = { + 'X': np.random.random((20, 6, 6)).astype(self.dtype), + 'Y': np.random.random((20, 6)).astype(self.dtype) + } + result = np.linalg.solve(self.inputs['X'], self.inputs['Y']) + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04) + + +# 3D batch + y broadcast case +class TestSolveOpBatched_case2(OpTest): + def setUp(self): + self.op_type = "solve" + self.dtype = "float64" + np.random.seed(2021) + self.inputs = { + 'X': np.random.random((2, 10, 10)).astype(self.dtype), + 'Y': np.random.random((1, 10, 10)).astype(self.dtype) + } + result = np.linalg.solve(self.inputs['X'], self.inputs['Y']) + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.02) + + +# x broadcast + 3D batch case +class TestSolveOpBatched_case3(OpTest): + def setUp(self): + self.op_type = "solve" + self.dtype = "float64" + np.random.seed(2021) + self.inputs = { + 'X': np.random.random((1, 10, 10)).astype(self.dtype), + 'Y': np.random.random((2, 10, 10)).astype(self.dtype) + } + result = np.linalg.solve(self.inputs['X'], self.inputs['Y']) + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.02) + + +# 3D normal batch case +class TestSolveOpBatched_case4(OpTest): + def setUp(self): + self.op_type = "solve" + self.dtype = "float64" + np.random.seed(2021) + self.inputs = { + 'X': np.random.random((3, 6, 6)).astype(self.dtype), + 'Y': np.random.random((3, 6, 7)).astype(self.dtype) + } + result = np.linalg.solve(self.inputs['X'], self.inputs['Y']) + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + +# 4D normal batch case +class TestSolveOpBatched_case5(OpTest): + def setUp(self): + self.op_type = "solve" + self.dtype = "float64" + np.random.seed(2021) + self.inputs = { + 'X': np.random.random((2, 2, 6, 6)).astype(self.dtype), + 'Y': np.random.random((2, 2, 6, 6)).astype(self.dtype) + } + result = np.linalg.solve(self.inputs['X'], self.inputs['Y']) + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + +# 4D batch + y broadcast case +class TestSolveOpBatched_case6(OpTest): + def setUp(self): + self.op_type = "solve" + self.dtype = "float64" + np.random.seed(2021) + self.inputs = { + 'X': np.random.random((2, 2, 6, 6)).astype(self.dtype), + 'Y': np.random.random((1, 2, 6, 9)).astype(self.dtype) + } + result = np.linalg.solve(self.inputs['X'], self.inputs['Y']) + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + +# 5D normal batch case +class TestSolveOpBatched_case7(OpTest): + def setUp(self): + self.op_type = "solve" + self.dtype = "float64" + np.random.seed(2021) + self.inputs = { + 'X': np.random.random((2, 2, 2, 4, 4)).astype(self.dtype), + 'Y': np.random.random((2, 2, 2, 4, 4)).astype(self.dtype) + } + result = np.linalg.solve(self.inputs['X'], self.inputs['Y']) + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04) + + +# 5D batch + y broadcast case +class TestSolveOpBatched_case8(OpTest): + def setUp(self): + self.op_type = "solve" + self.dtype = "float64" + np.random.seed(2021) + self.inputs = { + 'X': np.random.random((2, 2, 2, 4, 4)).astype(self.dtype), + 'Y': np.random.random((1, 2, 2, 4, 7)).astype(self.dtype) + } + result = np.linalg.solve(self.inputs['X'], self.inputs['Y']) + self.outputs = {'Out': result} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.04) + + +class TestSolveOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of solve_op must be Variable. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + y1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + self.assertRaises(TypeError, paddle.linalg.solve, x1, y1) + + # The data type of input must be float32 or float64. + x2 = fluid.data(name="x2", shape=[30, 30], dtype="bool") + y2 = fluid.data(name="y2", shape=[30, 10], dtype="bool") + self.assertRaises(TypeError, paddle.linalg.solve, x2, y2) + + x3 = fluid.data(name="x3", shape=[30, 30], dtype="int32") + y3 = fluid.data(name="y3", shape=[30, 10], dtype="int32") + self.assertRaises(TypeError, paddle.linalg.solve, x3, y3) + + x4 = fluid.data(name="x4", shape=[30, 30], dtype="int64") + y4 = fluid.data(name="y4", shape=[30, 10], dtype="int64") + self.assertRaises(TypeError, paddle.linalg.solve, x4, y4) + + x5 = fluid.data(name="x5", shape=[30, 30], dtype="float16") + y5 = fluid.data(name="y5", shape=[30, 10], dtype="float16") + self.assertRaises(TypeError, paddle.linalg.solve, x5, y5) + + # The number of dimensions of input'X must be >= 2. + x6 = fluid.data(name="x6", shape=[30], dtype="float64") + y6 = fluid.data(name="y6", shape=[30], dtype="float64") + self.assertRaises(ValueError, paddle.linalg.solve, x6, y6) + + # The inner-most 2 dimensions of input'X should be equal to each other + x7 = fluid.data(name="x7", shape=[2, 3, 4], dtype="float64") + y7 = fluid.data(name="y7", shape=[2, 4, 3], dtype="float64") + self.assertRaises(ValueError, paddle.linalg.solve, x7, y7) + + +# 2D + vector case, FP64 +class TestSolveOpAPI_1(unittest.TestCase): + def setUp(self): + np.random.seed(2021) + self.place = [paddle.CPUPlace()] + self.dtype = "float64" + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + paddle_input_x = fluid.data( + name="input_x", shape=[3, 3], dtype=self.dtype) + paddle_input_y = fluid.data( + name="input_y", shape=[3], dtype=self.dtype) + paddle_result = paddle.linalg.solve(paddle_input_x, paddle_input_y) + + np_input_x = np.random.random([3, 3]).astype(self.dtype) + np_input_y = np.random.random([3]).astype(self.dtype) + + np_result = np.linalg.solve(np_input_x, np_input_y) + + exe = fluid.Executor(place) + fetches = exe.run( + fluid.default_main_program(), + feed={"input_x": np_input_x, + "input_y": np_input_y}, + fetch_list=[paddle_result]) + self.assertTrue( + np.allclose(fetches[0], np.linalg.solve(np_input_x, + np_input_y))) + + def test_static(self): + for place in self.place: + self.check_static_result(place=place) + + def test_dygraph(self): + def run(place): + paddle.disable_static(place) + np.random.seed(2021) + input_x_np = np.random.random([3, 3]).astype(self.dtype) + input_y_np = np.random.random([3]).astype(self.dtype) + + tensor_input_x = paddle.to_tensor(input_x_np) + tensor_input_y = paddle.to_tensor(input_y_np) + + numpy_output = np.linalg.solve(input_x_np, input_y_np) + paddle_output = paddle.linalg.solve(tensor_input_x, tensor_input_y) + self.assertEqual( + np.allclose(numpy_output, paddle_output.numpy()), True) + self.assertEqual(numpy_output.shape, paddle_output.numpy().shape) + paddle.enable_static() + + for place in self.place: + run(place) + + +# 2D normal case, FP64 +class TestSolveOpAPI_2(unittest.TestCase): + def setUp(self): + np.random.seed(2021) + self.place = [paddle.CPUPlace()] + self.dtype = "float64" + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def check_static_result(self, place): + paddle.enable_static() + with fluid.program_guard(fluid.Program(), fluid.Program()): + paddle_input_x = fluid.data( + name="input_x", shape=[10, 10], dtype=self.dtype) + paddle_input_y = fluid.data( + name="input_y", shape=[10, 4], dtype=self.dtype) + paddle_result = paddle.linalg.solve(paddle_input_x, paddle_input_y) + + np_input_x = np.random.random([10, 10]).astype(self.dtype) + np_input_y = np.random.random([10, 4]).astype(self.dtype) + + np_result = np.linalg.solve(np_input_x, np_input_y) + + exe = fluid.Executor(place) + fetches = exe.run( + fluid.default_main_program(), + feed={"input_x": np_input_x, + "input_y": np_input_y}, + fetch_list=[paddle_result]) + self.assertTrue( + np.allclose(fetches[0], np.linalg.solve(np_input_x, + np_input_y))) + + def test_static(self): + for place in self.place: + self.check_static_result(place=place) + + def test_dygraph(self): + def run(place): + paddle.disable_static(place) + np.random.seed(2021) + input_x_np = np.random.random([10, 10]).astype(self.dtype) + input_y_np = np.random.random([10, 4]).astype(self.dtype) + + tensor_input_x = paddle.to_tensor(input_x_np) + tensor_input_y = paddle.to_tensor(input_y_np) + + numpy_output = np.linalg.solve(input_x_np, input_y_np) + paddle_output = paddle.linalg.solve(tensor_input_x, tensor_input_y) + self.assertEqual( + np.allclose(numpy_output, paddle_output.numpy()), True) + self.assertEqual(numpy_output.shape, paddle_output.numpy().shape) + paddle.enable_static() + + for place in self.place: + run(place) + + +# 2D normal case, FP32 +class TestSolveOpAPI_3(unittest.TestCase): + def setUp(self): + np.random.seed(2021) + self.place = [paddle.CPUPlace()] + self.dtype = "float32" + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def check_static_result(self, place): + paddle.enable_static() + with fluid.program_guard(fluid.Program(), fluid.Program()): + paddle_input_x = fluid.data( + name="input_x", shape=[10, 10], dtype=self.dtype) + paddle_input_y = fluid.data( + name="input_y", shape=[10, 4], dtype=self.dtype) + paddle_result = paddle.linalg.solve(paddle_input_x, paddle_input_y) + + np_input_x = np.random.random([10, 10]).astype(self.dtype) + np_input_y = np.random.random([10, 4]).astype(self.dtype) + + np_result = np.linalg.solve(np_input_x, np_input_y) + + exe = fluid.Executor(place) + fetches = exe.run( + fluid.default_main_program(), + feed={"input_x": np_input_x, + "input_y": np_input_y}, + fetch_list=[paddle_result]) + self.assertTrue( + np.allclose( + fetches[0], + np.linalg.solve(np_input_x, np_input_y), + rtol=1.e-4)) + + def test_static(self): + for place in self.place: + self.check_static_result(place=place) + + def test_dygraph(self): + def run(place): + paddle.disable_static(place) + np.random.seed(2021) + input_x_np = np.random.random([10, 10]).astype(self.dtype) + input_y_np = np.random.random([10, 4]).astype(self.dtype) + + tensor_input_x = paddle.to_tensor(input_x_np) + tensor_input_y = paddle.to_tensor(input_y_np) + + numpy_output = np.linalg.solve(input_x_np, input_y_np) + paddle_output = paddle.linalg.solve(tensor_input_x, tensor_input_y) + self.assertEqual( + np.allclose( + numpy_output, paddle_output.numpy(), rtol=1.e-4), + True) + self.assertEqual(numpy_output.shape, paddle_output.numpy().shape) + paddle.enable_static() + + for place in self.place: + run(place) + + +# 3D + y broadcast case, FP64 +class TestSolveOpAPI_4(unittest.TestCase): + def setUp(self): + np.random.seed(2021) + self.place = [paddle.CPUPlace()] + self.dtype = "float64" + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + paddle_input_x = fluid.data( + name="input_x", shape=[2, 3, 3], dtype=self.dtype) + paddle_input_y = fluid.data( + name="input_y", shape=[1, 3, 3], dtype=self.dtype) + paddle_result = paddle.linalg.solve(paddle_input_x, paddle_input_y) + + np_input_x = np.random.random([2, 3, 3]).astype(self.dtype) + np_input_y = np.random.random([1, 3, 3]).astype(self.dtype) + + np_result = np.linalg.solve(np_input_x, np_input_y) + + exe = fluid.Executor(place) + fetches = exe.run( + fluid.default_main_program(), + feed={"input_x": np_input_x, + "input_y": np_input_y}, + fetch_list=[paddle_result]) + self.assertTrue( + np.allclose(fetches[0], np.linalg.solve(np_input_x, + np_input_y))) + + def test_static(self): + for place in self.place: + self.check_static_result(place=place) + + def test_dygraph(self): + def run(place): + paddle.disable_static(place) + np.random.seed(2021) + input_x_np = np.random.random([2, 3, 3]).astype(self.dtype) + input_y_np = np.random.random([1, 3, 3]).astype(self.dtype) + + tensor_input_x = paddle.to_tensor(input_x_np) + tensor_input_y = paddle.to_tensor(input_y_np) + + numpy_output = np.linalg.solve(input_x_np, input_y_np) + paddle_output = paddle.linalg.solve(tensor_input_x, tensor_input_y) + self.assertEqual( + np.allclose(numpy_output, paddle_output.numpy()), True) + self.assertEqual(numpy_output.shape, paddle_output.numpy().shape) + paddle.enable_static() + + for place in self.place: + run(place) + + +class TestSolveOpSingularAPI(unittest.TestCase): + # Singular matrix is ​​not invertible + def setUp(self): + self.places = [fluid.CPUPlace()] + self.dtype = "float64" + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[4, 4], dtype=self.dtype) + y = fluid.data(name="y", shape=[4, 4], dtype=self.dtype) + + result = paddle.linalg.solve(x, y) + + input_x_np = np.ones([4, 4]).astype(self.dtype) + input_y_np = np.ones([4, 4]).astype(self.dtype) + + exe = fluid.Executor(place) + try: + fetches = exe.run(fluid.default_main_program(), + feed={"x": input_x_np, + "y": input_y_np}, + fetch_list=[result]) + except RuntimeError as ex: + print("The mat is singular") + pass + except ValueError as ex: + print("The mat is singular") + pass + + def test_static(self): + for place in self.places: + paddle.enable_static() + self.check_static_result(place=place) + + def test_dygraph(self): + for place in self.places: + with fluid.dygraph.guard(place): + input_x_np = np.ones([4, 4]).astype(self.dtype) + input_y_np = np.ones([4, 4]).astype(self.dtype) + input_x = fluid.dygraph.to_variable(input_x_np) + input_y = fluid.dygraph.to_variable(input_y_np) + + try: + result = paddle.linalg.solve(input_x, input_y) + except RuntimeError as ex: + print("The mat is singular") + pass + except ValueError as ex: + print("The mat is singular") + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py index 2b3383239a0ce350d5f47ff151b9ecfec41b660d..26d63826cc87a975e46576a1de06493f8a52c4ee 100644 --- a/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py +++ b/python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py @@ -48,6 +48,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [ 'lgamma', \ 'svd', \ 'matrix_power', \ + 'solve', \ ] NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = ['bilinear_interp',\ diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index f12cafd3421d61439b21c7d10cc6707bf6d5465a..d57d9a4bdb6780c1eed2f8a65fc71bddc45c1c82 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -16,6 +16,7 @@ from .tensor.linalg import cholesky # noqa: F401 from .tensor.linalg import norm # noqa: F401 from .tensor.linalg import cond # noqa: F401 from .tensor.linalg import matrix_power # noqa: F401 +from .tensor.linalg import solve # noqa: F401 from .tensor import inverse as inv # noqa: F401 from .tensor.linalg import eigvals # noqa: F401 from .tensor.linalg import multi_dot # noqa: F401 @@ -39,5 +40,6 @@ __all__ = [ 'det', 'slogdet', 'eigh', - 'pinv' + 'pinv', + 'solve' ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index c0f7d88d3bf58624cb8d3d3a3179fb1bc5042893..02b34bb21a79204cd1caaa0e3eda420476898243 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -51,6 +51,7 @@ from .linalg import multi_dot # noqa: F401 from .linalg import svd # noqa: F401 from .linalg import eigh # noqa: F401 from .linalg import pinv # noqa: F401 +from .linalg import solve # noqa: F401 from .logic import equal # noqa: F401 from .logic import greater_equal # noqa: F401 from .logic import greater_than # noqa: F401 @@ -386,6 +387,7 @@ tensor_method_func = [ #noqa 'bitwise_not', 'broadcast_tensors', 'uniform_', + 'solve', ] #this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index fbe6bd1697dbd434b9e367ef0e5abdb99176ad0a..b9fb0e7c563e708f26e43b112fcf20dc359db939 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2072,3 +2072,60 @@ def pinv(x, rcond=1e-15, hermitian=False, name=None): attrs={'trans_x': False, 'trans_y': True}, ) return out_2 + + +def solve(x, y, name=None): + r""" + Computes the solution of a square system of linear equations with a unique solution for input 'X' and 'Y'. + Let :math: `X` be a sqaure matrix or a batch of square matrices, :math:`Y` be + a vector/matrix or a batch of vectors/matrices, the equation should be: + + .. math:: + Out = X^-1 * Y + Specifically, + - This system of linear equations has one solution if and only if input 'X' is invertible. + + Args: + x (Tensor): A square matrix or a batch of square matrices. Its shape should be `[*, M, M]`, where `*` is zero or + more batch dimensions. Its data type should be float32 or float64. + y (Tensor): A vector/matrix or a batch of vectors/matrices. Its shape should be `[*, M, K]`, where `*` is zero or + more batch dimensions. Its data type should be float32 or float64. + name(str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The solution of a square system of linear equations with a unique solution for input 'x' and 'y'. + Its data type should be the same as that of `x`. + + Examples: + .. code-block:: python + + # a square system of linear equations: + # 2*X0 + X1 = 9 + # X0 + 2*X1 = 8 + + import paddle + import numpy as np + + np_x = np.array([[3, 1],[1, 2]]) + np_y = np.array([9, 8]) + x = paddle.to_tensor(np_x, dtype="float64") + y = paddle.to_tensor(np_y, dtype="float64") + out = paddle.linalg.solve(x, y) + + print(out) + # [2., 3.]) + """ + if in_dygraph_mode(): + return _C_ops.solve(x, y) + + inputs = {"X": [x], "Y": [y]} + helper = LayerHelper("solve", **locals()) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'solve') + check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'solve') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type="solve", inputs={"X": x, + "Y": y}, outputs={"Out": out}) + return out