diff --git a/paddle/fluid/platform/dynload/cusparse.cc b/paddle/fluid/platform/dynload/cusparse.cc index 2b41da541d9ae03dc93119685ed22f60bd0eb849..2a1fe322dabcf735eacf1dccb3eaabd821a23421 100644 --- a/paddle/fluid/platform/dynload/cusparse.cc +++ b/paddle/fluid/platform/dynload/cusparse.cc @@ -26,6 +26,10 @@ void *cusparse_dso_handle; #ifdef CUSPARSE_ROUTINE_EACH CUSPARSE_ROUTINE_EACH(DEFINE_WRAP); #endif + +#ifdef CUBLAS_BLAS_ROUTINE_EACH_R2 +CUSPARSE_ROUTINE_EACH_R2(DEFINE_WRAP); +#endif } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/cusparse.h b/paddle/fluid/platform/dynload/cusparse.h index 98841949676e471ae40c841a1e25aad06b0b3d64..e5be003fadf066bd54d24ba1b17f0dc64a8ebac1 100644 --- a/paddle/fluid/platform/dynload/cusparse.h +++ b/paddle/fluid/platform/dynload/cusparse.h @@ -41,8 +41,9 @@ extern void *cusparse_dso_handle; }; \ extern DynLoad__##__name __name -#ifndef _WIN32 -#if CUDA_VERSION >= 11020 +#if !defined(PADDLE_WITH_ARM) && !defined(_WIN32) +// APIs available after CUDA 11.0 +#if CUDA_VERSION >= 11000 #define CUSPARSE_ROUTINE_EACH(__macro) \ __macro(cusparseCreate); \ __macro(cusparseCreateCsr); \ @@ -51,12 +52,19 @@ extern void *cusparse_dso_handle; __macro(cusparseSpMM); \ __macro(cusparseDestroySpMat); \ __macro(cusparseDestroyDnMat); \ - __macro(cusparseDestroy); \ - __macro(cusparseSDDMM_bufferSize); \ - __macro(cusparseSDDMM_preprocess); \ - __macro(cusparseSDDMM); + __macro(cusparseDestroy); CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP); + +// APIs available after CUDA 11.2 +#if CUDA_VERSION >= 11020 +#define CUSPARSE_ROUTINE_EACH_R2(__macro) \ + __macro(cusparseSDDMM_bufferSize); \ + __macro(cusparseSDDMM_preprocess); \ + __macro(cusparseSDDMM); + +CUSPARSE_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP) +#endif #endif #endif diff --git a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py index ad618edd24d55b6bfdf309cf00e71b213ca648da..48401fb55ef3f568b4db4edd019f4b7375c032a6 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_attention_op.py @@ -169,13 +169,13 @@ class TestSparseAttentionOp(OpTest): 'Q': self.q, 'K': self.k, 'V': self.v, - 'offset': self.offset, - 'columns': self.columns + 'Offset': self.offset, + 'Columns': self.columns } self.outputs = { 'Out': result.astype(self.dtype), - 'ResultSdd': result_sdd.astype(self.dtype), - 'ResultSoftmax': result_softmax.astype(self.dtype) + 'SparseDotSdd': result_sdd.astype(self.dtype), + 'Softmax': result_softmax.astype(self.dtype) } def test_check_output(self):