Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ccb322d6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ccb322d6
编写于
1月 07, 2019
作者:
Z
Zeng Jinle
提交者:
sneaxiy
1月 07, 2019
浏览文件
操作
浏览文件
下载
差异文件
merge develop
上级
1df2399e
d0a8a1e9
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
159 addition
and
130 deletion
+159
-130
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+64
-70
paddle/fluid/platform/cuda_helper.h
paddle/fluid/platform/cuda_helper.h
+58
-0
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+13
-5
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+24
-52
paddle/fluid/platform/device_context_test.cu
paddle/fluid/platform/device_context_test.cu
+0
-3
未找到文件。
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
ccb322d6
...
@@ -62,27 +62,19 @@ struct CUBlas<float> {
...
@@ -62,27 +62,19 @@ struct CUBlas<float> {
cudaDataType_t
Atype
,
int
lda
,
const
void
*
B
,
cudaDataType_t
Atype
,
int
lda
,
const
void
*
B
,
cudaDataType_t
Btype
,
int
ldb
,
const
float
*
beta
,
void
*
C
,
cudaDataType_t
Btype
,
int
ldb
,
const
float
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
)
{
cudaDataType_t
Ctype
,
int
ldc
)
{
// Because the gcc 4.8 doesn't expand template parameter pack that
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// appears in a lambda-expression, I can not use template parameter pack
// here.
// here.
auto
cublas_call
=
[
&
]()
{
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
VLOG
(
5
)
<<
"use_tensor_op_math: "
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
platform
::
TensorCoreAvailable
()
?
"True"
:
"False"
);
<<
(
dev_ctx
->
tensor_core_available
()
?
"True"
:
"False"
);
dev_ctx
->
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemmEx
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemmEx
(
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
));
beta
,
C
,
Ctype
,
ldc
));
});
#else
#else
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
#else
cublas_call
();
#endif
#endif
}
}
};
};
...
@@ -170,11 +162,10 @@ struct CUBlas<platform::float16> {
...
@@ -170,11 +162,10 @@ struct CUBlas<platform::float16> {
cudaDataType_t
Btype
,
int
ldb
,
const
void
*
beta
,
void
*
C
,
cudaDataType_t
Btype
,
int
ldb
,
const
void
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
,
cudaDataType_t
Ctype
,
int
ldc
,
cudaDataType_t
computeType
)
{
cudaDataType_t
computeType
)
{
auto
cublas_call
=
[
&
]()
{
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
#if CUDA_VERSION >= 9000
#if CUDA_VERSION >= 9000
bool
use_tensor_op_math
=
platform
::
TensorCoreA
vailable
();
bool
use_tensor_op_math
=
dev_ctx
->
tensor_core_a
vailable
();
if
(
use_tensor_op_math
)
{
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
}
...
@@ -182,20 +173,13 @@ struct CUBlas<platform::float16> {
...
@@ -182,20 +173,13 @@ struct CUBlas<platform::float16> {
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
#endif // CUDA_VERSION >= 9000
#endif // CUDA_VERSION >= 9000
dev_ctx
->
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
});
#else
#else
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
#else
cublas_call
();
#endif
#endif
}
}
};
};
...
@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
...
@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CUDA_R_32F
,
N
);
CUDA_R_32F
,
N
);
}
else
{
}
else
{
#endif // CUDA_VERSION >= 8000
#endif // CUDA_VERSION >= 8000
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
CUBlas
<
T
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
);
lda
,
&
beta
,
C
,
N
);
});
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
}
}
...
@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
...
@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
);
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
);
#else
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas
<
platform
::
float16
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
&
h_beta
,
h_C
,
N
);
CUBlas
<
platform
::
float16
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
);
});
#endif // CUDA_VERSION >= 8000
#endif // CUDA_VERSION >= 8000
}
}
...
@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
...
@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
}
else
{
}
else
{
#endif // CUDA_VERSION >= 8000
#endif // CUDA_VERSION >= 8000
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
CUBlas
<
T
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
});
#if CUDA_VERSION >= 8000
#if CUDA_VERSION >= 8000
}
}
...
@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
...
@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
cublasOperation_t
cuTransA
=
transA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransA
=
transA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransB
=
transB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransB
=
transB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
CUBlas
<
platform
::
float16
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
CUBlas
<
platform
::
float16
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
ldc
);
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
});
}
}
template
<
>
template
<
>
template
<
typename
T
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
AXPY
(
int
n
,
T
alpha
,
const
T
*
x
,
void
Blas
<
platform
::
CUDADeviceContext
>::
AXPY
(
int
n
,
T
alpha
,
const
T
*
x
,
T
*
y
)
const
{
T
*
y
)
const
{
CUBlas
<
T
>::
AXPY
(
context_
.
cublas_handle
(),
n
,
&
alpha
,
x
,
1
,
y
,
1
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
AXPY
(
handle
,
n
,
&
alpha
,
x
,
1
,
y
,
1
);
});
}
}
template
<
>
template
<
>
...
@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
...
@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T
beta
,
T
*
C
)
const
{
T
beta
,
T
*
C
)
const
{
cublasOperation_t
cuTransA
=
!
trans_a
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransA
=
!
trans_a
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
CUBlas
<
T
>::
GEMV
(
context_
.
cublas_handle
(),
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
&
beta
,
C
,
1
);
CUBlas
<
T
>::
GEMV
(
handle
,
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
&
beta
,
C
,
1
);
});
}
}
template
<
>
template
<
>
...
@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
...
@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010
#if CUDA_VERSION >= 9010
if
(
FLAGS_enable_cublas_tensor_op_math
&&
std
::
is_same
<
T
,
float
>::
value
)
{
if
(
FLAGS_enable_cublas_tensor_op_math
&&
std
::
is_same
<
T
,
float
>::
value
)
{
auto
cublas_call
=
[
&
]()
{
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
bool
use_tensor_op_math
=
platform
::
TensorCoreA
vailable
();
bool
use_tensor_op_math
=
context_
.
tensor_core_a
vailable
();
if
(
use_tensor_op_math
)
{
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
context_
.
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmStridedBatchedEx
(
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmStridedBatchedEx
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
CUDA_R_32F
,
ldb
,
CUDA_R_32F
,
ldb
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
ldc
,
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
};
});
auto
&
dev_ctx
=
const_cast
<
platform
::
CUDADeviceContext
&>
(
context_
);
dev_ctx
.
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
}
else
{
}
else
{
#endif // CUDA_VERSION >= 9010
#endif // CUDA_VERSION >= 9010
CUBlas
<
T
>::
GEMM_STRIDED_BATCH
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
strideB
,
A
,
lda
,
CUBlas
<
T
>::
GEMM_STRIDED_BATCH
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
strideA
,
&
beta
,
C
,
ldc
,
strideC
,
batchCount
);
B
,
ldb
,
strideB
,
A
,
lda
,
strideA
,
&
beta
,
C
,
ldc
,
strideC
,
batchCount
);
});
#if CUDA_VERSION >= 9010
#if CUDA_VERSION >= 9010
}
}
...
...
paddle/fluid/platform/cuda_helper.h
0 → 100644
浏览文件 @
ccb322d6
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/macros.h"
#if CUDA_VERSION < 9000
enum
cublasMath_t
{
CUBLAS_DEFAULT_MATH
=
0
};
#endif
namespace
paddle
{
namespace
platform
{
class
CublasHandleHolder
{
public:
CublasHandleHolder
(
cudaStream_t
stream
,
cublasMath_t
math_type
)
{
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
handle_
,
stream
));
#if CUDA_VERSION >= 9000
if
(
math_type
==
CUBLAS_TENSOR_OP_MATH
)
{
PADDLE_ENFORCE
(
dynload
::
cublasSetMathMode
(
handle_
,
CUBLAS_TENSOR_OP_MATH
));
}
#endif
}
~
CublasHandleHolder
()
{
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
handle_
));
}
template
<
typename
Callback
>
inline
void
Call
(
Callback
&&
callback
)
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
callback
(
handle_
);
}
private:
DISABLE_COPY_AND_ASSIGN
(
CublasHandleHolder
);
cublasHandle_t
handle_
;
mutable
std
::
mutex
mtx_
;
};
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device_context.cc
浏览文件 @
ccb322d6
...
@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
...
@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
cublas_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_DEFAULT_MATH
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
if
(
TensorCoreAvailable
())
{
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_TENSOR_OP_MATH
));
#endif
}
if
(
dynload
::
HasCUDNN
())
{
if
(
dynload
::
HasCUDNN
())
{
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
}
}
...
@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
...
@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId
(
place_
.
device
);
SetDeviceId
(
place_
.
device
);
Wait
();
Wait
();
WaitStreamCallback
();
WaitStreamCallback
();
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
cublas_handle_
.
reset
();
cublas_tensor_core_handle_
.
reset
();
eigen_stream_
.
reset
();
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
...
@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return
eigen_device_
.
get
();
return
eigen_device_
.
get
();
}
}
cublasHandle_t
CUDADeviceContext
::
cublas_hand
le
()
const
{
bool
CUDADeviceContext
::
tensor_core_availab
le
()
const
{
return
cublas_
handle_
;
return
cublas_
tensor_core_handle_
!=
nullptr
;
}
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
ccb322d6
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
...
@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle {
...
@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle {
std
::
unique_ptr
<
std
::
lock_guard
<
std
::
mutex
>>
guard_
;
std
::
unique_ptr
<
std
::
lock_guard
<
std
::
mutex
>>
guard_
;
};
};
#if CUDA_VERSION >= 9000
class
ScopedCublasMathMode
{
public:
ScopedCublasMathMode
(
cublasHandle_t
handle
,
cublasMath_t
new_math_mode
)
:
handle_
(
handle
)
{
need_reset
=
false
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGetMathMode
(
handle_
,
&
old_math_mode_
),
"Failed to get old cublas math mode"
);
if
(
old_math_mode_
!=
new_math_mode
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
new_math_mode
),
"Failed to set old cublas math mode"
);
need_reset
=
true
;
}
}
~
ScopedCublasMathMode
()
{
if
(
need_reset
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
old_math_mode_
),
"Failed to set old cublas math mode"
);
}
}
private:
cublasHandle_t
handle_
;
cublasMath_t
old_math_mode_
;
bool
need_reset
;
};
#endif
class
CUDADeviceContext
:
public
DeviceContext
{
class
CUDADeviceContext
:
public
DeviceContext
{
public:
public:
explicit
CUDADeviceContext
(
CUDAPlace
place
);
explicit
CUDADeviceContext
(
CUDAPlace
place
);
...
@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */
/*! \brief Return eigen device in the device context. */
Eigen
::
GpuDevice
*
eigen_device
()
const
;
Eigen
::
GpuDevice
*
eigen_device
()
const
;
/*! \brief Return cublas handle in the device context. */
/*! \brief Call cublas function safely. */
cublasHandle_t
cublas_handle
()
const
;
template
<
typename
Callback
>
inline
void
CublasCall
(
Callback
&&
callback
)
const
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
/*! \brief Check whether tensor core is supported */
bool
tensor_core_available
()
const
;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template
<
typename
Callback
>
inline
void
TensorCoreCublasCallIfAvailable
(
Callback
&&
callback
)
const
{
if
(
cublas_tensor_core_handle_
)
{
cublas_tensor_core_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
else
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
}
/*! \brief Return cudnn handle in the device context. */
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
const
;
cudnnHandle_t
cudnn_handle
()
const
;
...
@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
template
<
typename
Callback
>
template
<
typename
Callback
>
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
callback
();
callback
();
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
}
}
...
@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext {
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template
<
typename
Callback
>
void
CublasCall
(
Callback
callback
,
cublasMath_t
new_math
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
cublas_mtx_
);
ScopedCublasMathMode
scoped_cublas_math
(
cublas_handle_
,
new_math
);
callback
();
}
#endif
private:
private:
CUDAPlace
place_
;
CUDAPlace
place_
;
...
@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
CudnnHolder
>
cudnn_holder_
;
std
::
unique_ptr
<
CudnnHolder
>
cudnn_holder_
;
cudaStream_t
stream_
;
cudaStream_t
stream_
;
cublasHandle_t
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_tensor_core_handle_
;
int
compute_capability_
;
int
compute_capability_
;
int
runtime_version_
;
int
runtime_version_
;
...
@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
int
multi_process_
;
int
multi_process_
;
int
max_threads_per_mp_
;
int
max_threads_per_mp_
;
mutable
std
::
mutex
mtx_
;
// StreamCallbackManager is thread-safe
// StreamCallbackManager is thread-safe
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
mutable
std
::
mutex
cublas_mtx_
;
DISABLE_COPY_AND_ASSIGN
(
CUDADeviceContext
)
;
};
};
template
<
>
template
<
>
...
...
paddle/fluid/platform/device_context_test.cu
浏览文件 @
ccb322d6
...
@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
...
@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE
(
nullptr
,
gpu_device
);
ASSERT_NE
(
nullptr
,
gpu_device
);
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
ASSERT_NE
(
nullptr
,
cudnn_handle
);
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
ASSERT_NE
(
nullptr
,
cublas_handle
);
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
delete
device_context
;
delete
device_context
;
}
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录