Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6a4790c2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6a4790c2
编写于
1月 08, 2019
作者:
X
Xin Pan
提交者:
GitHub
1月 08, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #15203 from sneaxiy/revert-15139-remove_op_handle_lock
Revert "Remove op handle lock"
上级
6ca9a481
dacfaaa9
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
130 addition
and
159 deletion
+130
-159
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+70
-64
paddle/fluid/platform/cuda_helper.h
paddle/fluid/platform/cuda_helper.h
+0
-58
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+5
-13
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+52
-24
paddle/fluid/platform/device_context_test.cu
paddle/fluid/platform/device_context_test.cu
+3
-0
未找到文件。
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
6a4790c2
...
@@ -62,19 +62,27 @@ struct CUBlas<float> {
...
@@ -62,19 +62,27 @@ struct CUBlas<float> {
cudaDataType_t
Atype
,
int
lda
,
const
void
*
B
,
cudaDataType_t
Atype
,
int
lda
,
const
void
*
B
,
cudaDataType_t
Btype
,
int
ldb
,
const
float
*
beta
,
void
*
C
,
cudaDataType_t
Btype
,
int
ldb
,
const
float
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
)
{
cudaDataType_t
Ctype
,
int
ldc
)
{
// Because the gcc 4.8 doesn't expand template parameter pack that
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// appears in a lambda-expression, I can not use template parameter pack
// here.
// here.
auto
cublas_call
=
[
&
]()
{
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
VLOG
(
5
)
<<
"use_tensor_op_math: "
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
dev_ctx
->
tensor_core_available
()
?
"True"
:
"False"
);
<<
(
platform
::
TensorCoreAvailable
()
?
"True"
:
"False"
);
dev_ctx
->
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemmEx
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
beta
,
C
,
Ctype
,
ldc
));
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
));
});
#else
#else
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
#else
cublas_call
();
#endif
#endif
}
}
};
};
...
@@ -162,10 +170,11 @@ struct CUBlas<platform::float16> {
...
@@ -162,10 +170,11 @@ struct CUBlas<platform::float16> {
cudaDataType_t
Btype
,
int
ldb
,
const
void
*
beta
,
void
*
C
,
cudaDataType_t
Btype
,
int
ldb
,
const
void
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
,
cudaDataType_t
Ctype
,
int
ldc
,
cudaDataType_t
computeType
)
{
cudaDataType_t
computeType
)
{
auto
cublas_call
=
[
&
]()
{
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
#if CUDA_VERSION >= 9000
#if CUDA_VERSION >= 9000
bool
use_tensor_op_math
=
dev_ctx
->
tensor_core_a
vailable
();
bool
use_tensor_op_math
=
platform
::
TensorCoreA
vailable
();
if
(
use_tensor_op_math
)
{
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
}
...
@@ -173,13 +182,20 @@ struct CUBlas<platform::float16> {
...
@@ -173,13 +182,20 @@ struct CUBlas<platform::float16> {
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
#endif // CUDA_VERSION >= 9000
#endif // CUDA_VERSION >= 9000
dev_ctx
->
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
});
#else
#else
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
#else
cublas_call
();
#endif
#endif
}
}
};
};
...
@@ -207,10 +223,9 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
...
@@ -207,10 +223,9 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CUDA_R_32F
,
N
);
CUDA_R_32F
,
N
);
}
else
{
}
else
{
#endif // CUDA_VERSION >= 8000
#endif // CUDA_VERSION >= 8000
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
lda
,
&
beta
,
C
,
N
);
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
);
});
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
}
}
...
@@ -251,12 +266,9 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
...
@@ -251,12 +266,9 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
);
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
);
#else
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas
<
platform
::
float16
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
CUBlas
<
platform
::
float16
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_beta
,
h_C
,
N
);
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
);
});
#endif // CUDA_VERSION >= 8000
#endif // CUDA_VERSION >= 8000
}
}
...
@@ -280,10 +292,8 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
...
@@ -280,10 +292,8 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
}
else
{
}
else
{
#endif // CUDA_VERSION >= 8000
#endif // CUDA_VERSION >= 8000
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
CUBlas
<
T
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
lda
,
&
beta
,
C
,
ldc
);
});
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
}
}
...
@@ -301,19 +311,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
...
@@ -301,19 +311,16 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
cublasOperation_t
cuTransA
=
transA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransA
=
transA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransB
=
transB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransB
=
transB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
platform
::
float16
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
CUBlas
<
platform
::
float16
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
ldc
);
});
}
}
template
<
>
template
<
>
template
<
typename
T
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
AXPY
(
int
n
,
T
alpha
,
const
T
*
x
,
void
Blas
<
platform
::
CUDADeviceContext
>::
AXPY
(
int
n
,
T
alpha
,
const
T
*
x
,
T
*
y
)
const
{
T
*
y
)
const
{
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
AXPY
(
context_
.
cublas_handle
(),
n
,
&
alpha
,
x
,
1
,
y
,
1
);
CUBlas
<
T
>::
AXPY
(
handle
,
n
,
&
alpha
,
x
,
1
,
y
,
1
);
});
}
}
template
<
>
template
<
>
...
@@ -323,9 +330,8 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
...
@@ -323,9 +330,8 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T
beta
,
T
*
C
)
const
{
T
beta
,
T
*
C
)
const
{
cublasOperation_t
cuTransA
=
!
trans_a
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransA
=
!
trans_a
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMV
(
context_
.
cublas_handle
(),
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
CUBlas
<
T
>::
GEMV
(
handle
,
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
&
beta
,
C
,
1
);
&
beta
,
C
,
1
);
});
}
}
template
<
>
template
<
>
...
@@ -347,28 +353,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
...
@@ -347,28 +353,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010
#if CUDA_VERSION >= 9010
if
(
FLAGS_enable_cublas_tensor_op_math
&&
std
::
is_same
<
T
,
float
>::
value
)
{
if
(
FLAGS_enable_cublas_tensor_op_math
&&
std
::
is_same
<
T
,
float
>::
value
)
{
auto
cublas_call
=
[
&
]()
{
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
bool
use_tensor_op_math
=
context_
.
tensor_core_a
vailable
();
bool
use_tensor_op_math
=
platform
::
TensorCoreA
vailable
();
if
(
use_tensor_op_math
)
{
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
context_
.
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmStridedBatchedEx
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmStridedBatchedEx
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
CUDA_R_32F
,
ldb
,
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
ldb
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
CUDA_R_32F
,
ldc
,
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
});
};
auto
&
dev_ctx
=
const_cast
<
platform
::
CUDADeviceContext
&>
(
context_
);
dev_ctx
.
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
}
else
{
}
else
{
#endif // CUDA_VERSION >= 9010
#endif // CUDA_VERSION >= 9010
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMM_STRIDED_BATCH
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
CUBlas
<
T
>::
GEMM_STRIDED_BATCH
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
strideB
,
A
,
lda
,
B
,
ldb
,
strideB
,
A
,
lda
,
strideA
,
&
beta
,
C
,
strideA
,
&
beta
,
C
,
ldc
,
strideC
,
batchCount
);
ldc
,
strideC
,
batchCount
);
});
#if CUDA_VERSION >= 9010
#if CUDA_VERSION >= 9010
}
}
...
...
paddle/fluid/platform/cuda_helper.h
已删除
100644 → 0
浏览文件 @
6ca9a481
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/macros.h"
#if CUDA_VERSION < 9000
enum
cublasMath_t
{
CUBLAS_DEFAULT_MATH
=
0
};
#endif
namespace
paddle
{
namespace
platform
{
class
CublasHandleHolder
{
public:
CublasHandleHolder
(
cudaStream_t
stream
,
cublasMath_t
math_type
)
{
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
handle_
,
stream
));
#if CUDA_VERSION >= 9000
if
(
math_type
==
CUBLAS_TENSOR_OP_MATH
)
{
PADDLE_ENFORCE
(
dynload
::
cublasSetMathMode
(
handle_
,
CUBLAS_TENSOR_OP_MATH
));
}
#endif
}
~
CublasHandleHolder
()
{
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
handle_
));
}
template
<
typename
Callback
>
inline
void
Call
(
Callback
&&
callback
)
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
callback
(
handle_
);
}
private:
DISABLE_COPY_AND_ASSIGN
(
CublasHandleHolder
);
cublasHandle_t
handle_
;
mutable
std
::
mutex
mtx_
;
};
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device_context.cc
浏览文件 @
6a4790c2
...
@@ -245,15 +245,8 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
...
@@ -245,15 +245,8 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
cublas_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_DEFAULT_MATH
));
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
if
(
TensorCoreAvailable
())
{
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_TENSOR_OP_MATH
));
#endif
}
if
(
dynload
::
HasCUDNN
())
{
if
(
dynload
::
HasCUDNN
())
{
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
}
}
...
@@ -313,8 +306,7 @@ CUDADeviceContext::~CUDADeviceContext() {
...
@@ -313,8 +306,7 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId
(
place_
.
device
);
SetDeviceId
(
place_
.
device
);
Wait
();
Wait
();
WaitStreamCallback
();
WaitStreamCallback
();
cublas_handle_
.
reset
();
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
cublas_tensor_core_handle_
.
reset
();
eigen_stream_
.
reset
();
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
@@ -343,8 +335,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
...
@@ -343,8 +335,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return
eigen_device_
.
get
();
return
eigen_device_
.
get
();
}
}
bool
CUDADeviceContext
::
tensor_core_availab
le
()
const
{
cublasHandle_t
CUDADeviceContext
::
cublas_hand
le
()
const
{
return
cublas_
tensor_core_handle_
!=
nullptr
;
return
cublas_
handle_
;
}
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
6a4790c2
...
@@ -20,7 +20,6 @@ limitations under the License. */
...
@@ -20,7 +20,6 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
...
@@ -210,6 +209,39 @@ class CudnnWorkspaceHandle {
...
@@ -210,6 +209,39 @@ class CudnnWorkspaceHandle {
std
::
unique_ptr
<
std
::
lock_guard
<
std
::
mutex
>>
guard_
;
std
::
unique_ptr
<
std
::
lock_guard
<
std
::
mutex
>>
guard_
;
};
};
#if CUDA_VERSION >= 9000
class
ScopedCublasMathMode
{
public:
ScopedCublasMathMode
(
cublasHandle_t
handle
,
cublasMath_t
new_math_mode
)
:
handle_
(
handle
)
{
need_reset
=
false
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGetMathMode
(
handle_
,
&
old_math_mode_
),
"Failed to get old cublas math mode"
);
if
(
old_math_mode_
!=
new_math_mode
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
new_math_mode
),
"Failed to set old cublas math mode"
);
need_reset
=
true
;
}
}
~
ScopedCublasMathMode
()
{
if
(
need_reset
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
old_math_mode_
),
"Failed to set old cublas math mode"
);
}
}
private:
cublasHandle_t
handle_
;
cublasMath_t
old_math_mode_
;
bool
need_reset
;
};
#endif
class
CUDADeviceContext
:
public
DeviceContext
{
class
CUDADeviceContext
:
public
DeviceContext
{
public:
public:
explicit
CUDADeviceContext
(
CUDAPlace
place
);
explicit
CUDADeviceContext
(
CUDAPlace
place
);
...
@@ -230,25 +262,8 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -230,25 +262,8 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */
/*! \brief Return eigen device in the device context. */
Eigen
::
GpuDevice
*
eigen_device
()
const
;
Eigen
::
GpuDevice
*
eigen_device
()
const
;
/*! \brief Call cublas function safely. */
/*! \brief Return cublas handle in the device context. */
template
<
typename
Callback
>
cublasHandle_t
cublas_handle
()
const
;
inline
void
CublasCall
(
Callback
&&
callback
)
const
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
/*! \brief Check whether tensor core is supported */
bool
tensor_core_available
()
const
;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template
<
typename
Callback
>
inline
void
TensorCoreCublasCallIfAvailable
(
Callback
&&
callback
)
const
{
if
(
cublas_tensor_core_handle_
)
{
cublas_tensor_core_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
else
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
}
/*! \brief Return cudnn handle in the device context. */
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
const
;
cudnnHandle_t
cudnn_handle
()
const
;
...
@@ -267,6 +282,7 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -267,6 +282,7 @@ class CUDADeviceContext : public DeviceContext {
template
<
typename
Callback
>
template
<
typename
Callback
>
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
callback
();
callback
();
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
}
}
...
@@ -278,6 +294,18 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -278,6 +294,18 @@ class CUDADeviceContext : public DeviceContext {
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template
<
typename
Callback
>
void
CublasCall
(
Callback
callback
,
cublasMath_t
new_math
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
cublas_mtx_
);
ScopedCublasMathMode
scoped_cublas_math
(
cublas_handle_
,
new_math
);
callback
();
}
#endif
private:
private:
CUDAPlace
place_
;
CUDAPlace
place_
;
...
@@ -285,9 +313,7 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -285,9 +313,7 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
CudnnHolder
>
cudnn_holder_
;
std
::
unique_ptr
<
CudnnHolder
>
cudnn_holder_
;
cudaStream_t
stream_
;
cudaStream_t
stream_
;
cublasHandle_t
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_tensor_core_handle_
;
int
compute_capability_
;
int
compute_capability_
;
int
runtime_version_
;
int
runtime_version_
;
...
@@ -295,10 +321,12 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -295,10 +321,12 @@ class CUDADeviceContext : public DeviceContext {
int
multi_process_
;
int
multi_process_
;
int
max_threads_per_mp_
;
int
max_threads_per_mp_
;
mutable
std
::
mutex
mtx_
;
// StreamCallbackManager is thread-safe
// StreamCallbackManager is thread-safe
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
DISABLE_COPY_AND_ASSIGN
(
CUDADeviceContext
)
;
mutable
std
::
mutex
cublas_mtx_
;
};
};
template
<
>
template
<
>
...
...
paddle/fluid/platform/device_context_test.cu
浏览文件 @
6a4790c2
...
@@ -43,6 +43,9 @@ TEST(Device, CUDADeviceContext) {
...
@@ -43,6 +43,9 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE
(
nullptr
,
gpu_device
);
ASSERT_NE
(
nullptr
,
gpu_device
);
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
ASSERT_NE
(
nullptr
,
cudnn_handle
);
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
ASSERT_NE
(
nullptr
,
cublas_handle
);
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
delete
device_context
;
delete
device_context
;
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录