未验证 提交 667dc9f0 编写于 作者: Z zhangkaihuo 提交者: GitHub

Add cusparse and unittest (#38431)



    将cuSparse的handle与DeviceContext进行绑定,避免op中进行创建和销毁
    添加对cuSparse中dense和sparse转换的API进行封装
    添加对封装的API的单测
上级 3658405c
......@@ -111,6 +111,15 @@ if(WITH_ROCM)
hip_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu.cc DEPS selected_rows_functor math_function)
endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
if(WITH_GPU AND (NOT WITH_ROCM))
#currenty not yet support ROCM
#the generic conversion APIs of dense and sparse are only supported after cuda11.2
if((NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 11.2))
cc_test(cusparse_conversion_api_test SRCS cusparse_conversion_api_test.cc DEPS tensor)
endif()
endif()
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
if(WITH_TESTING AND TEST im2col_test)
set_tests_properties(im2col_test PROPERTIES TIMEOUT 120)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/sparse.h"
template <typename T>
void TestNNZ(const std::vector<T>& dense_data, const int correct_nnz,
const int rows, const int cols) {
paddle::platform::CUDADeviceContext* context =
new paddle::platform::CUDADeviceContext(paddle::platform::CUDAPlace());
auto sparse =
paddle::operators::math::GetSparse<paddle::platform::CUDADeviceContext,
T>(*context);
paddle::framework::Tensor dense, nnz_tensor;
auto dense_dims = paddle::framework::make_ddim({rows, cols});
auto nnz_dims = paddle::framework::make_ddim({dense_dims[0] + 1});
dense.mutable_data<T>(dense_dims, paddle::platform::CUDAPlace());
paddle::framework::TensorFromVector<T>(dense_data, *context, &dense);
int32_t* nnz_ptr =
nnz_tensor.mutable_data<int32_t>(nnz_dims, paddle::platform::CUDAPlace());
sparse.nnz(rows, cols, dense.data<T>(), nnz_ptr, nnz_ptr + 1);
std::vector<int32_t> nnz_vec(dense_dims[0] + 1);
paddle::framework::TensorToVector<int32_t>(nnz_tensor, *context, &nnz_vec);
delete context;
CHECK_EQ(correct_nnz, nnz_vec[0]);
}
TEST(sparse, nnz) {
std::vector<float> dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 3.2, 0.0, 0.0};
TestNNZ<float>(dense_data, 4, 3, 3);
}
TEST(sparse, nnz_double) {
std::vector<double> dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 3.2, 0.0};
TestNNZ<double>(dense_data, 4, 4, 2);
}
template <typename T>
void TestDenseToSparse(const std::vector<T>& correct_dense_data,
const std::vector<int64_t>& correct_rows,
const std::vector<int64_t>& correct_cols,
const std::vector<T>& correct_values,
const int correct_nnz, const int rows, const int cols,
const std::string& mode) {
paddle::platform::CUDADeviceContext* context =
new paddle::platform::CUDADeviceContext(paddle::platform::CUDAPlace());
// get sparse
auto sparse =
paddle::operators::math::GetSparse<paddle::platform::CUDADeviceContext,
T>(*context);
// create tensor and copy vector to tensor
paddle::framework::Tensor dense_tensor, rows_tensor, cols_tensor,
values_tensor, actual_dense_tensor;
auto dense_dims = paddle::framework::make_ddim({rows, cols});
T* dense_data =
dense_tensor.mutable_data<T>(dense_dims, paddle::platform::CUDAPlace());
T* actual_dense_data = actual_dense_tensor.mutable_data<T>(
dense_dims, paddle::platform::CUDAPlace());
paddle::framework::TensorFromVector<T>(correct_dense_data, *context,
&dense_tensor);
auto nnz_dims = paddle::framework::make_ddim({correct_nnz});
auto crows_dims = paddle::framework::make_ddim({rows + 1});
int64_t* rows_data = nullptr;
if (mode == "COO") {
rows_data = rows_tensor.mutable_data<int64_t>(
nnz_dims, paddle::platform::CUDAPlace());
} else {
rows_data = rows_tensor.mutable_data<int64_t>(
crows_dims, paddle::platform::CUDAPlace());
}
int64_t* cols_data = cols_tensor.mutable_data<int64_t>(
nnz_dims, paddle::platform::CUDAPlace());
T* values_data =
values_tensor.mutable_data<T>(nnz_dims, paddle::platform::CUDAPlace());
// test dense_to_sparse
if (mode == "COO") {
sparse.DenseToSparseCoo(rows, cols, dense_data, rows_data, cols_data,
values_data);
} else {
sparse.DenseToSparseCsr(rows, cols, dense_data, rows_data, cols_data,
values_data);
}
std::vector<int64_t> actual_rows(correct_nnz), actual_crows(rows + 1),
actual_cols(correct_nnz);
std::vector<T> actual_values(correct_nnz), actual_dense_vec(rows * cols);
if (mode == "COO") {
paddle::framework::TensorToVector<int64_t>(rows_tensor, *context,
&actual_rows);
} else {
paddle::framework::TensorToVector<int64_t>(rows_tensor, *context,
&actual_crows);
}
paddle::framework::TensorToVector<int64_t>(cols_tensor, *context,
&actual_cols);
paddle::framework::TensorToVector<T>(values_tensor, *context, &actual_values);
for (int i = 0; i < correct_nnz; i++) {
if (mode == "COO") {
CHECK_EQ(correct_rows[i], actual_rows[i]);
}
CHECK_EQ(correct_cols[i], actual_cols[i]);
CHECK_EQ(correct_values[i], actual_values[i]);
}
if (mode == "CSR") {
for (int i = 0; i < rows + 1; i++) {
CHECK_EQ(correct_rows[i], actual_crows[i]);
}
}
// test sparse_to_dense
if (mode == "COO") {
sparse.SparseCooToDense(rows, cols, correct_nnz, rows_data, cols_data,
values_data, actual_dense_data);
} else {
sparse.SparseCsrToDense(rows, cols, correct_nnz, rows_data, cols_data,
values_data, actual_dense_data);
}
paddle::framework::TensorToVector<T>(actual_dense_tensor, *context,
&actual_dense_vec);
for (uint64_t i = 0; i < correct_dense_data.size(); i++) {
CHECK_EQ(correct_dense_data[i], actual_dense_vec[i]);
}
delete context;
}
TEST(sparse, dense_to_sparse) {
std::vector<float> dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 3.2, 0.0, 0.0};
std::vector<float> values = {1.0, 2.0, 3.0, 3.2};
std::vector<int64_t> rows = {0, 1, 1, 2};
std::vector<int64_t> crows = {0, 1, 3, 4};
std::vector<int64_t> cols = {1, 0, 2, 0};
TestDenseToSparse<float>(dense_data, rows, cols, values, 4, 3, 3, "COO");
TestDenseToSparse<float>(dense_data, crows, cols, values, 4, 3, 3, "CSR");
}
TEST(sparse, dense_to_sparse_double) {
std::vector<double> dense_data = {0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 3.2, 0.0};
std::vector<double> values = {1.0, 2.0, 3.0, 3.2};
std::vector<int64_t> rows = {0, 1, 2, 3};
std::vector<int64_t> crows = {0, 1, 2, 3, 4};
std::vector<int64_t> cols = {1, 1, 1, 0};
TestDenseToSparse<double>(dense_data, rows, cols, values, 4, 4, 2, "COO");
TestDenseToSparse<double>(dense_data, crows, cols, values, 4, 4, 2, "CSR");
}
TEST(sparse, dense_to_sparse_fp16) {
using float16 = paddle::platform::float16;
std::vector<float16> dense_data = {float16(0.0), float16(1.0), float16(0.0),
float16(2.0), float16(0.0), float16(3.0),
float16(3.2), float16(0.0)};
std::vector<float16> values = {float16(1.0), float16(2.0), float16(3.0),
float16(3.2)};
std::vector<int64_t> rows = {0, 1, 2, 3};
std::vector<int64_t> crows = {0, 1, 2, 3, 4};
std::vector<int64_t> cols = {1, 1, 1, 0};
TestDenseToSparse<float16>(dense_data, rows, cols, values, 4, 4, 2, "COO");
TestDenseToSparse<float16>(dense_data, crows, cols, values, 4, 4, 2, "CSR");
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace framework {
class ExecutionContext;
class Tensor;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
namespace math {
template <typename DeviceContext>
class Sparse {
public:
explicit Sparse(const DeviceContext& context) : context_(context) {}
template <typename T>
void nnz(const int M, const int N, const T* dense, int* nnz,
int* nnzPerRowColumn) const;
template <typename T>
void DenseToSparseCoo(const int M, const int N, const T* dense, int64_t* rows,
int64_t* cols, T* values) const;
template <typename T>
void DenseToSparseCsr(const int M, const int N, const T* dense,
int64_t* crows, int64_t* cols, T* values) const;
template <typename T>
void SparseCooToDense(const int64_t M, const int64_t N, const int64_t nnz,
const int64_t* rows, const int64_t* cols,
const T* values, T* dense) const;
template <typename T>
void SparseCsrToDense(const int64_t M, const int64_t N, const int64_t nnz,
const int64_t* crows, const int64_t* cols,
const T* values, T* dense) const;
private:
const DeviceContext& context_;
};
template <typename DeviceContext, typename T>
class SparseT : private Sparse<DeviceContext> {
public:
using Sparse<DeviceContext>::Sparse;
template <typename... ARGS>
void nnz(ARGS... args) const {
Base()->template nnz<T>(args...);
}
template <typename... ARGS>
void DenseToSparseCoo(ARGS... args) const {
Base()->template DenseToSparseCoo<T>(args...);
}
template <typename... ARGS>
void DenseToSparseCsr(ARGS... args) const {
Base()->template DenseToSparseCsr<T>(args...);
}
template <typename... ARGS>
void SparseCooToDense(ARGS... args) const {
Base()->template SparseCooToDense<T>(args...);
}
template <typename... ARGS>
void SparseCsrToDense(ARGS... args) const {
Base()->template SparseCsrToDense<T>(args...);
}
private:
const Sparse<DeviceContext>* Base() const {
return static_cast<const Sparse<DeviceContext>*>(this);
}
};
template <typename DeviceContext, typename T>
inline SparseT<DeviceContext, T> GetSparse(
const framework::ExecutionContext& exe_ctx) {
return SparseT<DeviceContext, T>(
exe_ctx.template device_context<DeviceContext>());
}
template <typename DeviceContext, typename T>
inline SparseT<DeviceContext, T> GetSparse(const DeviceContext& dev_ctx) {
return SparseT<DeviceContext, T>(dev_ctx);
}
} // namespace math
} // namespace operators
} // namespace paddle
#if defined(PADDLE_WITH_CUDA)
#if CUDA_VERSION >= 11020
#include "paddle/fluid/operators/math/sparse_impl.cu.h"
#endif
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/dynload/cusparse.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
cudaDataType_t GetGpuDataType() {
if (std::is_same<T, float>::value) {
return CUDA_R_32F;
} else if (std::is_same<T, double>::value) {
return CUDA_R_64F;
} else if (std::is_same<T, platform::float16>::value) {
return CUDA_R_16F;
}
}
template <>
template <typename T>
void Sparse<platform::CUDADeviceContext>::nnz(const int M, const int N,
const T* dense, int* nnz,
int* nnzPerRowColumn) const {}
template <>
template <>
void Sparse<platform::CUDADeviceContext>::nnz(const int M, const int N,
const float* dense, int* nnz,
int* nnzPerRowColumn) const {
cusparseMatDescr_t descr = 0;
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cusparseCreateMatDescr(&descr));
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseSetMatType(
descr, CUSPARSE_MATRIX_TYPE_GENERAL));
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseSetMatIndexBase(
descr, CUSPARSE_INDEX_BASE_ZERO));
context_.CusparseCall([&](cusparseHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseSnnz(
handle, CUSPARSE_DIRECTION_ROW, M, N, descr, dense, M, nnzPerRowColumn,
nnz));
});
}
template <>
template <>
void Sparse<platform::CUDADeviceContext>::nnz(const int M, const int N,
const double* dense, int* nnz,
int* nnzPerRowColumn) const {
cusparseMatDescr_t descr = 0;
PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cusparseCreateMatDescr(&descr));
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseSetMatType(
descr, CUSPARSE_MATRIX_TYPE_GENERAL));
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseSetMatIndexBase(
descr, CUSPARSE_INDEX_BASE_ZERO));
context_.CusparseCall([&](cusparseHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::cusparseDnnz(
handle, CUSPARSE_DIRECTION_ROW, M, N, descr, dense, M, nnzPerRowColumn,
nnz));
});
}
template <typename T>
inline void DenseToSparse(const platform::CUDADeviceContext& context,
const int M, const int N, const T* dense,
int64_t* rows, int64_t* cols, T* values,
const cusparseFormat_t format) {
cusparseSpMatDescr_t matB;
cusparseDnMatDescr_t matA;
cudaDataType_t dtype = GetGpuDataType<T>();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateDnMat(
&matA, M, N, N, const_cast<void*>(reinterpret_cast<const void*>(dense)),
dtype, CUSPARSE_ORDER_ROW));
if (format == CUSPARSE_FORMAT_COO) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateCoo(
&matB, M, N, 0, nullptr, nullptr, nullptr, CUSPARSE_INDEX_64I,
CUSPARSE_INDEX_BASE_ZERO, dtype));
} else if (format == CUSPARSE_FORMAT_CSR) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateCsr(
&matB, M, N, 0, rows, nullptr, nullptr, CUSPARSE_INDEX_64I,
CUSPARSE_INDEX_64I, CUSPARSE_INDEX_BASE_ZERO, dtype));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"the sparse format [%s] is not supported", format));
}
size_t buffer_size = 0;
context.CusparseCall([&](cusparseHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusparseDenseToSparse_bufferSize(
handle, matA, matB, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
&buffer_size));
});
framework::Tensor buffer;
float* buffer_data = buffer.mutable_data<float>(
{static_cast<int64_t>(buffer_size)}, context.GetPlace());
context.CusparseCall([&](cusparseHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusparseDenseToSparse_analysis(
handle, matA, matB, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT,
buffer_data));
});
if (format == CUSPARSE_FORMAT_COO) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCooSetPointers(
matB, rows, cols, reinterpret_cast<void*>(values)));
} else if (format == CUSPARSE_FORMAT_CSR) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCsrSetPointers(
matB, rows, cols, reinterpret_cast<void*>(values)));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"the sparse format [%s] is not supported", format));
}
context.CusparseCall([&](cusparseHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseDenseToSparse_convert(
handle, matA, matB, CUSPARSE_DENSETOSPARSE_ALG_DEFAULT, buffer_data));
});
}
template <>
template <typename T>
void Sparse<platform::CUDADeviceContext>::DenseToSparseCoo(
const int M, const int N, const T* dense, int64_t* rows, int64_t* cols,
T* values) const {
DenseToSparse<T>(context_, M, N, dense, rows, cols, values,
CUSPARSE_FORMAT_COO);
}
template <>
template <typename T>
void Sparse<platform::CUDADeviceContext>::DenseToSparseCsr(
const int M, const int N, const T* dense, int64_t* crows, int64_t* cols,
T* values) const {
DenseToSparse<T>(context_, M, N, dense, crows, cols, values,
CUSPARSE_FORMAT_CSR);
}
template <typename T>
void SparseToDense(const platform::CUDADeviceContext& context, const int64_t M,
const int64_t N, const int64_t nnz, const int64_t* rows,
const int64_t* cols, const T* values, T* dense,
const cusparseFormat_t format) {
cusparseSpMatDescr_t matA;
cusparseDnMatDescr_t matB;
cudaDataType_t dtype = GetGpuDataType<T>();
if (format == CUSPARSE_FORMAT_COO) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateCoo(
&matA, M, N, nnz,
const_cast<void*>(reinterpret_cast<const void*>(rows)),
const_cast<void*>(reinterpret_cast<const void*>(cols)),
const_cast<void*>(reinterpret_cast<const void*>(values)),
CUSPARSE_INDEX_64I, CUSPARSE_INDEX_BASE_ZERO, dtype));
} else if (format == CUSPARSE_FORMAT_CSR) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateCsr(
&matA, M, N, nnz,
const_cast<void*>(reinterpret_cast<const void*>(rows)),
const_cast<void*>(reinterpret_cast<const void*>(cols)),
const_cast<void*>(reinterpret_cast<const void*>(values)),
CUSPARSE_INDEX_64I, CUSPARSE_INDEX_64I, CUSPARSE_INDEX_BASE_ZERO,
dtype));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"the sparse format [%s] is not supported", format));
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseCreateDnMat(
&matB, M, N, N, reinterpret_cast<void*>(dense), dtype,
CUSPARSE_ORDER_ROW));
size_t buffer_size = 0;
context.CusparseCall([&](cusparseHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cusparseSparseToDense_bufferSize(
handle, matA, matB, CUSPARSE_SPARSETODENSE_ALG_DEFAULT,
&buffer_size));
});
framework::Tensor buffer;
float* buffer_data = buffer.mutable_data<float>(
{static_cast<int64_t>(buffer_size)}, context.GetPlace());
context.CusparseCall([&](cusparseHandle_t handle) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cusparseSparseToDense(
handle, matA, matB, CUSPARSE_SPARSETODENSE_ALG_DEFAULT, buffer_data));
});
}
template <>
template <typename T>
void Sparse<platform::CUDADeviceContext>::SparseCooToDense(
const int64_t M, const int64_t N, const int64_t nnz, const int64_t* rows,
const int64_t* cols, const T* values, T* dense) const {
SparseToDense<T>(context_, M, N, nnz, rows, cols, values, dense,
CUSPARSE_FORMAT_COO);
}
template <>
template <typename T>
void Sparse<platform::CUDADeviceContext>::SparseCsrToDense(
const int64_t M, const int64_t N, const int64_t nnz, const int64_t* crows,
const int64_t* cols, const T* values, T* dense) const {
SparseToDense<T>(context_, M, N, nnz, crows, cols, values, dense,
CUSPARSE_FORMAT_CSR);
}
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/cusparse.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace platform {
class CusparseHandleHolder {
public:
explicit CusparseHandleHolder(cudaStream_t stream) {
// ROCM is not yet supported
#if defined(PADDLE_WITH_CUDA)
// The generic APIs is supported from CUDA10.1
#if CUDA_VERSION >= 10010
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseCreate(&handle_));
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseSetStream(handle_, stream));
#endif
#endif
}
const cusparseHandle_t& GetCusparseHandle() const { return handle_; }
~CusparseHandleHolder() PADDLE_MAY_THROW {
#if defined(PADDLE_WITH_CUDA)
#if CUDA_VERSION >= 10010
PADDLE_RETRY_CUDA_SUCCESS(dynload::cusparseDestroy(handle_));
#endif
#endif
}
template <typename Callback>
inline void Call(Callback&& callback) const {
std::lock_guard<std::mutex> guard(mtx_);
callback(handle_);
}
private:
DISABLE_COPY_AND_ASSIGN(CusparseHandleHolder);
cusparseHandle_t handle_;
mutable std::mutex mtx_;
};
} // namespace platform
} // namespace paddle
......@@ -19,6 +19,7 @@
#include "paddle/fluid/platform/device/gpu/rocm/rocm_helper.h"
#else
#include "paddle/fluid/platform/device/gpu/cuda/cuda_helper.h"
#include "paddle/fluid/platform/device/gpu/cuda/cusparse_helper.h"
#endif
#define CUDA_KERNEL_LOOP(i, num) CUDA_KERNEL_LOOP_TYPE(i, num, int)
......
......@@ -484,6 +484,7 @@ CUDAContext::CUDAContext(const CUDAPlace& place,
InitCuBlasContext();
InitCuDNNContext();
#ifndef PADDLE_WITH_HIP
InitCuSparseContext();
InitCuSolverContext();
#endif
}
......@@ -513,6 +514,7 @@ CUDAContext::~CUDAContext() {
DestoryCuDNNContext();
DestoryCuBlasContext();
#ifndef PADDLE_WITH_HIP
DestoryCuSparseContext();
DestoryCuSolverContext();
#endif
}
......@@ -630,6 +632,9 @@ rocblas_handle CUDADeviceContext::cublas_handle() const {
cublasHandle_t CUDADeviceContext::cublas_handle() const {
return context()->CublasHandle()->GetCublasHandle();
}
cusparseHandle_t CUDADeviceContext::cusparse_handle() const {
return context()->CusparseHandle()->GetCusparseHandle();
}
#endif
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
......
......@@ -358,6 +358,12 @@ class CUDAContext {
return cublas_tensor_core_handle_;
}
#ifndef PADDLE_WITH_HIP
const std::unique_ptr<CusparseHandleHolder>& CusparseHandle() const {
return cusparse_handle_;
}
#endif
/*! \brief Call cublas function safely. */
template <typename Callback>
inline void CublasCall(Callback&& callback) const {
......@@ -368,6 +374,14 @@ class CUDAContext {
}
}
#ifndef PADDLE_WITH_HIP
/*! \brief Call cusparse function safely. */
template <typename Callback>
inline void CusparseCall(Callback&& callback) const {
cusparse_handle_->Call(std::forward<Callback>(callback));
}
#endif
/*! \brief Check whether tensor core is supported */
bool tensor_core_available() const;
......@@ -406,6 +420,12 @@ class CUDAContext {
}
#endif
#ifndef PADDLE_WITH_HIP
void InitCuSparseContext() {
cusparse_handle_.reset(new CusparseHandleHolder(RawStream()));
}
#endif
void InitCuDNNContext() {
if (dynload::HasCUDNN()) {
#ifdef PADDLE_WITH_HIP
......@@ -478,6 +498,10 @@ class CUDAContext {
cublas_tf32_tensor_core_handle_.reset();
}
#ifndef PADDLE_WITH_HIP
void DestoryCuSparseContext() { cusparse_handle_.reset(); }
#endif
#ifndef PADDLE_WITH_HIP
void DestoryCuSolverContext() {
if (cusolver_dn_handle_) {
......@@ -501,6 +525,7 @@ class CUDAContext {
std::unique_ptr<CublasHandleHolder> cublas_tf32_tensor_core_handle_;
#ifndef PADDLE_WITH_HIP
cusolverDnHandle_t cusolver_dn_handle_;
std::unique_ptr<CusparseHandleHolder> cusparse_handle_;
#endif
DISABLE_COPY_AND_ASSIGN(CUDAContext);
};
......@@ -540,6 +565,14 @@ class CUDADeviceContext : public DeviceContext {
return context()->CublasCall(callback);
}
#ifndef PADDLE_WITH_HIP
/*! \brief Call cusparse function safely. */
template <typename Callback>
inline void CusparseCall(Callback&& callback) const {
return context()->CusparseCall(callback);
}
#endif
/*! \brief Check whether tensor core is supported */
bool tensor_core_available() const;
......@@ -562,6 +595,7 @@ class CUDADeviceContext : public DeviceContext {
rocblas_handle cublas_handle() const;
#else
cublasHandle_t cublas_handle() const;
cusparseHandle_t cusparse_handle() const;
#endif
/*! \brief Return a cudnn workspace handle to call multiple cudnn
......
......@@ -30,6 +30,10 @@ CUSPARSE_ROUTINE_EACH(DEFINE_WRAP);
#ifdef CUBLAS_BLAS_ROUTINE_EACH_R2
CUSPARSE_ROUTINE_EACH_R2(DEFINE_WRAP);
#endif
#ifdef CUSPARSE_ROUTINE_EACH_11020
CUSPARSE_ROUTINE_EACH_11020(DEFINE_WRAP);
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -41,21 +41,41 @@ extern void *cusparse_dso_handle;
}; \
extern DynLoad__##__name __name
#if !defined(PADDLE_WITH_ARM) && !defined(_WIN32)
// APIs available after CUDA 11.0
#if CUDA_VERSION >= 11000
#if !defined(PADDLE_WITH_ARM)
// The generic APIs is supported from CUDA10.1
#if CUDA_VERSION >= 10010
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDestroy);
__macro(cusparseSetStream); \
__macro(cusparseCreateMatDescr); \
__macro(cusparseDestroy); \
__macro(cusparseSnnz); \
__macro(cusparseDnnz); \
__macro(cusparseSetMatType); \
__macro(cusparseSetMatIndexBase);
CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP);
// APIs available after CUDA 11.2
#if CUDA_VERSION >= 11020
#define CUSPARSE_ROUTINE_EACH_11020(__macro) \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateCoo); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseCooSetPointers); \
__macro(cusparseCsrSetPointers); \
__macro(cusparseDenseToSparse_bufferSize); \
__macro(cusparseDenseToSparse_analysis); \
__macro(cusparseDenseToSparse_convert); \
__macro(cusparseSparseToDense_bufferSize); \
__macro(cusparseSparseToDense);
CUSPARSE_ROUTINE_EACH_11020(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
// APIs available after CUDA 11.3
#if CUDA_VERSION >= 11030
#define CUSPARSE_ROUTINE_EACH_R2(__macro) \
......@@ -67,6 +87,7 @@ CUSPARSE_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif
#endif
#endif
#endif
#undef DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
} // namespace dynload
......
......@@ -33,6 +33,7 @@ limitations under the License. */
#include <cudnn.h>
#include <cufft.h>
#include <curand.h>
#include <cusparse.h>
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#include "paddle/fluid/platform/external_error.pb.h"
......@@ -707,6 +708,7 @@ DEFINE_EXTERNAL_API_TYPE(cudaError_t, cudaSuccess, CUDA);
DEFINE_EXTERNAL_API_TYPE(curandStatus_t, CURAND_STATUS_SUCCESS, CURAND);
DEFINE_EXTERNAL_API_TYPE(cudnnStatus_t, CUDNN_STATUS_SUCCESS, CUDNN);
DEFINE_EXTERNAL_API_TYPE(cublasStatus_t, CUBLAS_STATUS_SUCCESS, CUBLAS);
DEFINE_EXTERNAL_API_TYPE(cusparseStatus_t, CUSPARSE_STATUS_SUCCESS, CUSPARSE);
DEFINE_EXTERNAL_API_TYPE(cusolverStatus_t, CUSOLVER_STATUS_SUCCESS, CUSOLVER);
DEFINE_EXTERNAL_API_TYPE(cufftResult_t, CUFFT_SUCCESS, CUFFT);
DEFINE_EXTERNAL_API_TYPE(CUresult, CUDA_SUCCESS, CU);
......@@ -750,6 +752,10 @@ inline const char* GetErrorMsgUrl(T status) {
break;
case platform::proto::ApiType::CUFFT:
return "https://docs.nvidia.com/cuda/cufft/index.html#cufftresult";
case platform::proto::ApiType::CUSPARSE:
return "https://docs.nvidia.com/cuda/cusparse/"
"index.html#cusparseStatus_t";
break;
default:
return "Unknown type of External API, can't get error message URL!";
break;
......@@ -837,6 +843,7 @@ template std::string GetExternalErrorMsg<cudaError_t>(cudaError_t);
template std::string GetExternalErrorMsg<curandStatus_t>(curandStatus_t);
template std::string GetExternalErrorMsg<cudnnStatus_t>(cudnnStatus_t);
template std::string GetExternalErrorMsg<cublasStatus_t>(cublasStatus_t);
template std::string GetExternalErrorMsg<cusparseStatus_t>(cusparseStatus_t);
template std::string GetExternalErrorMsg<cusolverStatus_t>(cusolverStatus_t);
template std::string GetExternalErrorMsg<cufftResult_t>(cufftResult_t);
template std::string GetExternalErrorMsg<CUresult>(CUresult);
......@@ -889,6 +896,17 @@ inline std::string build_nvidia_error_msg(cublasStatus_t stat) {
return sout.str();
}
/*************** CUSPARSE ERROR ***************/
inline bool is_error(cusparseStatus_t stat) {
return stat != CUSPARSE_STATUS_SUCCESS;
}
inline std::string build_nvidia_error_msg(cusparseStatus_t stat) {
std::ostringstream sout;
sout << "CUSparse error(" << stat << "). " << GetExternalErrorMsg(stat);
return sout.str();
}
/*************** CUSOLVER ERROR ***************/
inline bool is_error(cusolverStatus_t stat) {
return stat != CUSOLVER_STATUS_SUCCESS;
......
......@@ -26,6 +26,7 @@ enum ApiType {
NCCL = 5;
CUFFT = 6;
CU = 7;
CUSPARSE = 8;
}
message MessageDesc {
......@@ -45,4 +46,4 @@ message AllMessageDesc {
message ExternalErrorDesc {
// Error messages of different kind of external third party API
repeated AllMessageDesc errors = 1;
}
\ No newline at end of file
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册