From 667dc9f083a2b65b8f8be90b4aa7c0afd599394a Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Thu, 30 Dec 2021 16:16:38 +0800 Subject: [PATCH] Add cusparse and unittest (#38431) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 将cuSparse的handle与DeviceContext进行绑定,避免op中进行创建和销毁 添加对cuSparse中dense和sparse转换的API进行封装 添加对封装的API的单测 --- paddle/fluid/operators/math/CMakeLists.txt | 9 + .../math/cusparse_conversion_api_test.cc | 180 ++++++++++++++ paddle/fluid/operators/math/sparse.h | 114 +++++++++ paddle/fluid/operators/math/sparse_impl.cu.h | 231 ++++++++++++++++++ .../device/gpu/cuda/cusparse_helper.h | 62 +++++ paddle/fluid/platform/device/gpu/gpu_helper.h | 1 + paddle/fluid/platform/device_context.cc | 5 + paddle/fluid/platform/device_context.h | 34 +++ paddle/fluid/platform/dynload/cusparse.cc | 4 + paddle/fluid/platform/dynload/cusparse.h | 41 +++- paddle/fluid/platform/enforce.h | 18 ++ paddle/fluid/platform/external_error.proto | 3 +- 12 files changed, 691 insertions(+), 11 deletions(-) create mode 100644 paddle/fluid/operators/math/cusparse_conversion_api_test.cc create mode 100644 paddle/fluid/operators/math/sparse.h create mode 100644 paddle/fluid/operators/math/sparse_impl.cu.h create mode 100644 paddle/fluid/platform/device/gpu/cuda/cusparse_helper.h diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index a2f619d84a..fcf988efcd 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -111,6 +111,15 @@ if(WITH_ROCM) hip_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu.cc DEPS selected_rows_functor math_function) endif() cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) + +if(WITH_GPU AND (NOT WITH_ROCM)) +#currenty not yet support ROCM +#the generic conversion APIs of dense and sparse are only supported after cuda11.2 + if((NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.2)) + cc_test(cusparse_conversion_api_test SRCS cusparse_conversion_api_test.cc DEPS tensor) + endif() +endif() + cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) if(WITH_TESTING AND TEST im2col_test) set_tests_properties(im2col_test PROPERTIES TIMEOUT 120) diff --git a/paddle/fluid/operators/math/cusparse_conversion_api_test.cc b/paddle/fluid/operators/math/cusparse_conversion_api_test.cc new file mode 100644 index 0000000000..d45b57420e --- /dev/null +++ b/paddle/fluid/operators/math/cusparse_conversion_api_test.cc @@ -0,0 +1,180 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "glog/logging.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/math/sparse.h" + +template +void TestNNZ(const std::vector& dense_data, const int correct_nnz, + const int rows, const int cols) { + paddle::platform::CUDADeviceContext* context = + new paddle::platform::CUDADeviceContext(paddle::platform::CUDAPlace()); + auto sparse = + paddle::operators::math::GetSparse(*context); + + paddle::framework::Tensor dense, nnz_tensor; + auto dense_dims = paddle::framework::make_ddim({rows, cols}); + auto nnz_dims = paddle::framework::make_ddim({dense_dims[0] + 1}); + dense.mutable_data(dense_dims, paddle::platform::CUDAPlace()); + paddle::framework::TensorFromVector(dense_data, *context, &dense); + int32_t* nnz_ptr = + nnz_tensor.mutable_data(nnz_dims, paddle::platform::CUDAPlace()); + sparse.nnz(rows, cols, dense.data(), nnz_ptr, nnz_ptr + 1); + std::vector nnz_vec(dense_dims[0] + 1); + paddle::framework::TensorToVector(nnz_tensor, *context, &nnz_vec); + delete context; + CHECK_EQ(correct_nnz, nnz_vec[0]); +} + +TEST(sparse, nnz) { + std::vector dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 3.2, 0.0, 0.0}; + TestNNZ(dense_data, 4, 3, 3); +} + +TEST(sparse, nnz_double) { + std::vector dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 3.2, 0.0}; + TestNNZ(dense_data, 4, 4, 2); +} + +template +void TestDenseToSparse(const std::vector& correct_dense_data, + const std::vector& correct_rows, + const std::vector& correct_cols, + const std::vector& correct_values, + const int correct_nnz, const int rows, const int cols, + const std::string& mode) { + paddle::platform::CUDADeviceContext* context = + new paddle::platform::CUDADeviceContext(paddle::platform::CUDAPlace()); + // get sparse + auto sparse = + paddle::operators::math::GetSparse(*context); + + // create tensor and copy vector to tensor + paddle::framework::Tensor dense_tensor, rows_tensor, cols_tensor, + values_tensor, actual_dense_tensor; + auto dense_dims = paddle::framework::make_ddim({rows, cols}); + T* dense_data = + dense_tensor.mutable_data(dense_dims, paddle::platform::CUDAPlace()); + T* actual_dense_data = actual_dense_tensor.mutable_data( + dense_dims, paddle::platform::CUDAPlace()); + paddle::framework::TensorFromVector(correct_dense_data, *context, + &dense_tensor); + + auto nnz_dims = paddle::framework::make_ddim({correct_nnz}); + auto crows_dims = paddle::framework::make_ddim({rows + 1}); + int64_t* rows_data = nullptr; + if (mode == "COO") { + rows_data = rows_tensor.mutable_data( + nnz_dims, paddle::platform::CUDAPlace()); + } else { + rows_data = rows_tensor.mutable_data( + crows_dims, paddle::platform::CUDAPlace()); + } + int64_t* cols_data = cols_tensor.mutable_data( + nnz_dims, paddle::platform::CUDAPlace()); + T* values_data = + values_tensor.mutable_data(nnz_dims, paddle::platform::CUDAPlace()); + + // test dense_to_sparse + if (mode == "COO") { + sparse.DenseToSparseCoo(rows, cols, dense_data, rows_data, cols_data, + values_data); + } else { + sparse.DenseToSparseCsr(rows, cols, dense_data, rows_data, cols_data, + values_data); + } + + std::vector actual_rows(correct_nnz), actual_crows(rows + 1), + actual_cols(correct_nnz); + std::vector actual_values(correct_nnz), actual_dense_vec(rows * cols); + if (mode == "COO") { + paddle::framework::TensorToVector(rows_tensor, *context, + &actual_rows); + } else { + paddle::framework::TensorToVector(rows_tensor, *context, + &actual_crows); + } + paddle::framework::TensorToVector(cols_tensor, *context, + &actual_cols); + paddle::framework::TensorToVector(values_tensor, *context, &actual_values); + + for (int i = 0; i < correct_nnz; i++) { + if (mode == "COO") { + CHECK_EQ(correct_rows[i], actual_rows[i]); + } + CHECK_EQ(correct_cols[i], actual_cols[i]); + CHECK_EQ(correct_values[i], actual_values[i]); + } + if (mode == "CSR") { + for (int i = 0; i < rows + 1; i++) { + CHECK_EQ(correct_rows[i], actual_crows[i]); + } + } + + // test sparse_to_dense + if (mode == "COO") { + sparse.SparseCooToDense(rows, cols, correct_nnz, rows_data, cols_data, + values_data, actual_dense_data); + } else { + sparse.SparseCsrToDense(rows, cols, correct_nnz, rows_data, cols_data, + values_data, actual_dense_data); + } + paddle::framework::TensorToVector(actual_dense_tensor, *context, + &actual_dense_vec); + for (uint64_t i = 0; i < correct_dense_data.size(); i++) { + CHECK_EQ(correct_dense_data[i], actual_dense_vec[i]); + } + + delete context; +} + +TEST(sparse, dense_to_sparse) { + std::vector dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 3.2, 0.0, 0.0}; + std::vector values = {1.0, 2.0, 3.0, 3.2}; + std::vector rows = {0, 1, 1, 2}; + std::vector crows = {0, 1, 3, 4}; + std::vector cols = {1, 0, 2, 0}; + TestDenseToSparse(dense_data, rows, cols, values, 4, 3, 3, "COO"); + TestDenseToSparse(dense_data, crows, cols, values, 4, 3, 3, "CSR"); +} + +TEST(sparse, dense_to_sparse_double) { + std::vector dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 3.2, 0.0}; + std::vector values = {1.0, 2.0, 3.0, 3.2}; + std::vector rows = {0, 1, 2, 3}; + std::vector crows = {0, 1, 2, 3, 4}; + std::vector cols = {1, 1, 1, 0}; + TestDenseToSparse(dense_data, rows, cols, values, 4, 4, 2, "COO"); + TestDenseToSparse(dense_data, crows, cols, values, 4, 4, 2, "CSR"); +} + +TEST(sparse, dense_to_sparse_fp16) { + using float16 = paddle::platform::float16; + std::vector dense_data = {float16(0.0), float16(1.0), float16(0.0), + float16(2.0), float16(0.0), float16(3.0), + float16(3.2), float16(0.0)}; + std::vector values = {float16(1.0), float16(2.0), float16(3.0), + float16(3.2)}; + std::vector rows = {0, 1, 2, 3}; + std::vector crows = {0, 1, 2, 3, 4}; + std::vector cols = {1, 1, 1, 0}; + TestDenseToSparse(dense_data, rows, cols, values, 4, 4, 2, "COO"); + TestDenseToSparse(dense_data, crows, cols, values, 4, 4, 2, "CSR"); +} diff --git a/paddle/fluid/operators/math/sparse.h b/paddle/fluid/operators/math/sparse.h new file mode 100644 index 0000000000..4ac68a3bdc --- /dev/null +++ b/paddle/fluid/operators/math/sparse.h @@ -0,0 +1,114 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor.h" + +namespace paddle { +namespace framework { +class ExecutionContext; +class Tensor; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +namespace math { + +template +class Sparse { + public: + explicit Sparse(const DeviceContext& context) : context_(context) {} + + template + void nnz(const int M, const int N, const T* dense, int* nnz, + int* nnzPerRowColumn) const; + + template + void DenseToSparseCoo(const int M, const int N, const T* dense, int64_t* rows, + int64_t* cols, T* values) const; + + template + void DenseToSparseCsr(const int M, const int N, const T* dense, + int64_t* crows, int64_t* cols, T* values) const; + + template + void SparseCooToDense(const int64_t M, const int64_t N, const int64_t nnz, + const int64_t* rows, const int64_t* cols, + const T* values, T* dense) const; + template + void SparseCsrToDense(const int64_t M, const int64_t N, const int64_t nnz, + const int64_t* crows, const int64_t* cols, + const T* values, T* dense) const; + + private: + const DeviceContext& context_; +}; + +template +class SparseT : private Sparse { + public: + using Sparse::Sparse; + + template + void nnz(ARGS... args) const { + Base()->template nnz(args...); + } + + template + void DenseToSparseCoo(ARGS... args) const { + Base()->template DenseToSparseCoo(args...); + } + template + void DenseToSparseCsr(ARGS... args) const { + Base()->template DenseToSparseCsr(args...); + } + template + void SparseCooToDense(ARGS... args) const { + Base()->template SparseCooToDense(args...); + } + template + void SparseCsrToDense(ARGS... args) const { + Base()->template SparseCsrToDense(args...); + } + + private: + const Sparse* Base() const { + return static_cast*>(this); + } +}; + +template +inline SparseT GetSparse( + const framework::ExecutionContext& exe_ctx) { + return SparseT( + exe_ctx.template device_context()); +} + +template +inline SparseT GetSparse(const DeviceContext& dev_ctx) { + return SparseT(dev_ctx); +} + +} // namespace math +} // namespace operators +} // namespace paddle + +#if defined(PADDLE_WITH_CUDA) +#if CUDA_VERSION >= 11020 +#include "paddle/fluid/operators/math/sparse_impl.cu.h" +#endif +#endif diff --git a/paddle/fluid/operators/math/sparse_impl.cu.h b/paddle/fluid/operators/math/sparse_impl.cu.h new file mode 100644 index 0000000000..8ff2f4b27d --- /dev/null +++ b/paddle/fluid/operators/math/sparse_impl.cu.h @@ -0,0 +1,231 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/dynload/cusparse.h" + +#include "paddle/fluid/platform/device/gpu/gpu_info.h" + +namespace paddle { +namespace operators { +namespace math { + +template +cudaDataType_t GetGpuDataType() { + if (std::is_same::value) { + return CUDA_R_32F; + } else if (std::is_same::value) { + return CUDA_R_64F; + } else if (std::is_same::value) { + return CUDA_R_16F; + } +} + +template <> +template +void Sparse::nnz(const int M, const int N, + const T* dense, int* nnz, + int* nnzPerRowColumn) const {} + +template <> +template <> +void Sparse::nnz(const int M, const int N, + const float* dense, int* nnz, + int* nnzPerRowColumn) const { + cusparseMatDescr_t descr = 0; + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusparseCreateMatDescr(&descr)); + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseSetMatType( + descr, CUSPARSE_MATRIX_TYPE_GENERAL)); + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseSetMatIndexBase( + descr, CUSPARSE_INDEX_BASE_ZERO)); + + context_.CusparseCall([&](cusparseHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseSnnz( + handle, CUSPARSE_DIRECTION_ROW, M, N, descr, dense, M, nnzPerRowColumn, + nnz)); + }); +} + +template <> +template <> +void Sparse::nnz(const int M, const int N, + const double* dense, int* nnz, + int* nnzPerRowColumn) const { + cusparseMatDescr_t descr = 0; + PADDLE_ENFORCE_GPU_SUCCESS( + paddle::platform::dynload::cusparseCreateMatDescr(&descr)); + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseSetMatType( + descr, CUSPARSE_MATRIX_TYPE_GENERAL)); + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseSetMatIndexBase( + descr, CUSPARSE_INDEX_BASE_ZERO)); + + context_.CusparseCall([&](cusparseHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseDnnz( + handle, CUSPARSE_DIRECTION_ROW, M, N, descr, dense, M, nnzPerRowColumn, + nnz)); + }); +} + +template +inline void DenseToSparse(const platform::CUDADeviceContext& context, + const int M, const int N, const T* dense, + int64_t* rows, int64_t* cols, T* values, + const cusparseFormat_t format) { + cusparseSpMatDescr_t matB; + cusparseDnMatDescr_t matA; + + cudaDataType_t dtype = GetGpuDataType(); + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateDnMat( + &matA, M, N, N, const_cast(reinterpret_cast(dense)), + dtype, CUSPARSE_ORDER_ROW)); + + if (format == CUSPARSE_FORMAT_COO) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateCoo( + &matB, M, N, 0, nullptr, nullptr, nullptr, CUSPARSE_INDEX_64I, + CUSPARSE_INDEX_BASE_ZERO, dtype)); + } else if (format == CUSPARSE_FORMAT_CSR) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateCsr( + &matB, M, N, 0, rows, nullptr, nullptr, CUSPARSE_INDEX_64I, + CUSPARSE_INDEX_64I, CUSPARSE_INDEX_BASE_ZERO, dtype)); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "the sparse format [%s] is not supported", format)); + } + + size_t buffer_size = 0; + context.CusparseCall([&](cusparseHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cusparseDenseToSparse_bufferSize( + handle, matA, matB, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, + &buffer_size)); + }); + framework::Tensor buffer; + float* buffer_data = buffer.mutable_data( + {static_cast(buffer_size)}, context.GetPlace()); + + context.CusparseCall([&](cusparseHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cusparseDenseToSparse_analysis( + handle, matA, matB, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, + buffer_data)); + }); + + if (format == CUSPARSE_FORMAT_COO) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCooSetPointers( + matB, rows, cols, reinterpret_cast(values))); + } else if (format == CUSPARSE_FORMAT_CSR) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCsrSetPointers( + matB, rows, cols, reinterpret_cast(values))); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "the sparse format [%s] is not supported", format)); + } + context.CusparseCall([&](cusparseHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseDenseToSparse_convert( + handle, matA, matB, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, buffer_data)); + }); +} +template <> +template +void Sparse::DenseToSparseCoo( + const int M, const int N, const T* dense, int64_t* rows, int64_t* cols, + T* values) const { + DenseToSparse(context_, M, N, dense, rows, cols, values, + CUSPARSE_FORMAT_COO); +} + +template <> +template +void Sparse::DenseToSparseCsr( + const int M, const int N, const T* dense, int64_t* crows, int64_t* cols, + T* values) const { + DenseToSparse(context_, M, N, dense, crows, cols, values, + CUSPARSE_FORMAT_CSR); +} + +template +void SparseToDense(const platform::CUDADeviceContext& context, const int64_t M, + const int64_t N, const int64_t nnz, const int64_t* rows, + const int64_t* cols, const T* values, T* dense, + const cusparseFormat_t format) { + cusparseSpMatDescr_t matA; + cusparseDnMatDescr_t matB; + + cudaDataType_t dtype = GetGpuDataType(); + if (format == CUSPARSE_FORMAT_COO) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateCoo( + &matA, M, N, nnz, + const_cast(reinterpret_cast(rows)), + const_cast(reinterpret_cast(cols)), + const_cast(reinterpret_cast(values)), + CUSPARSE_INDEX_64I, CUSPARSE_INDEX_BASE_ZERO, dtype)); + } else if (format == CUSPARSE_FORMAT_CSR) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateCsr( + &matA, M, N, nnz, + const_cast(reinterpret_cast(rows)), + const_cast(reinterpret_cast(cols)), + const_cast(reinterpret_cast(values)), + CUSPARSE_INDEX_64I, CUSPARSE_INDEX_64I, CUSPARSE_INDEX_BASE_ZERO, + dtype)); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "the sparse format [%s] is not supported", format)); + } + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateDnMat( + &matB, M, N, N, reinterpret_cast(dense), dtype, + CUSPARSE_ORDER_ROW)); + + size_t buffer_size = 0; + context.CusparseCall([&](cusparseHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cusparseSparseToDense_bufferSize( + handle, matA, matB, CUSPARSE_SPARSETODENSE_ALG_DEFAULT, + &buffer_size)); + }); + framework::Tensor buffer; + float* buffer_data = buffer.mutable_data( + {static_cast(buffer_size)}, context.GetPlace()); + + context.CusparseCall([&](cusparseHandle_t handle) { + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseSparseToDense( + handle, matA, matB, CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffer_data)); + }); +} + +template <> +template +void Sparse::SparseCooToDense( + const int64_t M, const int64_t N, const int64_t nnz, const int64_t* rows, + const int64_t* cols, const T* values, T* dense) const { + SparseToDense(context_, M, N, nnz, rows, cols, values, dense, + CUSPARSE_FORMAT_COO); +} + +template <> +template +void Sparse::SparseCsrToDense( + const int64_t M, const int64_t N, const int64_t nnz, const int64_t* crows, + const int64_t* cols, const T* values, T* dense) const { + SparseToDense(context_, M, N, nnz, crows, cols, values, dense, + CUSPARSE_FORMAT_CSR); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/device/gpu/cuda/cusparse_helper.h b/paddle/fluid/platform/device/gpu/cuda/cusparse_helper.h new file mode 100644 index 0000000000..43da9bb1fb --- /dev/null +++ b/paddle/fluid/platform/device/gpu/cuda/cusparse_helper.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include // NOLINT + +#include "paddle/fluid/platform/dynload/cusparse.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace platform { + +class CusparseHandleHolder { + public: + explicit CusparseHandleHolder(cudaStream_t stream) { +// ROCM is not yet supported +#if defined(PADDLE_WITH_CUDA) +// The generic APIs is supported from CUDA10.1 +#if CUDA_VERSION >= 10010 + PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseCreate(&handle_)); + PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseSetStream(handle_, stream)); +#endif +#endif + } + const cusparseHandle_t& GetCusparseHandle() const { return handle_; } + + ~CusparseHandleHolder() PADDLE_MAY_THROW { +#if defined(PADDLE_WITH_CUDA) +#if CUDA_VERSION >= 10010 + PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseDestroy(handle_)); +#endif +#endif + } + + template + inline void Call(Callback&& callback) const { + std::lock_guard guard(mtx_); + callback(handle_); + } + + private: + DISABLE_COPY_AND_ASSIGN(CusparseHandleHolder); + + cusparseHandle_t handle_; + mutable std::mutex mtx_; +}; + +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/gpu/gpu_helper.h b/paddle/fluid/platform/device/gpu/gpu_helper.h index 6077a7b625..878a122a49 100644 --- a/paddle/fluid/platform/device/gpu/gpu_helper.h +++ b/paddle/fluid/platform/device/gpu/gpu_helper.h @@ -19,6 +19,7 @@ #include "paddle/fluid/platform/device/gpu/rocm/rocm_helper.h" #else #include "paddle/fluid/platform/device/gpu/cuda/cuda_helper.h" +#include "paddle/fluid/platform/device/gpu/cuda/cusparse_helper.h" #endif #define CUDA_KERNEL_LOOP(i, num) CUDA_KERNEL_LOOP_TYPE(i, num, int) diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 23c4f216ba..a8d092a846 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -484,6 +484,7 @@ CUDAContext::CUDAContext(const CUDAPlace& place, InitCuBlasContext(); InitCuDNNContext(); #ifndef PADDLE_WITH_HIP + InitCuSparseContext(); InitCuSolverContext(); #endif } @@ -513,6 +514,7 @@ CUDAContext::~CUDAContext() { DestoryCuDNNContext(); DestoryCuBlasContext(); #ifndef PADDLE_WITH_HIP + DestoryCuSparseContext(); DestoryCuSolverContext(); #endif } @@ -630,6 +632,9 @@ rocblas_handle CUDADeviceContext::cublas_handle() const { cublasHandle_t CUDADeviceContext::cublas_handle() const { return context()->CublasHandle()->GetCublasHandle(); } +cusparseHandle_t CUDADeviceContext::cusparse_handle() const { + return context()->CusparseHandle()->GetCusparseHandle(); +} #endif CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 4b38e5ddf3..e1fcc3ae90 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -358,6 +358,12 @@ class CUDAContext { return cublas_tensor_core_handle_; } +#ifndef PADDLE_WITH_HIP + const std::unique_ptr& CusparseHandle() const { + return cusparse_handle_; + } +#endif + /*! \brief Call cublas function safely. */ template inline void CublasCall(Callback&& callback) const { @@ -368,6 +374,14 @@ class CUDAContext { } } +#ifndef PADDLE_WITH_HIP + /*! \brief Call cusparse function safely. */ + template + inline void CusparseCall(Callback&& callback) const { + cusparse_handle_->Call(std::forward(callback)); + } +#endif + /*! \brief Check whether tensor core is supported */ bool tensor_core_available() const; @@ -406,6 +420,12 @@ class CUDAContext { } #endif +#ifndef PADDLE_WITH_HIP + void InitCuSparseContext() { + cusparse_handle_.reset(new CusparseHandleHolder(RawStream())); + } +#endif + void InitCuDNNContext() { if (dynload::HasCUDNN()) { #ifdef PADDLE_WITH_HIP @@ -478,6 +498,10 @@ class CUDAContext { cublas_tf32_tensor_core_handle_.reset(); } +#ifndef PADDLE_WITH_HIP + void DestoryCuSparseContext() { cusparse_handle_.reset(); } +#endif + #ifndef PADDLE_WITH_HIP void DestoryCuSolverContext() { if (cusolver_dn_handle_) { @@ -501,6 +525,7 @@ class CUDAContext { std::unique_ptr cublas_tf32_tensor_core_handle_; #ifndef PADDLE_WITH_HIP cusolverDnHandle_t cusolver_dn_handle_; + std::unique_ptr cusparse_handle_; #endif DISABLE_COPY_AND_ASSIGN(CUDAContext); }; @@ -540,6 +565,14 @@ class CUDADeviceContext : public DeviceContext { return context()->CublasCall(callback); } +#ifndef PADDLE_WITH_HIP + /*! \brief Call cusparse function safely. */ + template + inline void CusparseCall(Callback&& callback) const { + return context()->CusparseCall(callback); + } +#endif + /*! \brief Check whether tensor core is supported */ bool tensor_core_available() const; @@ -562,6 +595,7 @@ class CUDADeviceContext : public DeviceContext { rocblas_handle cublas_handle() const; #else cublasHandle_t cublas_handle() const; + cusparseHandle_t cusparse_handle() const; #endif /*! \brief Return a cudnn workspace handle to call multiple cudnn diff --git a/paddle/fluid/platform/dynload/cusparse.cc b/paddle/fluid/platform/dynload/cusparse.cc index 2a1fe322da..be67f121d6 100644 --- a/paddle/fluid/platform/dynload/cusparse.cc +++ b/paddle/fluid/platform/dynload/cusparse.cc @@ -30,6 +30,10 @@ CUSPARSE_ROUTINE_EACH(DEFINE_WRAP); #ifdef CUBLAS_BLAS_ROUTINE_EACH_R2 CUSPARSE_ROUTINE_EACH_R2(DEFINE_WRAP); #endif + +#ifdef CUSPARSE_ROUTINE_EACH_11020 +CUSPARSE_ROUTINE_EACH_11020(DEFINE_WRAP); +#endif } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/cusparse.h b/paddle/fluid/platform/dynload/cusparse.h index e44e8ed085..9f53c06ac0 100644 --- a/paddle/fluid/platform/dynload/cusparse.h +++ b/paddle/fluid/platform/dynload/cusparse.h @@ -41,21 +41,41 @@ extern void *cusparse_dso_handle; }; \ extern DynLoad__##__name __name -#if !defined(PADDLE_WITH_ARM) && !defined(_WIN32) -// APIs available after CUDA 11.0 -#if CUDA_VERSION >= 11000 +#if !defined(PADDLE_WITH_ARM) +// The generic APIs is supported from CUDA10.1 +#if CUDA_VERSION >= 10010 #define CUSPARSE_ROUTINE_EACH(__macro) \ __macro(cusparseCreate); \ - __macro(cusparseCreateCsr); \ - __macro(cusparseCreateDnMat); \ - __macro(cusparseSpMM_bufferSize); \ - __macro(cusparseSpMM); \ - __macro(cusparseDestroySpMat); \ - __macro(cusparseDestroyDnMat); \ - __macro(cusparseDestroy); + __macro(cusparseSetStream); \ + __macro(cusparseCreateMatDescr); \ + __macro(cusparseDestroy); \ + __macro(cusparseSnnz); \ + __macro(cusparseDnnz); \ + __macro(cusparseSetMatType); \ + __macro(cusparseSetMatIndexBase); CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP); +// APIs available after CUDA 11.2 +#if CUDA_VERSION >= 11020 +#define CUSPARSE_ROUTINE_EACH_11020(__macro) \ + __macro(cusparseCreateCsr); \ + __macro(cusparseCreateCoo); \ + __macro(cusparseCreateDnMat); \ + __macro(cusparseSpMM_bufferSize); \ + __macro(cusparseSpMM); \ + __macro(cusparseDestroySpMat); \ + __macro(cusparseDestroyDnMat); \ + __macro(cusparseCooSetPointers); \ + __macro(cusparseCsrSetPointers); \ + __macro(cusparseDenseToSparse_bufferSize); \ + __macro(cusparseDenseToSparse_analysis); \ + __macro(cusparseDenseToSparse_convert); \ + __macro(cusparseSparseToDense_bufferSize); \ + __macro(cusparseSparseToDense); + +CUSPARSE_ROUTINE_EACH_11020(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) + // APIs available after CUDA 11.3 #if CUDA_VERSION >= 11030 #define CUSPARSE_ROUTINE_EACH_R2(__macro) \ @@ -67,6 +87,7 @@ CUSPARSE_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) #endif #endif #endif +#endif #undef DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP } // namespace dynload diff --git a/paddle/fluid/platform/enforce.h b/paddle/fluid/platform/enforce.h index 530ae6ba79..30930897ea 100644 --- a/paddle/fluid/platform/enforce.h +++ b/paddle/fluid/platform/enforce.h @@ -33,6 +33,7 @@ limitations under the License. */ #include #include #include +#include #include #include #include "paddle/fluid/platform/external_error.pb.h" @@ -707,6 +708,7 @@ DEFINE_EXTERNAL_API_TYPE(cudaError_t, cudaSuccess, CUDA); DEFINE_EXTERNAL_API_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS, CURAND); DEFINE_EXTERNAL_API_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS, CUDNN); DEFINE_EXTERNAL_API_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS, CUBLAS); +DEFINE_EXTERNAL_API_TYPE(cusparseStatus_t, CUSPARSE_STATUS_SUCCESS, CUSPARSE); DEFINE_EXTERNAL_API_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS, CUSOLVER); DEFINE_EXTERNAL_API_TYPE(cufftResult_t, CUFFT_SUCCESS, CUFFT); DEFINE_EXTERNAL_API_TYPE(CUresult, CUDA_SUCCESS, CU); @@ -750,6 +752,10 @@ inline const char* GetErrorMsgUrl(T status) { break; case platform::proto::ApiType::CUFFT: return "https://docs.nvidia.com/cuda/cufft/index.html#cufftresult"; + case platform::proto::ApiType::CUSPARSE: + return "https://docs.nvidia.com/cuda/cusparse/" + "index.html#cusparseStatus_t"; + break; default: return "Unknown type of External API, can't get error message URL!"; break; @@ -837,6 +843,7 @@ template std::string GetExternalErrorMsg(cudaError_t); template std::string GetExternalErrorMsg(curandStatus_t); template std::string GetExternalErrorMsg(cudnnStatus_t); template std::string GetExternalErrorMsg(cublasStatus_t); +template std::string GetExternalErrorMsg(cusparseStatus_t); template std::string GetExternalErrorMsg(cusolverStatus_t); template std::string GetExternalErrorMsg(cufftResult_t); template std::string GetExternalErrorMsg(CUresult); @@ -889,6 +896,17 @@ inline std::string build_nvidia_error_msg(cublasStatus_t stat) { return sout.str(); } +/*************** CUSPARSE ERROR ***************/ +inline bool is_error(cusparseStatus_t stat) { + return stat != CUSPARSE_STATUS_SUCCESS; +} + +inline std::string build_nvidia_error_msg(cusparseStatus_t stat) { + std::ostringstream sout; + sout << "CUSparse error(" << stat << "). " << GetExternalErrorMsg(stat); + return sout.str(); +} + /*************** CUSOLVER ERROR ***************/ inline bool is_error(cusolverStatus_t stat) { return stat != CUSOLVER_STATUS_SUCCESS; diff --git a/paddle/fluid/platform/external_error.proto b/paddle/fluid/platform/external_error.proto index fcbbb41626..8861c2c2ff 100644 --- a/paddle/fluid/platform/external_error.proto +++ b/paddle/fluid/platform/external_error.proto @@ -26,6 +26,7 @@ enum ApiType { NCCL = 5; CUFFT = 6; CU = 7; + CUSPARSE = 8; } message MessageDesc { @@ -45,4 +46,4 @@ message AllMessageDesc { message ExternalErrorDesc { // Error messages of different kind of external third party API repeated AllMessageDesc errors = 1; -} \ No newline at end of file +} -- GitLab