未验证 提交 3eb50715 编写于 作者: L Liu-xiandong 提交者: GitHub

fix cusparse compile problem, test=develop (#36199)

* fix cusparse compile problem, test=develop

* Modify file permissions
上级 1f93582c
...@@ -26,6 +26,10 @@ void *cusparse_dso_handle; ...@@ -26,6 +26,10 @@ void *cusparse_dso_handle;
#ifdef CUSPARSE_ROUTINE_EACH #ifdef CUSPARSE_ROUTINE_EACH
CUSPARSE_ROUTINE_EACH(DEFINE_WRAP); CUSPARSE_ROUTINE_EACH(DEFINE_WRAP);
#endif #endif
#ifdef CUBLAS_BLAS_ROUTINE_EACH_R2
CUSPARSE_ROUTINE_EACH_R2(DEFINE_WRAP);
#endif
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -41,8 +41,9 @@ extern void *cusparse_dso_handle; ...@@ -41,8 +41,9 @@ extern void *cusparse_dso_handle;
}; \ }; \
extern DynLoad__##__name __name extern DynLoad__##__name __name
#ifndef _WIN32 #if !defined(PADDLE_WITH_ARM) && !defined(_WIN32)
#if CUDA_VERSION >= 11020 // APIs available after CUDA 11.0
#if CUDA_VERSION >= 11000
#define CUSPARSE_ROUTINE_EACH(__macro) \ #define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \ __macro(cusparseCreate); \
__macro(cusparseCreateCsr); \ __macro(cusparseCreateCsr); \
...@@ -51,12 +52,19 @@ extern void *cusparse_dso_handle; ...@@ -51,12 +52,19 @@ extern void *cusparse_dso_handle;
__macro(cusparseSpMM); \ __macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \ __macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \ __macro(cusparseDestroyDnMat); \
__macro(cusparseDestroy); \ __macro(cusparseDestroy);
__macro(cusparseSDDMM_bufferSize); \
__macro(cusparseSDDMM_preprocess); \
__macro(cusparseSDDMM);
CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP); CUSPARSE_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP);
// APIs available after CUDA 11.2
#if CUDA_VERSION >= 11020
#define CUSPARSE_ROUTINE_EACH_R2(__macro) \
__macro(cusparseSDDMM_bufferSize); \
__macro(cusparseSDDMM_preprocess); \
__macro(cusparseSDDMM);
CUSPARSE_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP)
#endif
#endif #endif
#endif #endif
......
...@@ -169,13 +169,13 @@ class TestSparseAttentionOp(OpTest): ...@@ -169,13 +169,13 @@ class TestSparseAttentionOp(OpTest):
'Q': self.q, 'Q': self.q,
'K': self.k, 'K': self.k,
'V': self.v, 'V': self.v,
'offset': self.offset, 'Offset': self.offset,
'columns': self.columns 'Columns': self.columns
} }
self.outputs = { self.outputs = {
'Out': result.astype(self.dtype), 'Out': result.astype(self.dtype),
'ResultSdd': result_sdd.astype(self.dtype), 'SparseDotSdd': result_sdd.astype(self.dtype),
'ResultSoftmax': result_softmax.astype(self.dtype) 'Softmax': result_softmax.astype(self.dtype)
} }
def test_check_output(self): def test_check_output(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册