Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
ce938ae5
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ce938ae5
编写于
6月 26, 2017
作者:
L
liaogang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
FIX: Pinned memory
上级
db128c45
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
27 addition
and
35 deletion
+27
-35
paddle/memory/README.md
paddle/memory/README.md
+1
-0
paddle/memory/detail/CMakeLists.txt
paddle/memory/detail/CMakeLists.txt
+1
-5
paddle/memory/detail/cpu_allocator.h
paddle/memory/detail/cpu_allocator.h
+21
-18
paddle/memory/detail/cpu_allocator_test.cc
paddle/memory/detail/cpu_allocator_test.cc
+4
-12
未找到文件。
paddle/memory/README.md
浏览文件 @
ce938ae5
...
@@ -97,6 +97,7 @@ class BuddyAllocator {
...
@@ -97,6 +97,7 @@ class BuddyAllocator {
struct
Block
{
struct
Block
{
size_t
size
;
size_t
size
;
Block
*
left
,
right
;
Block
*
left
,
right
;
size_t
index
;
// allocator id
};
};
...
...
};
};
...
...
paddle/memory/detail/CMakeLists.txt
浏览文件 @
ce938ae5
if
(
${
WITH_GPU
}
)
cc_test
(
cpu_allocator_test SRCS cpu_allocator_test.cc
)
nv_test
(
cpu_allocator_test SRCS cpu_allocator_test.cc
)
# nv_test links CUDA, but
else
(
${
WITH_GPU
}
)
cc_test
(
cpu_allocator_test SRCS cpu_allocator_test.cc
)
# cc_test doesn't.
endif
(
${
WITH_GPU
}
)
paddle/memory/detail/cpu_allocator.h
浏览文件 @
ce938ae5
...
@@ -14,20 +14,19 @@ limitations under the License. */
...
@@ -14,20 +14,19 @@ limitations under the License. */
#pragma once
#pragma once
#include <malloc.h> // for malloc and free
#include <stddef.h> // for size_t
#include <stddef.h> // for size_t
#include <cstdlib> // for malloc and free
#ifdef PADDLE_WITH_GPU
#ifndef _WIN32
#include <cuda.h>
#include <sys/mman.h> // for mlock and munlock
#include <cuda_runtime_api.h>
#endif
#endif // PADDLE_WITH_GPU
namespace
paddle
{
namespace
paddle
{
namespace
memory
{
namespace
memory
{
namespace
detail
{
namespace
detail
{
// CPUAllocator<staging=true> calls
cudaMallocHost
, which returns
// CPUAllocator<staging=true> calls
mlock
, which returns
// pinned and
m
locked memory as staging areas for data exchange
// pinned and locked memory as staging areas for data exchange
// between host and device. Allocates too much would reduce the
// between host and device. Allocates too much would reduce the
// amount of memory available to the system for paging. So, by
// amount of memory available to the system for paging. So, by
// default, we should use CPUAllocator<staging=false>.
// default, we should use CPUAllocator<staging=false>.
...
@@ -35,33 +34,37 @@ template <bool staging>
...
@@ -35,33 +34,37 @@ template <bool staging>
class
CPUAllocator
{
class
CPUAllocator
{
public:
public:
void
*
Alloc
(
size_t
size
);
void
*
Alloc
(
size_t
size
);
void
Free
(
void
*
p
);
void
Free
(
void
*
p
,
size_t
size
);
};
};
template
<
>
template
<
>
class
CPUAllocator
<
false
>
{
class
CPUAllocator
<
false
>
{
public:
public:
void
*
Alloc
(
size_t
size
)
{
return
malloc
(
size
);
}
void
*
Alloc
(
size_t
size
)
{
return
std
::
malloc
(
size
);
}
void
Free
(
void
*
p
)
{
free
(
p
);
}
void
Free
(
void
*
p
,
size_t
size
)
{
std
::
free
(
p
);
}
};
};
// If CMake macro PADDLE_WITH_GPU is OFF, C++ compiler won't generate the
// following specialization that depends on the CUDA library.
#ifdef PADDLE_WITH_GPU
template
<
>
template
<
>
class
CPUAllocator
<
true
>
{
class
CPUAllocator
<
true
>
{
public:
public:
void
*
Alloc
(
size_t
size
)
{
void
*
Alloc
(
size_t
size
)
{
void
*
p
;
void
*
p
=
std
::
malloc
(
size
)
;
if
(
cudaMallocHost
(
&
p
,
size
)
!=
cudaSuccess
)
{
if
(
p
==
nullptr
)
{
return
NULL
;
return
p
;
}
}
#ifndef _WIN32
mlock
(
p
,
size
);
#endif
return
p
;
return
p
;
}
}
void
Free
(
void
*
p
)
{
cudaFreeHost
(
p
);
}
void
Free
(
void
*
p
,
size_t
size
)
{
#ifndef _WIN32
munlock
(
p
,
size
);
#endif
std
::
free
(
p
);
}
};
};
#endif // PADDLE_WITH_GPU
}
// namespace detail
}
// namespace detail
}
// namespace memory
}
// namespace memory
...
...
paddle/memory/detail/cpu_allocator_test.cc
浏览文件 @
ce938ae5
...
@@ -19,20 +19,12 @@ TEST(CPUAllocator, NonStaging) {
...
@@ -19,20 +19,12 @@ TEST(CPUAllocator, NonStaging) {
paddle
::
memory
::
detail
::
CPUAllocator
<
false
>
a
;
paddle
::
memory
::
detail
::
CPUAllocator
<
false
>
a
;
void
*
p
=
a
.
Alloc
(
4096
);
void
*
p
=
a
.
Alloc
(
4096
);
EXPECT_NE
(
p
,
nullptr
);
EXPECT_NE
(
p
,
nullptr
);
a
.
Free
(
p
);
a
.
Free
(
p
,
4096
);
}
}
#ifdef PADDLE_WITH_GPU
TEST
(
CPUAllocator
,
Staging
)
{
TEST
(
CPUAllocator
,
Staging
)
{
paddle
::
memory
::
detail
::
CPUAllocator
<
true
>
a
;
paddle
::
memory
::
detail
::
CPUAllocator
<
true
>
a
;
void
*
p
=
a
.
Alloc
(
4096
);
int
devices
;
EXPECT_NE
(
p
,
nullptr
);
if
(
cudaGetDeviceCount
(
&
devices
)
==
cudaSuccess
&&
devices
>
0
)
{
a
.
Free
(
p
,
4096
);
void
*
p
=
a
.
Alloc
(
4096
);
EXPECT_NE
(
p
,
nullptr
);
a
.
Free
(
p
);
}
else
{
EXPECT_EQ
(
a
.
Alloc
(
4096
),
nullptr
);
}
}
}
#endif // PADDLE_WITH_GPU
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录