Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6f06e6cd
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6f06e6cd
编写于
1月 02, 2019
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
差异文件
Merge remote origin
test=develop
上级
adc96e06
d25395fc
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
66 addition
and
101 deletion
+66
-101
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+33
-56
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+25
-0
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+8
-45
未找到文件。
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
6f06e6cd
...
...
@@ -62,27 +62,17 @@ struct CUBlas<float> {
cudaDataType_t
Atype
,
int
lda
,
const
void
*
B
,
cudaDataType_t
Btype
,
int
ldb
,
const
float
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
)
{
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
auto
cublas_call
=
[
&
]()
{
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
#if CUDA_VERSION >= 8000
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
platform
::
TensorCoreA
vailable
()
?
"True"
:
"False"
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemmEx
(
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
));
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
dev_ctx
->
tensor_core_a
vailable
()
?
"True"
:
"False"
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemmEx
(
dev_ctx
->
possible_cublas_tensor_core_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
));
#else
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
#else
cublas_call
();
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
#endif
}
};
...
...
@@ -170,32 +160,23 @@ struct CUBlas<platform::float16> {
cudaDataType_t
Btype
,
int
ldb
,
const
void
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
,
cudaDataType_t
computeType
)
{
auto
cublas_call
=
[
&
]()
{
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
#if CUDA_VERSION >= 9000
bool
use_tensor_op_math
=
platform
::
TensorCoreA
vailable
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
bool
use_tensor_op_math
=
dev_ctx
->
tensor_core_a
vailable
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
#endif // CUDA_VERSION >= 9000
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
#else
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
dev_ctx
->
possible_cublas_tensor_core_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
#else
cublas_call
(
);
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
#endif
}
};
...
...
@@ -353,22 +334,18 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010
if
(
FLAGS_enable_cublas_tensor_op_math
&&
std
::
is_same
<
T
,
float
>::
value
)
{
auto
cublas_call
=
[
&
]()
{
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
bool
use_tensor_op_math
=
platform
::
TensorCoreAvailable
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmStridedBatchedEx
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
CUDA_R_32F
,
ldb
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
CUDA_R_32F
,
ldc
,
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
};
auto
&
dev_ctx
=
const_cast
<
platform
::
CUDADeviceContext
&>
(
context_
);
dev_ctx
.
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
bool
use_tensor_op_math
=
context_
.
tensor_core_available
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmStridedBatchedEx
(
context_
.
possible_cublas_tensor_core_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
CUDA_R_32F
,
ldb
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
CUDA_R_32F
,
ldc
,
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
}
else
{
#endif // CUDA_VERSION >= 9010
...
...
paddle/fluid/platform/device_context.cc
浏览文件 @
6f06e6cd
...
...
@@ -247,6 +247,18 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
if
(
TensorCoreAvailable
())
{
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_
.
reset
(
new
cublasHandle_t
());
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
cublas_tensor_core_handle_
.
get
()));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
*
cublas_tensor_core_handle_
,
stream_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetMathMode
(
*
cublas_tensor_core_handle_
,
CUBLAS_TENSOR_OP_MATH
));
#endif
}
if
(
dynload
::
HasCUDNN
())
{
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
}
...
...
@@ -307,6 +319,10 @@ CUDADeviceContext::~CUDADeviceContext() {
Wait
();
WaitStreamCallback
();
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
if
(
cublas_tensor_core_handle_
)
{
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
*
cublas_tensor_core_handle_
));
cublas_tensor_core_handle_
.
reset
();
}
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
...
@@ -339,6 +355,15 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
return
cublas_handle_
;
}
cublasHandle_t
CUDADeviceContext
::
possible_cublas_tensor_core_handle
()
const
{
return
cublas_tensor_core_handle_
?
*
cublas_tensor_core_handle_
:
cublas_handle_
;
}
bool
CUDADeviceContext
::
tensor_core_available
()
const
{
return
cublas_tensor_core_handle_
!=
nullptr
;
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
return
cudnn_holder_
->
cudnn_handle
();
}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
6f06e6cd
...
...
@@ -209,39 +209,6 @@ class CudnnWorkspaceHandle {
std
::
unique_ptr
<
std
::
lock_guard
<
std
::
mutex
>>
guard_
;
};
#if CUDA_VERSION >= 9000
class
ScopedCublasMathMode
{
public:
ScopedCublasMathMode
(
cublasHandle_t
handle
,
cublasMath_t
new_math_mode
)
:
handle_
(
handle
)
{
need_reset
=
false
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGetMathMode
(
handle_
,
&
old_math_mode_
),
"Failed to get old cublas math mode"
);
if
(
old_math_mode_
!=
new_math_mode
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
new_math_mode
),
"Failed to set old cublas math mode"
);
need_reset
=
true
;
}
}
~
ScopedCublasMathMode
()
{
if
(
need_reset
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
old_math_mode_
),
"Failed to set old cublas math mode"
);
}
}
private:
cublasHandle_t
handle_
;
cublasMath_t
old_math_mode_
;
bool
need_reset
;
};
#endif
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
CUDAPlace
place
);
...
...
@@ -265,6 +232,13 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return cublas handle in the device context. */
cublasHandle_t
cublas_handle
()
const
;
/*! \brief Check whether tensor core is supported */
bool
tensor_core_available
()
const
;
/*! \brief Return cublas handle supporting Tensor Core. If Tensor Core is
* not supported, return the same handle as cublas_handle(). */
cublasHandle_t
possible_cublas_tensor_core_handle
()
const
;
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
const
;
...
...
@@ -294,18 +268,6 @@ class CUDADeviceContext : public DeviceContext {
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template
<
typename
Callback
>
void
CublasCall
(
Callback
callback
,
cublasMath_t
new_math
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
cublas_mtx_
);
ScopedCublasMathMode
scoped_cublas_math
(
cublas_handle_
,
new_math
);
callback
();
}
#endif
private:
CUDAPlace
place_
;
...
...
@@ -314,6 +276,7 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
CudnnHolder
>
cudnn_holder_
;
cudaStream_t
stream_
;
cublasHandle_t
cublas_handle_
;
std
::
unique_ptr
<
cublasHandle_t
>
cublas_tensor_core_handle_
;
int
compute_capability_
;
int
runtime_version_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录