Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
dd594652
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
dd594652
编写于
4月 24, 2018
作者:
F
fengjiayi
提交者:
GitHub
4月 24, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #10142 from JiayiFeng/Add_TensorCopySync
Add synchronous TensorCopy
上级
2486d563
c5e178f4
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
79 addition
and
25 deletion
+79
-25
paddle/fluid/framework/tensor_util.cc
paddle/fluid/framework/tensor_util.cc
+16
-10
paddle/fluid/framework/tensor_util.h
paddle/fluid/framework/tensor_util.h
+2
-1
paddle/fluid/memory/memcpy.cc
paddle/fluid/memory/memcpy.cc
+32
-7
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
.../fluid/operators/reader/create_double_buffer_reader_op.cc
+2
-1
paddle/fluid/platform/gpu_info.cc
paddle/fluid/platform/gpu_info.cc
+16
-3
paddle/fluid/platform/gpu_info.h
paddle/fluid/platform/gpu_info.h
+11
-3
未找到文件。
paddle/fluid/framework/tensor_util.cc
浏览文件 @
dd594652
...
...
@@ -20,7 +20,7 @@ namespace paddle {
namespace
framework
{
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
)
{
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
,
bool
sync
)
{
VLOG
(
3
)
<<
"TensorCopy "
<<
src
.
dims
()
<<
" from "
<<
src
.
place
()
<<
" to "
<<
dst_place
;
src
.
check_memory_size
();
...
...
@@ -47,9 +47,11 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
PADDLE_ENFORCE_EQ
(
src_gpu_place
,
ctx_gpu_place
);
memory
::
Copy
(
dst_cpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
auto
stream
=
sync
?
nullptr
:
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
)
.
stream
();
memory
::
Copy
(
dst_cpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
stream
);
}
else
if
(
platform
::
is_cpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
auto
src_cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
);
...
...
@@ -58,18 +60,22 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
auto
ctx_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx_place
);
PADDLE_ENFORCE_EQ
(
dst_gpu_place
,
ctx_gpu_place
);
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_cpu_place
,
src_ptr
,
size
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
auto
stream
=
sync
?
nullptr
:
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
)
.
stream
();
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_cpu_place
,
src_ptr
,
size
,
stream
);
}
else
if
(
platform
::
is_gpu_place
(
src_place
)
&&
platform
::
is_gpu_place
(
dst_place
))
{
auto
src_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
src_place
);
auto
dst_gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
dst_place
);
auto
ctx_place
=
ctx
.
GetPlace
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx_place
));
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
).
stream
());
auto
stream
=
sync
?
nullptr
:
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
)
.
stream
();
memory
::
Copy
(
dst_gpu_place
,
dst_ptr
,
src_gpu_place
,
src_ptr
,
size
,
stream
);
}
#endif
}
...
...
paddle/fluid/framework/tensor_util.h
浏览文件 @
dd594652
...
...
@@ -24,7 +24,8 @@ namespace paddle {
namespace
framework
{
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
);
const
platform
::
DeviceContext
&
ctx
,
Tensor
*
dst
,
bool
sync
=
false
);
void
TensorCopy
(
const
Tensor
&
src
,
const
platform
::
Place
&
dst_place
,
Tensor
*
dst
);
...
...
paddle/fluid/memory/memcpy.cc
浏览文件 @
dd594652
...
...
@@ -32,7 +32,11 @@ void Copy<platform::CPUPlace, platform::CUDAPlace>(
platform
::
CPUPlace
dst_place
,
void
*
dst
,
platform
::
CUDAPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
SetDeviceId
(
src_place
.
device
);
if
(
stream
)
{
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
}
else
{
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
);
}
}
template
<
>
...
...
@@ -40,7 +44,11 @@ void Copy<platform::CUDAPlace, platform::CPUPlace>(
platform
::
CUDAPlace
dst_place
,
void
*
dst
,
platform
::
CPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
SetDeviceId
(
dst_place
.
device
);
if
(
stream
)
{
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
}
else
{
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
);
}
}
template
<
>
...
...
@@ -49,10 +57,19 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>(
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
if
(
dst_place
==
src_place
)
{
platform
::
SetDeviceId
(
src_place
.
device
);
if
(
stream
)
{
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
}
else
{
platform
::
GpuMemcpyPeer
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
,
stream
);
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
);
}
}
else
{
if
(
stream
)
{
platform
::
GpuMemcpyPeerAsync
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
,
stream
);
}
else
{
platform
::
GpuMemcpyPeerSync
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
);
}
}
}
...
...
@@ -83,7 +100,11 @@ void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
platform
::
CUDAPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
SetDeviceId
(
src_place
.
device
);
if
(
stream
)
{
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
}
else
{
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
);
}
}
template
<
>
...
...
@@ -92,7 +113,11 @@ void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
platform
::
CUDAPinnedPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
SetDeviceId
(
dst_place
.
device
);
if
(
stream
)
{
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
}
else
{
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
);
}
}
#endif
...
...
paddle/fluid/operators/reader/create_double_buffer_reader_op.cc
浏览文件 @
dd594652
...
...
@@ -180,7 +180,8 @@ void DoubleBufferReader::PrefetchThreadFunc() {
auto
*
gpu_ctx
=
ctxs_
[
cached_tensor_id
].
get
();
gpu_batch
.
resize
(
cpu_batch
.
size
());
for
(
size_t
i
=
0
;
i
<
cpu_batch
.
size
();
++
i
)
{
framework
::
TensorCopy
(
cpu_batch
[
i
],
place_
,
*
gpu_ctx
,
&
gpu_batch
[
i
]);
framework
::
TensorCopy
(
cpu_batch
[
i
],
place_
,
*
gpu_ctx
,
&
gpu_batch
[
i
],
true
);
gpu_batch
[
i
].
set_lod
(
cpu_batch
[
i
].
lod
());
}
}
...
...
paddle/fluid/platform/gpu_info.cc
浏览文件 @
dd594652
...
...
@@ -127,11 +127,24 @@ void GpuMemcpyAsync(void *dst, const void *src, size_t count,
"cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync"
);
}
void
GpuMemcpyPeer
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
,
cudaStream_t
stream
)
{
void
GpuMemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
enum
cudaMemcpyKind
kind
)
{
PADDLE_ENFORCE
(
cudaMemcpy
(
dst
,
src
,
count
,
kind
),
"cudaMemcpy failed in paddle::platform::GpuMemcpySync"
);
}
void
GpuMemcpyPeerAsync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
,
cudaStream_t
stream
)
{
PADDLE_ENFORCE
(
cudaMemcpyPeerAsync
(
dst
,
dst_device
,
src
,
src_device
,
count
,
stream
),
"cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer"
);
"cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeerAsync"
);
}
void
GpuMemcpyPeerSync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
)
{
PADDLE_ENFORCE
(
cudaMemcpyPeer
(
dst
,
dst_device
,
src
,
src_device
,
count
),
"cudaMemcpyPeer failed in paddle::platform::GpuMemcpyPeerSync"
);
}
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
cudaStream_t
stream
)
{
...
...
paddle/fluid/platform/gpu_info.h
浏览文件 @
dd594652
...
...
@@ -57,9 +57,17 @@ size_t GpuMaxChunkSize();
void
GpuMemcpyAsync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
enum
cudaMemcpyKind
kind
,
cudaStream_t
stream
);
//! Copy memory from one device to another device.
void
GpuMemcpyPeer
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
,
cudaStream_t
stream
);
//! Copy memory from address src to dst synchronously.
void
GpuMemcpySync
(
void
*
dst
,
const
void
*
src
,
size_t
count
,
enum
cudaMemcpyKind
kind
);
//! Copy memory from one device to another device asynchronously.
void
GpuMemcpyPeerAsync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
,
cudaStream_t
stream
);
//! Copy memory from one device to another device synchronously.
void
GpuMemcpyPeerSync
(
void
*
dst
,
int
dst_device
,
const
void
*
src
,
int
src_device
,
size_t
count
);
//! Set memory dst with value count size asynchronously
void
GpuMemsetAsync
(
void
*
dst
,
int
value
,
size_t
count
,
cudaStream_t
stream
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录