Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
be2c1a3b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
be2c1a3b
编写于
7月 12, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
follow comments
上级
a07deac9
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
18 addition
and
20 deletion
+18
-20
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+2
-2
paddle/platform/device_context.h
paddle/platform/device_context.h
+12
-14
paddle/platform/device_context_test.cc
paddle/platform/device_context_test.cc
+4
-4
未找到文件。
paddle/platform/device_context.cc
浏览文件 @
be2c1a3b
...
...
@@ -15,13 +15,13 @@ namespace paddle {
namespace
platform
{
template
<
>
Eigen
::
DefaultDevice
DeviceContext
::
get_eigen_device
<
Eigen
::
DefaultDevice
>
()
{
Eigen
::
DefaultDevice
*
DeviceContext
::
get_eigen_device
<
Eigen
::
DefaultDevice
>
()
{
return
reinterpret_cast
<
CPUDeviceContext
*>
(
this
)
->
eigen_device
();
}
#ifndef PADDLE_ONLY_CPU
template
<
>
Eigen
::
GpuDevice
DeviceContext
::
get_eigen_device
<
Eigen
::
GpuDevice
>
()
{
Eigen
::
GpuDevice
*
DeviceContext
::
get_eigen_device
<
Eigen
::
GpuDevice
>
()
{
return
reinterpret_cast
<
CUDADeviceContext
*>
(
this
)
->
eigen_device
();
}
#endif
...
...
paddle/platform/device_context.h
浏览文件 @
be2c1a3b
...
...
@@ -31,16 +31,16 @@ class DeviceContext {
virtual
Place
GetPlace
()
const
=
0
;
template
<
typename
DeviceType
>
DeviceType
get_eigen_device
();
DeviceType
*
get_eigen_device
();
};
class
CPUDeviceContext
:
public
DeviceContext
{
public:
Eigen
::
DefaultDevice
eigen_device
()
{
Eigen
::
DefaultDevice
*
eigen_device
()
{
if
(
!
eigen_device_
)
{
eigen_device_
=
new
Eigen
::
DefaultDevice
(
);
eigen_device_
.
reset
(
new
Eigen
::
DefaultDevice
()
);
}
return
*
eigen_device_
;
return
eigen_device_
.
get
()
;
}
Place
GetPlace
()
const
override
{
...
...
@@ -49,7 +49,7 @@ class CPUDeviceContext : public DeviceContext {
}
private:
Eigen
::
DefaultDevice
*
eigen_device_
{
nullptr
}
;
std
::
unique_ptr
<
Eigen
::
DefaultDevice
>
eigen_device_
;
};
#ifndef PADDLE_ONLY_CPU
...
...
@@ -74,8 +74,8 @@ class CUDADeviceContext : public DeviceContext {
GPUPlaceGuard
guard
(
gpu_place_
);
paddle
::
platform
::
throw_on_error
(
cudaStreamCreate
(
&
stream_
),
"cudaStreamCreate failed"
);
eigen_stream_
=
new
Eigen
::
CudaStreamDevice
(
&
stream_
);
eigen_device_
=
new
Eigen
::
GpuDevice
(
eigen_stream_
);
eigen_stream_
.
reset
(
new
Eigen
::
CudaStreamDevice
(
&
stream_
)
);
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
())
);
}
Place
GetPlace
()
const
override
{
...
...
@@ -90,7 +90,7 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t
stream
()
{
return
stream_
;
}
Eigen
::
GpuDevice
eigen_device
()
{
return
*
eigen_device_
;
}
Eigen
::
GpuDevice
*
eigen_device
()
{
return
eigen_device_
.
get
()
;
}
cublasHandle_t
cublas_handle
()
{
if
(
!
blas_handle_
)
{
...
...
@@ -155,10 +155,8 @@ class CUDADeviceContext : public DeviceContext {
rand_generator_
)
==
CURAND_STATUS_SUCCESS
,
"curandDestroyGenerator failed"
);
}
delete
eigen_stream_
;
delete
eigen_device_
;
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
paddle
::
platform
::
throw_on_error
(
cudaStreamDestroy
(
stream_
),
"cudaStreamDestroy failed"
);
}
...
...
@@ -167,8 +165,8 @@ class CUDADeviceContext : public DeviceContext {
GPUPlace
gpu_place_
;
cudaStream_t
stream_
;
Eigen
::
CudaStreamDevice
*
eigen_stream_
;
Eigen
::
GpuDevice
*
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
CudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
cublasHandle_t
blas_handle_
{
nullptr
};
...
...
paddle/platform/device_context_test.cc
浏览文件 @
be2c1a3b
...
...
@@ -21,9 +21,9 @@ TEST(Device, Init) {
for
(
int
i
=
0
;
i
<
count
;
i
++
)
{
paddle
::
platform
::
DeviceContext
*
device_context
=
new
paddle
::
platform
::
CUDADeviceContext
(
i
);
Eigen
::
GpuDevice
gpu_device
=
Eigen
::
GpuDevice
*
gpu_device
=
device_context
->
template
get_eigen_device
<
DEVICE_GPU
>();
ASSERT_NE
(
nullptr
,
gpu_device
.
stream
()
);
ASSERT_NE
(
nullptr
,
gpu_device
);
delete
device_context
;
}
}
...
...
@@ -33,8 +33,8 @@ TEST(Device, CUDADeviceContext) {
for
(
int
i
=
0
;
i
<
count
;
i
++
)
{
paddle
::
platform
::
CUDADeviceContext
*
device_context
=
new
paddle
::
platform
::
CUDADeviceContext
(
i
);
Eigen
::
GpuDevice
gpu_device
=
device_context
->
eigen_device
();
ASSERT_NE
(
nullptr
,
gpu_device
.
stream
()
);
Eigen
::
GpuDevice
*
gpu_device
=
device_context
->
eigen_device
();
ASSERT_NE
(
nullptr
,
gpu_device
);
cudnnHandle_t
cudnn_handle
=
device_context
->
cudnn_handle
();
ASSERT_NE
(
nullptr
,
cudnn_handle
);
cublasHandle_t
cublas_handle
=
device_context
->
cublas_handle
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录