Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
5f98d800
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5f98d800
编写于
12月 06, 2018
作者:
W
wangguibao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
AsyncExecutor bugfix: Tensor change to LoDTensor
上级
f6a877bc
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
18 addition
and
71 deletion
+18
-71
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+13
-30
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+1
-30
paddle/fluid/framework/data_feed_test.cc
paddle/fluid/framework/data_feed_test.cc
+4
-11
未找到文件。
paddle/fluid/framework/data_feed.cc
浏览文件 @
5f98d800
...
...
@@ -33,11 +33,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
CheckInit
();
for
(
size_t
i
=
0
;
i
<
use_slots_
.
size
();
++
i
)
{
if
(
name
==
use_slots_
[
i
])
{
if
(
use_slots_is_dense_
[
i
])
{
feed_vec_
[
i
]
=
MixTensor
(
var
->
GetMutable
<
Tensor
>
());
}
else
{
feed_vec_
[
i
]
=
MixTensor
(
var
->
GetMutable
<
LoDTensor
>
());
}
feed_vec_
[
i
]
=
var
->
GetMutable
<
LoDTensor
>
();
}
}
}
...
...
@@ -350,34 +346,21 @@ void MultiSlotDataFeed::PutToFeedVec(
int
total_instance
=
static_cast
<
int
>
(
offset
.
back
());
if
(
type
[
0
]
==
'f'
)
{
// float
const
auto
&
feasign
=
ins_vec
[
i
].
GetFloatData
();
if
(
feed_vec_
[
i
].
IsDense
())
{
int
size_in_each_batch
=
total_instance
/
batch_size_
;
float
*
tensor_ptr
=
feed_vec_
[
i
].
GetTensor
()
->
mutable_data
<
float
>
(
{
batch_size_
,
size_in_each_batch
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
float
));
}
else
{
float
*
tensor_ptr
=
feed_vec_
[
i
].
GetLoDTensor
()
->
mutable_data
<
float
>
(
{
total_instance
,
1
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
float
));
LoD
data_lod
{
offset
};
feed_vec_
[
i
].
GetLoDTensor
()
->
set_lod
(
data_lod
);
}
float
*
tensor_ptr
=
feed_vec_
[
i
]
->
mutable_data
<
float
>
(
{
total_instance
,
1
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
float
));
}
else
if
(
type
[
0
]
==
'u'
)
{
// uint64
// no uint64_t type in paddlepaddle
const
auto
&
feasign
=
ins_vec
[
i
].
GetUint64Data
();
if
(
feed_vec_
[
i
].
IsDense
())
{
int
size_in_each_batch
=
total_instance
/
batch_size_
;
int64_t
*
tensor_ptr
=
feed_vec_
[
i
].
GetTensor
()
->
mutable_data
<
int64_t
>
(
{
batch_size_
,
size_in_each_batch
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
int64_t
));
}
else
{
int64_t
*
tensor_ptr
=
feed_vec_
[
i
].
GetLoDTensor
()
->
mutable_data
<
int64_t
>
(
{
total_instance
,
1
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
int64_t
));
LoD
data_lod
{
offset
};
feed_vec_
[
i
].
GetLoDTensor
()
->
set_lod
(
data_lod
);
}
int64_t
*
tensor_ptr
=
feed_vec_
[
i
]
->
mutable_data
<
int64_t
>
(
{
total_instance
,
1
},
platform
::
CPUPlace
());
memcpy
(
tensor_ptr
,
&
feasign
[
0
],
total_instance
*
sizeof
(
int64_t
));
}
LoD
data_lod
{
offset
};
feed_vec_
[
i
]
->
set_lod
(
data_lod
);
if
(
use_slots_is_dense_
[
i
])
{
int
dim
=
total_instance
/
batch_size_
;
feed_vec_
[
i
]
->
Resize
({
batch_size_
,
dim
});
}
}
}
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
5f98d800
...
...
@@ -30,35 +30,6 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
// Pack Tensor type and LoDTensor type into MixTensor type, in order
// to record either Tensor or LoDTensor information at the same time.
class
MixTensor
{
public:
MixTensor
()
{}
explicit
MixTensor
(
LoDTensor
*
lodtensor
)
{
is_dense_
=
false
;
lodtensor_
=
lodtensor
;
}
explicit
MixTensor
(
Tensor
*
tensor
)
{
is_dense_
=
true
;
tensor_
=
tensor
;
}
bool
IsDense
()
{
return
is_dense_
;
}
LoDTensor
*
GetLoDTensor
()
{
PADDLE_ENFORCE
(
!
is_dense_
,
"Let a dense var return a LoDTensor ptr."
);
return
lodtensor_
;
}
Tensor
*
GetTensor
()
{
PADDLE_ENFORCE
(
is_dense_
,
"Let a sparse var return a Tensor ptr."
);
return
tensor_
;
}
private:
bool
is_dense_
;
LoDTensor
*
lodtensor_
;
Tensor
*
tensor_
;
};
// DataFeed is the base virtual class for all ohther DataFeeds.
// It is used to read files and parse the data for subsequent trainer.
// Example:
...
...
@@ -133,7 +104,7 @@ class DataFeed {
use_slots_index_
;
// -1: not used; >=0: the index of use_slots_
// The data read by DataFeed will be stored here
std
::
vector
<
MixTensor
>
feed_vec_
;
std
::
vector
<
LoDTensor
*
>
feed_vec_
;
// the batch size defined by user
int
default_batch_size_
;
...
...
paddle/fluid/framework/data_feed_test.cc
浏览文件 @
5f98d800
...
...
@@ -152,19 +152,13 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
const
auto
&
multi_slot_desc
=
data_feed_desc
.
multi_slot_desc
();
std
::
map
<
std
::
string
,
const
paddle
::
framework
::
LoDTensor
*>
lodtensor_targets
;
std
::
map
<
std
::
string
,
const
paddle
::
framework
::
Tensor
*>
tensor_targets
;
for
(
int
i
=
0
;
i
<
multi_slot_desc
.
slots_size
();
++
i
)
{
const
auto
&
slot
=
multi_slot_desc
.
slots
(
i
);
if
(
slot
.
is_used
())
{
const
auto
&
name
=
slot
.
name
();
readers
[
idx
]
->
AddFeedVar
(
scope
->
Var
(
name
),
name
);
if
(
slot
.
is_dense
())
{
tensor_targets
[
name
]
=
&
scope
->
FindVar
(
name
)
->
Get
<
paddle
::
framework
::
Tensor
>
();
}
else
{
lodtensor_targets
[
name
]
=
&
scope
->
FindVar
(
name
)
->
Get
<
paddle
::
framework
::
LoDTensor
>
();
}
lodtensor_targets
[
name
]
=
&
scope
->
FindVar
(
name
)
->
Get
<
paddle
::
framework
::
LoDTensor
>
();
}
}
readers
[
idx
]
->
Start
();
...
...
@@ -175,8 +169,9 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
if
(
!
slot
.
is_used
())
{
continue
;
}
const
paddle
::
framework
::
LoDTensor
*
tens
=
lodtensor_targets
[
slot
.
name
()];
if
(
slot
.
is_dense
())
{
// dense branch
const
paddle
::
framework
::
Tensor
*
tens
=
tensor_targets
[
slot
.
name
()];
if
(
slot
.
type
()
==
"uint64"
)
{
const
int64_t
*
data
=
tens
->
data
<
int64_t
>
();
int
batch_size
=
tens
->
dims
()[
0
];
...
...
@@ -202,8 +197,6 @@ void GetElemSetFromReader(std::vector<MultiTypeSet>* reader_elem_set,
PADDLE_THROW
(
"Error type in proto file."
);
}
}
else
{
// sparse branch
const
paddle
::
framework
::
LoDTensor
*
tens
=
lodtensor_targets
[
slot
.
name
()];
if
(
slot
.
type
()
==
"uint64"
)
{
const
int64_t
*
data
=
tens
->
data
<
int64_t
>
();
for
(
size_t
i
=
0
;
i
<
tens
->
NumElements
();
++
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录