Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
9f11da59
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9f11da59
编写于
4月 23, 2018
作者:
F
fengjiayi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add synchronous TensorCopy and use it in double buffer
上级
3863c6a9
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
83 addition
and
29 deletion
+83
-29
paddle/fluid/framework/tensor_util.cc
paddle/fluid/framework/tensor_util.cc
+17
-10
paddle/fluid/memory/memcpy.cc
paddle/fluid/memory/memcpy.cc
+37
-12
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+2
-1
paddle/fluid/platform/gpu_info.cc
paddle/fluid/platform/gpu_info.cc
+16
-3
paddle/fluid/platform/gpu_info.h
paddle/fluid/platform/gpu_info.h
+11
-3
未找到文件。
paddle/fluid/framework/tensor_util.cc
浏览文件 @
9f11da59
...
@@ -20,7 +20,8 @@ namespace paddle {
...
@@ -20,7 +20,8 @@ namespace paddle {
namespace
framework
{
namespace
framework
{
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
)
{
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
,
bool
sync
=
false
)
{
VLOG
(
3
)
<<
"TensorCopy "
<<
src
.
dims
()
<<
" from "
<<
src
.
place
()
<<
" to "
VLOG
(
3
)
<<
"TensorCopy "
<<
src
.
dims
()
<<
" from "
<<
src
.
place
()
<<
" to "
<<
dst_place
;
<<
dst_place
;
src
.
check_memory_size
();
src
.
check_memory_size
();
...
@@ -47,9 +48,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
...
@@ -47,9 +48,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
PADDLE_ENFORCE_EQ
(
src_gpu_place
,
ctx_gpu_place
);
PADDLE_ENFORCE_EQ
(
src_gpu_place
,
ctx_gpu_place
);
memory
::
Copy
(
auto
stream
=
dst_cpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
sync
?
nullptr
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
:
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
)
.
stream
();
memory
::
Copy
(
dst_cpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
stream
);
}
else
if
(
platform
::
is_cpu_place
(
src_place
)
&&
}
else
if
(
platform
::
is_cpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
platform
::
is_gpu_place
(
dst_place
))
{
auto
src_cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
);
auto
src_cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
);
...
@@ -58,18 +61,22 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
...
@@ -58,18 +61,22 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
PADDLE_ENFORCE_EQ
(
dst_gpu_place
,
ctx_gpu_place
);
PADDLE_ENFORCE_EQ
(
dst_gpu_place
,
ctx_gpu_place
);
memory
::
Copy
(
auto
stream
=
dst_gpu_place
,
dst_ptr
,
src_cpu_place
,
src_ptr
,
size
,
sync
?
nullptr
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
:
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
)
.
stream
();
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_cpu_place
,
src_ptr
,
size
,
stream
);
}
else
if
(
platform
::
is_gpu_place
(
src_place
)
&&
}
else
if
(
platform
::
is_gpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
platform
::
is_gpu_place
(
dst_place
))
{
auto
src_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
src_place
);
auto
src_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
src_place
);
auto
dst_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dst_place
);
auto
dst_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dst_place
);
auto
ctx_place
=
ctx
.
GetPlace
();
auto
ctx_place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
memory
::
Copy
(
auto
stream
=
dst_gpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
sync
?
nullptr
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
:
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
)
.
stream
();
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
stream
);
}
}
#endif
#endif
}
}
...
...
paddle/fluid/memory/memcpy.cc
浏览文件 @
9f11da59
...
@@ -30,29 +30,46 @@ void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
...
@@ -30,29 +30,46 @@ void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
template
<
>
template
<
>
void
Copy
<
platform
::
CPUPlace
,
platform
::
CUDAPlace
>
(
void
Copy
<
platform
::
CPUPlace
,
platform
::
CUDAPlace
>
(
platform
::
CPUPlace
dst_place
,
void
*
dst
,
platform
::
CUDAPlace
src_place
,
platform
::
CPUPlace
dst_place
,
void
*
dst
,
platform
::
CUDAPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
=
nullptr
)
{
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
if
(
stream
)
{
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
}
else
{
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
);
}
}
}
template
<
>
template
<
>
void
Copy
<
platform
::
CUDAPlace
,
platform
::
CPUPlace
>
(
void
Copy
<
platform
::
CUDAPlace
,
platform
::
CPUPlace
>
(
platform
::
CUDAPlace
dst_place
,
void
*
dst
,
platform
::
CPUPlace
src_place
,
platform
::
CUDAPlace
dst_place
,
void
*
dst
,
platform
::
CPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
=
nullptr
)
{
platform
::
SetDeviceId
(
dst_place
.
device
);
platform
::
SetDeviceId
(
dst_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
if
(
stream
)
{
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
}
else
{
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
);
}
}
}
template
<
>
template
<
>
void
Copy
<
platform
::
CUDAPlace
,
platform
::
CUDAPlace
>
(
void
Copy
<
platform
::
CUDAPlace
,
platform
::
CUDAPlace
>
(
platform
::
CUDAPlace
dst_place
,
void
*
dst
,
platform
::
CUDAPlace
src_place
,
platform
::
CUDAPlace
dst_place
,
void
*
dst
,
platform
::
CUDAPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
=
nullptr
)
{
if
(
dst_place
==
src_place
)
{
if
(
dst_place
==
src_place
)
{
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
if
(
stream
)
{
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
}
else
{
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
);
}
}
else
{
}
else
{
platform
::
GpuMemcpyPeer
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
,
if
(
stream
)
{
stream
);
platform
::
GpuMemcpyPeerAsync
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
,
stream
);
}
else
{
platform
::
GpuMemcpyPeerSync
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
,
stream
);
}
}
}
}
}
...
@@ -81,18 +98,26 @@ template <>
...
@@ -81,18 +98,26 @@ template <>
void
Copy
<
platform
::
CUDAPinnedPlace
,
platform
::
CUDAPlace
>
(
void
Copy
<
platform
::
CUDAPinnedPlace
,
platform
::
CUDAPlace
>
(
platform
::
CUDAPinnedPlace
dst_place
,
void
*
dst
,
platform
::
CUDAPinnedPlace
dst_place
,
void
*
dst
,
platform
::
CUDAPlace
src_place
,
const
void
*
src
,
size_t
num
,
platform
::
CUDAPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
cudaStream_t
stream
=
nullptr
)
{
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
if
(
stream
)
{
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
}
else
{
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
);
}
}
}
template
<
>
template
<
>
void
Copy
<
platform
::
CUDAPlace
,
platform
::
CUDAPinnedPlace
>
(
void
Copy
<
platform
::
CUDAPlace
,
platform
::
CUDAPinnedPlace
>
(
platform
::
CUDAPlace
dst_place
,
void
*
dst
,
platform
::
CUDAPlace
dst_place
,
void
*
dst
,
platform
::
CUDAPinnedPlace
src_place
,
const
void
*
src
,
size_t
num
,
platform
::
CUDAPinnedPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
cudaStream_t
stream
=
nullptr
)
{
platform
::
SetDeviceId
(
dst_place
.
device
);
platform
::
SetDeviceId
(
dst_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
if
(
stream
)
{
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
}
else
{
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
);
}
}
}
#endif
#endif
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
9f11da59
...
@@ -180,7 +180,8 @@ void DoubleBufferReader::PrefetchThreadFunc() {
...
@@ -180,7 +180,8 @@ void DoubleBufferReader::PrefetchThreadFunc() {
auto
*
gpu_ctx
=
ctxs_
[
cached_tensor_id
].
get
();
auto
*
gpu_ctx
=
ctxs_
[
cached_tensor_id
].
get
();
gpu_batch
.
resize
(
cpu_batch
.
size
());
gpu_batch
.
resize
(
cpu_batch
.
size
());
for
(
size_t
i
=
0
;
i
<
cpu_batch
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
cpu_batch
.
size
();
++
i
)
{
framework
::
TensorCopy
(
cpu_batch
[
i
],
place_
,
*
gpu_ctx
,
&
gpu_batch
[
i
]);
framework
::
TensorCopy
(
cpu_batch
[
i
],
place_
,
*
gpu_ctx
,
&
gpu_batch
[
i
],
true
);
gpu_batch
[
i
].
set_lod
(
cpu_batch
[
i
].
lod
());
gpu_batch
[
i
].
set_lod
(
cpu_batch
[
i
].
lod
());
}
}
}
}
...
...
paddle/fluid/platform/gpu_info.cc
浏览文件 @
9f11da59
...
@@ -127,11 +127,24 @@ void GpuMemcpyAsync(void *dst, const void *src, size_t count,
...
@@ -127,11 +127,24 @@ void GpuMemcpyAsync(void *dst, const void *src, size_t count,
"cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync"
);
"cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync"
);
}
}
void
GpuMemcpyPeer
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
void
GpuMemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
size_t
count
,
cudaStream_t
stream
)
{
enum
cudaMemcpyKind
kind
)
{
PADDLE_ENFORCE
(
cudaMemcpy
(
dst
,
src
,
count
,
kind
),
"cudaMemcpy failed in paddle::platform::GpuMemcpySync"
);
}
void
GpuMemcpyPeerAsync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
,
cudaStream_t
stream
)
{
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
cudaMemcpyPeerAsync
(
dst
,
dst_device
,
src
,
src_device
,
count
,
stream
),
cudaMemcpyPeerAsync
(
dst
,
dst_device
,
src
,
src_device
,
count
,
stream
),
"cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer"
);
"cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeerAsync"
);
}
void
GpuMemcpyPeerSync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
)
{
PADDLE_ENFORCE
(
cudaMemcpyPeer
(
dst
,
dst_device
,
src
,
src_device
,
count
),
"cudaMemcpyPeer failed in paddle::platform::GpuMemcpyPeerSync"
);
}
}
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
cudaStream_t
stream
)
{
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
cudaStream_t
stream
)
{
...
...
paddle/fluid/platform/gpu_info.h
浏览文件 @
9f11da59
...
@@ -57,9 +57,17 @@ size_t GpuMaxChunkSize();
...
@@ -57,9 +57,17 @@ size_t GpuMaxChunkSize();
void
GpuMemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
void
GpuMemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
enum
cudaMemcpyKind
kind
,
cudaStream_t
stream
);
enum
cudaMemcpyKind
kind
,
cudaStream_t
stream
);
//! Copy memory from one device to another device.
//! Copy memory from address src to dst synchronously.
void
GpuMemcpyPeer
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
void
GpuMemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
size_t
count
,
cudaStream_t
stream
);
enum
cudaMemcpyKind
kind
);
//! Copy memory from one device to another device asynchronously.
void
GpuMemcpyPeerAsync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
,
cudaStream_t
stream
);
//! Copy memory from one device to another device synchronously.
void
GpuMemcpyPeerSync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
);
//! Set memory dst with value count size asynchronously
//! Set memory dst with value count size asynchronously
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
cudaStream_t
stream
);
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
cudaStream_t
stream
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录