Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
65e5aebd
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
65e5aebd
编写于
7月 25, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix mixed_vector_test
上级
da035fc6
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
76 addition
and
63 deletion
+76
-63
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/mixed_vector_test.cc
paddle/fluid/framework/mixed_vector_test.cc
+0
-62
paddle/fluid/framework/mixed_vector_test.cu
paddle/fluid/framework/mixed_vector_test.cu
+75
-0
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
65e5aebd
...
...
@@ -23,7 +23,7 @@ endif()
cc_test
(
eigen_test SRCS eigen_test.cc DEPS tensor
)
if
(
WITH_GPU
)
nv_test
(
mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor
)
nv_test
(
mixed_vector_test SRCS mixed_vector_test.cc
mixed_vector_test.cu
DEPS place memory device_context tensor
)
else
()
cc_test
(
mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor
)
endif
()
...
...
paddle/fluid/framework/mixed_vector_test.cc
浏览文件 @
65e5aebd
...
...
@@ -12,18 +12,11 @@
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
#include <memory>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/platform/gpu_info.h"
template
<
typename
T
>
using
vec
=
paddle
::
framework
::
Vector
<
T
>
;
...
...
@@ -77,58 +70,3 @@ TEST(mixed_vector, Resize) {
vec
.
push_back
(
0
);
vec
.
push_back
(
0
);
}
#ifdef PADDLE_WITH_CUDA
static
__global__
void
multiply_10
(
int
*
ptr
)
{
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
ptr
[
i
]
*=
10
;
}
}
cudaStream_t
GetCUDAStream
(
paddle
::
platform
::
CUDAPlace
place
)
{
return
reinterpret_cast
<
const
paddle
::
platform
::
CUDADeviceContext
*>
(
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
))
->
stream
();
}
TEST
(
mixed_vector
,
GPU_VECTOR
)
{
vec
<
int
>
tmp
;
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
tmp
.
push_back
(
i
);
}
ASSERT_EQ
(
tmp
.
size
(),
10UL
);
paddle
::
platform
::
CUDAPlace
gpu
(
0
);
multiply_10
<<<
1
,
1
,
0
,
GetCUDAStream
(
gpu
)
>>>
(
tmp
.
MutableData
(
gpu
));
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
ASSERT_EQ
(
tmp
[
i
],
i
*
10
);
}
}
TEST
(
mixed_vector
,
MultiGPU
)
{
if
(
paddle
::
platform
::
GetCUDADeviceCount
()
<
2
)
{
LOG
(
WARNING
)
<<
"Skip mixed_vector.MultiGPU since there are not multiple "
"GPUs in your machine."
;
return
;
}
vec
<
int
>
tmp
;
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
tmp
.
push_back
(
i
);
}
ASSERT_EQ
(
tmp
.
size
(),
10UL
);
paddle
::
platform
::
CUDAPlace
gpu0
(
0
);
paddle
::
platform
::
SetDeviceId
(
0
);
multiply_10
<<<
1
,
1
,
0
,
GetCUDAStream
(
gpu0
)
>>>
(
tmp
.
MutableData
(
gpu0
));
paddle
::
platform
::
CUDAPlace
gpu1
(
1
);
auto
*
gpu1_ptr
=
tmp
.
MutableData
(
gpu1
);
paddle
::
platform
::
SetDeviceId
(
1
);
multiply_10
<<<
1
,
1
,
0
,
GetCUDAStream
(
gpu1
)
>>>
(
gpu1_ptr
);
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
ASSERT_EQ
(
tmp
[
i
],
i
*
100
);
}
}
#endif
paddle/fluid/framework/mixed_vector_test.cu
0 → 100644
浏览文件 @
65e5aebd
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_runtime.h>
#include <memory>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/platform/gpu_info.h"
template
<
typename
T
>
using
vec
=
paddle
::
framework
::
Vector
<
T
>
;
static
__global__
void
multiply_10
(
int
*
ptr
)
{
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
ptr
[
i
]
*=
10
;
}
}
cudaStream_t
GetCUDAStream
(
paddle
::
platform
::
CUDAPlace
place
)
{
return
reinterpret_cast
<
const
paddle
::
platform
::
CUDADeviceContext
*>
(
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
))
->
stream
();
}
TEST
(
mixed_vector
,
GPU_VECTOR
)
{
vec
<
int
>
tmp
;
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
tmp
.
push_back
(
i
);
}
ASSERT_EQ
(
tmp
.
size
(),
10UL
);
paddle
::
platform
::
CUDAPlace
gpu
(
0
);
multiply_10
<<<
1
,
1
,
0
,
GetCUDAStream
(
gpu
)
>>>
(
tmp
.
MutableData
(
gpu
));
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
ASSERT_EQ
(
tmp
[
i
],
i
*
10
);
}
}
TEST
(
mixed_vector
,
MultiGPU
)
{
if
(
paddle
::
platform
::
GetCUDADeviceCount
()
<
2
)
{
LOG
(
WARNING
)
<<
"Skip mixed_vector.MultiGPU since there are not multiple "
"GPUs in your machine."
;
return
;
}
vec
<
int
>
tmp
;
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
tmp
.
push_back
(
i
);
}
ASSERT_EQ
(
tmp
.
size
(),
10UL
);
paddle
::
platform
::
CUDAPlace
gpu0
(
0
);
paddle
::
platform
::
SetDeviceId
(
0
);
multiply_10
<<<
1
,
1
,
0
,
GetCUDAStream
(
gpu0
)
>>>
(
tmp
.
MutableData
(
gpu0
));
paddle
::
platform
::
CUDAPlace
gpu1
(
1
);
auto
*
gpu1_ptr
=
tmp
.
MutableData
(
gpu1
);
paddle
::
platform
::
SetDeviceId
(
1
);
multiply_10
<<<
1
,
1
,
0
,
GetCUDAStream
(
gpu1
)
>>>
(
gpu1_ptr
);
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
ASSERT_EQ
(
tmp
[
i
],
i
*
100
);
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录