Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
bcc6e5ce
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bcc6e5ce
编写于
7月 27, 2017
作者:
G
gangliao
提交者:
GitHub
7月 27, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3084 from gangliao/device_ctx
Refine device context and fix GetPlace()
上级
2200ff5e
c2b8bd34
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
133 addition
and
118 deletion
+133
-118
paddle/memory/CMakeLists.txt
paddle/memory/CMakeLists.txt
+1
-1
paddle/memory/memcpy.cc
paddle/memory/memcpy.cc
+3
-3
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+85
-1
paddle/platform/device_context.h
paddle/platform/device_context.h
+39
-108
paddle/platform/enforce.h
paddle/platform/enforce.h
+5
-5
未找到文件。
paddle/memory/CMakeLists.txt
浏览文件 @
bcc6e5ce
add_subdirectory
(
detail
)
cc_library
(
memory SRCS memory.cc
)
cc_library
(
memcpy SRCS memcpy.cc
DEPS device_context
)
cc_library
(
memcpy SRCS memcpy.cc
)
cc_library
(
paddle_memory
DEPS
...
...
paddle/memory/memcpy.cc
浏览文件 @
bcc6e5ce
...
...
@@ -35,7 +35,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
platform
::
GPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
GPUPlaceGuard
g
(
src_place
.
device
);
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
}
...
...
@@ -45,7 +45,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
platform
::
CPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
GPUPlaceGuard
g
(
dst_place
.
device
);
platform
::
SetDeviceId
(
dst_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
}
...
...
@@ -56,7 +56,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
if
(
dst_place
==
src_place
)
{
platform
::
GPUPlaceGuard
g
(
src_place
.
device
);
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
}
else
{
platform
::
GpuMemcpyPeer
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
,
...
...
paddle/platform/device_context.cc
浏览文件 @
bcc6e5ce
...
...
@@ -20,12 +20,96 @@ Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>()
return
reinterpret_cast
<
const
CPUDeviceContext
*>
(
this
)
->
eigen_device
();
}
CPUDeviceContext
::
CPUDeviceContext
()
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
CPUDeviceContext
::
CPUDeviceContext
(
CPUPlace
place
)
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
Eigen
::
DefaultDevice
*
CPUDeviceContext
::
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
Place
CPUDeviceContext
::
GetPlace
()
const
{
return
CPUPlace
();
}
#ifndef PADDLE_ONLY_CPU
template
<
>
Eigen
::
GpuDevice
*
DeviceContext
::
get_eigen_device
<
Eigen
::
GpuDevice
>
()
const
{
return
reinterpret_cast
<
const
CUDADeviceContext
*>
(
this
)
->
eigen_device
();
}
#endif
CUDADeviceContext
::
CUDADeviceContext
(
GPUPlace
place
)
:
place_
(
place
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
eigen_stream_
.
reset
(
new
Eigen
::
CudaStreamDevice
(
&
stream_
));
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
}
CUDADeviceContext
::~
CUDADeviceContext
()
{
SetDeviceId
(
place_
.
device
);
Wait
();
if
(
cublas_handle_
)
{
PADDLE_ENFORCE
(
dynload
::
cublasDestroy
(
cublas_handle_
));
}
if
(
cudnn_handle_
)
{
PADDLE_ENFORCE
(
dynload
::
cudnnDestroy
(
cudnn_handle_
));
}
if
(
curand_generator_
)
{
PADDLE_ENFORCE
(
dynload
::
curandDestroyGenerator
(
curand_generator_
));
}
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
}
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
void
CUDADeviceContext
::
Wait
()
const
{
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
}
Eigen
::
GpuDevice
*
CUDADeviceContext
::
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
cublasHandle_t
CUDADeviceContext
::
cublas_handle
()
{
if
(
!
cublas_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
}
return
cublas_handle_
;
}
cudnnHandle_t
CUDADeviceContext
::
cudnn_handle
()
{
if
(
!
cudnn_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream_
));
}
return
cudnn_handle_
;
}
curandGenerator_t
CUDADeviceContext
::
curand_generator
()
{
if
(
!
curand_generator_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
curandCreateGenerator
(
&
curand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
));
PADDLE_ENFORCE
(
dynload
::
curandSetPseudoRandomGeneratorSeed
(
curand_generator_
,
seed_
));
PADDLE_ENFORCE
(
dynload
::
curandSetStream
(
curand_generator_
,
stream_
));
}
return
curand_generator_
;
}
#endif // PADDLE_ONLY_CPU
}
// namespace platform
}
// namespace paddle
paddle/platform/device_context.h
浏览文件 @
bcc6e5ce
...
...
@@ -39,14 +39,13 @@ class DeviceContext {
class
CPUDeviceContext
:
public
DeviceContext
{
public:
CPUDeviceContext
()
{
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
());
}
CPUDeviceContext
();
CPUDeviceContext
(
CPUPlace
);
virtual
~
CPUDeviceContext
()
{}
Eigen
::
DefaultDevice
*
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
Eigen
::
DefaultDevice
*
eigen_device
()
const
;
Place
GetPlace
()
const
override
{
Place
retv
=
CPUPlace
();
return
retv
;
}
Place
GetPlace
()
const
override
;
private:
std
::
unique_ptr
<
Eigen
::
DefaultDevice
>
eigen_device_
;
...
...
@@ -54,119 +53,51 @@ class CPUDeviceContext : public DeviceContext {
#ifndef PADDLE_ONLY_CPU
class
GPUPlaceGuard
{
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
GPUPlaceGuard
(
GPUPlace
new_place
)
:
previous_
(
GetCurrentDeviceId
())
{
if
(
previous_
!=
new_place
)
{
paddle
::
platform
::
SetDeviceId
(
new_place
.
device
);
}
}
explicit
CUDADeviceContext
(
GPUPlace
);
virtual
~
CUDADeviceContext
();
~
GPUPlaceGuard
()
{
paddle
::
platform
::
SetDeviceId
(
previous_
.
device
);
}
/*! \brief Wait for all operations completion in the stream. */
void
Wait
()
const
;
private:
GPUPlace
previous_
;
};
/*! \brief Return CUDA stream in the device context. */
cudaStream_t
stream
()
const
;
class
CUDADeviceContext
:
public
DeviceContext
{
public:
explicit
CUDADeviceContext
(
const
GPUPlace
gpu_place
)
:
gpu_place_
(
gpu_place
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
),
"cudaStreamCreate failed"
);
eigen_stream_
.
reset
(
new
Eigen
::
CudaStreamDevice
(
&
stream_
));
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
}
Place
GetPlace
()
const
override
{
Place
retv
=
GPUPlace
();
return
retv
;
}
void
Wait
()
{
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
),
"cudaStreamSynchronize failed"
);
}
cudaStream_t
stream
()
const
{
return
stream_
;
}
Eigen
::
GpuDevice
*
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
cublasHandle_t
cublas_handle
()
{
if
(
!
blas_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasCreate
(
&
blas_handle_
),
"cublasCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasSetStream
(
blas_handle_
,
stream_
),
"cublasSetStream failed"
);
}
return
blas_handle_
;
}
cudnnHandle_t
cudnn_handle
()
{
if
(
!
dnn_handle_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnCreate
(
&
dnn_handle_
),
"cudnnCreate failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnSetStream
(
dnn_handle_
,
stream_
),
"cudnnSetStream failed"
);
}
return
dnn_handle_
;
}
curandGenerator_t
curand_generator
()
{
if
(
!
rand_generator_
)
{
GPUPlaceGuard
guard
(
gpu_place_
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandCreateGenerator
(
&
rand_generator_
,
CURAND_RNG_PSEUDO_DEFAULT
),
"curandCreateGenerator failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetPseudoRandomGeneratorSeed
(
rand_generator_
,
random_seed_
),
"curandSetPseudoRandomGeneratorSeed failed"
);
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandSetStream
(
rand_generator_
,
stream_
),
"curandSetStream failed"
);
}
return
rand_generator_
;
}
~
CUDADeviceContext
()
{
Wait
();
if
(
blas_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cublasDestroy
(
blas_handle_
),
"cublasDestroy failed"
);
}
if
(
dnn_handle_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
cudnnDestroy
(
dnn_handle_
),
"cudnnDestroy failed"
);
}
if
(
rand_generator_
)
{
PADDLE_ENFORCE
(
paddle
::
platform
::
dynload
::
curandDestroyGenerator
(
rand_generator_
),
"curandDestroyGenerator failed"
);
}
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
),
"cudaStreamDestroy failed"
);
}
/*! \brief Return place in the device context. */
Place
GetPlace
()
const
override
;
/*! \brief Return eigen device in the device context. */
Eigen
::
GpuDevice
*
eigen_device
()
const
;
// clang-format off
/*! \brief Return cublas handle in the device context. */
cublasHandle_t
cublas_handle
();
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t
cudnn_handle
();
/*! \brief Return curand handle in the device context. */
curandGenerator_t
curand_generator
();
// clang-format on
private:
GPUPlace
gpu_place_
;
cudaStream_t
stream_
;
GPUPlace
place_
;
std
::
unique_ptr
<
Eigen
::
CudaStreamDevice
>
eigen_stream_
;
private:
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
CudaStreamDevice
>
eigen_stream_
;
cublasHandle_t
blas_handle_
{
nullptr
};
private:
uint64_t
seed_
;
cud
nnHandle_t
dnn_handle_
{
nullptr
}
;
cud
aStream_t
stream_
;
int
random_seed_
;
curandGenerator_t
rand_generator_
{
nullptr
};
// clang-format off
cudnnHandle_t
cudnn_handle_
=
nullptr
;
cublasHandle_t
cublas_handle_
=
nullptr
;
curandGenerator_t
curand_generator_
=
nullptr
;
// clang-format on
};
#endif
...
...
paddle/platform/enforce.h
浏览文件 @
bcc6e5ce
...
...
@@ -58,11 +58,6 @@ struct EnforceNotMet : public std::exception {
// For more details, please check https://stackoverflow.com/a/43870188/724872.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
template
<
typename
T
>
inline
void
throw_on_error
(
T
e
)
{
throw_on_error
(
e
,
""
);
}
template
<
typename
...
Args
>
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
!=
0
,
void
>::
type
throw_on_error
(
int
stat
,
const
Args
&
...
args
)
{
...
...
@@ -132,6 +127,11 @@ inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
#endif // PADDLE_ONLY_CPU
template
<
typename
T
>
inline
void
throw_on_error
(
T
e
)
{
throw_on_error
(
e
,
""
);
}
#define PADDLE_THROW(...) \
do { \
throw ::paddle::platform::EnforceNotMet( \
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录