Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
858dea88
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
858dea88
编写于
7月 21, 2017
作者:
Y
Yi Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move memory::Copy out from memory.h into memcpy.h
上级
6cae35b5
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
101 addition
and
51 deletion
+101
-51
paddle/memory/CMakeLists.txt
paddle/memory/CMakeLists.txt
+1
-0
paddle/memory/memcpy.cc
paddle/memory/memcpy.cc
+67
-0
paddle/memory/memcpy.h
paddle/memory/memcpy.h
+33
-0
paddle/memory/memory.cc
paddle/memory/memory.cc
+0
-42
paddle/memory/memory.h
paddle/memory/memory.h
+0
-9
未找到文件。
paddle/memory/CMakeLists.txt
浏览文件 @
858dea88
add_subdirectory
(
detail
)
add_subdirectory
(
detail
)
cc_library
(
memory SRCS memory.cc
)
cc_library
(
memory SRCS memory.cc
)
cc_library
(
memcpy SRCS memcpy.cc
)
cc_library
(
paddle_memory
cc_library
(
paddle_memory
DEPS
DEPS
...
...
paddle/memory/memcpy.cc
0 → 100644
浏览文件 @
858dea88
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/memory/memcpy.h"
#include <cstring> // for memcpy
#include "paddle/platform/device_context.h"
namespace
paddle
{
namespace
memory
{
template
<
>
void
Copy
<
platform
::
CPUPlace
,
platform
::
CPUPlace
>
(
platform
::
CPUPlace
,
void
*
dst
,
platform
::
CPUPlace
,
const
void
*
src
,
size_t
num
)
{
std
::
memcpy
(
dst
,
src
,
num
);
}
#ifndef PADDLE_ONLY_CPU
template
<
>
void
Copy
<
platform
::
CPUPlace
,
platform
::
GPUPlace
>
(
platform
::
CPUPlace
dst_place
,
void
*
dst
,
platform
::
GPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
GPUPlaceGuard
g
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
}
template
<
>
void
Copy
<
platform
::
GPUPlace
,
platform
::
CPUPlace
>
(
platform
::
GPUPlace
dst_place
,
void
*
dst
,
platform
::
CPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
GPUPlaceGuard
g
(
dst_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
}
template
<
>
void
Copy
<
platform
::
GPUPlace
,
platform
::
GPUPlace
>
(
platform
::
GPUPlace
dst_place
,
void
*
dst
,
platform
::
GPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
if
(
dst_place
==
src_place
)
{
platform
::
GPUPlaceGuard
g
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
}
else
{
platform
::
GpuMemcpyPeer
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
,
stream
);
}
}
#endif // PADDLE_ONLY_CPU
paddle/memory/memcpy.h
0 → 100644
浏览文件 @
858dea88
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/platform/gpu_info.h"
#include "paddle/platform/place.h"
namespace
paddle
{
namespace
memory
{
template
<
typename
DstPlace
,
typename
SrcPlace
>
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
);
#ifndef PADDLE_ONLY_CPU
template
<
typename
DstPlace
,
typename
SrcPlace
>
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
);
#endif // PADDLE_ONLY_CPU
}
// namespace memory
}
// namespace paddle
paddle/memory/memory.cc
浏览文件 @
858dea88
...
@@ -46,13 +46,6 @@ size_t Used<platform::CPUPlace>(platform::CPUPlace place) {
...
@@ -46,13 +46,6 @@ size_t Used<platform::CPUPlace>(platform::CPUPlace place) {
return
GetCPUBuddyAllocator
()
->
Used
();
return
GetCPUBuddyAllocator
()
->
Used
();
}
}
template
<
>
void
Copy
<
platform
::
CPUPlace
,
platform
::
CPUPlace
>
(
platform
::
CPUPlace
,
void
*
dst
,
platform
::
CPUPlace
,
const
void
*
src
,
size_t
num
)
{
std
::
memcpy
(
dst
,
src
,
num
);
}
#ifndef PADDLE_ONLY_CPU
#ifndef PADDLE_ONLY_CPU
detail
::
BuddyAllocator
*
GetGPUBuddyAllocator
(
int
gpu_id
)
{
detail
::
BuddyAllocator
*
GetGPUBuddyAllocator
(
int
gpu_id
)
{
...
@@ -85,41 +78,6 @@ size_t Used<platform::GPUPlace>(platform::GPUPlace place) {
...
@@ -85,41 +78,6 @@ size_t Used<platform::GPUPlace>(platform::GPUPlace place) {
return
GetGPUBuddyAllocator
(
place
.
device
)
->
Used
();
return
GetGPUBuddyAllocator
(
place
.
device
)
->
Used
();
}
}
template
<
>
void
Copy
<
platform
::
CPUPlace
,
platform
::
GPUPlace
>
(
platform
::
CPUPlace
dst_place
,
void
*
dst
,
platform
::
GPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
}
template
<
>
void
Copy
<
platform
::
GPUPlace
,
platform
::
CPUPlace
>
(
platform
::
GPUPlace
dst_place
,
void
*
dst
,
platform
::
CPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
platform
::
SetDeviceId
(
dst_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
}
template
<
>
void
Copy
<
platform
::
GPUPlace
,
platform
::
GPUPlace
>
(
platform
::
GPUPlace
dst_place
,
void
*
dst
,
platform
::
GPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
if
(
dst_place
==
src_place
)
{
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
}
else
{
platform
::
GpuMemcpyPeer
(
dst
,
dst_place
.
device
,
src
,
src_place
.
device
,
num
,
stream
);
}
}
#endif // PADDLE_ONLY_CPU
#endif // PADDLE_ONLY_CPU
}
// namespace memory
}
// namespace memory
...
...
paddle/memory/memory.h
浏览文件 @
858dea88
...
@@ -29,15 +29,6 @@ void Free(Place, void*);
...
@@ -29,15 +29,6 @@ void Free(Place, void*);
template
<
typename
Place
>
template
<
typename
Place
>
size_t
Used
(
Place
);
size_t
Used
(
Place
);
template
<
typename
DstPlace
,
typename
SrcPlace
>
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
);
#ifndef PADDLE_ONLY_CPU
template
<
typename
DstPlace
,
typename
SrcPlace
>
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
);
#endif // PADDLE_ONLY_CPU
template
<
typename
T
,
/* must be POD types */
template
<
typename
T
,
/* must be POD types */
typename
Place
/* platform::GPUPlace or platform::CPUPlace */
,
typename
Place
/* platform::GPUPlace or platform::CPUPlace */
,
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录