Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
5022b14d
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5022b14d
编写于
7月 24, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix mixed tensor compile and add cpu unit test
上级
e011e34a
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
51 addition
and
26 deletion
+51
-26
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+6
-1
paddle/fluid/framework/mixed_vector.h
paddle/fluid/framework/mixed_vector.h
+5
-5
paddle/fluid/framework/mixed_vector_test.cc
paddle/fluid/framework/mixed_vector_test.cc
+40
-20
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
5022b14d
...
...
@@ -22,7 +22,12 @@ endif()
cc_test
(
eigen_test SRCS eigen_test.cc DEPS tensor
)
nv_test
(
mixed_vector_test SRCS mixed_vector_test.cu DEPS place memory device_context tensor
)
if
(
WITH_GPU
)
nv_test
(
mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor
)
else
()
cc_test
(
mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor
)
endif
()
cc_library
(
lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto recordio
)
cc_test
(
lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory
)
nv_test
(
lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor
)
...
...
paddle/fluid/framework/mixed_vector.h
浏览文件 @
5022b14d
...
...
@@ -16,6 +16,7 @@
#include <algorithm>
#include <initializer_list>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/tensor.h"
...
...
@@ -386,13 +387,14 @@ template <typename T>
class
CPUVector
:
public
std
::
vector
<
T
,
std
::
allocator
<
T
>>
{
public:
CPUVector
()
:
std
::
vector
<
T
>
()
{}
CPUVector
(
size_t
count
,
const
T
&
value
=
T
())
explicit
CPUVector
(
size_t
count
,
const
T
&
value
=
T
())
:
std
::
vector
<
T
>
(
count
,
value
)
{}
CPUVector
(
std
::
initializer_list
<
T
>
init
)
:
std
::
vector
<
T
>
(
init
)
{}
CPUVector
(
const
std
::
vector
<
T
>
&
other
)
:
std
::
vector
<
T
>
(
other
)
{}
explicit
CPUVector
(
const
std
::
vector
<
T
>
&
other
)
:
std
::
vector
<
T
>
(
other
)
{}
explicit
CPUVector
(
const
CPUVector
<
T
>
&
other
)
:
std
::
vector
<
T
>
(
other
)
{}
CPUVector
(
CPUVector
<
T
>
&&
other
)
:
std
::
vector
<
T
>
(
std
::
move
(
other
))
{}
CPUVector
(
std
::
vector
<
T
>
&&
other
)
:
std
::
vector
<
T
>
(
std
::
move
(
other
))
{}
explicit
CPUVector
(
std
::
vector
<
T
>
&&
other
)
:
std
::
vector
<
T
>
(
std
::
move
(
other
))
{}
CPUVector
&
operator
=
(
const
CPUVector
&
other
)
{
this
->
assign
(
other
.
begin
(),
other
.
end
());
return
*
this
;
...
...
@@ -410,8 +412,6 @@ class CPUVector : public std::vector<T, std::allocator<T>> {
return
os
;
}
void
resize
(
size_t
size
)
{
this
->
resize
(
size
);
}
T
&
operator
[](
size_t
id
)
{
return
this
->
at
(
id
);
}
const
T
&
operator
[](
size_t
id
)
const
{
return
this
->
at
(
id
);
}
...
...
paddle/fluid/framework/mixed_vector_test.c
u
→
paddle/fluid/framework/mixed_vector_test.c
c
浏览文件 @
5022b14d
...
...
@@ -11,8 +11,15 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
#include <memory>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/mixed_vector.h"
...
...
@@ -41,6 +48,38 @@ TEST(mixed_vector, CPU_VECTOR) {
}
}
TEST
(
mixed_vector
,
InitWithCount
)
{
paddle
::
framework
::
Vector
<
int
>
vec
(
10
,
10
);
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
ASSERT_EQ
(
vec
[
i
],
10
);
}
}
TEST
(
mixed_vector
,
ForEach
)
{
vec
<
int
>
tmp
;
for
(
auto
&
v
:
tmp
)
{
VLOG
(
3
)
<<
v
;
}
}
TEST
(
mixed_vector
,
Reserve
)
{
paddle
::
framework
::
Vector
<
int
>
vec
;
vec
.
reserve
(
1
);
vec
.
push_back
(
0
);
vec
.
push_back
(
0
);
vec
.
push_back
(
0
);
}
TEST
(
mixed_vector
,
Resize
)
{
paddle
::
framework
::
Vector
<
int
>
vec
;
vec
.
resize
(
1
);
vec
.
push_back
(
0
);
vec
.
push_back
(
0
);
vec
.
push_back
(
0
);
}
#ifdef PADDLE_WITH_CUDA
static
__global__
void
multiply_10
(
int
*
ptr
)
{
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
ptr
[
i
]
*=
10
;
...
...
@@ -92,23 +131,4 @@ TEST(mixed_vector, MultiGPU) {
}
}
TEST
(
mixed_vector
,
InitWithCount
)
{
paddle
::
framework
::
Vector
<
int
>
vec
(
10
,
10
);
for
(
int
i
=
0
;
i
<
10
;
++
i
)
{
ASSERT_EQ
(
vec
[
i
],
10
);
}
}
TEST
(
mixed_vector
,
ForEach
)
{
vec
<
int
>
tmp
;
for
(
auto
&
v
:
tmp
)
{
}
}
TEST
(
mixed_vector
,
Reserve
)
{
paddle
::
framework
::
Vector
<
int
>
vec
;
vec
.
reserve
(
1
);
vec
.
push_back
(
0
);
vec
.
push_back
(
0
);
vec
.
push_back
(
0
);
}
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录