diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index 6546f854df0f4ca7f1e08f3f178ac5c836633312..f245bad01aa4c11eaf67e22d00864b31be60c1ee 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -253,6 +253,12 @@ class Blas { void BatchedGETRS(CBLAS_TRANSPOSE trans, int n, int nrhs, const T** a, int lda, int* ipiv, T** b, int ldb, int* info, int batch_size) const; + + // cuBlas triangular_solve + template + void BatchedTRSM(CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, + CBLAS_DIAG diag, int M, int N, T alpha, const T** a, int lda, + T** b, int ldb, int batch_size) const; #endif private: @@ -414,6 +420,12 @@ class BlasT : private Blas { void BatchedGETRS(ARGS... args) const { Base()->template BatchedGETRS(args...); } + + // triangular_solve + template + void BatchedTRSM(ARGS... args) const { + Base()->template BatchedTRSM(args...); + } #endif private: diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index 6f83faf1e40d865c6435dcd1fe7dfaab7693dc02..70c6cf9dcab03619a5ed8c57036b3f03365da7a7 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -120,6 +120,11 @@ struct CUBlas { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cublasSgetrsBatched(args...)); } + + template + static void TRSM_BATCH(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasStrsmBatched(args...)); + } }; template <> @@ -194,6 +199,11 @@ struct CUBlas { PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::cublasDgetrsBatched(args...)); } + + template + static void TRSM_BATCH(ARGS... args) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasDtrsmBatched(args...)); + } }; template <> @@ -339,6 +349,19 @@ struct CUBlas> { reinterpret_cast(C), ldc)); } + static void TRSM(cublasHandle_t handle, cublasSideMode_t side, + cublasFillMode_t uplo, cublasOperation_t transa, + cublasDiagType_t diag, int m, int n, + const paddle::platform::complex *alpha, + const paddle::platform::complex *A, int lda, + paddle::platform::complex *B, int ldb) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCtrsm( + handle, side, uplo, transa, diag, m, n, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb)); + } + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode template @@ -370,6 +393,20 @@ struct CUBlas> { "cublasGemmEx is not supported on cuda <= 7.5")); #endif } + + static void TRSM_BATCH(cublasHandle_t handle, cublasSideMode_t side, + cublasFillMode_t uplo, cublasOperation_t transa, + cublasDiagType_t diag, int m, int n, + const paddle::platform::complex *alpha, + const paddle::platform::complex **A, int lda, + paddle::platform::complex **B, int ldb, + int batch_size) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasCtrsmBatched( + handle, side, uplo, transa, diag, m, n, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb, batch_size)); + } }; template <> @@ -440,6 +477,33 @@ struct CUBlas> { reinterpret_cast(C), ldc)); } + static void TRSM(cublasHandle_t handle, cublasSideMode_t side, + cublasFillMode_t uplo, cublasOperation_t transa, + cublasDiagType_t diag, int m, int n, + const paddle::platform::complex *alpha, + const paddle::platform::complex *A, int lda, + paddle::platform::complex *B, int ldb) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZtrsm( + handle, side, uplo, transa, diag, m, n, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb)); + } + + static void TRSM_BATCH(cublasHandle_t handle, cublasSideMode_t side, + cublasFillMode_t uplo, cublasOperation_t transa, + cublasDiagType_t diag, int m, int n, + const paddle::platform::complex *alpha, + const paddle::platform::complex **A, int lda, + paddle::platform::complex **B, int ldb, + int batch_size) { + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cublasZtrsmBatched( + handle, side, uplo, transa, diag, m, n, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(B), ldb, batch_size)); + } + // NOTES: GEMM_EX can use Tensor Core to accelerate matrix multiply. // https://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode template @@ -897,6 +961,30 @@ void Blas::BatchedGETRS( }); } +template <> +template +void Blas::BatchedTRSM( + CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, CBLAS_DIAG diag, + int M, int N, T alpha, const T **A, int lda, T **B, int ldb, + int batch_size) const { + // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` + // where ' stands for transpose + cublasSideMode_t cuSide = + (side == CblasLeft) ? CUBLAS_SIDE_RIGHT : CUBLAS_SIDE_LEFT; + cublasFillMode_t cuUplo = + (uplo == CblasLower) ? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; + // use CUBLAS_OP_C (conjugate transpose) for complex + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasDiagType_t cuDiag = + (diag == CblasUnit) ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; + + context_.CublasCall([&](cublasHandle_t handle) { + CUBlas::TRSM_BATCH(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, + &alpha, A, lda, B, ldb, batch_size); + }); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index cb4044b1b08c7a154a07ecf6b1cf58a84a46876a..4bcf3baa6493252785702dc6433a379e6142f295 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -434,6 +434,17 @@ struct CBlas> { a_, lda, b_, ldb, &beta, c_, ldc); } + static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N, + paddle::platform::complex alpha, + const paddle::platform::complex *A, int lda, + paddle::platform::complex *B, int ldb) { + const void *a_ = (const void *)(A); + void *b_ = static_cast(B); + platform::dynload::cblas_ctrsm(layout, side, uplo, trans_a, diag, M, N, + &alpha, a_, lda, b_, ldb); + } + template static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, @@ -562,6 +573,17 @@ struct CBlas> { a_, lda, b_, ldb, &beta, c_, ldc); } + static void TRSM(CBLAS_LAYOUT layout, CBLAS_SIDE side, CBLAS_UPLO uplo, + CBLAS_TRANSPOSE trans_a, CBLAS_DIAG diag, int M, int N, + paddle::platform::complex alpha, + const paddle::platform::complex *A, int lda, + paddle::platform::complex *B, int ldb) { + const void *a_ = (const void *)(A); + void *b_ = static_cast(B); + platform::dynload::cblas_ztrsm(layout, side, uplo, trans_a, diag, M, N, + &alpha, a_, lda, b_, ldb); + } + template static void GEMM_BATCH(CBLAS_LAYOUT layout, CBLAS_TRANSPOSE *trans_a, CBLAS_TRANSPOSE *trans_b, int *M, int *N, int *K, @@ -682,6 +704,15 @@ struct CBlas> { cblas_cgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); } + + static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side, + const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA, + const CBLAS_DIAG diag, const int M, const int N, + const paddle::platform::complex alpha, + const paddle::platform::complex *A, const int lda, + paddle::platform::complex *B, const int ldb) { + cblas_ctrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); + } }; template <> @@ -720,6 +751,15 @@ struct CBlas> { cblas_zgemm(layout, TransA, TransB, M, N, K, &alpha, A, lda, B, ldb, &beta, C, ldc); } + + static void TRSM(const CBLAS_LAYOUT layout, const CBLAS_SIDE side, + const CBLAS_UPLO uplo, const CBLAS_TRANSPOSE transA, + const CBLAS_DIAG diag, const int M, const int N, + const paddle::platform::complex alpha, + const paddle::platform::complex *A, const int lda, + paddle::platform::complex *B, const int ldb) { + cblas_ztrsm(layout, side, uplo, transA, diag, M, N, &alpha, A, lda, B, ldb); + } }; #endif diff --git a/paddle/fluid/operators/math/blas_impl.hip.h b/paddle/fluid/operators/math/blas_impl.hip.h index 1ce5bac5242ab872cb3ef423c9ae7940ad38db38..f972d38adda5fbb2e507af4936546a19de4cdd41 100644 --- a/paddle/fluid/operators/math/blas_impl.hip.h +++ b/paddle/fluid/operators/math/blas_impl.hip.h @@ -90,6 +90,12 @@ struct CUBlas { PADDLE_THROW(platform::errors::Unimplemented( "cublasSmatinvBatched is not supported on HIP platform.")); } + + template + static void TRSM_BATCH(ARGS... args) { + PADDLE_THROW(platform::errors::Unimplemented( + "cublasStrsmBatched is not supported on HIP platform.")); + } }; template <> @@ -153,6 +159,12 @@ struct CUBlas { PADDLE_THROW(platform::errors::Unimplemented( "cublasDmatinvBatched is not supported on HIP platform.")); } + + template + static void TRSM_BATCH(ARGS... args) { + PADDLE_THROW(platform::errors::Unimplemented( + "cublasDtrsmBatched is not supported on HIP platform.")); + } }; template <> @@ -730,6 +742,32 @@ void Blas::BatchedGETRS( batch_size); }); } + +template <> +template +void Blas::BatchedTRSM( + CBLAS_SIDE side, CBLAS_UPLO uplo, CBLAS_TRANSPOSE transA, CBLAS_DIAG diag, + int M, int N, T alpha, const T **A, int lda, T **B, int ldb, + int batch_size) const { + // solve row major `op ( A ) X = α B` by taking it as `X' op ( A' ) = α B'` + // where ' stands for transpose + rocblas_side cuSide = + (side == CblasLeft) ? rocblas_side_right : rocblas_side_left; + rocblas_fill cuUplo = + (uplo == CblasLower) ? rocblas_fill_upper : rocblas_fill_lower; + // use CUBLAS_OP_C (conjugate transpose) for complex + rocblas_operation cuTransA = (transA == CblasNoTrans) + ? rocblas_operation_none + : rocblas_operation_transpose; + rocblas_diagonal cuDiag = + (diag == CblasUnit) ? rocblas_diagonal_unit : rocblas_diagonal_non_unit; + + context_.CublasCall([&](rocblas_handle handle) { + CUBlas::TRSM_BATCH(handle, cuSide, cuUplo, cuTransA, cuDiag, N, M, + &alpha, A, lda, B, ldb, batch_size); + }); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_solve.cc b/paddle/fluid/operators/math/matrix_solve.cc index 7f13b5c8a70eef7b33e2776d3314bcf18c972dad..95c84d83976f529fae6223eba57dd9661b362018 100644 --- a/paddle/fluid/operators/math/matrix_solve.cc +++ b/paddle/fluid/operators/math/matrix_solve.cc @@ -34,6 +34,45 @@ class MatrixSolveFunctor { template class MatrixSolveFunctor; template class MatrixSolveFunctor; +template +class TriangularSolveFunctor { + public: + void operator()(const platform::CPUDeviceContext& context, + const framework::Tensor* a, framework::Tensor* b, bool left, + bool upper, bool transpose, bool unitriangular) { + CBLAS_SIDE side = left ? CblasLeft : CblasRight; + CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower; + CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans; + CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit; + + const T* a_data = a->data(); + T* b_data = b->mutable_data(context.GetPlace()); + + int a_dim_size = a->dims().size(); + int b_dim_size = b->dims().size(); + + int M = static_cast(b->dims()[b_dim_size - 2]); + int N = static_cast(b->dims()[b_dim_size - 1]); + auto lda = left ? std::max(1, M) : std::max(1, N); + auto ldb = std::max(1, N); + + int batch_size = 1; + auto& a_dim = a->dims(); + for (int i = 0; i < a_dim_size - 2; i++) { + batch_size *= a_dim[i]; + } + + auto blas = math::GetBlas(context); + for (int i = 0; i < batch_size; i++) { + blas.TRSM(side, uplo, transA, diag, M, N, T(1), a_data + i * M * M, lda, + b_data + i * N * M, ldb); + } + } +}; + +template class TriangularSolveFunctor; +template class TriangularSolveFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_solve.cu.cc b/paddle/fluid/operators/math/matrix_solve.cu.cc index efb3a07e4c1b47d649642f630c1e9adc49a9598c..4e5601248c1a2b34d9aea42a4075c616944299b0 100644 --- a/paddle/fluid/operators/math/matrix_solve.cu.cc +++ b/paddle/fluid/operators/math/matrix_solve.cu.cc @@ -163,6 +163,68 @@ class MatrixSolveFunctor { template class MatrixSolveFunctor; template class MatrixSolveFunctor; +template +class TriangularSolveFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, const Tensor* a, + Tensor* b, bool left, bool upper, bool transpose, + bool unitriangular) { + CBLAS_SIDE side = left ? CblasLeft : CblasRight; + CBLAS_UPLO uplo = upper ? CblasUpper : CblasLower; + CBLAS_TRANSPOSE transA = transpose ? CblasTrans : CblasNoTrans; + CBLAS_DIAG diag = unitriangular ? CblasUnit : CblasNonUnit; + + const T* a_data = a->data(); + T* b_data = b->mutable_data(context.GetPlace()); + + int a_dim_size = a->dims().size(); + int b_dim_size = b->dims().size(); + + int M = static_cast(b->dims()[b_dim_size - 2]); + int N = static_cast(b->dims()[b_dim_size - 1]); + auto lda = left ? std::max(1, M) : std::max(1, N); + auto ldb = std::max(1, N); + + int batch_size = 1; + auto& a_dim = a->dims(); + for (int i = 0; i < a_dim_size - 2; i++) { + batch_size *= a_dim[i]; + } + + auto blas = math::GetBlas(context); + if (batch_size <= 8 && M >= 64) { + for (auto i = 0; i < batch_size; i++) { + blas.TRSM(side, uplo, transA, diag, M, N, static_cast(1.0), + a_data + i * M * M, lda, b_data + i * N * M, ldb); + } + } else { + std::vector cpu_ptrs(batch_size * 2); + for (int i = 0; i < batch_size; ++i) { + cpu_ptrs[i] = a_data + i * M * M; + cpu_ptrs[i + batch_size] = b_data + i * M * N; + } + + // Copy the addresses of A and tmp_b from host to device. + memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + memory::Alloc(context, cpu_ptrs.size() * sizeof(T*)); + memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), + tmp_gpu_ptrs_data->ptr(), platform::CPUPlace(), + static_cast(cpu_ptrs.data()), + cpu_ptrs.size() * sizeof(T*), context.stream()); + + const T** gpu_a_ptrs = + reinterpret_cast(tmp_gpu_ptrs_data->ptr()); + T** gpu_b_ptrs = + reinterpret_cast(tmp_gpu_ptrs_data->ptr()) + batch_size; + blas.BatchedTRSM(side, uplo, transA, diag, M, N, static_cast(1.0), + gpu_a_ptrs, lda, gpu_b_ptrs, ldb, batch_size); + } + } +}; + +template class TriangularSolveFunctor; +template class TriangularSolveFunctor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_solve.h b/paddle/fluid/operators/math/matrix_solve.h index 415d0c6dd8e0cf51958783c32aa49c66cce9e15c..1dc43205592f69cc105b43fe49b2f7872f8251c3 100644 --- a/paddle/fluid/operators/math/matrix_solve.h +++ b/paddle/fluid/operators/math/matrix_solve.h @@ -117,6 +117,14 @@ class MatrixSolveFunctor { const framework::Tensor& b, framework::Tensor* out); }; +template +class TriangularSolveFunctor { + public: + void operator()(const DeviceContext& context, const framework::Tensor* a, + framework::Tensor* b, bool left, bool upper, bool transpose, + bool unitriangular); +}; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/solve_op.h b/paddle/fluid/operators/solve_op.h index f70baf486db461d3951ac1b06c61627ed2aa5098..ec72269f697e87b5cb957682312d0a1fa7a8d506 100644 --- a/paddle/fluid/operators/solve_op.h +++ b/paddle/fluid/operators/solve_op.h @@ -49,9 +49,9 @@ struct IdentityFunctor { }; template -void ReduceSumForSolveGrad(const Tensor* input, Tensor* output, - const std::vector& reduce_dims, bool keep_dim, - const paddle::framework::ExecutionContext& ctx) { +void ReduceSumForSolve(const Tensor* input, Tensor* output, + const std::vector& reduce_dims, bool keep_dim, + const paddle::framework::ExecutionContext& ctx) { #if defined(__NVCC__) || defined(__HIPCC__) auto stream = ctx.cuda_device_context().stream(); TensorReduce(*input, output, reduce_dims, @@ -185,36 +185,6 @@ static std::vector get_broadcast_batch_portion( return batchPortion; } -// necessary check before expand operation -static void expand_check(const Tensor& arg1, - std::vector expand_shape) { - auto rank = arg1.dims().size(); - PADDLE_ENFORCE_GE( - rank, 1, platform::errors::InvalidArgument( - "The rank of the input 'X' for expand must be positive, " - "but the value received is %d.", - rank)); - PADDLE_ENFORCE_LE( - rank, MAX_RANK_SUPPORTED, - platform::errors::InvalidArgument( - "The rank of the input 'X' for expand must be less than " - "or equal to %d, but the value received is %d.", - MAX_RANK_SUPPORTED, rank)); - auto shape_size = static_cast(expand_shape.size()); - PADDLE_ENFORCE_GE( - shape_size, rank, - platform::errors::InvalidArgument( - "The number (%d) of elements of 'shape' for expand must be " - "greater than or equal to the rank (%d) of the input 'X'.", - shape_size, rank)); - PADDLE_ENFORCE_LE( - shape_size, MAX_RANK_SUPPORTED, - platform::errors::InvalidArgument( - "The number (%d) of elements of 'shape' for expand must be " - "less than or equal to %d.", - shape_size, MAX_RANK_SUPPORTED)); -} - // broadcast the batch dimensions of tensor x and tensor y. static inline std::tuple, std::vector> get_broadcast_dims(const Tensor& x, const Tensor& y) { @@ -246,15 +216,13 @@ get_broadcast_dims(const Tensor& x, const Tensor& y) { } template -void tensor_expand(const framework::ExecutionContext& context, - const Tensor& arg1, Tensor* out0, - std::vector expand_size) { - auto in_dims = arg1.dims(); - auto expand_shape = expand_size; - auto vec_in_dims = framework::vectorize(in_dims); +void expand_impl(const DeviceContext& context, const Tensor& in, Tensor* out, + const std::vector& expand_shape) { + auto vec_in_dims = framework::vectorize(in.dims()); auto diff = expand_shape.size() - vec_in_dims.size(); vec_in_dims.insert(vec_in_dims.begin(), diff, 1); std::vector repeat_times(vec_in_dims.size()); + for (size_t i = 0; i < vec_in_dims.size(); ++i) { PADDLE_ENFORCE_NE( expand_shape[i], 0, @@ -301,12 +269,11 @@ void tensor_expand(const framework::ExecutionContext& context, out_dims[i] *= repeat_times[i]; } - out0->Resize(out_dims); - auto x = EigenTensor::From(arg1, new_in_dims); - out0->mutable_data(context.GetPlace()); - auto y = EigenTensor::From(*out0, out_dims); - auto& place = - *context.template device_context().eigen_device(); + out->Resize(out_dims); + out->mutable_data(context.GetPlace()); + auto x = EigenTensor::From(in, new_in_dims); + auto y = EigenTensor::From(*out, out_dims); + auto& place = *context.eigen_device(); // use 32-bit index to speed up bool use_32bit_index = y.size() < Eigen::NumTraits::highest(); if (use_32bit_index) { @@ -318,6 +285,41 @@ void tensor_expand(const framework::ExecutionContext& context, } } +template +void TensorExpand(const DeviceContext& context, const Tensor& in, Tensor* out, + const std::vector& expand_shape) { + // necessary check before expand operation + PADDLE_ENFORCE_GE(expand_shape.size(), in.dims().size(), + platform::errors::InvalidArgument( + "The size of 'expand_shape' (%d) should >= the input " + "Tensor's rank (%d).", + expand_shape.size(), in.dims().size())); + PADDLE_ENFORCE_LE(expand_shape.size(), MAX_RANK_SUPPORTED, + platform::errors::InvalidArgument( + "The size of 'expand_shape' (%d) should be <= %d", + expand_shape.size(), MAX_RANK_SUPPORTED)); + switch (expand_shape.size()) { + case 1: + expand_impl<1, T, DeviceContext>(context, in, out, expand_shape); + break; + case 2: + expand_impl<2, T, DeviceContext>(context, in, out, expand_shape); + break; + case 3: + expand_impl<3, T, DeviceContext>(context, in, out, expand_shape); + break; + case 4: + expand_impl<4, T, DeviceContext>(context, in, out, expand_shape); + break; + case 5: + expand_impl<5, T, DeviceContext>(context, in, out, expand_shape); + break; + case 6: + expand_impl<6, T, DeviceContext>(context, in, out, expand_shape); + break; + } +} + template static void linalg_solve(const framework::ExecutionContext& context, const framework::Tensor* x, const framework::Tensor* y, @@ -356,69 +358,11 @@ static void linalg_solve(const framework::ExecutionContext& context, std::tie(x_broadcast_dims, y_broadcast_dims) = get_broadcast_dims(tmp_x, tmp_y); - expand_check(tmp_x, x_broadcast_dims); - expand_check(tmp_y, y_broadcast_dims); - Tensor tmp_x_bc; - Tensor tmp_y_bc; - auto tmp_x_rank = tmp_x.dims().size(); - auto tmp_y_rank = tmp_y.dims().size(); - - auto rank_0 = std::max(tmp_x_rank, static_cast(x_broadcast_dims.size())); - switch (rank_0) { - case 1: - tensor_expand<1, T, DeviceContext>(context, tmp_x, &tmp_x_bc, - x_broadcast_dims); - break; - case 2: - tensor_expand<2, T, DeviceContext>(context, tmp_x, &tmp_x_bc, - x_broadcast_dims); - break; - case 3: - tensor_expand<3, T, DeviceContext>(context, tmp_x, &tmp_x_bc, - x_broadcast_dims); - break; - case 4: - tensor_expand<4, T, DeviceContext>(context, tmp_x, &tmp_x_bc, - x_broadcast_dims); - break; - case 5: - tensor_expand<5, T, DeviceContext>(context, tmp_x, &tmp_x_bc, - x_broadcast_dims); - break; - case 6: - tensor_expand<6, T, DeviceContext>(context, tmp_x, &tmp_x_bc, - x_broadcast_dims); - break; - } + TensorExpand(dev_ctx, tmp_x, &tmp_x_bc, x_broadcast_dims); - auto rank_1 = std::max(tmp_y_rank, static_cast(y_broadcast_dims.size())); - switch (rank_1) { - case 1: - tensor_expand<1, T, DeviceContext>(context, tmp_y, &tmp_y_bc, - y_broadcast_dims); - break; - case 2: - tensor_expand<2, T, DeviceContext>(context, tmp_y, &tmp_y_bc, - y_broadcast_dims); - break; - case 3: - tensor_expand<3, T, DeviceContext>(context, tmp_y, &tmp_y_bc, - y_broadcast_dims); - break; - case 4: - tensor_expand<4, T, DeviceContext>(context, tmp_y, &tmp_y_bc, - y_broadcast_dims); - break; - case 5: - tensor_expand<5, T, DeviceContext>(context, tmp_y, &tmp_y_bc, - y_broadcast_dims); - break; - case 6: - tensor_expand<6, T, DeviceContext>(context, tmp_y, &tmp_y_bc, - y_broadcast_dims); - break; - } + Tensor tmp_y_bc; + TensorExpand(dev_ctx, tmp_y, &tmp_y_bc, y_broadcast_dims); auto x_dim = x->dims(); auto y_dim = y->dims(); @@ -658,8 +602,8 @@ class SolveGradKernel : public framework::OpKernel { if (dy_help.dims().size() != dy->dims().size()) { keep_dim = false; } - ReduceSumForSolveGrad(&dy_help, dy, dy_reduce_dims, - keep_dim, ctx); + ReduceSumForSolve(&dy_help, dy, dy_reduce_dims, + keep_dim, ctx); } dy->Resize(y->dims()); } @@ -708,8 +652,8 @@ class SolveGradKernel : public framework::OpKernel { if (dx_help.dims().size() != dx->dims().size()) { keep_dim = false; } - ReduceSumForSolveGrad(&dx_help, dx, dx_reduce_dims, - keep_dim, ctx); + ReduceSumForSolve(&dx_help, dx, dx_reduce_dims, + keep_dim, ctx); } dx->Resize(input->dims()); } diff --git a/paddle/fluid/operators/triangular_solve_op.cc b/paddle/fluid/operators/triangular_solve_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b01669bf55b407896017e0f598fe6e73cf534ba --- /dev/null +++ b/paddle/fluid/operators/triangular_solve_op.cc @@ -0,0 +1,187 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/triangular_solve_op.h" +#include "paddle/fluid/operators/solve_op.h" + +namespace paddle { +namespace operators { + +class TriangularSolveOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "TriangularSolve"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "TriangularSolve"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "TriangularSolve"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + auto x_dims_n = x_dims.size(); + auto y_dims_n = y_dims.size(); + + PADDLE_ENFORCE_GE( + x_dims_n, 2, platform::errors::InvalidArgument( + "The input tensor X's dimensions of TriangularSolveOp " + "should be >= 2. But received X's " + "dimensions = %d, X's shape = [%s]", + x_dims.size(), x_dims)); + + PADDLE_ENFORCE_GE( + y_dims_n, 2, platform::errors::InvalidArgument( + "The input tensor Y's dimensions of TriangularSolveOp " + "should be >=2. But received Y's " + "dimensions = %d, Y's shape = [%s]", + y_dims.size(), y_dims)); + + PADDLE_ENFORCE_EQ(x_dims[x_dims_n - 2], x_dims[x_dims_n - 1], + platform::errors::InvalidArgument( + "The inner-most 2 dimensions of Input(X) all should " + "be square matrices " + "But received X's shape[-2] = %d and shape[-1] = %d.", + x_dims[x_dims_n - 2], x_dims[x_dims_n - 1])); + + std::vector x_dims_vec = paddle::framework::vectorize(x_dims); + std::vector y_dims_vec = paddle::framework::vectorize(y_dims); + + std::vector x_dims_vec_cut(x_dims_vec.begin(), + x_dims_vec.end() - 2); + std::vector y_dims_vec_cut(y_dims_vec.begin(), + y_dims_vec.end() - 2); + + std::vector expand_batch_portion = + get_broadcast_batch_portion(x_dims_vec_cut, y_dims_vec_cut); + + std::vector y_broadcast_dims({expand_batch_portion}); + y_broadcast_dims.insert(y_broadcast_dims.end(), {y_dims_vec[y_dims_n - 2], + y_dims_vec[y_dims_n - 1]}); + + // dim of 'Out' is the same with 'Y' after broadcast + ctx->SetOutputDim("Out", framework::make_ddim(y_broadcast_dims)); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class TriangularSolveOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), The first input tensor of triangular solve op, which " + "is the triangular coefficient matrix."); + AddInput("Y", + "(Tensor), The second input tensor of triangular solve op, which " + "is multiple right-hand."); + AddOutput("Out", "(Tensor), The solution tensor of triangular solve op."); + AddAttr("upper", + "whether to solve the upper-triangular or the " + "lower-triangular system of equations") + .SetDefault(true); + AddAttr("transpose", "whether X should be transposed firstly.") + .SetDefault(false); + AddAttr("unitriangular", "whether X is unit triangular.") + .SetDefault(false); + AddComment(R"DOC( + Triangular Solve Operator. + This operator is used to computes the solution of equations with a triangular coefficient matrix. + + The equation is: + $$Out = X^-1 * Y$$ +)DOC"); + } +}; + +class TriangularSolveOpInferVarType + : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map& GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; + } +}; + +class TriangularSolveGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "triangular_solve"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "triangular_solve"); + OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "triangular_solve"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "triangular_solve"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +template +class TriangularSolveOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("triangular_solve_grad"); + retv->SetInput("X", this->Input("X")); + retv->SetInput("Y", this->Input("Y")); + retv->SetInput("Out", this->Output("Out")); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + retv->SetOutput(framework::GradVarName("Y"), this->InputGrad("Y")); + retv->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(triangular_solve, ops::TriangularSolveOp, + ops::TriangularSolveOpMaker, + ops::TriangularSolveOpInferVarType, + ops::TriangularSolveOpGradMaker, + ops::TriangularSolveOpGradMaker); + +REGISTER_OPERATOR(triangular_solve_grad, ops::TriangularSolveGradOp); + +REGISTER_OP_CPU_KERNEL( + triangular_solve, + ops::TriangularSolveKernel, + ops::TriangularSolveKernel); + +REGISTER_OP_CPU_KERNEL( + triangular_solve_grad, + ops::TriangularSolveGradKernel, + ops::TriangularSolveGradKernel); diff --git a/paddle/fluid/operators/triangular_solve_op.cu b/paddle/fluid/operators/triangular_solve_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..c5218aec03e2820f21bdc85d6fd0e278376ebf64 --- /dev/null +++ b/paddle/fluid/operators/triangular_solve_op.cu @@ -0,0 +1,64 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" +#include "paddle/fluid/operators/triangular_solve_op.h" + +namespace paddle { +namespace operators { + +template +struct MatrixReduceSumFunctor { + void operator()(const Tensor& in, Tensor* out, + const framework::ExecutionContext& ctx) { + // For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] + // out_reduce_dim should be [0, 2] + const std::vector in_dims = framework::vectorize(in.dims()); + auto in_size = in_dims.size(); + const std::vector out_dims = + framework::vectorize(out->dims()); + auto out_size = out_dims.size(); + + std::vector out_bst_dims(in_size); + + std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1); + std::copy(out_dims.data(), out_dims.data() + out_size, + out_bst_dims.data() + in_size - out_size); + + std::vector out_reduce_dims; + for (size_t idx = 0; idx <= in_size - 3; idx++) { + if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) { + out_reduce_dims.push_back(idx); + } + } + gpuStream_t stream = ctx.cuda_device_context().stream(); + TensorReduceFunctorImpl(in, out, out_reduce_dims, stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + triangular_solve, + ops::TriangularSolveKernel, + ops::TriangularSolveKernel); + +REGISTER_OP_CUDA_KERNEL( + triangular_solve_grad, + ops::TriangularSolveGradKernel, + ops::TriangularSolveGradKernel); diff --git a/paddle/fluid/operators/triangular_solve_op.h b/paddle/fluid/operators/triangular_solve_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f64b016366e39b2260f4f8aebbb2e371ee2a8a7a --- /dev/null +++ b/paddle/fluid/operators/triangular_solve_op.h @@ -0,0 +1,227 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "glog/logging.h" +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/reduce_ops/reduce_op.h" +#include "paddle/fluid/operators/solve_op.h" +#include "paddle/fluid/operators/tril_triu_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +static void triangular_solve(const DeviceContext& context, const Tensor& x, + const Tensor& y, Tensor* out, bool upper, + bool transpose, bool unitriangular) { + // Tensor broadcast use eigen + std::vector x_bst_dims_vec; + std::vector y_bst_dims_vec; + std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(x, y); + + Tensor x_bst(x.type()); + TensorExpand(context, x, &x_bst, x_bst_dims_vec); + + Tensor y_bst(y.type()); + TensorExpand(context, y, &y_bst, y_bst_dims_vec); + + // TriangularSolveFunctor performs calculations in-place + // x_clone should be a copy of 'x' after broadcast + // out should be a copy of 'y' after broadcast + Tensor x_clone(x.type()); + x_clone.Resize(framework::make_ddim(x_bst_dims_vec)); + x_clone.mutable_data(context.GetPlace()); + framework::TensorCopy(x_bst, context.GetPlace(), context, &x_clone); + + out->Resize(framework::make_ddim(y_bst_dims_vec)); + out->mutable_data(context.GetPlace()); + framework::TensorCopy(y_bst, context.GetPlace(), context, out); + + math::TriangularSolveFunctor functor; + functor(context, &x_clone, out, /*left=*/true, upper, transpose, + unitriangular); +} + +template +class MatrixReduceSumFunctor { + public: + void operator()(const Tensor& input, Tensor* output, + const framework::ExecutionContext& ctx); +}; + +template +class MatrixReduceSumFunctor { + public: + void operator()(const Tensor& in, Tensor* out, + const framework::ExecutionContext& ctx) { + // For example: in's dim = [5, 3, 2, 7, 3] ; out's dim = [3, 1, 7, 3] + // out_reduce_dim should be [0, 2] + const std::vector in_dims = framework::vectorize(in.dims()); + auto in_size = in_dims.size(); + const std::vector out_dims = + framework::vectorize(out->dims()); + auto out_size = out_dims.size(); + + std::vector out_bst_dims(in_size); + + std::fill(out_bst_dims.data(), out_bst_dims.data() + in_size - out_size, 1); + std::copy(out_dims.data(), out_dims.data() + out_size, + out_bst_dims.data() + in_size - out_size); + out->Resize(framework::make_ddim(out_bst_dims)); + + std::vector out_reduce_dims; + for (size_t idx = 0; idx <= in_size - 3; idx++) { + if (in_dims[idx] != 1 && out_bst_dims[idx] == 1) { + out_reduce_dims.push_back(idx); + } + } + + ReduceKernelFunctor( + &in, out, out_reduce_dims, true, false, ctx) + .template apply(); + out->Resize(framework::make_ddim(out_dims)); + } +}; + +template +class TriangularSolveKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* x = ctx.Input("X"); + const auto* y = ctx.Input("Y"); + auto* out = ctx.Output("Out"); + + bool upper = ctx.template Attr("upper"); + bool transpose = ctx.template Attr("transpose"); + bool unitriangular = ctx.template Attr("unitriangular"); + + const auto& dev_ctx = ctx.template device_context(); + triangular_solve(dev_ctx, *x, *y, out, upper, transpose, + unitriangular); + } +}; + +template +class TriangularSolveGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* x = ctx.Input("X"); + const auto* y = ctx.Input("Y"); + const auto* out = ctx.Input("Out"); + const auto* dout = + ctx.Input(framework::GradVarName("Out")); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + bool upper = ctx.template Attr("upper"); + bool transpose = ctx.template Attr("transpose"); + bool unitriangular = ctx.template Attr("unitriangular"); + + auto& dev_ctx = ctx.template device_context(); + + std::vector x_bst_dims_vec; + std::vector y_bst_dims_vec; + std::tie(x_bst_dims_vec, y_bst_dims_vec) = get_broadcast_dims(*x, *y); + + Tensor dy_bst(y->type()); + if (dy) { + dy->mutable_data(y->dims(), dev_ctx.GetPlace()); + dy_bst.Resize(framework::make_ddim(y_bst_dims_vec)); + dy_bst.mutable_data(dev_ctx.GetPlace()); + + // calculate x's conjugate for complex + Tensor x_conj(x->type()); + platform::ForRange x_for_range(dev_ctx, x->numel()); + math::ConjFunctor x_functor( + x->data(), x->numel(), + x_conj.mutable_data(x->dims(), dev_ctx.GetPlace())); + x_for_range(x_functor); + + // reuse forward to get dy_bst, and the result has been broadcated. + triangular_solve(dev_ctx, x_conj, *dout, &dy_bst, upper, + !transpose, unitriangular); + + if (dy_bst.dims() == dy->dims()) { + framework::TensorCopy(dy_bst, dev_ctx.GetPlace(), dev_ctx, dy); + } else { + MatrixReduceSumFunctor functor; + functor(dy_bst, dy, ctx); + dy->Resize(y->dims()); + } + } + + Tensor dx_bst(x->type()); + if (dx) { + dx->mutable_data(x->dims(), dev_ctx.GetPlace()); + dx_bst.Resize(framework::make_ddim(x_bst_dims_vec)); + dx_bst.mutable_data(dev_ctx.GetPlace()); + + // calculate out's conjugate for complex + Tensor out_conj(out->type()); + platform::ForRange out_for_range(dev_ctx, out->numel()); + math::ConjFunctor out_functor( + out->data(), out->numel(), + out_conj.mutable_data(out->dims(), dev_ctx.GetPlace())); + out_for_range(out_functor); + + auto blas = math::GetBlas(ctx); + if (transpose) { + auto mat_dim_a = + math::CreateMatrixDescriptor(out_conj.dims(), 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(dy_bst.dims(), 0, true); + blas.MatMul(out_conj, mat_dim_a, dy_bst, mat_dim_b, static_cast(-1), + &dx_bst, static_cast(0)); + } else { + auto mat_dim_a = math::CreateMatrixDescriptor(dy_bst.dims(), 0, false); + auto mat_dim_b = math::CreateMatrixDescriptor(out_conj.dims(), 0, true); + blas.MatMul(dy_bst, mat_dim_a, out_conj, mat_dim_b, static_cast(-1), + &dx_bst, static_cast(0)); + } + + Tensor dx_bst_upper(x->type()); + // get upper or lower triangular + dx_bst_upper.Resize(dx_bst.dims()); + dx_bst_upper.mutable_data(dev_ctx.GetPlace()); + + const auto& dims = dx_bst.dims(); + const auto H = dims[dims.size() - 2]; + const auto W = dims[dims.size() - 1]; + platform::ForRange x_for_range(dev_ctx, dx_bst.numel()); + TrilTriuCompute tril_triu_computer(dx_bst.data(), unitriangular, + !upper, H, W, + dx_bst_upper.data()); + x_for_range(tril_triu_computer); + + if (dx_bst_upper.dims() == dx->dims()) { + framework::TensorCopy(dx_bst_upper, dev_ctx.GetPlace(), dev_ctx, dx); + } else { + MatrixReduceSumFunctor functor; + functor(dx_bst_upper, dx, ctx); + dx->Resize(x->dims()); + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cublas.h b/paddle/fluid/platform/dynload/cublas.h index ab30ab307a9c7cb688aefc072dbf2a639d1a2531..17ae4d5bf03d7b20862b6d384719b25d5fc69e90 100644 --- a/paddle/fluid/platform/dynload/cublas.h +++ b/paddle/fluid/platform/dynload/cublas.h @@ -75,6 +75,8 @@ extern void *cublas_dso_handle; __macro(cublasDgeam); \ __macro(cublasStrsm_v2); \ __macro(cublasDtrsm_v2); \ + __macro(cublasCtrsm_v2); \ + __macro(cublasZtrsm_v2); \ __macro(cublasCreate_v2); \ __macro(cublasDestroy_v2); \ __macro(cublasSetStream_v2); \ @@ -84,6 +86,10 @@ extern void *cublas_dso_handle; __macro(cublasDgemmBatched); \ __macro(cublasCgemmBatched); \ __macro(cublasZgemmBatched); \ + __macro(cublasStrsmBatched); \ + __macro(cublasDtrsmBatched); \ + __macro(cublasCtrsmBatched); \ + __macro(cublasZtrsmBatched); \ __macro(cublasSgetrfBatched); \ __macro(cublasSgetriBatched); \ __macro(cublasDgetrfBatched); \ diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h index 11208289165935ab8843435ff39477378a554efd..335b919f41c34b08fb7ea4398f2db96620058e4f 100644 --- a/paddle/fluid/platform/dynload/mklml.h +++ b/paddle/fluid/platform/dynload/mklml.h @@ -25,7 +25,7 @@ namespace platform { namespace dynload { extern std::once_flag mklml_dso_flag; -extern void* mklml_dso_handle; +extern void *mklml_dso_handle; /** * The following macro definition can generate structs @@ -40,7 +40,7 @@ extern void* mklml_dso_handle; std::call_once(mklml_dso_flag, []() { \ mklml_dso_handle = paddle::platform::dynload::GetMKLMLDsoHandle(); \ }); \ - static void* p_##_name = dlsym(mklml_dso_handle, #__name); \ + static void *p_##_name = dlsym(mklml_dso_handle, #__name); \ return reinterpret_cast(p_##_name)(args...); \ } \ }; \ @@ -67,6 +67,8 @@ extern void* mklml_dso_handle; __macro(cblas_zgemv); \ __macro(cblas_strsm); \ __macro(cblas_dtrsm); \ + __macro(cblas_ctrsm); \ + __macro(cblas_ztrsm); \ __macro(cblas_sgemm_alloc); \ __macro(cblas_dgemm_alloc); \ __macro(cblas_sgemm_pack); \ diff --git a/python/paddle/fluid/tests/unittests/test_triangular_solve_op.py b/python/paddle/fluid/tests/unittests/test_triangular_solve_op.py new file mode 100644 index 0000000000000000000000000000000000000000..45e88d681d8e095bdfe732de2f66eb0720cb7346 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_triangular_solve_op.py @@ -0,0 +1,339 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.w + +from __future__ import print_function + +import unittest +import numpy as np + +import sys +sys.path.append("..") +import paddle +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard, core + +paddle.enable_static() + + +# 2D + 2D , test 'upper' +class TestTriangularSolveOp(OpTest): + """ + case 1 + """ + + def config(self): + self.x_shape = [12, 12] + self.y_shape = [12, 10] + self.upper = True + self.transpose = False + self.unitriangular = False + self.dtype = "float64" + + def set_output(self): + self.output = np.linalg.solve( + np.triu(self.inputs['X']), self.inputs['Y']) + + def setUp(self): + self.op_type = "triangular_solve" + self.config() + + self.inputs = { + 'X': np.random.random(self.x_shape).astype(self.dtype), + 'Y': np.random.random(self.y_shape).astype(self.dtype) + } + self.attrs = { + 'upper': self.upper, + 'transpose': self.transpose, + 'unitriangular': self.unitriangular, + } + self.set_output() + self.outputs = {'Out': self.output} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + +# 2D(broadcast) + 3D, test 'transpose' +class TestTriangularSolveOp2(TestTriangularSolveOp): + """ + case 2 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [3, 10, 8] + self.upper = False + self.transpose = True + self.unitriangular = False + self.dtype = "float64" + + def set_output(self): + x = np.tril(self.inputs['X']).transpose(1, 0) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + +# 3D(broadcast) + 3D +class TestTriangularSolveOp3(TestTriangularSolveOp): + """ + case 3 + """ + + def config(self): + self.x_shape = [1, 10, 10] + self.y_shape = [6, 10, 12] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = "float64" + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + +# 3D + 3D(broadcast), test 'transpose' +class TestTriangularSolveOp4(TestTriangularSolveOp): + """ + case 4 + """ + + def config(self): + self.x_shape = [3, 10, 10] + self.y_shape = [1, 10, 12] + self.upper = True + self.transpose = True + self.unitriangular = False + self.dtype = "float64" + + def set_output(self): + x = np.triu(self.inputs['X']).transpose(0, 2, 1) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + +# 2D + 2D , test 'unitriangular' specially +class TestTriangularSolveOp5(TestTriangularSolveOp): + """ + case 5 + """ + + def config(self): + self.x_shape = [10, 10] + self.y_shape = [10, 10] + self.upper = True + self.transpose = False + self.unitriangular = True + self.dtype = "float64" + + def set_output(self): + x = np.triu(self.inputs['X']) + np.fill_diagonal(x, 1.) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + def test_check_grad_normal(self): + x = np.triu(self.inputs['X']) + np.fill_diagonal(x, 1.) + grad_out = np.ones([10, 10]).astype('float64') + grad_y = np.linalg.solve(x.transpose(1, 0), grad_out) + + grad_x = -np.matmul(grad_y, self.output.transpose(1, 0)) + grad_x = np.triu(grad_x) + np.fill_diagonal(grad_x, 0.) + + self.check_grad( + ['X', 'Y'], + 'Out', + user_defined_grads=[grad_x, grad_y], + user_defined_grad_outputs=[grad_out]) + + +# 4D(broadcast) + 4D(broadcast) +class TestTriangularSolveOp6(TestTriangularSolveOp): + """ + case 6 + """ + + def config(self): + self.x_shape = [1, 3, 10, 10] + self.y_shape = [2, 1, 10, 5] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = "float64" + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + +# 3D(broadcast) + 4D(broadcast), test 'upper' +class TestTriangularSolveOp7(TestTriangularSolveOp): + """ + case 7 + """ + + def config(self): + self.x_shape = [2, 10, 10] + self.y_shape = [5, 1, 10, 2] + self.upper = True + self.transpose = True + self.unitriangular = False + self.dtype = "float64" + + def set_output(self): + x = np.triu(self.inputs['X']).transpose(0, 2, 1) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + +# 3D(broadcast) + 5D +class TestTriangularSolveOp8(TestTriangularSolveOp): + """ + case 8 + """ + + def config(self): + self.x_shape = [12, 3, 3] + self.y_shape = [2, 3, 12, 3, 2] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = "float64" + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.linalg.solve(x, y) + + +# 5D + 4D(broadcast) +class TestTriangularSolveOp9(TestTriangularSolveOp): + """ + case 9 + """ + + def config(self): + self.x_shape = [2, 4, 2, 3, 3] + self.y_shape = [4, 1, 3, 10] + self.upper = False + self.transpose = False + self.unitriangular = False + self.dtype = "float64" + + def set_output(self): + x = np.tril(self.inputs['X']) + y = self.inputs['Y'] + self.output = np.matmul(np.linalg.inv(x), y) + + +class TestTriangularSolveAPI(unittest.TestCase): + def setUp(self): + np.random.seed(2021) + self.place = [paddle.CPUPlace()] + self.dtype = "float64" + if core.is_compiled_with_cuda(): + self.place.append(paddle.CUDAPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = fluid.data(name="x", shape=[3, 3], dtype=self.dtype) + y = fluid.data(name="y", shape=[3, 2], dtype=self.dtype) + z = paddle.linalg.triangular_solve(x, y) + + x_np = np.random.random([3, 3]).astype(self.dtype) + y_np = np.random.random([3, 2]).astype(self.dtype) + z_np = np.linalg.solve(np.triu(x_np), y_np) + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"x": x_np, + "y": y_np}, + fetch_list=[z]) + self.assertTrue(np.allclose(fetches[0], z_np)) + + def test_static(self): + for place in self.place: + self.check_static_result(place=place) + + def test_dygraph(self): + def run(place): + paddle.disable_static(place) + x_np = np.random.random([3, 3]).astype(self.dtype) + y_np = np.random.random([3, 2]).astype(self.dtype) + z_np = np.linalg.solve(np.tril(x_np), y_np) + + x = paddle.to_tensor(x_np) + y = paddle.to_tensor(y_np) + z = paddle.linalg.triangular_solve(x, y, upper=False) + + self.assertTrue(np.allclose(z_np, z.numpy())) + self.assertEqual(z_np.shape, z.numpy().shape) + paddle.enable_static() + + for place in self.place: + run(place) + + +class TestTriangularSolveOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of solve_op must be Variable. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + y1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + self.assertRaises(TypeError, paddle.linalg.triangular_solve, x1, y1) + + # The data type of input must be float32 or float64. + x2 = fluid.data(name="x2", shape=[30, 30], dtype="bool") + y2 = fluid.data(name="y2", shape=[30, 10], dtype="bool") + self.assertRaises(TypeError, paddle.linalg.triangular_solve, x2, y2) + + x3 = fluid.data(name="x3", shape=[30, 30], dtype="int32") + y3 = fluid.data(name="y3", shape=[30, 10], dtype="int32") + self.assertRaises(TypeError, paddle.linalg.triangular_solve, x3, y3) + + x4 = fluid.data(name="x4", shape=[30, 30], dtype="float16") + y4 = fluid.data(name="y4", shape=[30, 10], dtype="float16") + self.assertRaises(TypeError, paddle.linalg.triangular_solve, x4, y4) + + # The number of dimensions of input'X must be >= 2. + x5 = fluid.data(name="x5", shape=[30], dtype="float64") + y5 = fluid.data(name="y5", shape=[30, 30], dtype="float64") + self.assertRaises(ValueError, paddle.linalg.triangular_solve, x5, + y5) + + # The number of dimensions of input'Y must be >= 2. + x6 = fluid.data(name="x6", shape=[30, 30], dtype="float64") + y6 = fluid.data(name="y6", shape=[30], dtype="float64") + self.assertRaises(ValueError, paddle.linalg.triangular_solve, x6, + y6) + + # The inner-most 2 dimensions of input'X should be equal to each other + x7 = fluid.data(name="x7", shape=[2, 3, 4], dtype="float64") + y7 = fluid.data(name="y7", shape=[2, 4, 3], dtype="float64") + self.assertRaises(ValueError, paddle.linalg.triangular_solve, x7, + y7) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 63a058cd25b1e6f8cbbde8d183609d371e9d5466..824e5d2e3b59455cfeaea65bf3d2d358b074ff2a 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -29,6 +29,7 @@ from .tensor.linalg import eigvalsh from .tensor.linalg import det from .tensor.linalg import slogdet from .tensor.linalg import pinv +from .tensor.linalg import triangular_solve __all__ = [ 'cholesky', #noqa @@ -47,5 +48,6 @@ __all__ = [ 'eigh', 'eigvalsh', 'pinv', - 'solve' + 'solve', + 'triangular_solve', ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index d046b666c3ef33a738367b4d405bfa8389fca187..9458b35f00d967a4311be09f641fa0f4b9c90e72 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -397,6 +397,7 @@ tensor_method_func = [ #noqa 'uniform_', 'multi_dot', 'solve', + 'triangular_solve' ] #this list used in math_op_patch.py for magic_method bind diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index e25b10236869e7b3c8f0530a7529b46142dcbd34..8f4efed1f0a8ed8e41a8584ed4005d1cb8a66eb9 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2315,6 +2315,79 @@ def solve(x, y, name=None): return out +def triangular_solve(x, + y, + upper=True, + transpose=False, + unitriangular=False, + name=None): + r""" + Computes the solution of a system of equations with a triangular coefficient matrix `x` and + multiple right-hand sides `y` . + + Input `x` and `y` is 2D matrices or batches of 2D matrices. If the inputs are batches, the outputs + is also batches. + + Args: + x (Tensor): The input triangular coefficient matrix. Its shape should be `[*, M, M]`, where `*` is zero or + more batch dimensions. Its data type should be float32 or float64. + y (Tensor): Multiple right-hand sides of system of equations. Its shape should be `[*, M, K]`, where `*` is + zero or more batch dimensions. Its data type should be float32 or float64. + upper (bool, optional): Whether to solve the upper-triangular system of equations (default) or the lower-triangular + system of equations. Default: True. + transpose (bool, optional): whether `x` should be transposed before calculation. Default: False. + unitriangular (bool, optional): whether `x` is unit triangular. If True, the diagonal elements of `x` are assumed + to be 1 and not referenced from `x` . Default: False. + name(str, optional): Name for the operation (optional, default is None). + For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor: The solution of the system of equations. Its data type should be the same as that of `x`. + + Examples: + .. code-block:: python + + # a square system of linear equations: + # x1 + x2 + x3 = 0 + # 2*x2 + x3 = -9 + # -x3 = 5 + + import paddle + import numpy as np + + x = paddle.to_tensor([[1, 1, 1], + [0, 2, 1], + [0, 0,-1]], dtype="float64") + y = paddle.to_tensor([[0], [-9], [5]], dtype="float64") + out = paddle.linalg.triangular_solve(x, y, upper=True) + + print(out) + # [7, -2, -5] + """ + if in_dygraph_mode(): + return _C_ops.triangular_solve(x, y, 'upper', upper, 'transpose', + transpose, 'unitriangular', + unitriangular) + + inputs = {"X": [x], "Y": [y]} + helper = LayerHelper("triangular_solve", **locals()) + check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'triangular_solve') + check_variable_and_dtype(y, 'y', ['float32', 'float64'], 'triangular_solve') + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='triangular_solve', + inputs={'X': x, + 'Y': y}, + outputs={'Out': out}, + attrs={ + 'upper': upper, + 'transpose': transpose, + 'unitriangular': unitriangular + }) + return out + + def eigvalsh(x, UPLO='L', name=None): """ Computes the eigenvalues of a