Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ffa63974
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ffa63974
编写于
3月 29, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
compare the performance of unpinned memory and pinned memory
上级
58a9f9f7
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
156 addition
and
8 deletion
+156
-8
paddle/fluid/memory/CMakeLists.txt
paddle/fluid/memory/CMakeLists.txt
+12
-8
paddle/fluid/memory/pinned_memory_test.cu
paddle/fluid/memory/pinned_memory_test.cu
+144
-0
未找到文件。
paddle/fluid/memory/CMakeLists.txt
浏览文件 @
ffa63974
...
@@ -4,13 +4,17 @@ cc_library(memory SRCS memory.cc DEPS place enforce)
...
@@ -4,13 +4,17 @@ cc_library(memory SRCS memory.cc DEPS place enforce)
cc_library
(
memcpy SRCS memcpy.cc DEPS place
)
cc_library
(
memcpy SRCS memcpy.cc DEPS place
)
cc_library
(
paddle_memory
cc_library
(
paddle_memory
DEPS
DEPS
memory
memory
memcpy
memcpy
meta_data
meta_data
meta_cache
meta_cache
memory_block
memory_block
buddy_allocator
buddy_allocator
system_allocator
)
system_allocator
)
cc_test
(
memory_test SRCS memory_test.cc DEPS place paddle_memory
)
cc_test
(
memory_test SRCS memory_test.cc DEPS place paddle_memory
)
if
(
WITH_GPU
)
nv_test
(
pinned_memory_test SRCS pinned_memory_test.cu DEPS place paddle_memory
)
endif
()
paddle/fluid/memory/pinned_memory_test.cu
0 → 100644
浏览文件 @
ffa63974
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/memory/detail/memory_block.h"
#include "paddle/fluid/memory/detail/meta_data.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/place.h"
#include <gtest/gtest.h>
#include <unordered_map>
template
<
typename
T
>
__global__
void
Kernel
(
T
*
output
,
int
dim
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
dim
)
{
output
[
tid
]
=
output
[
tid
]
*
output
[
tid
]
/
100
;
}
}
template
<
typename
Place
>
void
test_pinned_memory
()
{
Place
cpu_place
;
paddle
::
platform
::
CUDAPlace
cuda_place
;
const
int
data_size
=
4096
;
const
int
iteration
=
10
;
// create event start and end
cudaEvent_t
start_e
,
stop_e
,
copying_e
;
float
elapsedTime
=
0
;
cudaEventCreate
(
&
start_e
);
cudaEventCreate
(
&
stop_e
);
cudaEventCreate
(
&
copying_e
);
// create computation stream, data copying stream
cudaStream_t
computation_stream
,
copying_stream
;
cudaStreamCreate
(
&
computation_stream
);
cudaStreamCreate
(
&
copying_stream
);
// create record event, pinned memory, gpu memory
std
::
vector
<
cudaEvent_t
>
record_event
(
iteration
);
std
::
vector
<
float
*>
input_pinned_mem
(
iteration
);
std
::
vector
<
float
*>
gpu_mem
(
iteration
);
std
::
vector
<
float
*>
output_pinned_mem
(
iteration
);
// initial data
for
(
int
j
=
0
;
j
<
iteration
;
++
j
)
{
cudaEventCreateWithFlags
(
&
record_event
[
j
],
cudaEventDisableTiming
);
cudaEventCreate
(
&
(
record_event
[
j
]));
input_pinned_mem
[
j
]
=
static_cast
<
float
*>
(
paddle
::
memory
::
Alloc
(
cpu_place
,
data_size
*
sizeof
(
float
)));
output_pinned_mem
[
j
]
=
static_cast
<
float
*>
(
paddle
::
memory
::
Alloc
(
cpu_place
,
data_size
*
sizeof
(
float
)));
gpu_mem
[
j
]
=
static_cast
<
float
*>
(
paddle
::
memory
::
Alloc
(
cuda_place
,
data_size
*
sizeof
(
float
)));
for
(
int
k
=
0
;
k
<
data_size
;
++
k
)
{
input_pinned_mem
[
j
][
k
]
=
k
;
}
}
cudaEventRecord
(
start_e
,
computation_stream
);
// computation
for
(
int
m
=
0
;
m
<
30
;
++
m
)
{
for
(
int
i
=
0
;
i
<
iteration
;
++
i
)
{
// cpu -> GPU on computation stream.
// note: this operation is async for pinned memory.
paddle
::
memory
::
Copy
(
cuda_place
,
gpu_mem
[
i
],
cpu_place
,
input_pinned_mem
[
i
],
data_size
*
sizeof
(
float
),
computation_stream
);
// call kernel on computation stream.
Kernel
<<<
4
,
1024
,
0
,
computation_stream
>>>
(
gpu_mem
[
i
],
data_size
);
// record event_computation on computation stream
cudaEventRecord
(
record_event
[
i
],
computation_stream
);
// wait event_computation on copy stream.
// note: this operation is async.
cudaStreamWaitEvent
(
copying_stream
,
record_event
[
i
],
0
);
// copy data GPU->CPU, on copy stream.
// note: this operation is async for pinned memory.
paddle
::
memory
::
Copy
(
cpu_place
,
output_pinned_mem
[
i
],
cuda_place
,
gpu_mem
[
i
],
data_size
*
sizeof
(
float
),
copying_stream
);
}
}
cudaEventRecord
(
copying_e
,
copying_stream
);
cudaStreamWaitEvent
(
computation_stream
,
copying_e
,
0
);
cudaEventRecord
(
stop_e
,
computation_stream
);
cudaEventSynchronize
(
start_e
);
cudaEventSynchronize
(
stop_e
);
cudaEventElapsedTime
(
&
elapsedTime
,
start_e
,
stop_e
);
std
::
cout
<<
cpu_place
<<
" "
<<
"time consume:"
<<
elapsedTime
/
30
<<
std
::
endl
;
for
(
int
l
=
0
;
l
<
iteration
;
++
l
)
{
for
(
int
k
=
0
;
k
<
data_size
;
++
k
)
{
float
temp
=
input_pinned_mem
[
l
][
k
];
temp
=
temp
*
temp
/
100
;
EXPECT_FLOAT_EQ
(
temp
,
output_pinned_mem
[
l
][
k
]);
}
}
// destroy resource
cudaEventDestroy
(
copying_e
);
cudaEventDestroy
(
start_e
);
cudaEventDestroy
(
stop_e
);
for
(
int
j
=
0
;
j
<
10
;
++
j
)
{
cudaEventDestroy
((
record_event
[
j
]));
paddle
::
memory
::
Free
(
cpu_place
,
input_pinned_mem
[
j
]);
paddle
::
memory
::
Free
(
cpu_place
,
output_pinned_mem
[
j
]);
paddle
::
memory
::
Free
(
cuda_place
,
gpu_mem
[
j
]);
}
}
TEST
(
CPUANDCUDAPinned
,
CPUAllocator
)
{
test_pinned_memory
<
paddle
::
platform
::
CPUPlace
>
();
}
TEST
(
CPUANDCUDAPinned
,
CUDAPinnedAllocator
)
{
test_pinned_memory
<
paddle
::
platform
::
CUDAPinnedPlace
>
();
}
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录