Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
90ae3533
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
90ae3533
编写于
6月 22, 2022
作者:
X
xiaoxiaohehe001
提交者:
GitHub
6月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
gpu_context (#43661)
上级
1aafc31b
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
461 addition
and
267 deletion
+461
-267
paddle/phi/backends/gpu/CMakeLists.txt
paddle/phi/backends/gpu/CMakeLists.txt
+2
-1
paddle/phi/backends/gpu/gpu_context.cc
paddle/phi/backends/gpu/gpu_context.cc
+133
-266
paddle/phi/backends/gpu/gpu_context.h
paddle/phi/backends/gpu/gpu_context.h
+4
-0
paddle/phi/backends/gpu/gpu_resources.cc
paddle/phi/backends/gpu/gpu_resources.cc
+271
-0
paddle/phi/backends/gpu/gpu_resources.h
paddle/phi/backends/gpu/gpu_resources.h
+51
-0
未找到文件。
paddle/phi/backends/gpu/CMakeLists.txt
浏览文件 @
90ae3533
...
...
@@ -6,4 +6,5 @@ elseif(WITH_ROCM)
hip_library
(
phi_gpu_info SRCS gpu_info.cc DEPS phi_rocm_info gflags glog enforce phi_dynload_cuda
)
endif
()
cc_library
(
gpu_context SRCS gpu_context.cc DEPS phi_device_context phi_gpu_info eigen3
)
cc_library
(
gpu_resources SRCS gpu_resources.cc DEPS phi_device_context phi_gpu_info
)
cc_library
(
gpu_context SRCS gpu_context.cc DEPS phi_device_context phi_gpu_info eigen3 gpu_resources
)
paddle/phi/backends/gpu/gpu_context.cc
浏览文件 @
90ae3533
...
...
@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include <algorithm>
#include <array>
#include <functional>
...
...
@@ -21,10 +22,11 @@ limitations under the License. */
#include <memory>
#include <mutex>
#include "glog/logging.h"
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_resources.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
...
...
@@ -202,27 +204,31 @@ struct GPUContext::Impl {
void
Init
()
{
owned_
=
true
;
backends
::
gpu
::
GPUDeviceGuard
guard
(
place_
.
device
);
InitGpuProperties
();
InitStream
();
phi
::
InitGpuProperties
(
place_
,
&
compute_capability_
,
&
runtime_version_
,
&
driver_version_
,
&
multi_process_
,
&
max_threads_per_mp_
,
&
max_threads_per_block_
,
&
max_grid_dim_size_
);
phi
::
InitStream
(
&
stream_
);
InitEigenDevice
();
InitBlasHandle
();
InitBlasLtHandle
();
InitDNNHandle
();
InitSolverHandle
();
InitSparseHandle
();
InitDnnWorkspace
();
}
void
PartialInitWithoutAllocator
()
{
owned_
=
true
;
backends
::
gpu
::
GPUDeviceGuard
guard
(
place_
.
device
);
InitGpuProperties
();
InitStream
();
InitBlasHandle
();
InitBlasLtHandle
();
InitDNNHandle
();
InitSolverHandle
();
InitSparseHandle
();
phi
::
InitGpuProperties
(
place_
,
&
compute_capability_
,
&
runtime_version_
,
&
driver_version_
,
&
multi_process_
,
&
max_threads_per_mp_
,
&
max_threads_per_block_
,
&
max_grid_dim_size_
);
phi
::
InitStream
(
&
stream_
);
}
void
PartialInitWithAllocator
()
{
...
...
@@ -238,19 +244,23 @@ struct GPUContext::Impl {
~
Impl
()
{
backends
::
gpu
::
GPUDeviceGuard
guard
(
place_
.
device
);
DestoryInternalWorkspace
();
DestoryInternalEigenDevice
();
DestroyInternalSparseHandle
();
DestroyInternalSolverHandle
();
DestroyInternalDnnHandle
();
if
(
owned_
)
{
DestoryInternalWorkspace
();
DestoryInternalEigenDevice
();
phi
::
DestroySparseHandle
(
sparse_handle_
);
phi
::
DestroySolverHandle
(
solver_handle_
);
phi
::
DestroyDnnHandle
(
dnn_handle_
);
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
if
(
nccl_comm_
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
ncclCommDestroy
(
nccl_comm_
));
}
if
(
nccl_comm_
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
ncclCommDestroy
(
nccl_comm_
));
}
#endif
DestroyInternalBlasHandle
();
DestroyInternalBlasLtHandle
();
DestoryInternalStream
();
phi
::
DestroyBlasHandle
(
blas_handle_
);
phi
::
DestroyBlasHandle
(
blas_tensor_core_handle_
);
phi
::
DestroyBlasHandle
(
blas_tf32_tensor_core_handle_
);
phi
::
DestroyBlasLtHandle
(
blaslt_handle_
);
phi
::
DestoryStream
(
stream_
);
}
}
const
Place
&
GetPlace
()
const
{
return
place_
;
}
...
...
@@ -259,73 +269,6 @@ struct GPUContext::Impl {
return
blas_tensor_core_handle_
!=
nullptr
;
}
void
InitGpuProperties
()
{
backends
::
gpu
::
GPUDeviceGuard
guard
(
place_
.
GetDeviceId
());
compute_capability_
=
backends
::
gpu
::
GetGPUComputeCapability
(
place_
.
GetDeviceId
());
multi_process_
=
backends
::
gpu
::
GetGPUMultiProcessors
(
place_
.
GetDeviceId
());
max_threads_per_mp_
=
backends
::
gpu
::
GetGPUMaxThreadsPerMultiProcessor
(
place_
.
GetDeviceId
());
max_grid_dim_size_
=
backends
::
gpu
::
GetGpuMaxGridDimSize
(
place_
.
GetDeviceId
());
max_threads_per_block_
=
backends
::
gpu
::
GetGPUMaxThreadsPerBlock
(
place_
.
GetDeviceId
());
driver_version_
=
backends
::
gpu
::
GetGPUDriverVersion
(
place_
.
GetDeviceId
());
runtime_version_
=
backends
::
gpu
::
GetGPURuntimeVersion
(
place_
.
GetDeviceId
());
// TODO(wilber): glog may be replaced in the future?
LOG_FIRST_N
(
WARNING
,
1
)
<<
"Please NOTE: device: "
<<
static_cast
<
int
>
(
place_
.
device
)
<<
", GPU Compute Capability: "
<<
compute_capability_
/
10
<<
"."
<<
compute_capability_
%
10
<<
", Driver API Version: "
<<
driver_version_
/
1000
<<
"."
<<
(
driver_version_
%
100
)
/
10
<<
", Runtime API Version: "
<<
runtime_version_
/
1000
<<
"."
<<
(
runtime_version_
%
100
)
/
10
;
#ifdef PADDLE_WITH_HIP
size_t
miopen_major
,
miopen_minor
,
miopen_patch
;
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
miopenGetVersion
(
&
miopen_major
,
&
miopen_minor
,
&
miopen_patch
));
auto
cudnn_dso_ver
=
(
miopen_major
*
1000
+
miopen_minor
*
10
+
miopen_patch
)
/
10
;
auto
compile_miopen_version
=
MIOPEN_VERSION
/
10
;
if
(
cudnn_dso_ver
<
static_cast
<
size_t
>
(
compile_miopen_version
))
{
LOG_FIRST_N
(
WARNING
,
1
)
<<
"WARNING: device: "
<<
static_cast
<
int
>
(
place_
.
device
)
<<
". The installed Paddle is compiled with MIOPEN "
<<
compile_miopen_version
/
100
<<
"."
<<
compile_miopen_version
%
100
<<
", but MIOPEN version in your machine is "
<<
cudnn_dso_ver
/
100
<<
"."
<<
cudnn_dso_ver
%
100
<<
", which may cause serious incompatible bug. "
<<
"Please recompile or reinstall Paddle with compatible MIOPEN "
"version."
;
}
#else
size_t
cudnn_dso_ver
=
dynload
::
cudnnGetVersion
();
LOG_FIRST_N
(
WARNING
,
1
)
<<
"device: "
<<
static_cast
<
int
>
(
place_
.
device
)
<<
", cuDNN Version: "
<<
cudnn_dso_ver
/
1000
<<
"."
<<
(
cudnn_dso_ver
%
1000
)
/
100
<<
"."
;
// Check CUDA/CUDNN version compatiblity
auto
local_cuda_version
=
(
driver_version_
/
1000
)
*
10
+
(
driver_version_
%
100
)
/
10
;
auto
compile_cuda_version
=
(
CUDA_VERSION
/
1000
)
*
10
+
(
CUDA_VERSION
%
100
)
/
10
;
if
(
local_cuda_version
<
compile_cuda_version
)
{
LOG_FIRST_N
(
WARNING
,
1
)
<<
"WARNING: device: "
<<
static_cast
<
int
>
(
place_
.
device
)
<<
". The installed Paddle is compiled with CUDA "
<<
compile_cuda_version
/
10
<<
"."
<<
compile_cuda_version
%
10
<<
", but CUDA runtime version in your machine is "
<<
local_cuda_version
/
10
<<
"."
<<
local_cuda_version
%
10
<<
", which may cause serious incompatible bug. "
<<
"Please recompile or reinstall Paddle with compatible CUDA "
"version."
;
}
#endif
}
void
InitDnnWorkspace
()
{
PD_CHECK
(
allocator_
!=
nullptr
,
"the device allocator for gpu context is nullptr."
);
...
...
@@ -350,27 +293,6 @@ struct GPUContext::Impl {
return
DnnWorkspaceHandle
(
allocator_
,
stream_
);
}
void
InitStream
()
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipStreamCreateWithPriority
(
&
stream_
,
hipStreamDefault
,
0
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamCreateWithPriority
(
&
stream_
,
cudaStreamDefault
,
0
));
#endif
}
void
DestoryInternalStream
()
{
if
(
owned_
&&
stream_
!=
nullptr
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipStreamDestroy
(
stream_
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamDestroy
(
stream_
));
#endif
}
stream_
=
nullptr
;
}
void
SetStream
(
gpuStream_t
stream
)
{
stream_
=
stream
;
}
gpuStream_t
GetStream
()
const
{
...
...
@@ -400,129 +322,56 @@ struct GPUContext::Impl {
return
eigen_device_
;
}
void
InitBlasHandle
()
{
#ifdef PADDLE_WITH_HIP
phi
::
dynload
::
rocblas_create_handle
(
&
blas_handle_
);
phi
::
dynload
::
rocblas_set_stream
(
blas_handle_
,
stream_
);
#else // PADDLE_WITH_CUDA
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasCreate
(
&
blas_handle_
));
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetStream
(
blas_handle_
,
stream_
));
blasHandle_t
GetBlasHandle
()
{
std
::
call_once
(
flag_blas_
,
[
=
]()
{
if
(
!
blas_handle_
)
{
phi
::
InitBlasHandle
(
&
blas_handle_
,
stream_
);
}
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasCreate
(
&
blas_tensor_core_handle_
)
);
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetStream
(
blas_tensor_core_handle_
,
stream_
));
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetMathMode
(
blas_tensor_core_handle_
,
CUBLAS_TENSOR_OP_MATH
));
if
(
!
blas_tensor_core_handle_
)
{
phi
::
InitBlasHandle
(
&
blas_tensor_core_handle_
,
stream_
);
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetMathMode
(
blas_tensor_core_handle_
,
CUBLAS_TENSOR_OP_MATH
));
}
#endif
#if CUDA_VERSION >= 11000
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasCreate
(
&
blas_tf32_tensor_core_handle_
));
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetStream
(
blas_tf32_tensor_core_handle_
,
stream_
));
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetMathMode
(
blas_tf32_tensor_core_handle_
,
CUBLAS_TF32_TENSOR_OP_MATH
));
#endif // CUDA_VERSION >= 11000
#endif // CUDA_VERSION >= 9000
#endif // PADDLE_WITH_HIP
}
void
DestroyInternalBlasHandle
()
{
#ifdef PADDLE_WITH_HIP
if
(
owned_
&&
blas_handle_
!=
nullptr
)
{
phi
::
dynload
::
rocblas_destroy_handle
(
blas_handle_
);
blas_handle_
=
nullptr
;
}
#else
if
(
owned_
&&
blas_handle_
!=
nullptr
)
{
phi
::
dynload
::
cublasDestroy
(
blas_handle_
);
blas_handle_
=
nullptr
;
}
if
(
owned_
&&
blas_tensor_core_handle_
!=
nullptr
)
{
phi
::
dynload
::
cublasDestroy
(
blas_tensor_core_handle_
);
blas_tensor_core_handle_
=
nullptr
;
}
if
(
owned_
&&
blas_tf32_tensor_core_handle_
!=
nullptr
)
{
phi
::
dynload
::
cublasDestroy
(
blas_tf32_tensor_core_handle_
);
blas_tf32_tensor_core_handle_
=
nullptr
;
}
#endif // PADDLE_WITH_HIP
}
blasHandle_t
GetBlasHandle
()
const
{
if
(
!
blas_tf32_tensor_core_handle_
)
{
phi
::
InitBlasHandle
(
&
blas_tf32_tensor_core_handle_
,
stream_
);
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetMathMode
(
blas_tf32_tensor_core_handle_
,
CUBLAS_TF32_TENSOR_OP_MATH
));
}
#endif
#endif
});
PD_CHECK
(
blas_handle_
!=
nullptr
,
"the gpu blas handle is nullptr."
);
return
blas_handle_
;
}
void
SetBlasHandle
(
blasHandle_t
blas
)
{
blas_handle_
=
blas
;
}
void
InitBlasLtHandle
()
{
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
phi
::
dynload
::
cublasLtCreate
(
&
blaslt_handle_
);
#endif
void
SetBlasTensorCoreHandle
(
blasHandle_t
handle
)
{
blas_tensor_core_handle_
=
handle
;
}
void
DestroyInternalBlasLtHandle
()
{
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
phi
::
dynload
::
cublasLtDestroy
(
blaslt_handle_
);
#endif
void
SetBlasTF32Handle
(
blasHandle_t
handle
)
{
blas_tf32_tensor_core_handle_
=
handle
;
}
void
SetBlasLtHandle
(
blasLtHandle_t
blaslt
)
{
blaslt_handle_
=
blaslt
;
}
blasLtHandle_t
GetBlasLtHandle
()
const
{
blasLtHandle_t
GetBlasLtHandle
()
{
std
::
call_once
(
flag_blaslt_
,
[
=
]()
{
if
(
!
blaslt_handle_
)
phi
::
InitBlasLtHandle
(
&
blaslt_handle_
);
});
PD_CHECK
(
blaslt_handle_
!=
nullptr
,
"the gpu blasLt handle is nullptr."
);
return
blaslt_handle_
;
}
void
InitDNNHandle
()
{
if
(
phi
::
dynload
::
HasCUDNN
())
{
#ifdef PADDLE_WITH_HIP
size_t
miopen_major
,
miopen_minor
,
miopen_patch
;
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
miopenGetVersion
(
&
miopen_major
,
&
miopen_minor
,
&
miopen_patch
));
auto
local_miopen_version
=
(
miopen_major
*
1000
+
miopen_minor
*
10
+
miopen_patch
)
/
10
;
auto
compile_miopen_version
=
MIOPEN_VERSION
/
10
;
if
(
local_miopen_version
<
static_cast
<
size_t
>
(
compile_miopen_version
))
{
LOG_FIRST_N
(
WARNING
,
1
)
<<
"WARNING: device: "
<<
place_
.
device
<<
". The installed Paddle is compiled with MIOPEN "
<<
compile_miopen_version
/
100
<<
"."
<<
compile_miopen_version
%
100
<<
", but MIOPEN version in your machine is "
<<
local_miopen_version
/
100
<<
"."
<<
local_miopen_version
%
100
<<
", which may cause serious incompatible bug. "
<<
"Please recompile or reinstall Paddle with compatible MIOPEN "
"version."
;
}
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
miopenCreate
(
&
dnn_handle_
));
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
miopenSetStream
(
dnn_handle_
,
stream_
));
#else
auto
local_cudnn_version
=
phi
::
dynload
::
cudnnGetVersion
()
/
100
;
auto
compile_cudnn_version
=
CUDNN_VERSION
/
100
;
if
(
local_cudnn_version
<
static_cast
<
size_t
>
(
compile_cudnn_version
))
{
LOG_FIRST_N
(
WARNING
,
1
)
<<
"WARNING: device: "
<<
place_
.
device
<<
". The installed Paddle is compiled with CUDNN "
<<
compile_cudnn_version
/
10
<<
"."
<<
compile_cudnn_version
%
10
<<
", but CUDNN version in your machine is "
<<
local_cudnn_version
/
10
<<
"."
<<
local_cudnn_version
%
10
<<
", which may cause serious incompatible bug. "
<<
"Please recompile or reinstall Paddle with compatible CUDNN "
"version."
;
}
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cudnnCreate
(
&
dnn_handle_
));
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cudnnSetStream
(
dnn_handle_
,
stream_
));
#endif
}
else
{
dnn_handle_
=
nullptr
;
}
}
dnnHandle_t
GetDnnHandle
()
{
std
::
call_once
(
flag_dnn_
,
[
=
]()
{
if
(
!
dnn_handle_
)
phi
::
InitDnnHandle
(
&
dnn_handle_
,
stream_
,
place_
);
});
PD_CHECK
(
dnn_handle_
!=
nullptr
,
"the gpu dnn handle is nullptr."
);
return
dnn_handle_
;
}
...
...
@@ -543,54 +392,16 @@ struct GPUContext::Impl {
void
SetDnnHandle
(
dnnHandle_t
handle
)
{
dnn_handle_
=
handle
;
}
void
InitSolverHandle
()
{
#ifndef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cusolverDnCreate
(
&
solver_handle_
));
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cusolverDnSetStream
(
solver_handle_
,
stream_
));
#endif
}
void
DestroyInternalSolverHandle
()
{
#ifndef PADDLE_WITH_HIP
if
(
owned_
&&
solver_handle_
!=
nullptr
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
cusolverDnDestroy
(
solver_handle_
));
solver_handle_
=
nullptr
;
}
#endif
}
solverHandle_t
GetSolverHandle
()
const
{
solverHandle_t
GetSolverHandle
()
{
std
::
call_once
(
flag_slover_
,
[
=
]()
{
if
(
!
solver_handle_
)
phi
::
InitSolverHandle
(
&
solver_handle_
,
stream_
);
});
PD_CHECK
(
solver_handle_
!=
nullptr
,
"the gpu solver handle is nullptr."
);
return
solver_handle_
;
}
void
SetSolverHandle
(
solverHandle_t
handle
)
{
solver_handle_
=
handle
;
}
void
InitSparseHandle
()
{
// ROCM is not yet supported
#if defined(PADDLE_WITH_CUDA)
// The generic APIs is supported from CUDA10.1
#if CUDA_VERSION >= 10010
PADDLE_RETRY_CUDA_SUCCESS
(
dynload
::
cusparseCreate
(
&
sparse_handle_
));
PADDLE_RETRY_CUDA_SUCCESS
(
dynload
::
cusparseSetStream
(
sparse_handle_
,
stream_
));
#endif
#endif
}
void
DestroyInternalSparseHandle
()
{
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10010
if
(
owned_
&&
sparse_handle_
!=
nullptr
)
{
PADDLE_RETRY_CUDA_SUCCESS
(
dynload
::
cusparseDestroy
(
sparse_handle_
));
sparse_handle_
=
nullptr
;
}
#endif
#endif
}
sparseHandle_t
GetSparseHandle
()
const
{
PD_CHECK
(
sparse_handle_
!=
nullptr
,
"the gpu sparse handle is nullptr."
);
return
sparse_handle_
;
...
...
@@ -646,8 +457,28 @@ struct GPUContext::Impl {
#endif
}
inline
void
CublasCall
(
const
std
::
function
<
void
(
blasHandle_t
)
>&
callback
)
const
{
inline
void
CublasCall
(
const
std
::
function
<
void
(
blasHandle_t
)
>&
callback
)
{
std
::
call_once
(
flag_cublas_
,
[
=
]()
{
if
(
!
blas_handle_
)
{
phi
::
InitBlasHandle
(
&
blas_handle_
,
stream_
);
}
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if
(
!
blas_tensor_core_handle_
)
{
phi
::
InitBlasHandle
(
&
blas_tensor_core_handle_
,
stream_
);
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetMathMode
(
blas_tensor_core_handle_
,
CUBLAS_TENSOR_OP_MATH
));
}
#endif
#if CUDA_VERSION >= 11000
if
(
!
blas_tf32_tensor_core_handle_
)
{
phi
::
InitBlasHandle
(
&
blas_tf32_tensor_core_handle_
,
stream_
);
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetMathMode
(
blas_tf32_tensor_core_handle_
,
CUBLAS_TF32_TENSOR_OP_MATH
));
}
#endif
#endif
});
if
(
blas_tf32_tensor_core_handle_
!=
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
blas_tf32_mtx_
);
callback
(
blas_tf32_tensor_core_handle_
);
...
...
@@ -658,7 +489,26 @@ struct GPUContext::Impl {
}
inline
void
TensorCoreCublasCallIfAvailable
(
const
std
::
function
<
void
(
blasHandle_t
)
>&
callback
)
const
{
const
std
::
function
<
void
(
blasHandle_t
)
>&
callback
)
{
std
::
call_once
(
flag_tensorcore_cublas_
,
[
=
]()
{
if
(
!
blas_handle_
)
phi
::
InitBlasHandle
(
&
blas_handle_
,
stream_
);
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 9000
if
(
!
blas_tensor_core_handle_
)
{
phi
::
InitBlasHandle
(
&
blas_tensor_core_handle_
,
stream_
);
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetMathMode
(
blas_tensor_core_handle_
,
CUBLAS_TENSOR_OP_MATH
));
}
#endif
#if CUDA_VERSION >= 11000
if
(
!
blas_tf32_tensor_core_handle_
)
{
phi
::
InitBlasHandle
(
&
blas_tf32_tensor_core_handle_
,
stream_
);
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetMathMode
(
blas_tf32_tensor_core_handle_
,
CUBLAS_TF32_TENSOR_OP_MATH
));
}
#endif
#endif
});
if
(
blas_tensor_core_handle_
!=
nullptr
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
blas_tensor_core_mtx_
);
callback
(
blas_tensor_core_handle_
);
...
...
@@ -689,8 +539,7 @@ struct GPUContext::Impl {
void
AddStreamCallback
(
const
std
::
function
<
void
()
>&
callback
)
const
{
// NOTE(zhiqiu): better use threadpool here, otherwise "std::async" may
// launch too
// many threads and result in thread oversubscription.
// launch too many threads and result in thread oversubscription.
auto
*
callback_func
=
new
std
::
function
<
void
()
>
(
std
::
move
(
callback
));
auto
*
func
=
new
std
::
function
<
void
()
>
([
this
,
callback_func
]
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
stream_call_back_mtx_
);
...
...
@@ -749,6 +598,13 @@ struct GPUContext::Impl {
sparseHandle_t
sparse_handle_
{
nullptr
};
DnnWorkspaceHandle
*
workspace_
{
nullptr
};
std
::
once_flag
flag_blas_
;
std
::
once_flag
flag_blaslt_
;
std
::
once_flag
flag_dnn_
;
std
::
once_flag
flag_slover_
;
std
::
once_flag
flag_cublas_
;
std
::
once_flag
flag_tensorcore_cublas_
;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
// NCCL communicator (single process version) for NCCL collective operations.
// NCCL collective operations provides fast collectives over multiple GPUs
...
...
@@ -878,7 +734,10 @@ void GPUContext::Init() {
impl_
->
Init
();
}
void
GPUContext
::
SetStream
(
gpuStream_t
stream
)
{
impl_
->
SetStream
(
stream
);
}
void
GPUContext
::
SetStream
(
gpuStream_t
stream
)
{
impl_
->
allocator_
=
const_cast
<
Allocator
*>
(
&
this
->
GetAllocator
());
impl_
->
SetStream
(
stream
);
}
void
GPUContext
::
SetEigenDevice
(
Eigen
::
GpuDevice
*
device
)
{
impl_
->
SetEigenDevice
(
device
);
...
...
@@ -888,6 +747,14 @@ void GPUContext::SetBlasHandle(blasHandle_t blas) {
impl_
->
SetBlasHandle
(
blas
);
}
void
GPUContext
::
SetBlasTensorCoreHandle
(
blasHandle_t
handle
)
{
impl_
->
SetBlasTensorCoreHandle
(
handle
);
}
void
GPUContext
::
SetBlasTF32Handle
(
blasHandle_t
handle
)
{
impl_
->
SetBlasTF32Handle
(
handle
);
}
void
GPUContext
::
SetBlasLtHandle
(
blasLtHandle_t
blaslt
)
{
impl_
->
SetBlasLtHandle
(
blaslt
);
}
...
...
paddle/phi/backends/gpu/gpu_context.h
浏览文件 @
90ae3533
...
...
@@ -199,6 +199,10 @@ class PADDLE_API GPUContext : public DeviceContext {
void
SetBlasHandle
(
blasHandle_t
);
void
SetBlasTensorCoreHandle
(
blasHandle_t
);
void
SetBlasTF32Handle
(
blasHandle_t
);
void
SetBlasLtHandle
(
blasLtHandle_t
);
void
SetDnnHandle
(
dnnHandle_t
);
...
...
paddle/phi/backends/gpu/gpu_resources.cc
0 → 100644
浏览文件 @
90ae3533
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/gpu/gpu_resources.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/phi/backends/dynload/cublas.h"
#include "paddle/phi/backends/dynload/cudnn.h"
#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/dynload/cusparse.h"
#if !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#include "paddle/phi/backends/dynload/nccl.h"
#endif // !defined(__APPLE__) && defined(PADDLE_WITH_NCCL)
#endif // PADDLE_WITH_CUDA
#include "unsupported/Eigen/CXX11/Tensor"
// TODO(phi): remove fluid header.
#include "paddle/fluid/platform/enforce.h"
namespace
phi
{
void
InitGpuProperties
(
Place
place
,
int
*
compute_capability
,
int
*
runtime_version
,
int
*
driver_version
,
int
*
multi_process
,
int
*
max_threads_per_mp
,
int
*
max_threads_per_block
,
std
::
array
<
int
,
3
>*
max_grid_dim_size
)
{
backends
::
gpu
::
GPUDeviceGuard
guard
(
place
.
GetDeviceId
());
*
compute_capability
=
backends
::
gpu
::
GetGPUComputeCapability
(
place
.
GetDeviceId
());
*
multi_process
=
backends
::
gpu
::
GetGPUMultiProcessors
(
place
.
GetDeviceId
());
*
max_threads_per_mp
=
backends
::
gpu
::
GetGPUMaxThreadsPerMultiProcessor
(
place
.
GetDeviceId
());
*
max_grid_dim_size
=
backends
::
gpu
::
GetGpuMaxGridDimSize
(
place
.
GetDeviceId
());
*
max_threads_per_block
=
backends
::
gpu
::
GetGPUMaxThreadsPerBlock
(
place
.
GetDeviceId
());
*
driver_version
=
backends
::
gpu
::
GetGPUDriverVersion
(
place
.
GetDeviceId
());
*
runtime_version
=
backends
::
gpu
::
GetGPURuntimeVersion
(
place
.
GetDeviceId
());
// TODO(wilber): glog may be replaced in the future?
LOG_FIRST_N
(
WARNING
,
1
)
<<
"Please NOTE: device: "
<<
static_cast
<
int
>
(
place
.
device
)
<<
", GPU Compute Capability: "
<<
*
compute_capability
/
10
<<
"."
<<
*
compute_capability
%
10
<<
", Driver API Version: "
<<
*
driver_version
/
1000
<<
"."
<<
(
*
driver_version
%
100
)
/
10
<<
", Runtime API Version: "
<<
*
runtime_version
/
1000
<<
"."
<<
(
*
runtime_version
%
100
)
/
10
;
#ifdef PADDLE_WITH_HIP
size_t
miopen_major
,
miopen_minor
,
miopen_patch
;
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
miopenGetVersion
(
&
miopen_major
,
&
miopen_minor
,
&
miopen_patch
));
auto
cudnn_dso_ver
=
(
miopen_major
*
1000
+
miopen_minor
*
10
+
miopen_patch
)
/
10
;
auto
compile_miopen_version
=
MIOPEN_VERSION
/
10
;
if
(
cudnn_dso_ver
<
static_cast
<
size_t
>
(
compile_miopen_version
))
{
LOG_FIRST_N
(
WARNING
,
1
)
<<
"WARNING: device: "
<<
static_cast
<
int
>
(
place
.
device
)
<<
". The installed Paddle is compiled with MIOPEN "
<<
compile_miopen_version
/
100
<<
"."
<<
compile_miopen_version
%
100
<<
", but MIOPEN version in your machine is "
<<
cudnn_dso_ver
/
100
<<
"."
<<
cudnn_dso_ver
%
100
<<
", which may cause serious incompatible bug. "
<<
"Please recompile or reinstall Paddle with compatible MIOPEN "
"version."
;
}
#else
size_t
cudnn_dso_ver
=
dynload
::
cudnnGetVersion
();
LOG_FIRST_N
(
WARNING
,
1
)
<<
"device: "
<<
static_cast
<
int
>
(
place
.
device
)
<<
", cuDNN Version: "
<<
cudnn_dso_ver
/
1000
<<
"."
<<
(
cudnn_dso_ver
%
1000
)
/
100
<<
"."
;
// Check CUDA/CUDNN version compatiblity
auto
local_cuda_version
=
(
*
driver_version
/
1000
)
*
10
+
(
*
driver_version
%
100
)
/
10
;
auto
compile_cuda_version
=
(
CUDA_VERSION
/
1000
)
*
10
+
(
CUDA_VERSION
%
100
)
/
10
;
if
(
local_cuda_version
<
compile_cuda_version
)
{
LOG_FIRST_N
(
WARNING
,
1
)
<<
"WARNING: device: "
<<
static_cast
<
int
>
(
place
.
device
)
<<
". The installed Paddle is compiled with CUDA "
<<
compile_cuda_version
/
10
<<
"."
<<
compile_cuda_version
%
10
<<
", but CUDA runtime version in your machine is "
<<
local_cuda_version
/
10
<<
"."
<<
local_cuda_version
%
10
<<
", which may cause serious incompatible bug. "
<<
"Please recompile or reinstall Paddle with compatible CUDA "
"version."
;
}
#endif
}
void
InitStream
(
gpuStream_t
*
stream
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipStreamCreateWithPriority
(
stream
,
hipStreamDefault
,
0
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamCreateWithPriority
(
stream
,
cudaStreamDefault
,
0
));
#endif
}
void
DestoryStream
(
gpuStream_t
stream
)
{
if
(
stream
!=
nullptr
)
{
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipStreamDestroy
(
stream
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaStreamDestroy
(
stream
));
#endif
}
stream
=
nullptr
;
}
void
InitBlasHandle
(
blasHandle_t
*
blas_handle
,
gpuStream_t
stream
)
{
#ifdef PADDLE_WITH_HIP
phi
::
dynload
::
rocblas_create_handle
(
blas_handle
);
phi
::
dynload
::
rocblas_set_stream
(
*
blas_handle
,
stream
);
#else // PADDLE_WITH_CUDA
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasCreate
(
blas_handle
));
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cublasSetStream
(
*
blas_handle
,
stream
));
#endif // PADDLE_WITH_HIP
}
void
DestroyBlasHandle
(
blasHandle_t
handle
)
{
#ifdef PADDLE_WITH_HIP
if
(
handle
!=
nullptr
)
{
phi
::
dynload
::
rocblas_destroy_handle
(
handle
);
handle
=
nullptr
;
}
#else
if
(
handle
!=
nullptr
)
{
phi
::
dynload
::
cublasDestroy
(
handle
);
handle
=
nullptr
;
}
#endif // PADDLE_WITH_HIP
}
void
InitBlasLtHandle
(
blasLtHandle_t
*
blaslt_handle
)
{
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
phi
::
dynload
::
cublasLtCreate
(
blaslt_handle
);
#endif
}
void
DestroyBlasLtHandle
(
blasLtHandle_t
handle
)
{
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
if
(
handle
!=
nullptr
)
{
phi
::
dynload
::
cublasLtDestroy
(
handle
);
handle
=
nullptr
;
}
#endif
}
void
InitDnnHandle
(
dnnHandle_t
*
handle
,
gpuStream_t
stream
,
Place
place
)
{
if
(
phi
::
dynload
::
HasCUDNN
())
{
#ifdef PADDLE_WITH_HIP
size_t
miopen_major
,
miopen_minor
,
miopen_patch
;
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
miopenGetVersion
(
&
miopen_major
,
&
miopen_minor
,
&
miopen_patch
));
auto
local_miopen_version
=
(
miopen_major
*
1000
+
miopen_minor
*
10
+
miopen_patch
)
/
10
;
auto
compile_miopen_version
=
MIOPEN_VERSION
/
10
;
if
(
local_miopen_version
<
static_cast
<
size_t
>
(
compile_miopen_version
))
{
LOG_FIRST_N
(
WARNING
,
1
)
<<
"WARNING: device: "
<<
place
.
device
<<
". The installed Paddle is compiled with MIOPEN "
<<
compile_miopen_version
/
100
<<
"."
<<
compile_miopen_version
%
100
<<
", but MIOPEN version in your machine is "
<<
local_miopen_version
/
100
<<
"."
<<
local_miopen_version
%
100
<<
", which may cause serious incompatible bug. "
<<
"Please recompile or reinstall Paddle with compatible MIOPEN "
"version."
;
}
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
miopenCreate
(
handle
));
PADDLE_ENFORCE_GPU_SUCCESS
(
dynload
::
miopenSetStream
(
*
handle
,
stream
));
#else
auto
local_cudnn_version
=
phi
::
dynload
::
cudnnGetVersion
()
/
100
;
auto
compile_cudnn_version
=
CUDNN_VERSION
/
100
;
if
(
local_cudnn_version
<
static_cast
<
size_t
>
(
compile_cudnn_version
))
{
LOG_FIRST_N
(
WARNING
,
1
)
<<
"WARNING: device: "
<<
place
.
device
<<
". The installed Paddle is compiled with CUDNN "
<<
compile_cudnn_version
/
10
<<
"."
<<
compile_cudnn_version
%
10
<<
", but CUDNN version in your machine is "
<<
local_cudnn_version
/
10
<<
"."
<<
local_cudnn_version
%
10
<<
", which may cause serious incompatible bug. "
<<
"Please recompile or reinstall Paddle with compatible CUDNN "
"version."
;
}
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cudnnCreate
(
handle
));
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cudnnSetStream
(
*
handle
,
stream
));
#endif
}
else
{
*
handle
=
nullptr
;
}
}
void
DestroyDnnHandle
(
dnnHandle_t
handle
)
{
#ifdef PADDLE_WITH_HIP
if
(
handle
!=
nullptr
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
miopenDestroy
(
handle
));
handle
=
nullptr
;
}
#else
if
(
handle
!=
nullptr
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
cudnnDestroy
(
handle
));
handle
=
nullptr
;
}
#endif // PADDLE_WITH_HIP
}
void
InitSolverHandle
(
solverHandle_t
*
handle
,
gpuStream_t
stream
)
{
#ifndef PADDLE_WITH_HIP
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cusolverDnCreate
(
handle
));
PADDLE_RETRY_CUDA_SUCCESS
(
phi
::
dynload
::
cusolverDnSetStream
(
*
handle
,
stream
));
#endif
}
void
DestroySolverHandle
(
solverHandle_t
solver_handle
)
{
#ifndef PADDLE_WITH_HIP
if
(
solver_handle
!=
nullptr
)
{
PADDLE_ENFORCE_GPU_SUCCESS
(
phi
::
dynload
::
cusolverDnDestroy
(
solver_handle
));
solver_handle
=
nullptr
;
}
#endif
}
void
InitSparseHandle
(
sparseHandle_t
*
handle
,
gpuStream_t
stream
)
{
// ROCM is not yet supported
#if defined(PADDLE_WITH_CUDA)
// The generic APIs is supported from CUDA10.1
#if CUDA_VERSION >= 10010
PADDLE_RETRY_CUDA_SUCCESS
(
dynload
::
cusparseCreate
(
handle
));
PADDLE_RETRY_CUDA_SUCCESS
(
dynload
::
cusparseSetStream
(
*
handle
,
stream
));
#endif
#endif
}
void
DestroySparseHandle
(
sparseHandle_t
handle
)
{
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION >= 10010
if
(
handle
!=
nullptr
)
{
PADDLE_RETRY_CUDA_SUCCESS
(
dynload
::
cusparseDestroy
(
handle
));
handle
=
nullptr
;
}
#endif
#endif
}
}
// namespace phi
paddle/phi/backends/gpu/gpu_resources.h
0 → 100644
浏览文件 @
90ae3533
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/common/place.h"
namespace
phi
{
void
InitGpuProperties
(
Place
place
,
int
*
compute_capability
,
int
*
runtime_version
,
int
*
driver_version
,
int
*
multi_process
,
int
*
max_threads_per_mp
,
int
*
max_threads_per_block
,
std
::
array
<
int
,
3
>*
max_grid_dim_size
);
void
InitStream
(
gpuStream_t
*
stream
);
void
DestoryStream
(
gpuStream_t
stream
);
void
InitBlasHandle
(
blasHandle_t
*
blas_handle
,
gpuStream_t
stream
);
void
DestroyBlasHandle
(
blasHandle_t
handle
);
void
InitBlasLtHandle
(
blasLtHandle_t
*
blaslt_handle
);
void
DestroyBlasLtHandle
(
blasLtHandle_t
handle
);
void
InitDnnHandle
(
dnnHandle_t
*
handle
,
gpuStream_t
stream
,
Place
place
);
void
DestroyDnnHandle
(
dnnHandle_t
handle
);
void
InitSolverHandle
(
solverHandle_t
*
handle
,
gpuStream_t
stream
);
void
DestroySolverHandle
(
solverHandle_t
solver_handle
);
void
InitSparseHandle
(
sparseHandle_t
*
handle
,
gpuStream_t
stream
);
void
DestroySparseHandle
(
sparseHandle_t
handle
);
// void InitDnnWorkspace();
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录