From ae8ca76468213739b19047bb37f2c525896b1083 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Wed, 13 Jul 2022 18:05:44 +0800 Subject: [PATCH] [Phi] Migrate matrix_solve to phi (#44298) * [Phi] Migrate matrix_solve to phi * replace mutable_data with Alloc --- paddle/fluid/operators/eig_op.h | 4 +- paddle/fluid/operators/lstsq_op.h | 2 +- paddle/fluid/operators/math/CMakeLists.txt | 1 - paddle/fluid/operators/math/matrix_solve.cc | 41 ---- .../fluid/operators/math/matrix_solve.cu.cc | 189 ------------------ paddle/fluid/operators/solve_op.h | 69 +------ paddle/phi/kernels/funcs/CMakeLists.txt | 1 + paddle/phi/kernels/funcs/matrix_solve.cc | 32 +++ paddle/phi/kernels/funcs/matrix_solve.cu | 178 +++++++++++++++++ .../math => phi/kernels/funcs}/matrix_solve.h | 108 +++++++--- 10 files changed, 302 insertions(+), 323 deletions(-) delete mode 100644 paddle/fluid/operators/math/matrix_solve.cc delete mode 100644 paddle/fluid/operators/math/matrix_solve.cu.cc create mode 100644 paddle/phi/kernels/funcs/matrix_solve.cc create mode 100644 paddle/phi/kernels/funcs/matrix_solve.cu rename paddle/{fluid/operators/math => phi/kernels/funcs}/matrix_solve.h (61%) diff --git a/paddle/fluid/operators/eig_op.h b/paddle/fluid/operators/eig_op.h index 138a987a0b..82c7fe6881 100644 --- a/paddle/fluid/operators/eig_op.h +++ b/paddle/fluid/operators/eig_op.h @@ -19,7 +19,6 @@ #include #include -#include "paddle/fluid/operators/math/matrix_solve.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/complex_kernel.h" @@ -30,6 +29,7 @@ #include "paddle/phi/kernels/funcs/diag_functor.h" #include "paddle/phi/kernels/funcs/lapack/lapack_function.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/matrix_solve.h" #include "paddle/phi/kernels/funcs/slice.h" #include "paddle/phi/kernels/funcs/unsqueeze.h" #include "paddle/phi/kernels/matmul_kernel.h" @@ -366,7 +366,7 @@ void ComputeBackwardForComplexInput( int k = rhs.dims()[rhs.dims().size() - 1]; auto* matrix_data = Vh.data(); auto* rhs_data = rhs.data(); - math::SolveLinearSystem( + phi::funcs::SolveLinearSystem( matrix_data, rhs_data, x_grad_data, m, k, batch_count); } diff --git a/paddle/fluid/operators/lstsq_op.h b/paddle/fluid/operators/lstsq_op.h index f99e027e9c..b3e5894a94 100644 --- a/paddle/fluid/operators/lstsq_op.h +++ b/paddle/fluid/operators/lstsq_op.h @@ -21,13 +21,13 @@ #include "paddle/fluid/operators/eig_op.h" #include "paddle/fluid/operators/math/eigen_values_vectors.h" -#include "paddle/fluid/operators/math/matrix_solve.h" #include "paddle/fluid/operators/svd_helper.h" #include "paddle/fluid/operators/transpose_op.h" #include "paddle/fluid/platform/for_range.h" #include "paddle/phi/kernels/funcs/complex_functors.h" #include "paddle/phi/kernels/funcs/lapack/lapack_function.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/matrix_solve.h" #define EPSILON 1e-6 diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 3f7206ac08..927feedd18 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -54,7 +54,6 @@ math_library(vol2col) math_library(prelu) math_library(bert_encoder_functor) math_library(tree2col DEPS math_function) -math_library(matrix_solve) cc_test( selected_rows_functor_test diff --git a/paddle/fluid/operators/math/matrix_solve.cc b/paddle/fluid/operators/math/matrix_solve.cc deleted file mode 100644 index b0f8843a53..0000000000 --- a/paddle/fluid/operators/math/matrix_solve.cc +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/matrix_solve.h" - -#include "Eigen/Core" -#include "Eigen/LU" -#include "paddle/phi/kernels/funcs/blas/blas.h" - -namespace paddle { -namespace operators { -namespace math { - -template -class MatrixSolveFunctor { - public: - void operator()(const phi::CPUContext& dev_ctx, - const framework::Tensor& a, - const framework::Tensor& b, - framework::Tensor* out) { - compute_solve_eigen(dev_ctx, a, b, out); - } -}; - -template class MatrixSolveFunctor; -template class MatrixSolveFunctor; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/matrix_solve.cu.cc b/paddle/fluid/operators/math/matrix_solve.cu.cc deleted file mode 100644 index 41b14c07b7..0000000000 --- a/paddle/fluid/operators/math/matrix_solve.cu.cc +++ /dev/null @@ -1,189 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/matrix_solve.h" - -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/solve_op.h" -#include "paddle/fluid/platform/device_context.h" -#include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/funcs/math_function.h" - -namespace paddle { -namespace platform { -class CUDADeviceContext; -} // namespace platform -} // namespace paddle - -namespace paddle { -namespace operators { -namespace math { - -template -class MatrixSolveFunctor; - -template -class MatrixSolveFunctor { - public: - void operator()(const platform::CUDADeviceContext& context, - const framework::Tensor& a, - const framework::Tensor& b, - framework::Tensor* out) { -#ifndef PADDLE_WITH_HIP - - // solve the equation: Ax = B, - // use cuBlas cublasgetrfBatched funcion to performs the LU - // factorization of each matrix A, - // and then use cuBlas cublasgetriBatched function to solve the - // equation after LU factorization. - // ref: - // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched - const auto& a_dims = a.dims(); - const int a_rank = a_dims.size(); - int n = a_dims[a_rank - 1]; - int lda = n; - int batch_size = a_rank > 2 ? a.numel() / (n * n) : 1; - - const auto& b_dims = b.dims(); - const int b_rank = b_dims.size(); - int nrhs = b_dims[b_rank - 1]; - int ldb = b_dims[b_rank - 2]; - - // make sure the out dims is right - out->Resize(b_dims); - out->mutable_data(context.GetPlace()); - - // copy input A to a temporary tensor tmp_a, - // LU factorization, written back to original matrix A, so in the beginning, - // it's necessary to create a temporary tensor tmp_a. - Tensor tmp_a(a.dtype()); - tmp_a.Resize(a.dims()); - tmp_a.mutable_data(context.GetPlace()); - framework::TensorCopy(a, context.GetPlace(), &tmp_a); - - // copy input B to a temporary tensor tmp_b, and transpose tmp_b, - // because cuBlas assumes column-major while Paddle uses row-majar. - Tensor tmp_b(b.type()); - const auto& new_dims_vec = getNewDimsVec(b_dims); - tmp_b.Resize(phi::make_ddim(new_dims_vec)); - tmp_b.mutable_data(context.GetPlace()); - phi::funcs::TransposeNormal trans; - std::vector new_axis = getNewAxis(b_rank); - trans(context, b, &tmp_b, new_axis); - - const T* a_data_in_gpu = tmp_a.data(); - const T* b_data_in_gpu = tmp_b.data(); - - std::vector cpu_ptrs(batch_size * 2); - for (int i = 0; i < batch_size; ++i) { - cpu_ptrs[i] = a_data_in_gpu + i * n * n; - cpu_ptrs[i + batch_size] = b_data_in_gpu + i * n * nrhs; - } - - // Copy the addresses of A and tmp_b from host to device. - memory::allocation::AllocationPtr tmp_gpu_ptrs_data = - memory::Alloc(context, cpu_ptrs.size() * sizeof(T*)); - memory::Copy(context.GetPlace(), - tmp_gpu_ptrs_data->ptr(), - platform::CPUPlace(), - static_cast(cpu_ptrs.data()), - cpu_ptrs.size() * sizeof(T*), - context.stream()); - - T** gpu_tmp_b_ptrs = - reinterpret_cast(tmp_gpu_ptrs_data->ptr()) + batch_size; - - // Allocate device memory for BatchedGETRF's info and pivots. - int num_ints = n < 32 ? batch_size : batch_size * (n + 1); - memory::allocation::AllocationPtr tmp_gpu_info_data = - memory::Alloc(context, num_ints * sizeof(int)); - int* gpu_info_ptr = reinterpret_cast(tmp_gpu_info_data->ptr()); - - auto blas = phi::funcs::GetBlas(context); - - // only for singular checking - std::vector info; - info.resize(batch_size); - - int* gpu_pivot_ptr = - reinterpret_cast(tmp_gpu_info_data->ptr()) + batch_size; - - // This function performs the LU factorization of each matrix A by the - // equation A = L * U. L and U are written back to original matrix A, - // and diagonal elements of L are discarded. - blas.BatchedGETRF(n, - reinterpret_cast(tmp_gpu_ptrs_data->ptr()), - gpu_pivot_ptr, - gpu_info_ptr, - batch_size); - - // check whether BatchedGETRF is executed successfully or not - memory::Copy(platform::CPUPlace(), - info.data(), - context.GetPlace(), - gpu_info_ptr, - sizeof(int) * batch_size, - context.stream()); - for (int i = 0; i < batch_size; ++i) { - PADDLE_ENFORCE_EQ(info[i], - 0, - platform::errors::PreconditionNotMet( - "For batch [%d]: U(%d, %d) is zero, singular U. " - "Please check the matrix value and change it to a " - "non-singular matrix", - i, - info[i], - info[i])); - } - - // hold the result code from BatchedGETRS - int host_info = 0; - - // to solve the equation after LU factorization - CBLAS_TRANSPOSE transA = CblasTrans; - blas.BatchedGETRS(transA, - n, - nrhs, - reinterpret_cast(tmp_gpu_ptrs_data->ptr()), - lda, - gpu_pivot_ptr, - gpu_tmp_b_ptrs, - ldb, - &host_info, - batch_size); - - // check whether BatchedGETRS is executed successfully or not - PADDLE_ENFORCE_EQ(host_info, - 0, - platform::errors::InvalidArgument( - "The [%d]'th argument to cublas*getrsBatched had " - "an illegal value.", - -host_info)); - - // transpose tmp_b to get the final result in row-major form. - phi::funcs::TransposeNormal trans2; - trans2(context, tmp_b, out, new_axis); - -#else - compute_solve_eigen(context, a, b, out); -#endif - } -}; - -template class MatrixSolveFunctor; -template class MatrixSolveFunctor; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/solve_op.h b/paddle/fluid/operators/solve_op.h index b97b8d01cc..1152237494 100644 --- a/paddle/fluid/operators/solve_op.h +++ b/paddle/fluid/operators/solve_op.h @@ -20,11 +20,11 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#include "paddle/fluid/operators/math/matrix_solve.h" #include "paddle/fluid/operators/reduce_ops/reduce_sum_op.h" #include "paddle/fluid/operators/squeeze_op.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/matrix_solve.h" #if defined(__NVCC__) || defined(__HIPCC__) #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #endif @@ -351,7 +351,7 @@ static void linalg_solve(const framework::ExecutionContext& context, out->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); - math::MatrixSolveFunctor mat_solve; + phi::funcs::MatrixSolveFunctor mat_solve; // input y can be vector or matrix // but need to be unsqueezed if y is a vector @@ -425,67 +425,6 @@ static void linalg_solve(const framework::ExecutionContext& context, } } -// for TransposeNormal -static std::vector getNewAxis(const int b_rank) { - std::vector axis_1 = {0}; - std::vector axis_2 = {1, 0}; - std::vector axis_3 = {0, 2, 1}; - std::vector axis_4 = {0, 1, 3, 2}; - std::vector axis_5 = {0, 1, 2, 4, 3}; - std::vector axis_6 = {0, 1, 2, 3, 5, 4}; - std::vector axis_7 = {0, 1, 2, 3, 4, 6, 5}; - std::vector axis_8 = {0, 1, 2, 3, 4, 5, 7, 6}; - std::vector axis_9 = {0, 1, 2, 3, 4, 5, 6, 8, 7}; - switch (b_rank) { - case 1: - return axis_1; - break; - case 2: - return axis_2; - break; - case 3: - return axis_3; - break; - case 4: - return axis_4; - break; - case 5: - return axis_5; - break; - case 6: - return axis_6; - break; - case 7: - return axis_7; - break; - case 8: - return axis_8; - break; - default: - return axis_9; - } -} - -// for Resize -static std::vector getNewDimsVec(const DDim& b_dims) { - std::vector b_dims_vec = phi::vectorize(b_dims); - int size = b_dims_vec.size(); - if (size >= 2) { - // swap the last 2 elements in b_dims_vec - int64_t temp = b_dims_vec[size - 1]; - b_dims_vec[size - 1] = b_dims_vec[size - 2]; - b_dims_vec[size - 2] = temp; - return b_dims_vec; - } - PADDLE_ENFORCE_NE( - b_dims_vec.empty(), - true, - platform::errors::PreconditionNotMet( - "The size of tensor b must not be %d after getting new dims", 0)); - // if b_dims_vec.size() == 1, just retun original vec - return b_dims_vec; -} - template class SolveKernel : public framework::OpKernel { public: @@ -553,11 +492,11 @@ class SolveGradKernel : public framework::OpKernel { tmp_dy.mutable_data(ctx.GetPlace()); Tensor tmp_input(input->dtype()); - const auto& new_dims_vec = getNewDimsVec(input->dims()); + const auto& new_dims_vec = phi::funcs::getNewDimsVec(input->dims()); tmp_input.Resize(phi::make_ddim(new_dims_vec)); tmp_input.mutable_data(ctx.GetPlace()); phi::funcs::TransposeNormal trans; - std::vector new_axis = getNewAxis(input->dims().size()); + std::vector new_axis = phi::funcs::getNewAxis(input->dims().size()); auto& dev_ctx = ctx.template device_context(); trans(dev_ctx, *input, &tmp_input, new_axis); diff --git a/paddle/phi/kernels/funcs/CMakeLists.txt b/paddle/phi/kernels/funcs/CMakeLists.txt index 6d16fc8f81..25696a34e3 100644 --- a/paddle/phi/kernels/funcs/CMakeLists.txt +++ b/paddle/phi/kernels/funcs/CMakeLists.txt @@ -14,3 +14,4 @@ math_library(matrix_inverse DEPS dense_tensor eigen3 blas) math_library(pooling DEPS dense_tensor) math_library(segment_pooling) math_library(sequence2batch) +math_library(matrix_solve DEPS dense_tensor eigen3 blas math_function) diff --git a/paddle/phi/kernels/funcs/matrix_solve.cc b/paddle/phi/kernels/funcs/matrix_solve.cc new file mode 100644 index 0000000000..31baedb3c3 --- /dev/null +++ b/paddle/phi/kernels/funcs/matrix_solve.cc @@ -0,0 +1,32 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/funcs/matrix_solve.h" + +namespace phi { +namespace funcs { + +template +void MatrixSolveFunctor::operator()(const Context& dev_ctx, + const DenseTensor& a, + const DenseTensor& b, + DenseTensor* out) { + compute_solve_eigen(dev_ctx, a, b, out); +} + +template class MatrixSolveFunctor; +template class MatrixSolveFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/funcs/matrix_solve.cu b/paddle/phi/kernels/funcs/matrix_solve.cu new file mode 100644 index 0000000000..fccceb7e20 --- /dev/null +++ b/paddle/phi/kernels/funcs/matrix_solve.cu @@ -0,0 +1,178 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/funcs/matrix_solve.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" +#include "paddle/phi/kernels/funcs/math_function.h" + +namespace phi { +namespace funcs { + +template +void MatrixSolveFunctor::operator()(const Context& context, + const DenseTensor& a, + const DenseTensor& b, + DenseTensor* out) { +#ifndef PADDLE_WITH_HIP + + // solve the equation: Ax = B, + // use cuBlas cublasgetrfBatched funcion to performs the LU + // factorization of each matrix A, + // and then use cuBlas cublasgetriBatched function to solve the + // equation after LU factorization. + // ref: + // https://docs.nvidia.com/cuda/cublas/index.html#cublas-lt-t-gt-getrfbatched + const auto& a_dims = a.dims(); + const int a_rank = a_dims.size(); + int n = a_dims[a_rank - 1]; + int lda = n; + int batch_size = a_rank > 2 ? a.numel() / (n * n) : 1; + + const auto& b_dims = b.dims(); + const int b_rank = b_dims.size(); + int nrhs = b_dims[b_rank - 1]; + int ldb = b_dims[b_rank - 2]; + + // make sure the out dims is right + out->Resize(b_dims); + + context.template Alloc(out); + + // copy input A to a temporary tensor tmp_a, + // LU factorization, written back to original matrix A, so in the beginning, + // it's necessary to create a temporary tensor tmp_a. + DenseTensor tmp_a(a.dtype()); + tmp_a.Resize(a.dims()); + + context.template Alloc(&tmp_a); + paddle::framework::TensorCopy(a, context.GetPlace(), &tmp_a); + + // copy input B to a temporary tensor tmp_b, and transpose tmp_b, + // because cuBlas assumes column-major while Paddle uses row-majar. + DenseTensor tmp_b(b.type()); + const auto& new_dims_vec = getNewDimsVec(b_dims); + tmp_b.Resize(phi::make_ddim(new_dims_vec)); + context.template Alloc(&tmp_b); + phi::funcs::TransposeNormal trans; + std::vector new_axis = getNewAxis(b_rank); + trans(context, b, &tmp_b, new_axis); + + const T* a_data_in_gpu = tmp_a.data(); + const T* b_data_in_gpu = tmp_b.data(); + + std::vector cpu_ptrs(batch_size * 2); + for (int i = 0; i < batch_size; ++i) { + cpu_ptrs[i] = a_data_in_gpu + i * n * n; + cpu_ptrs[i + batch_size] = b_data_in_gpu + i * n * nrhs; + } + + // Copy the addresses of A and tmp_b from host to device. + paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + paddle::memory::Alloc(context, cpu_ptrs.size() * sizeof(T*)); + paddle::memory::Copy(context.GetPlace(), + tmp_gpu_ptrs_data->ptr(), + phi::CPUPlace(), + static_cast(cpu_ptrs.data()), + cpu_ptrs.size() * sizeof(T*), + context.stream()); + + T** gpu_tmp_b_ptrs = + reinterpret_cast(tmp_gpu_ptrs_data->ptr()) + batch_size; + + // Allocate device memory for BatchedGETRF's info and pivots. + int num_ints = n < 32 ? batch_size : batch_size * (n + 1); + paddle::memory::allocation::AllocationPtr tmp_gpu_info_data = + paddle::memory::Alloc(context, num_ints * sizeof(int)); + int* gpu_info_ptr = reinterpret_cast(tmp_gpu_info_data->ptr()); + + auto blas = phi::funcs::GetBlas(context); + + // only for singular checking + std::vector info; + info.resize(batch_size); + + int* gpu_pivot_ptr = + reinterpret_cast(tmp_gpu_info_data->ptr()) + batch_size; + + // This function performs the LU factorization of each matrix A by the + // equation A = L * U. L and U are written back to original matrix A, + // and diagonal elements of L are discarded. + blas.BatchedGETRF(n, + reinterpret_cast(tmp_gpu_ptrs_data->ptr()), + gpu_pivot_ptr, + gpu_info_ptr, + batch_size); + + // check whether BatchedGETRF is executed successfully or not + paddle::memory::Copy(phi::CPUPlace(), + info.data(), + context.GetPlace(), + gpu_info_ptr, + sizeof(int) * batch_size, + context.stream()); + for (int i = 0; i < batch_size; ++i) { + PADDLE_ENFORCE_EQ(info[i], + 0, + phi::errors::PreconditionNotMet( + "For batch [%d]: U(%d, %d) is zero, singular U. " + "Please check the matrix value and change it to a " + "non-singular matrix", + i, + info[i], + info[i])); + } + + // hold the result code from BatchedGETRS + int host_info = 0; + + // to solve the equation after LU factorization + CBLAS_TRANSPOSE transA = CblasTrans; + blas.BatchedGETRS(transA, + n, + nrhs, + reinterpret_cast(tmp_gpu_ptrs_data->ptr()), + lda, + gpu_pivot_ptr, + gpu_tmp_b_ptrs, + ldb, + &host_info, + batch_size); + + // check whether BatchedGETRS is executed successfully or not + PADDLE_ENFORCE_EQ(host_info, + 0, + phi::errors::InvalidArgument( + "The [%d]'th argument to cublas*getrsBatched had " + "an illegal value.", + -host_info)); + + // transpose tmp_b to get the final result in row-major form. + phi::funcs::TransposeNormal trans2; + trans2(context, tmp_b, out, new_axis); + +#else + compute_solve_eigen(context, a, b, out); +#endif +} + +template class MatrixSolveFunctor; +template class MatrixSolveFunctor; + +// TODO(wuweilong): remove these instantiations later +template class MatrixSolveFunctor; +template class MatrixSolveFunctor; + +} // namespace funcs +} // namespace phi diff --git a/paddle/fluid/operators/math/matrix_solve.h b/paddle/phi/kernels/funcs/matrix_solve.h similarity index 61% rename from paddle/fluid/operators/math/matrix_solve.h rename to paddle/phi/kernels/funcs/matrix_solve.h index 6852d04e5a..3856c06c1b 100644 --- a/paddle/fluid/operators/math/matrix_solve.h +++ b/paddle/phi/kernels/funcs/matrix_solve.h @@ -18,18 +18,79 @@ limitations under the License. */ #include "Eigen/Core" #include "Eigen/LU" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" - -namespace paddle { -namespace operators { -namespace math { - -template -void compute_solve_eigen(const DeviceContext& context, - const framework::Tensor& a, - const framework::Tensor& b, - framework::Tensor* out) { +#include "paddle/phi/backends/all_context.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" + +namespace phi { +namespace funcs { + +// for TransposeNormal +static std::vector getNewAxis(const int b_rank) { + std::vector axis_1 = {0}; + std::vector axis_2 = {1, 0}; + std::vector axis_3 = {0, 2, 1}; + std::vector axis_4 = {0, 1, 3, 2}; + std::vector axis_5 = {0, 1, 2, 4, 3}; + std::vector axis_6 = {0, 1, 2, 3, 5, 4}; + std::vector axis_7 = {0, 1, 2, 3, 4, 6, 5}; + std::vector axis_8 = {0, 1, 2, 3, 4, 5, 7, 6}; + std::vector axis_9 = {0, 1, 2, 3, 4, 5, 6, 8, 7}; + switch (b_rank) { + case 1: + return axis_1; + break; + case 2: + return axis_2; + break; + case 3: + return axis_3; + break; + case 4: + return axis_4; + break; + case 5: + return axis_5; + break; + case 6: + return axis_6; + break; + case 7: + return axis_7; + break; + case 8: + return axis_8; + break; + default: + return axis_9; + } +} + +// for Resize +static std::vector getNewDimsVec(const DDim& b_dims) { + std::vector b_dims_vec = phi::vectorize(b_dims); + int size = b_dims_vec.size(); + if (size >= 2) { + // swap the last 2 elements in b_dims_vec + int64_t temp = b_dims_vec[size - 1]; + b_dims_vec[size - 1] = b_dims_vec[size - 2]; + b_dims_vec[size - 2] = temp; + return b_dims_vec; + } + PADDLE_ENFORCE_NE( + b_dims_vec.empty(), + true, + phi::errors::PreconditionNotMet( + "The size of tensor b must not be %d after getting new dims", 0)); + // if b_dims_vec.size() == 1, just retun original vec + return b_dims_vec; +} + +template +void compute_solve_eigen(const Context& context, + const DenseTensor& a, + const DenseTensor& b, + DenseTensor* out) { using Matrix = Eigen::Matrix; using EigenMatrixMap = Eigen::Map; @@ -51,7 +112,7 @@ void compute_solve_eigen(const DeviceContext& context, const T* b_ptr = b.data(); out->Resize(b_mat_dims); // make sure the out dims is right - T* out_ptr = out->mutable_data(context.GetPlace()); + T* out_ptr = context.template Alloc(out); if (a_batch_size == b_batch_size) { for (int i = 0; i < a_batch_size; ++i) { ConstEigenMatrixMap a_mat(a_ptr + i * n * n, n, n); @@ -63,13 +124,13 @@ void compute_solve_eigen(const DeviceContext& context, PADDLE_ENFORCE_GT( min_abs_pivot, static_cast(0), - platform::errors::InvalidArgument("Input is not invertible.")); + phi::errors::InvalidArgument("Input is not invertible.")); out_mat.noalias() = lu.solve(b_mat); } } else { PADDLE_ENFORCE_EQ(a_batch_size, b_batch_size, - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "All input tensors must have the same rank.")); } } @@ -114,22 +175,21 @@ void SolveLinearSystem(T* matrix_data, lu_decomposition.matrixLU().diagonal().cwiseAbs().minCoeff(); PADDLE_ENFORCE_GT(min_abs_piv, Treal(0), - platform::errors::InvalidArgument( + phi::errors::InvalidArgument( "Something's wrong with SolveLinearSystem. ")); output = lu_decomposition.solve(input_rhs); } } -template +template class MatrixSolveFunctor { public: - void operator()(const DeviceContext& context, - const framework::Tensor& a, - const framework::Tensor& b, - framework::Tensor* out); + void operator()(const Context& context, + const DenseTensor& a, + const DenseTensor& b, + DenseTensor* out); }; -} // namespace math -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi -- GitLab