Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
7e18106a
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7e18106a
编写于
9月 15, 2021
作者:
L
Liu-xiandong
提交者:
GitHub
9月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add nvidia cusparse library, test=develop (#35675)
Put Nvidia's cusparse library into paddle.
上级
3760be06
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
115 addition
and
1 deletion
+115
-1
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+1
-0
paddle/fluid/platform/dynload/CMakeLists.txt
paddle/fluid/platform/dynload/CMakeLists.txt
+1
-1
paddle/fluid/platform/dynload/cusparse.cc
paddle/fluid/platform/dynload/cusparse.cc
+31
-0
paddle/fluid/platform/dynload/cusparse.h
paddle/fluid/platform/dynload/cusparse.h
+64
-0
paddle/fluid/platform/dynload/dynamic_loader.cc
paddle/fluid/platform/dynload/dynamic_loader.cc
+17
-0
paddle/fluid/platform/dynload/dynamic_loader.h
paddle/fluid/platform/dynload/dynamic_loader.h
+1
-0
未找到文件。
paddle/fluid/platform/device_context.h
浏览文件 @
7e18106a
...
...
@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/cusolver.h"
#include "paddle/fluid/platform/dynload/cusparse.h"
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/platform/dynload/nccl.h"
#endif
...
...
paddle/fluid/platform/dynload/CMakeLists.txt
浏览文件 @
7e18106a
cc_library
(
dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags enforce
)
list
(
APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc nvtx.cc
)
list
(
APPEND CUDA_SRCS cublas.cc cudnn.cc curand.cc cusolver.cc
cusparse.cc
nvtx.cc
)
if
(
NOT WITH_NV_JETSON
)
list
(
APPEND CUDA_SRCS nvjpeg.cc
)
...
...
paddle/fluid/platform/dynload/cusparse.cc
0 → 100644
浏览文件 @
7e18106a
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/cusparse.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
std
::
once_flag
cusparse_dso_flag
;
void
*
cusparse_dso_handle
;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
#ifdef CUSPARSE_ROUTINE_EACH
CUSPARSE_ROUTINE_EACH
(
DEFINE_WRAP
);
#endif
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/cusparse.h
0 → 100644
浏览文件 @
7e18106a
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
#include <cusparse.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/platform/port.h"
namespace
paddle
{
namespace
platform
{
namespace
dynload
{
extern
std
::
once_flag
cusparse_dso_flag
;
extern
void
*
cusparse_dso_handle
;
#define DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cusparseStatus_t operator()(Args... args) { \
using cusparseFunc = decltype(&::__name); \
std::call_once(cusparse_dso_flag, []() { \
cusparse_dso_handle = \
paddle::platform::dynload::GetCusparseDsoHandle(); \
}); \
static void *p_##__name = dlsym(cusparse_dso_handle, #__name); \
return reinterpret_cast<cusparseFunc>(p_##__name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#if CUDA_VERSION >= 11020
#define CUSPARSE_ROUTINE_EACH(__macro) \
__macro(cusparseCreate); \
__macro(cusparseCreateCsr); \
__macro(cusparseCreateDnMat); \
__macro(cusparseSpMM_bufferSize); \
__macro(cusparseSpMM); \
__macro(cusparseDestroySpMat); \
__macro(cusparseDestroyDnMat); \
__macro(cusparseDestroy); \
__macro(cusparseSDDMM_bufferSize); \
__macro(cusparseSDDMM_preprocess); \
__macro(cusparseSDDMM);
CUSPARSE_ROUTINE_EACH
(
DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
);
#endif
#undef DECLARE_DYNAMIC_LOAD_CUSPARSE_WRAP
}
// namespace dynload
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/dynload/dynamic_loader.cc
浏览文件 @
7e18106a
...
...
@@ -106,6 +106,9 @@ static constexpr char* win_nvjpeg_lib =
static
constexpr
char
*
win_cusolver_lib
=
"cusolver64_"
CUDA_VERSION_MAJOR
CUDA_VERSION_MINOR
".dll;cusolver64_"
CUDA_VERSION_MAJOR
".dll;cusolver64_10.dll"
;
static
constexpr
char
*
win_cusparse_lib
=
"cusparse64_"
CUDA_VERSION_MAJOR
CUDA_VERSION_MINOR
".dll;cusparse64_"
CUDA_VERSION_MAJOR
".dll;cusparse64_10.dll"
;
#else
static
constexpr
char
*
win_curand_lib
=
"curand64_"
CUDA_VERSION_MAJOR
CUDA_VERSION_MINOR
...
...
@@ -116,6 +119,9 @@ static constexpr char* win_nvjpeg_lib =
static
constexpr
char
*
win_cusolver_lib
=
"cusolver64_"
CUDA_VERSION_MAJOR
CUDA_VERSION_MINOR
".dll;cusolver64_"
CUDA_VERSION_MAJOR
".dll"
;
static
constexpr
char
*
win_cusparse_lib
=
"cusparse64_"
CUDA_VERSION_MAJOR
CUDA_VERSION_MINOR
".dll;cusparse64_"
CUDA_VERSION_MAJOR
".dll"
;
#endif // CUDA_VERSION
#endif
...
...
@@ -358,6 +364,17 @@ void* GetCusolverDsoHandle() {
#endif
}
void
*
GetCusparseDsoHandle
()
{
#if defined(__APPLE__) || defined(__OSX__)
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libcusparse.dylib"
);
#elif defined(_WIN32) && defined(PADDLE_WITH_CUDA)
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
win_cusparse_lib
,
true
,
{
cuda_lib_path
});
#else
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libcusparse.so"
);
#endif
}
void
*
GetNVRTCDsoHandle
()
{
#if defined(__APPLE__) || defined(__OSX__)
return
GetDsoHandleFromSearchPath
(
FLAGS_cuda_dir
,
"libnvrtc.dylib"
,
false
);
...
...
paddle/fluid/platform/dynload/dynamic_loader.h
浏览文件 @
7e18106a
...
...
@@ -31,6 +31,7 @@ void* GetCUPTIDsoHandle();
void
*
GetCurandDsoHandle
();
void
*
GetNvjpegDsoHandle
();
void
*
GetCusolverDsoHandle
();
void
*
GetCusparseDsoHandle
();
void
*
GetNVRTCDsoHandle
();
void
*
GetCUDADsoHandle
();
void
*
GetWarpCTCDsoHandle
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录