Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1d8fe2a2
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1d8fe2a2
编写于
3月 22, 2018
作者:
Y
Yu Yang
提交者:
dzhwinter
3月 22, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enhance device context pool (#9293)
上级
64c5c8f8
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
35 addition
and
30 deletion
+35
-30
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+19
-16
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+4
-14
paddle/fluid/platform/place.h
paddle/fluid/platform/place.h
+12
-0
未找到文件。
paddle/fluid/platform/device_context.cc
浏览文件 @
1d8fe2a2
...
...
@@ -10,43 +10,45 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_context.h"
#include <unordered_set>
#include "paddle/fluid/memory/memory.h"
namespace
paddle
{
namespace
platform
{
DeviceContextPool
*
DeviceContextPool
::
pool
=
nullptr
;
const
platform
::
DeviceContext
*
DeviceContextPool
::
Get
(
const
platform
::
Place
&
place
)
{
platform
::
DeviceContext
*
DeviceContextPool
::
Get
(
const
platform
::
Place
&
place
)
{
auto
it
=
device_contexts_
.
find
(
place
);
if
(
it
==
device_contexts_
.
end
())
{
PADDLE_THROW
(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option"
);
}
return
it
->
second
;
return
it
->
second
.
get
()
;
}
DeviceContextPool
::
DeviceContextPool
(
const
std
::
vector
<
platform
::
Place
>&
places
)
{
PADDLE_ENFORCE_GT
(
places
.
size
(),
0
);
for
(
size_t
i
=
0
;
i
<
places
.
size
();
i
++
)
{
if
(
platform
::
is_cpu_place
(
places
[
i
]))
{
using
PtrType
=
std
::
unique_ptr
<
DeviceContext
>
;
std
::
unordered_set
<
Place
,
PlaceHash
>
set
;
for
(
auto
&
p
:
places
)
{
set
.
insert
(
p
);
}
for
(
auto
&
p
:
set
)
{
if
(
platform
::
is_cpu_place
(
p
))
{
#ifdef PADDLE_WITH_MKLDNN
device_contexts_
.
emplace
(
places
[
i
],
new
platform
::
MKLDNNDeviceContext
(
boost
::
get
<
platform
::
CPUPlace
>
(
places
[
i
])));
device_contexts_
.
emplace
(
p
,
PtrType
(
new
MKLDNNDeviceContext
(
boost
::
get
<
CPUPlace
>
(
p
))));
#else
device_contexts_
.
emplace
(
places
[
i
],
new
platform
::
CPUDeviceContext
(
boost
::
get
<
platform
::
CPUPlace
>
(
places
[
i
])));
device_contexts_
.
emplace
(
p
,
PtrType
(
new
CPUDeviceContext
(
boost
::
get
<
CPUPlace
>
(
p
))));
#endif
}
else
if
(
platform
::
is_gpu_place
(
p
laces
[
i
]
))
{
}
else
if
(
platform
::
is_gpu_place
(
p
))
{
#ifdef PADDLE_WITH_CUDA
device_contexts_
.
emplace
(
places
[
i
],
new
platform
::
CUDADeviceContext
(
boost
::
get
<
platform
::
CUDAPlace
>
(
places
[
i
])));
device_contexts_
.
emplace
(
p
,
PtrType
(
new
CUDADeviceContext
(
boost
::
get
<
CUDAPlace
>
(
p
))));
#else
PADDLE_THROW
(
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
...
...
@@ -159,6 +161,7 @@ CUDADeviceContext::~CUDADeviceContext() {
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
void
CUDADeviceContext
::
Wait
()
const
{
std
::
lock_guard
<
std
::
mutex
>
guard
(
mutex_
);
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaGetLastError
());
}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
1d8fe2a2
...
...
@@ -103,6 +103,7 @@ class CUDADeviceContext : public DeviceContext {
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
mutable
std
::
mutex
mutex_
;
cudaStream_t
stream_
;
cudnnHandle_t
cudnn_handle_
;
cublasHandle_t
cublas_handle_
;
...
...
@@ -159,7 +160,7 @@ class DeviceContextPool {
}
/*! \brief Return handle of single device context. */
const
platform
::
DeviceContext
*
Get
(
const
platform
::
Place
&
place
);
platform
::
DeviceContext
*
Get
(
const
platform
::
Place
&
place
);
template
<
typename
Place
>
const
typename
DefaultDeviceContextType
<
Place
>::
TYPE
*
GetByPlace
(
...
...
@@ -172,19 +173,8 @@ class DeviceContextPool {
private:
static
DeviceContextPool
*
pool
;
constexpr
static
int
LEFT_SHIFT
=
8
;
struct
Hash
{
std
::
hash
<
int
>
hash_
;
size_t
operator
()(
const
platform
::
Place
&
place
)
const
{
int
pre_hash
=
place
.
which
()
<<
LEFT_SHIFT
;
if
(
platform
::
is_gpu_place
(
place
))
{
pre_hash
+=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
).
GetDeviceId
();
}
return
hash_
(
pre_hash
);
}
};
std
::
unordered_map
<
const
platform
::
Place
,
const
platform
::
DeviceContext
*
,
Hash
>
std
::
unordered_map
<
const
platform
::
Place
,
std
::
unique_ptr
<
platform
::
DeviceContext
>
,
PlaceHash
>
device_contexts_
;
DISABLE_COPY_AND_ASSIGN
(
DeviceContextPool
);
};
...
...
paddle/fluid/platform/place.h
浏览文件 @
1d8fe2a2
...
...
@@ -65,6 +65,18 @@ bool is_cpu_place(const Place &);
bool
places_are_same_class
(
const
Place
&
,
const
Place
&
);
bool
is_same_place
(
const
Place
&
,
const
Place
&
);
struct
PlaceHash
{
std
::
size_t
operator
()(
const
Place
&
p
)
const
{
constexpr
size_t
num_dev_bits
=
4
;
std
::
hash
<
int
>
ihash
;
size_t
dev_id
=
0
;
if
(
is_gpu_place
(
p
))
{
dev_id
=
boost
::
get
<
CUDAPlace
>
(
p
).
device
;
}
return
ihash
(
dev_id
<<
num_dev_bits
|
p
.
which
());
}
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Place
&
);
template
<
typename
Visitor
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录