Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
6d0d29f6
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6d0d29f6
编写于
9月 11, 2017
作者:
Q
qingqing01
提交者:
GitHub
9月 11, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4001 from qingqing01/lod_tensor_py
Correctly use host_vector in LoDTensor and expose LoDTensor to Python.
上级
af523df4
28dc4340
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
201 addition
and
5 deletion
+201
-5
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-0
paddle/framework/lod_tensor.h
paddle/framework/lod_tensor.h
+4
-1
paddle/framework/lod_tensor_test.cu
paddle/framework/lod_tensor_test.cu
+52
-0
paddle/operators/math/im2col_test.cc
paddle/operators/math/im2col_test.cc
+1
-1
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+62
-0
python/paddle/v2/framework/tests/test_tensor.py
python/paddle/v2/framework/tests/test_tensor.py
+81
-3
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
6d0d29f6
...
...
@@ -9,6 +9,7 @@ cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library
(
lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor
)
cc_test
(
lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor
)
nv_test
(
lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor
)
cc_test
(
variable_test SRCS variable_test.cc
)
...
...
paddle/framework/lod_tensor.h
浏览文件 @
6d0d29f6
...
...
@@ -18,8 +18,10 @@
#ifndef PADDLE_ONLY_CPU
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/system/cuda/experimental/pinned_allocator.h>
#endif
#include <glog/logging.h>
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/enforce.h"
...
...
@@ -32,7 +34,8 @@ template <typename T>
using
Vector
=
std
::
vector
<
T
>
;
#else
template
<
typename
T
>
using
Vector
=
thrust
::
host_vector
<
T
>
;
using
Vector
=
thrust
::
host_vector
<
T
,
thrust
::
system
::
cuda
::
experimental
::
pinned_allocator
<
T
>>
;
#endif
using
LoD
=
std
::
vector
<
Vector
<
size_t
>>
;
...
...
paddle/framework/lod_tensor_test.cu
0 → 100644
浏览文件 @
6d0d29f6
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include "paddle/framework/lod_tensor.h"
#include "paddle/platform/assert.h"
#include <gtest/gtest.h>
__global__
void
test
(
size_t
*
a
,
int
size
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
a
[
i
]
*=
2
;
}
}
TEST
(
LoDTensor
,
LoDInGPU
)
{
paddle
::
framework
::
Tensor
tensor
;
paddle
::
framework
::
LoDTensor
lod_tensor
;
paddle
::
platform
::
GPUPlace
place
(
0
);
paddle
::
framework
::
LoD
src_lod
;
src_lod
.
push_back
(
std
::
vector
<
size_t
>
{
0
,
2
,
4
,
6
,
8
,
10
,
12
,
14
});
tensor
.
Resize
({
14
,
16
});
tensor
.
mutable_data
<
float
>
(
place
);
lod_tensor
.
set_lod
(
src_lod
);
lod_tensor
.
set_tensor
(
&
tensor
);
CHECK_EQ
(
lod_tensor
.
lod_element
(
0
,
2
),
4
);
CHECK_EQ
(
lod_tensor
.
lod_element
(
0
,
4
),
8
);
auto
lod
=
lod_tensor
.
lod
();
test
<<<
1
,
8
>>>
(
lod
[
0
].
data
(),
lod
[
0
].
size
());
cudaDeviceSynchronize
();
for
(
size_t
i
=
0
;
i
<
src_lod
[
0
].
size
();
++
i
)
{
CHECK_EQ
(
lod
[
0
].
data
()[
i
],
src_lod
[
0
].
data
()[
i
]
*
2
);
}
}
paddle/operators/math/im2col_test.cc
浏览文件 @
6d0d29f6
paddle/pybind/pybind.cc
浏览文件 @
6d0d29f6
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/backward.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
...
...
@@ -58,6 +59,8 @@ namespace paddle {
namespace
framework
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoD
=
framework
::
LoD
;
static
size_t
UniqueIntegerGenerator
()
{
static
std
::
atomic
<
size_t
>
generator
;
...
...
@@ -117,6 +120,60 @@ PYBIND11_PLUGIN(core) {
return
self
.
data
<
float
>
()[
offset
];
});
py
::
class_
<
LoDTensor
>
(
m
,
"LoDTensor"
,
R"DOC(LoD(Leval of Ddetails) Tensor.
The tensor and LoD info should be created before creating the LoDTensor, then
call the set_tensor and set_lod functions to set them.
)DOC"
)
.
def
(
"__init__"
,
[](
LoDTensor
&
instance
,
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
lod
,
Tensor
*
t
)
{
#ifdef PADDLE_ONLY_CPU
new
(
&
instance
)
LoDTensor
(
lod
,
t
);
#else
paddle
::
framework
::
LoD
new_lod
;
new_lod
.
reserve
(
lod
.
size
());
std
::
copy
(
lod
.
begin
(),
lod
.
end
(),
std
::
back_inserter
(
new_lod
));
new
(
&
instance
)
LoDTensor
(
new_lod
,
t
);
#endif
})
.
def
(
"set_tensor"
,
[](
LoDTensor
&
self
,
Tensor
*
tensor
)
{
self
.
set_tensor
(
tensor
);
})
.
def
(
"set_lod"
,
[](
LoDTensor
&
self
,
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
lod
)
{
#ifdef PADDLE_ONLY_CPU
self
.
set_lod
(
lod
);
#else
paddle
::
framework
::
LoD
new_lod
;
new_lod
.
reserve
(
lod
.
size
());
std
::
copy
(
lod
.
begin
(),
lod
.
end
(),
std
::
back_inserter
(
new_lod
));
self
.
set_lod
(
new_lod
);
#endif
})
.
def
(
"tensor"
,
[](
LoDTensor
&
self
)
->
Tensor
&
{
return
self
.
tensor
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"lod"
,
[](
LoDTensor
&
self
)
->
std
::
vector
<
std
::
vector
<
size_t
>>
{
#ifdef PADDLE_ONLY_CPU
return
self
.
lod
();
#else
auto
lod
=
self
.
lod
();
std
::
vector
<
std
::
vector
<
size_t
>>
new_lod
;
new_lod
.
reserve
(
lod
.
size
());
std
::
transform
(
lod
.
begin
(),
lod
.
end
(),
std
::
back_inserter
(
new_lod
),
[](
paddle
::
framework
::
Vector
<
size_t
>
item
)
->
std
::
vector
<
size_t
>
{
std
::
vector
<
size_t
>
v
;
v
.
reserve
(
item
.
size
());
std
::
copy
(
item
.
begin
(),
item
.
end
(),
std
::
back_inserter
(
v
));
return
v
;
});
return
new_lod
;
#endif
});
py
::
class_
<
Variable
>
(
m
,
"Variable"
,
R"DOC(Variable Class.
All parameter, weight, gradient are variables in Paddle.
...
...
@@ -128,6 +185,11 @@ All parameter, weight, gradient are variables in Paddle.
.
def
(
"get_tensor"
,
[](
Variable
&
self
)
->
Tensor
*
{
return
self
.
GetMutable
<
Tensor
>
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"get_lod_tensor"
,
[](
Variable
&
self
)
->
LoDTensor
*
{
return
self
.
GetMutable
<
LoDTensor
>
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"get_net"
,
[](
Variable
&
self
)
->
operators
::
NetOp
*
{
return
self
.
GetMutable
<
operators
::
NetOp
>
();
...
...
python/paddle/v2/framework/tests/test_tensor.py
浏览文件 @
6d0d29f6
...
...
@@ -3,7 +3,7 @@ import unittest
import
numpy
class
Test
Scope
(
unittest
.
TestCase
):
class
Test
Tensor
(
unittest
.
TestCase
):
def
test_int_tensor
(
self
):
scope
=
core
.
Scope
()
var
=
scope
.
new_var
(
"test_tensor"
)
...
...
@@ -20,8 +20,8 @@ class TestScope(unittest.TestCase):
tensor
.
set
(
tensor_array
,
place
)
tensor_array_2
=
numpy
.
array
(
tensor
)
self
.
assertEqual
(
1
.0
,
tensor_array_2
[
3
,
9
])
self
.
assertEqual
(
2
.0
,
tensor_array_2
[
19
,
11
])
self
.
assertEqual
(
1
,
tensor_array_2
[
3
,
9
])
self
.
assertEqual
(
2
,
tensor_array_2
[
19
,
11
])
def
test_float_tensor
(
self
):
scope
=
core
.
Scope
()
...
...
@@ -43,6 +43,84 @@ class TestScope(unittest.TestCase):
self
.
assertAlmostEqual
(
1.0
,
tensor_array_2
[
3
,
9
])
self
.
assertAlmostEqual
(
2.0
,
tensor_array_2
[
19
,
11
])
def
test_int_lod_tensor
(
self
):
places
=
[
core
.
CPUPlace
(),
core
.
GPUPlace
(
0
)]
for
place
in
places
:
scope
=
core
.
Scope
()
var
=
scope
.
new_var
(
"test_tensor"
)
var_lod
=
scope
.
new_var
(
"test_lod_tensor"
)
tensor
=
var
.
get_tensor
()
lod_tensor
=
var_lod
.
get_lod_tensor
()
tensor
.
set_dims
([
4
,
4
,
6
])
tensor
.
alloc_int
(
place
)
array
=
numpy
.
array
(
tensor
)
array
[
0
,
0
,
0
]
=
3
array
[
3
,
3
,
5
]
=
10
tensor
.
set
(
array
,
place
)
lod_tensor
.
set_tensor
(
tensor
)
lod_tensor
.
set_lod
([[
0
,
2
,
4
]])
lod_v
=
numpy
.
array
(
lod_tensor
.
tensor
())
self
.
assertTrue
(
numpy
.
alltrue
(
array
==
lod_v
))
lod
=
lod_tensor
.
lod
()
self
.
assertEqual
(
0
,
lod
[
0
][
0
])
self
.
assertEqual
(
2
,
lod
[
0
][
1
])
self
.
assertEqual
(
4
,
lod
[
0
][
2
])
def
test_float_lod_tensor
(
self
):
places
=
[
core
.
CPUPlace
(),
core
.
GPUPlace
(
0
)]
for
place
in
places
:
scope
=
core
.
Scope
()
var
=
scope
.
new_var
(
"test_tensor"
)
var_lod
=
scope
.
new_var
(
"test_lod_tensor"
)
tensor
=
var
.
get_tensor
()
lod_tensor
=
var_lod
.
get_lod_tensor
()
tensor
.
set_dims
([
5
,
2
,
3
,
4
])
tensor
.
alloc_float
(
place
)
tensor_array
=
numpy
.
array
(
tensor
)
self
.
assertEqual
((
5
,
2
,
3
,
4
),
tensor_array
.
shape
)
tensor_array
[
0
,
0
,
0
,
0
]
=
1.0
tensor_array
[
0
,
0
,
0
,
1
]
=
2.0
tensor
.
set
(
tensor_array
,
place
)
lod_tensor
.
set_tensor
(
tensor
)
lod_v
=
numpy
.
array
(
lod_tensor
.
tensor
())
self
.
assertAlmostEqual
(
1.0
,
lod_v
[
0
,
0
,
0
,
0
])
self
.
assertAlmostEqual
(
2.0
,
lod_v
[
0
,
0
,
0
,
1
])
self
.
assertEqual
(
len
(
lod_tensor
.
lod
()),
0
)
lod_py
=
[[
0
,
2
,
5
],
[
0
,
2
,
4
,
5
]]
lod_tensor
.
set_lod
(
lod_py
)
lod
=
lod_tensor
.
lod
()
self
.
assertListEqual
(
lod_py
,
lod
)
def
test_lod_tensor_init
(
self
):
scope
=
core
.
Scope
()
var
=
scope
.
new_var
(
"test_tensor"
)
place
=
core
.
CPUPlace
()
tensor
=
var
.
get_tensor
()
tensor
.
set_dims
([
5
,
2
,
3
,
4
])
tensor
.
alloc_float
(
place
)
tensor_array
=
numpy
.
array
(
tensor
)
tensor_array
[
0
,
0
,
0
,
0
]
=
1.0
tensor_array
[
0
,
0
,
0
,
1
]
=
2.0
tensor
.
set
(
tensor_array
,
place
)
lod_py
=
[[
0
,
2
,
5
],
[
0
,
2
,
4
,
5
]]
lod_tensor
=
core
.
LoDTensor
(
lod_py
,
tensor
)
lod_v
=
numpy
.
array
(
lod_tensor
.
tensor
())
self
.
assertAlmostEqual
(
1.0
,
lod_v
[
0
,
0
,
0
,
0
])
self
.
assertAlmostEqual
(
2.0
,
lod_v
[
0
,
0
,
0
,
1
])
self
.
assertListEqual
(
lod_py
,
lod_tensor
.
lod
())
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录