Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ed409ac9
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ed409ac9
编写于
1月 08, 2019
作者:
S
sneaxiy
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "Revert "Remove op handle lock""
test=develop
上级
4a443ffc
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
159 addition
and
130 deletion
+159
-130
paddle/fluid/operators/math/blas_impl.cu.h
paddle/fluid/operators/math/blas_impl.cu.h
+64
-70
paddle/fluid/platform/cuda_helper.h
paddle/fluid/platform/cuda_helper.h
+58
-0
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+13
-5
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+24
-52
paddle/fluid/platform/device_context_test.cu
paddle/fluid/platform/device_context_test.cu
+0
-3
未找到文件。
paddle/fluid/operators/math/blas_impl.cu.h
浏览文件 @
ed409ac9
...
...
@@ -62,27 +62,19 @@ struct CUBlas<float> {
cudaDataType_t
Atype
,
int
lda
,
const
void
*
B
,
cudaDataType_t
Btype
,
int
ldb
,
const
float
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
)
{
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
auto
cublas_call
=
[
&
]()
{
// Because the gcc 4.8 doesn't expand template parameter pack that
// appears in a lambda-expression, I can not use template parameter pack
// here.
#if CUDA_VERSION >= 8000
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
platform
::
TensorCoreAvailable
()
?
"True"
:
"False"
);
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
dev_ctx
->
tensor_core_available
()
?
"True"
:
"False"
);
dev_ctx
->
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSgemmEx
(
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
));
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
));
});
#else
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
#else
cublas_call
();
PADDLE_THROW
(
"cublasSgemmEx is supported on cuda >= 8.0"
);
#endif
}
};
...
...
@@ -170,32 +162,24 @@ struct CUBlas<platform::float16> {
cudaDataType_t
Btype
,
int
ldb
,
const
void
*
beta
,
void
*
C
,
cudaDataType_t
Ctype
,
int
ldc
,
cudaDataType_t
computeType
)
{
auto
cublas_call
=
[
&
]()
{
#if CUDA_VERSION >= 8000
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
#if CUDA_VERSION >= 9000
bool
use_tensor_op_math
=
platform
::
TensorCoreA
vailable
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
bool
use_tensor_op_math
=
dev_ctx
->
tensor_core_a
vailable
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
#endif // CUDA_VERSION >= 9000
dev_ctx
->
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmEx
(
dev_ctx
->
cublas_handle
(),
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
Atype
,
lda
,
B
,
Btype
,
ldb
,
beta
,
C
,
Ctype
,
ldc
,
computeType
,
algo
));
});
#else
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
#endif
};
#if CUDA_VERSION >= 9000
// NOTES: To use Tensor Core, we should change the cublas config,
// but the cublas may be hold by multi-thread.
dev_ctx
->
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
#else
cublas_call
();
PADDLE_THROW
(
"cublasGemmEx is supported on cuda >= 8.0"
);
#endif
}
};
...
...
@@ -223,9 +207,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(CBLAS_TRANSPOSE transA,
CUDA_R_32F
,
N
);
}
else
{
#endif // CUDA_VERSION >= 8000
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
N
);
});
#if CUDA_VERSION >= 8000
}
...
...
@@ -266,9 +251,12 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
CUDA_R_16F
,
lda
,
&
h_beta
,
C
,
CUDA_R_16F
,
N
,
CUDA_R_32F
);
#else
// CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm
CUBlas
<
platform
::
float16
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
platform
::
float16
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
h_alpha
,
h_B
,
ldb
,
h_A
,
lda
,
&
h_beta
,
h_C
,
N
);
});
#endif // CUDA_VERSION >= 8000
}
...
...
@@ -292,8 +280,10 @@ void Blas<platform::CUDADeviceContext>::GEMM(bool transA, bool transB, int M,
}
else
{
#endif // CUDA_VERSION >= 8000
CUBlas
<
T
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
});
#if CUDA_VERSION >= 8000
}
...
...
@@ -311,16 +301,19 @@ inline void Blas<platform::CUDADeviceContext>::GEMM(
cublasOperation_t
cuTransA
=
transA
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
cublasOperation_t
cuTransB
=
transB
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
CUBlas
<
platform
::
float16
>::
GEMM
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
platform
::
float16
>::
GEMM
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
A
,
lda
,
&
beta
,
C
,
ldc
);
});
}
template
<
>
template
<
typename
T
>
void
Blas
<
platform
::
CUDADeviceContext
>::
AXPY
(
int
n
,
T
alpha
,
const
T
*
x
,
T
*
y
)
const
{
CUBlas
<
T
>::
AXPY
(
context_
.
cublas_handle
(),
n
,
&
alpha
,
x
,
1
,
y
,
1
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
AXPY
(
handle
,
n
,
&
alpha
,
x
,
1
,
y
,
1
);
});
}
template
<
>
...
...
@@ -330,8 +323,9 @@ void Blas<platform::CUDADeviceContext>::GEMV(bool trans_a, int M, int N,
T
beta
,
T
*
C
)
const
{
cublasOperation_t
cuTransA
=
!
trans_a
?
CUBLAS_OP_T
:
CUBLAS_OP_N
;
CUBlas
<
T
>::
GEMV
(
context_
.
cublas_handle
(),
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
&
beta
,
C
,
1
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMV
(
handle
,
cuTransA
,
N
,
M
,
&
alpha
,
A
,
N
,
B
,
1
,
&
beta
,
C
,
1
);
});
}
template
<
>
...
...
@@ -353,28 +347,28 @@ void Blas<platform::CUDADeviceContext>::BatchedGEMM(
#if CUDA_VERSION >= 9010
if
(
FLAGS_enable_cublas_tensor_op_math
&&
std
::
is_same
<
T
,
float
>::
value
)
{
auto
cublas_call
=
[
&
]()
{
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
bool
use_tensor_op_math
=
platform
::
TensorCoreAvailable
();
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
cublasGemmAlgo_t
algo
=
CUBLAS_GEMM_DFALT
;
bool
use_tensor_op_math
=
context_
.
tensor_core_available
()
;
if
(
use_tensor_op_math
)
{
algo
=
CUBLAS_GEMM_DFALT_TENSOR_OP
;
}
VLOG
(
5
)
<<
"use_tensor_op_math: "
<<
(
use_tensor_op_math
?
"True"
:
"False"
);
context_
.
TensorCoreCublasCallIfAvailable
([
&
](
cublasHandle_t
handle
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGemmStridedBatchedEx
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
CUDA_R_32F
,
ldb
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
CUDA_R_32F
,
ldc
,
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
};
auto
&
dev_ctx
=
const_cast
<
platform
::
CUDADeviceContext
&>
(
context_
);
dev_ctx
.
CublasCall
(
cublas_call
,
CUBLAS_TENSOR_OP_MATH
);
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
CUDA_R_32F
,
ldb
,
strideB
,
A
,
CUDA_R_32F
,
lda
,
strideA
,
&
beta
,
C
,
CUDA_R_32F
,
ldc
,
strideC
,
batchCount
,
CUDA_R_32F
,
algo
));
});
}
else
{
#endif // CUDA_VERSION >= 9010
CUBlas
<
T
>::
GEMM_STRIDED_BATCH
(
context_
.
cublas_handle
(),
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
strideB
,
A
,
lda
,
strideA
,
&
beta
,
C
,
ldc
,
strideC
,
batchCount
);
context_
.
CublasCall
([
&
](
cublasHandle_t
handle
)
{
CUBlas
<
T
>::
GEMM_STRIDED_BATCH
(
handle
,
cuTransB
,
cuTransA
,
N
,
M
,
K
,
&
alpha
,
B
,
ldb
,
strideB
,
A
,
lda
,
strideA
,
&
beta
,
C
,
ldc
,
strideC
,
batchCount
);
});
#if CUDA_VERSION >= 9010
}
...
...
paddle/fluid/platform/cuda_helper.h
0 → 100644
浏览文件 @
ed409ac9
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/macros.h"
#if CUDA_VERSION < 9000
enum
cublasMath_t
{
CUBLAS_DEFAULT_MATH
=
0
};
#endif
namespace
paddle
{
namespace
platform
{
class
CublasHandleHolder
{
public:
CublasHandleHolder
(
cudaStream_t
stream
,
cublasMath_t
math_type
)
{
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
handle_
,
stream
));
#if CUDA_VERSION >= 9000
if
(
math_type
==
CUBLAS_TENSOR_OP_MATH
)
{
PADDLE_ENFORCE
(
dynload
::
cublasSetMathMode
(
handle_
,
CUBLAS_TENSOR_OP_MATH
));
}
#endif
}
~
CublasHandleHolder
()
{
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
handle_
));
}
template
<
typename
Callback
>
inline
void
Call
(
Callback
&&
callback
)
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
callback
(
handle_
);
}
private:
DISABLE_COPY_AND_ASSIGN
(
CublasHandleHolder
);
cublasHandle_t
handle_
;
mutable
std
::
mutex
mtx_
;
};
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device_context.cc
浏览文件 @
ed409ac9
...
...
@@ -245,8 +245,15 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
cublas_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_DEFAULT_MATH
));
if
(
TensorCoreAvailable
())
{
#if CUDA_VERSION >= 9000
cublas_tensor_core_handle_
.
reset
(
new
CublasHandleHolder
(
stream_
,
CUBLAS_TENSOR_OP_MATH
));
#endif
}
if
(
dynload
::
HasCUDNN
())
{
cudnn_holder_
.
reset
(
new
CudnnHolder
(
&
stream_
,
place
));
}
...
...
@@ -306,7 +313,8 @@ CUDADeviceContext::~CUDADeviceContext() {
SetDeviceId
(
place_
.
device
);
Wait
();
WaitStreamCallback
();
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
cublas_handle_
.
reset
();
cublas_tensor_core_handle_
.
reset
();
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
...
...
@@ -335,8 +343,8 @@ Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
return
eigen_device_
.
get
();
}
cublasHandle_t
CUDADeviceContext
::
cublas_hand
le
()
const
{
return
cublas_
handle_
;
bool
CUDADeviceContext
::
tensor_core_availab
le
()
const
{
return
cublas_
tensor_core_handle_
!=
nullptr
;
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
const
{
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
ed409ac9
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/platform/temporary_allocator.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_helper.h"
#include "paddle/fluid/platform/dynload/cublas.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/fluid/platform/gpu_info.h"
...
...
@@ -209,39 +210,6 @@ class CudnnWorkspaceHandle {
std
::
unique_ptr
<
std
::
lock_guard
<
std
::
mutex
>>
guard_
;
};
#if CUDA_VERSION >= 9000
class
ScopedCublasMathMode
{
public:
ScopedCublasMathMode
(
cublasHandle_t
handle
,
cublasMath_t
new_math_mode
)
:
handle_
(
handle
)
{
need_reset
=
false
;
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasGetMathMode
(
handle_
,
&
old_math_mode_
),
"Failed to get old cublas math mode"
);
if
(
old_math_mode_
!=
new_math_mode
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
new_math_mode
),
"Failed to set old cublas math mode"
);
need_reset
=
true
;
}
}
~
ScopedCublasMathMode
()
{
if
(
need_reset
)
{
PADDLE_ENFORCE
(
platform
::
dynload
::
cublasSetMathMode
(
handle_
,
old_math_mode_
),
"Failed to set old cublas math mode"
);
}
}
private:
cublasHandle_t
handle_
;
cublasMath_t
old_math_mode_
;
bool
need_reset
;
};
#endif
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
CUDAPlace
place
);
...
...
@@ -262,8 +230,25 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return eigen device in the device context. */
Eigen
::
GpuDevice
*
eigen_device
()
const
;
/*! \brief Return cublas handle in the device context. */
cublasHandle_t
cublas_handle
()
const
;
/*! \brief Call cublas function safely. */
template
<
typename
Callback
>
inline
void
CublasCall
(
Callback
&&
callback
)
const
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
/*! \brief Check whether tensor core is supported */
bool
tensor_core_available
()
const
;
/*! \brief Call cublas function with Tensor Core safely. If
Tensor Core is not available, use DEFAULT_MATH instead. */
template
<
typename
Callback
>
inline
void
TensorCoreCublasCallIfAvailable
(
Callback
&&
callback
)
const
{
if
(
cublas_tensor_core_handle_
)
{
cublas_tensor_core_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
else
{
cublas_handle_
->
Call
(
std
::
forward
<
Callback
>
(
callback
));
}
}
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
()
const
;
...
...
@@ -282,7 +267,6 @@ class CUDADeviceContext : public DeviceContext {
template
<
typename
Callback
>
void
RecordEvent
(
cudaEvent_t
ev
,
Callback
callback
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mtx_
);
callback
();
PADDLE_ENFORCE
(
cudaEventRecord
(
ev
,
stream_
));
}
...
...
@@ -294,18 +278,6 @@ class CUDADeviceContext : public DeviceContext {
void
WaitStreamCallback
()
const
{
callback_manager_
->
Wait
();
}
#if CUDA_VERSION >= 9000
/*! \brief CublasCall may need to change cublas's config,
* but the cublas may be hold by multi-thread, so we should
* add lock here. */
template
<
typename
Callback
>
void
CublasCall
(
Callback
callback
,
cublasMath_t
new_math
)
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
cublas_mtx_
);
ScopedCublasMathMode
scoped_cublas_math
(
cublas_handle_
,
new_math
);
callback
();
}
#endif
private:
CUDAPlace
place_
;
...
...
@@ -313,7 +285,9 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
CudnnHolder
>
cudnn_holder_
;
cudaStream_t
stream_
;
cublasHandle_t
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_handle_
;
std
::
unique_ptr
<
CublasHandleHolder
>
cublas_tensor_core_handle_
;
int
compute_capability_
;
int
runtime_version_
;
...
...
@@ -321,12 +295,10 @@ class CUDADeviceContext : public DeviceContext {
int
multi_process_
;
int
max_threads_per_mp_
;
mutable
std
::
mutex
mtx_
;
// StreamCallbackManager is thread-safe
std
::
unique_ptr
<
StreamCallbackManager
>
callback_manager_
;
mutable
std
::
mutex
cublas_mtx_
;
DISABLE_COPY_AND_ASSIGN
(
CUDADeviceContext
)
;
};
template
<
>
...
...
paddle/fluid/platform/device_context_test.cu
浏览文件 @
ed409ac9
...
...
@@ -43,9 +43,6 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE
(
nullptr
,
gpu_device
);
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
ASSERT_NE
(
nullptr
,
cublas_handle
);
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
delete
device_context
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录