Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3cea00bd
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3cea00bd
编写于
3月 12, 2019
作者:
X
xujiaqi01
提交者:
dongdaxiang
3月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
store memory data in Dataset && fix bug
上级
ff87698a
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
356 addition
and
70 deletion
+356
-70
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+110
-21
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+38
-5
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+78
-24
paddle/fluid/framework/data_set.h
paddle/fluid/framework/data_set.h
+41
-7
paddle/fluid/framework/fleet/fleet_wrapper.cc
paddle/fluid/framework/fleet/fleet_wrapper.cc
+62
-0
paddle/fluid/framework/fleet/fleet_wrapper.h
paddle/fluid/framework/fleet/fleet_wrapper.h
+14
-0
paddle/fluid/pybind/data_set_py.cc
paddle/fluid/pybind/data_set_py.cc
+9
-9
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+2
-2
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+2
-2
未找到文件。
paddle/fluid/framework/data_feed.cc
浏览文件 @
3cea00bd
...
@@ -68,8 +68,10 @@ void DataFeed::SetBatchSize(int batch_size) {
...
@@ -68,8 +68,10 @@ void DataFeed::SetBatchSize(int batch_size) {
bool
DataFeed
::
PickOneFile
(
std
::
string
*
filename
)
{
bool
DataFeed
::
PickOneFile
(
std
::
string
*
filename
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
if
(
file_idx_
==
filelist_
.
size
())
{
if
(
file_idx_
==
filelist_
.
size
())
{
VLOG
(
3
)
<<
"DataFeed::PickOneFile no more file to pick"
;
return
false
;
return
false
;
}
}
VLOG
(
3
)
<<
"file_idx_="
<<
file_idx_
;
*
filename
=
filelist_
[
file_idx_
++
];
*
filename
=
filelist_
[
file_idx_
++
];
// LOG(ERROR) << "pick file:" << *filename;
// LOG(ERROR) << "pick file:" << *filename;
return
true
;
return
true
;
...
@@ -146,17 +148,18 @@ template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
...
@@ -146,17 +148,18 @@ template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
template
<
typename
T
>
template
<
typename
T
>
InMemoryDataFeed
<
T
>::
InMemoryDataFeed
()
{
InMemoryDataFeed
<
T
>::
InMemoryDataFeed
()
{
cur_channel_
=
0
;
cur_channel_
=
0
;
shuffled_ins_
=
nullptr
;
shuffled_ins_
=
std
::
make_shared
<
paddle
::
framework
::
BlockingQueue
<
T
>>
();
shuffled_ins_out_
=
nullptr
;
shuffled_ins_out_
=
std
::
make_shared
<
paddle
::
framework
::
BlockingQueue
<
T
>>
();
fleet_send_batch_size_
=
10000
;
}
}
template
<
typename
T
>
template
<
typename
T
>
bool
InMemoryDataFeed
<
T
>::
Start
()
{
bool
InMemoryDataFeed
<
T
>::
Start
()
{
DataFeed
::
CheckSetFileList
();
DataFeed
::
CheckSetFileList
();
if
(
memory_data_
.
size
()
!
=
0
)
{
if
(
shuffled_ins_
->
Size
()
==
0
&&
shuffled_ins_out_
->
Size
()
=
=
0
)
{
CHECK_EQ
(
cur_channel_
,
0
);
FillMemoryDataToChannel
(
);
shuffled_ins_
->
Extend
(
std
::
move
(
memory_data_
)
);
//std::unique_lock<std::mutex> lock(*mutex_for_update_memory_data_
);
std
::
vector
<
T
>
().
swap
(
memory_data_
);
//
std::vector<T>().swap(memory_data_);
}
}
DataFeed
::
finish_start_
=
true
;
DataFeed
::
finish_start_
=
true
;
return
true
;
return
true
;
...
@@ -196,6 +199,31 @@ int InMemoryDataFeed<T>::Next() {
...
@@ -196,6 +199,31 @@ int InMemoryDataFeed<T>::Next() {
return
DataFeed
::
batch_size_
;
return
DataFeed
::
batch_size_
;
}
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetMemoryData
(
void
*
memory_data
)
{
memory_data_
=
static_cast
<
std
::
vector
<
T
>*>
(
memory_data
);
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetMemoryDataMutex
(
std
::
mutex
*
mutex
)
{
mutex_for_update_memory_data_
=
mutex
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetThreadId
(
int
thread_id
)
{
thread_id_
=
thread_id
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetThreadNum
(
int
thread_num
)
{
thread_num_
=
thread_num
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
template
<
typename
T
>
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
void
InMemoryDataFeed
<
T
>::
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
T
ins
;
T
ins
;
...
@@ -203,11 +231,54 @@ void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
...
@@ -203,11 +231,54 @@ void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
shuffled_ins_
->
Push
(
std
::
move
(
ins
));
shuffled_ins_
->
Push
(
std
::
move
(
ins
));
}
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
FillMemoryDataToChannel
()
{
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::FillMemoryDataToChannel, thread_id="
<<
thread_id_
;
int64_t
start
=
0
;
int64_t
end
=
0
;
int64_t
size
=
memory_data_
->
size
();
VLOG
(
3
)
<<
"memory_data size="
<<
size
;
for
(
int64_t
i
=
0
;
i
<=
static_cast
<
int64_t
>
(
thread_id_
);
++
i
)
{
int64_t
len
=
size
/
static_cast
<
int64_t
>
(
thread_num_
)
+
(
i
<
(
size
%
static_cast
<
int64_t
>
(
thread_num_
)));
start
=
end
;
end
+=
len
;
}
for
(
int64_t
i
=
start
;
i
<
end
;
++
i
)
{
T
&
t
=
(
*
memory_data_
)[
i
];
shuffled_ins_
->
Push
(
std
::
move
(
t
));
}
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
FillChannelToMemoryData
()
{
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::FillChannelToMemoryData, thread_id="
<<
thread_id_
;
std
::
vector
<
T
>
local_vec
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
channel
=
nullptr
;
if
(
cur_channel_
==
0
)
{
channel
=
shuffled_ins_
;
}
else
{
channel
=
shuffled_ins_out_
;
}
CHECK
(
channel
!=
nullptr
);
local_vec
.
reserve
(
channel
->
Size
());
for
(
int64_t
i
=
0
;
i
<
channel
->
Size
();
++
i
)
{
channel
->
Pop
(
local_vec
[
i
]);
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
*
mutex_for_update_memory_data_
);
lock
.
lock
();
memory_data_
->
insert
(
memory_data_
->
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
lock
.
unlock
();
std
::
vector
<
T
>
().
swap
(
local_vec
);
}
template
<
typename
T
>
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
LoadIntoMemory
()
{
void
InMemoryDataFeed
<
T
>::
LoadIntoMemory
()
{
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::LoadIntoMemory() begin, thread_id="
<<
thread_id_
;
std
::
vector
<
T
>
local_vec
;
std
::
vector
<
T
>
local_vec
;
std
::
string
filename
;
std
::
string
filename
;
while
(
DataFeed
::
PickOneFile
(
&
filename
))
{
while
(
DataFeed
::
PickOneFile
(
&
filename
))
{
VLOG
(
3
)
<<
"PickOneFile, filename="
<<
filename
<<
", thread_id="
<<
thread_id_
;
int
err_no
=
0
;
int
err_no
=
0
;
PrivateQueueDataFeed
<
T
>::
fp_
=
PrivateQueueDataFeed
<
T
>::
fp_
=
fs_open_read
(
filename
,
&
err_no
,
PrivateQueueDataFeed
<
T
>::
pipe_command_
);
fs_open_read
(
filename
,
&
err_no
,
PrivateQueueDataFeed
<
T
>::
pipe_command_
);
...
@@ -216,35 +287,50 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
...
@@ -216,35 +287,50 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
local_vec
.
push_back
(
instance
);
local_vec
.
push_back
(
instance
);
}
}
memory_data_
.
insert
(
memory_data_
.
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::LoadIntoMemory() read all lines, thread_id="
<<
thread_id_
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
mutex_for_update_memory_data_
);
memory_data_
->
insert
(
memory_data_
->
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
}
std
::
vector
<
T
>
().
swap
(
local_vec
);
std
::
vector
<
T
>
().
swap
(
local_vec
);
}
}
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::LoadIntoMemory() end, thread_id="
<<
thread_id_
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
LocalShuffle
()
{
void
InMemoryDataFeed
<
T
>::
LocalShuffle
()
{
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::LocalShuffle() begin, thread_id="
<<
thread_id_
;
FillMemoryDataToChannel
();
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::LocalShuffle() end, thread_id="
<<
thread_id_
;
}
}
// todo global shuffle
/*
template
<
typename
T
>
template
<
typename
T
>
void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
void
InMemoryDataFeed
<
T
>::
GlobalShuffle
()
{
std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
for (int64_t i = 0; i < memory_data_.size(); ++i) {
std
::
vector
<
std
::
string
>
send_str_vec
(
trainer_num_
);
for
(
int64_t
i
=
0
;
i
<
memory_data_
->
size
();
++
i
)
{
// todo get ins id
// todo get ins id
//std::string ins_id = memory_data_[i].ins_id;
//std::string ins_id = memory_data_[i].ins_id;
// todo hash
// todo hash
int64_t hash_id = paddle::ps::local_random_engine()();
//
int64_t hash_id = paddle::ps::local_random_engine()();
//int64_t hash_id = hash(ins_id)
;
int64_t
hash_id
=
0
;
int64_t
node_id
=
hash_id
%
trainer_num_
;
int64_t
node_id
=
hash_id
%
trainer_num_
;
std
::
string
str
;
std
::
string
str
;
SerializeIns(memory_data_[i], str);
SerializeIns
((
*
memory_data_
)[
i
],
str
);
auto fleet_ptr = FleetWrapper::GetInstance();
send_str_vec
[
node_id
]
+=
str
;
auto ret = fleet_ptr->send_client2client_msg(0, node_id, str);
if
(
i
%
fleet_send_batch_size_
==
0
&&
i
!=
0
)
{
for
(
int
j
=
0
;
j
<
send_str_vec
.
size
();
++
j
)
{
fleet_ptr
->
send_client2client_msg
(
0
,
j
,
send_str_vec
[
j
]);
send_str_vec
[
j
]
=
""
;
}
}
}
for
(
int
j
=
0
;
j
<
send_str_vec
.
size
();
++
j
)
{
if
(
send_str_vec
[
j
].
length
()
!=
0
)
{
fleet_ptr
->
send_client2client_msg
(
0
,
j
,
send_str_vec
[
j
]);
}
}
}
}
}
*/
// explicit instantiation
// explicit instantiation
template
class
InMemoryDataFeed
<
std
::
vector
<
MultiSlotType
>
>
;
template
class
InMemoryDataFeed
<
std
::
vector
<
MultiSlotType
>
>
;
...
@@ -646,6 +732,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(
...
@@ -646,6 +732,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(
if
(
getline
(
file_
,
line
))
{
if
(
getline
(
file_
,
line
))
{
int
use_slots_num
=
use_slots_
.
size
();
int
use_slots_num
=
use_slots_
.
size
();
instance
->
resize
(
use_slots_num
);
instance
->
resize
(
use_slots_num
);
VLOG
(
3
)
<<
line
;
// parse line
// parse line
const
char
*
str
=
line
.
c_str
();
const
char
*
str
=
line
.
c_str
();
char
*
endptr
=
const_cast
<
char
*>
(
str
);
char
*
endptr
=
const_cast
<
char
*>
(
str
);
...
@@ -735,12 +822,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
...
@@ -735,12 +822,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle
// todo serialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
SerializeIns
(
void
MultiSlotInMemoryDataFeed
::
SerializeIns
(
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
&
str
)
{
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
&
str
)
{
return
;
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
Serialize
(
ins
,
str
);
}
}
// todo deserialize ins in global shuffle
// todo deserialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
std
::
vector
<
MultiSlotType
>&
ins
,
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
std
::
vector
<
MultiSlotType
>&
ins
,
const
std
::
string
&
str
)
{
const
std
::
string
&
str
)
{
return
;
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
Deserialize
(
ins
,
str
);
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
3cea00bd
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include <string>
#include <string>
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include <vector>
#include <sstream>
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
@@ -78,17 +79,33 @@ class DataFeed {
...
@@ -78,17 +79,33 @@ class DataFeed {
// This function is used for binding feed_vec memory
// This function is used for binding feed_vec memory
virtual
void
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
);
virtual
void
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
);
// This function will do nothing at default
virtual
void
SetMemoryData
(
void
*
memory_data
)
{
}
// This function will do nothing at default
virtual
void
SetMemoryDataMutex
(
std
::
mutex
*
mutex
)
{
}
// This function will do nothing at default
virtual
void
SetThreadId
(
int
thread_id
)
{
}
// This function will do nothing at default
virtual
void
SetThreadNum
(
int
thread_num
)
{
}
// This function will do nothing at default
virtual
void
SetTrainerNum
(
int
trainer_num
)
{
}
virtual
void
LoadIntoMemory
()
{
virtual
void
LoadIntoMemory
()
{
PADDLE_THROW
(
"This function(LoadIntoMemory) is not implemented."
);
PADDLE_THROW
(
"This function(LoadIntoMemory) is not implemented."
);
}
}
virtual
void
LocalShuffle
()
{
virtual
void
LocalShuffle
()
{
PADDLE_THROW
(
"This function(LocalShuffle) is not implemented."
);
PADDLE_THROW
(
"This function(LocalShuffle) is not implemented."
);
}
}
virtual
void
GlobalShuffle
(
int
trainer_num
)
{
virtual
void
GlobalShuffle
()
{
PADDLE_THROW
(
"This function(GlobalShuffle) is not implemented."
);
PADDLE_THROW
(
"This function(GlobalShuffle) is not implemented."
);
}
}
virtual
void
FillMemoryDataToChannel
()
{
PADDLE_THROW
(
"This function(FillMemoryDataToChannel) is not implemented."
);
}
virtual
void
FillChannelToMemoryData
()
{
PADDLE_THROW
(
"This function(FillChannelToMemoryData) is not implemented."
);
}
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
PADDLE_THROW
(
"This function(PutToChannel) is not implemented."
);
PADDLE_THROW
(
"This function(Put
Ins
ToChannel) is not implemented."
);
}
}
protected:
protected:
...
@@ -181,13 +198,20 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
...
@@ -181,13 +198,20 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public:
public:
InMemoryDataFeed
();
InMemoryDataFeed
();
virtual
~
InMemoryDataFeed
()
{}
virtual
~
InMemoryDataFeed
()
{}
virtual
void
Init
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
)
=
0
;
virtual
bool
Start
();
virtual
bool
Start
();
virtual
int
Next
();
virtual
int
Next
();
virtual
void
SetMemoryData
(
void
*
memory_data
);
virtual
void
SetMemoryDataMutex
(
std
::
mutex
*
mutex
);
virtual
void
SetThreadId
(
int
thread_id
);
virtual
void
SetThreadNum
(
int
thread_num
);
virtual
void
SetTrainerNum
(
int
trainer_num
);
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
);
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
);
virtual
void
FillMemoryDataToChannel
();
virtual
void
FillChannelToMemoryData
();
virtual
void
LoadIntoMemory
();
virtual
void
LoadIntoMemory
();
virtual
void
LocalShuffle
();
virtual
void
LocalShuffle
();
// todo global shuffle
virtual
void
GlobalShuffle
();
//virtual void GlobalShuffle(int trainer_num);
protected:
protected:
virtual
void
AddInstanceToInsVec
(
T
*
vec_ins
,
const
T
&
instance
,
int
index
)
=
0
;
virtual
void
AddInstanceToInsVec
(
T
*
vec_ins
,
const
T
&
instance
,
int
index
)
=
0
;
virtual
bool
ParseOneInstance
(
T
*
instance
)
=
0
;
virtual
bool
ParseOneInstance
(
T
*
instance
)
=
0
;
...
@@ -196,13 +220,18 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
...
@@ -196,13 +220,18 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual
void
SerializeIns
(
const
T
&
ins
,
std
::
string
&
str
)
=
0
;
virtual
void
SerializeIns
(
const
T
&
ins
,
std
::
string
&
str
)
=
0
;
virtual
void
DeserializeIns
(
T
&
ins
,
const
std
::
string
&
str
)
=
0
;
virtual
void
DeserializeIns
(
T
&
ins
,
const
std
::
string
&
str
)
=
0
;
std
::
vector
<
T
>
memory_data_
;
int
thread_id_
;
int
thread_num_
;
int
trainer_num_
;
std
::
vector
<
T
>*
memory_data_
;
std
::
mutex
*
mutex_for_update_memory_data_
;
// when read ins, we put ins from one channel to the other,
// when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
int
cur_channel_
;
int
cur_channel_
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
shuffled_ins_
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
shuffled_ins_
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
shuffled_ins_out_
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
shuffled_ins_out_
;
int64_t
fleet_send_batch_size_
;
};
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
...
@@ -226,6 +255,7 @@ class MultiSlotType {
...
@@ -226,6 +255,7 @@ class MultiSlotType {
offset_
[
0
]
=
0
;
offset_
[
0
]
=
0
;
}
}
const
std
::
vector
<
size_t
>&
GetOffset
()
const
{
return
offset_
;
}
const
std
::
vector
<
size_t
>&
GetOffset
()
const
{
return
offset_
;
}
std
::
vector
<
size_t
>&
MutableOffset
()
{
return
offset_
;
}
void
AddValue
(
const
float
v
)
{
void
AddValue
(
const
float
v
)
{
CheckFloat
();
CheckFloat
();
float_feasign_
.
push_back
(
v
);
float_feasign_
.
push_back
(
v
);
...
@@ -248,8 +278,11 @@ class MultiSlotType {
...
@@ -248,8 +278,11 @@ class MultiSlotType {
}
}
}
}
const
std
::
vector
<
float
>&
GetFloatData
()
const
{
return
float_feasign_
;
}
const
std
::
vector
<
float
>&
GetFloatData
()
const
{
return
float_feasign_
;
}
std
::
vector
<
float
>&
MutableFloatData
()
{
return
float_feasign_
;
}
const
std
::
vector
<
uint64_t
>&
GetUint64Data
()
const
{
return
uint64_feasign_
;
}
const
std
::
vector
<
uint64_t
>&
GetUint64Data
()
const
{
return
uint64_feasign_
;
}
std
::
vector
<
uint64_t
>&
MutableUint64Data
()
{
return
uint64_feasign_
;
}
const
std
::
string
&
GetType
()
const
{
return
type_
;
}
const
std
::
string
&
GetType
()
const
{
return
type_
;
}
std
::
string
&
MutableType
()
{
return
type_
;
}
private:
private:
void
CheckType
(
const
std
::
string
&
type
)
const
{
void
CheckType
(
const
std
::
string
&
type
)
const
{
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
3cea00bd
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License. */
* limitations under the License. */
#include <random>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/data_set.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/message.h"
...
@@ -21,23 +22,27 @@
...
@@ -21,23 +22,27 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
Dataset
::
Dataset
()
{
thread_num_
=
1
;
}
template
<
typename
T
>
DatasetImpl
<
T
>::
DatasetImpl
()
{
thread_num_
=
1
;
}
void
Dataset
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
{
VLOG
(
3
)
<<
"filelist size: "
<<
filelist
.
size
();
VLOG
(
3
)
<<
"filelist size: "
<<
filelist
.
size
();
filelist_
=
filelist
;
filelist_
=
filelist
;
/*
int file_cnt = filelist_.size();
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
if (thread_num_ > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num_
VLOG(1) << "DataSet thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
<< ". Changing DataSet thread num = " << file_cnt;
thread_num_ = file_cnt;
thread_num_ = file_cnt;
}
}
*/
}
}
// buggy here, a user should set filelist first before this function
// buggy here, a user should set filelist first before this function
// not user friendly
// not user friendly
void
Dataset
::
SetThreadNum
(
int
thread_num
)
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetThreadNum
(
int
thread_num
)
{
int
file_cnt
=
filelist_
.
size
();
int
file_cnt
=
filelist_
.
size
();
if
(
file_cnt
!=
0
&&
thread_num
>
file_cnt
)
{
if
(
file_cnt
!=
0
&&
thread_num
>
file_cnt
)
{
VLOG
(
1
)
<<
"DataSet thread num = "
<<
thread_num
VLOG
(
1
)
<<
"DataSet thread num = "
<<
thread_num
...
@@ -48,19 +53,24 @@ void Dataset::SetThreadNum(int thread_num) {
...
@@ -48,19 +53,24 @@ void Dataset::SetThreadNum(int thread_num) {
thread_num_
=
thread_num
;
thread_num_
=
thread_num
;
}
}
void
Dataset
::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
void
Dataset
::
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
{
google
::
protobuf
::
TextFormat
::
ParseFromString
(
data_feed_desc_str
,
google
::
protobuf
::
TextFormat
::
ParseFromString
(
data_feed_desc_str
,
&
data_feed_desc_
);
&
data_feed_desc_
);
}
}
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
template
<
typename
T
>
Dataset
::
GetReaders
()
{
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
DatasetImpl
<
T
>::
GetReaders
()
{
return
readers_
;
return
readers_
;
}
}
void
Dataset
::
LoadIntoMemory
()
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
LoadIntoMemory
()
{
VLOG
(
3
)
<<
"DatasetImpl<T>::LoadIntoMemory() begin"
;
if
(
readers_
.
size
()
==
0
)
{
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
CreateReaders
();
}
}
...
@@ -72,12 +82,18 @@ void Dataset::LoadIntoMemory() {
...
@@ -72,12 +82,18 @@ void Dataset::LoadIntoMemory() {
for
(
std
::
thread
&
t
:
load_threads
)
{
for
(
std
::
thread
&
t
:
load_threads
)
{
t
.
join
();
t
.
join
();
}
}
VLOG
(
3
)
<<
"DatasetImpl<T>::LoadIntoMemory() end"
;
}
}
void
Dataset
::
LocalShuffle
()
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
LocalShuffle
()
{
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() begin"
;
if
(
readers_
.
size
()
==
0
)
{
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
CreateReaders
();
}
}
// if it is not InMemory, memory_data_ is empty
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
std
::
vector
<
std
::
thread
>
local_shuffle_threads
;
std
::
vector
<
std
::
thread
>
local_shuffle_threads
;
for
(
int64_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
local_shuffle_threads
.
push_back
(
std
::
thread
(
local_shuffle_threads
.
push_back
(
std
::
thread
(
...
@@ -86,30 +102,37 @@ void Dataset::LocalShuffle() {
...
@@ -86,30 +102,37 @@ void Dataset::LocalShuffle() {
for
(
std
::
thread
&
t
:
local_shuffle_threads
)
{
for
(
std
::
thread
&
t
:
local_shuffle_threads
)
{
t
.
join
();
t
.
join
();
}
}
std
::
vector
<
T
>
().
swap
(
memory_data_
);
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() end"
;
}
}
// todo global shuffle
template
<
typename
T
>
void
Dataset
::
GlobalShuffle
()
{
void
DatasetImpl
<
T
>::
GlobalShuffle
()
{
/*
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() begin"
;
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
}
// if it is not InMemory, memory_data_ is empty
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
registe_client2client_msg_handler
(
0
,
fleet_ptr
->
registe_client2client_msg_handler
(
0
,
[
this
](
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
->
int
{
[
this
](
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
->
int
{
return
this
->
ReceiveFromClient
(
msg_type
,
client_id
,
msg
);
return
this
->
ReceiveFromClient
(
msg_type
,
client_id
,
msg
);
});
});
if (readers_.size() == 0) {
CreateReaders();
}
std
::
vector
<
std
::
thread
>
global_shuffle_threads
;
std
::
vector
<
std
::
thread
>
global_shuffle_threads
;
for (int64_t i = 0; i < thread_num_; ++i) {
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
global_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::GlobalShuffle,
global_shuffle_threads
.
push_back
(
readers_[i].get(), trainer_num_));
std
::
thread
(
&
paddle
::
framework
::
DataFeed
::
GlobalShuffle
,
readers_
[
i
].
get
()));
}
}
for
(
std
::
thread
&
t
:
global_shuffle_threads
)
{
for
(
std
::
thread
&
t
:
global_shuffle_threads
)
{
t
.
join
();
t
.
join
();
}*/
}
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() end"
;
}
}
void
Dataset
::
CreateReaders
()
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
CreateReaders
()
{
VLOG
(
3
)
<<
"Calling CreateReaders()"
;
VLOG
(
3
)
<<
"Calling CreateReaders()"
;
CHECK
(
thread_num_
>
0
)
<<
"thread_num should > 0"
;
CHECK
(
thread_num_
>
0
)
<<
"thread_num should > 0"
;
VLOG
(
3
)
<<
"thread_num in Readers: "
<<
thread_num_
;
VLOG
(
3
)
<<
"thread_num in Readers: "
<<
thread_num_
;
...
@@ -118,22 +141,53 @@ void Dataset::CreateReaders() {
...
@@ -118,22 +141,53 @@ void Dataset::CreateReaders() {
return
;
return
;
}
}
VLOG
(
3
)
<<
"data feed class name: "
<<
data_feed_desc_
.
name
();
VLOG
(
3
)
<<
"data feed class name: "
<<
data_feed_desc_
.
name
();
for
(
int
64_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
readers_
.
push_back
(
DataFeedFactory
::
CreateDataFeed
(
data_feed_desc_
.
name
()));
readers_
.
push_back
(
DataFeedFactory
::
CreateDataFeed
(
data_feed_desc_
.
name
()));
readers_
.
back
()
->
Init
(
data_feed_desc_
);
readers_
.
back
()
->
Init
(
data_feed_desc_
);
readers_
.
back
()
->
SetMemoryData
(
&
memory_data_
);
readers_
.
back
()
->
SetMemoryDataMutex
(
&
mutex_for_update_memory_data_
);
readers_
.
back
()
->
SetThreadId
(
i
);
readers_
.
back
()
->
SetThreadNum
(
thread_num_
);
readers_
.
back
()
->
SetTrainerNum
(
trainer_num_
);
}
}
VLOG
(
3
)
<<
"Filelist size in readers: "
<<
filelist_
.
size
();
VLOG
(
3
)
<<
"Filelist size in readers: "
<<
filelist_
.
size
();
readers_
[
0
]
->
SetFileList
(
filelist_
);
readers_
[
0
]
->
SetFileList
(
filelist_
);
}
}
int
Dataset
::
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
template
<
typename
T
>
void
DatasetImpl
<
T
>::
DestroyReaders
()
{
VLOG
(
3
)
<<
"Calling DestroyReaders()"
;
// clear memory_data_ before fill it
// because if LoadIntoMemory but no Shuffle,
// memory_data_ has empty data which has been std::move to channel
if
(
memory_data_
.
size
()
!=
0
)
{
std
::
vector
<
T
>
().
swap
(
memory_data_
);
}
std
::
vector
<
std
::
thread
>
fill_threads
;
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
fill_threads
.
push_back
(
std
::
thread
(
&
paddle
::
framework
::
DataFeed
::
FillChannelToMemoryData
,
readers_
[
i
].
get
()));
}
for
(
std
::
thread
&
t
:
fill_threads
)
{
t
.
join
();
}
std
::
vector
<
std
::
string
>
().
swap
(
filelist_
);
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
().
swap
(
readers_
);
}
template
<
typename
T
>
int
DatasetImpl
<
T
>::
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
{
const
std
::
string
&
msg
)
{
//
can also use hash
//
todo random
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
int64_t
index
=
0
;
int64_t
index
=
0
;
readers_
[
index
]
->
PutInsToChannel
(
msg
);
readers_
[
index
]
->
PutInsToChannel
(
msg
);
return
0
;
return
0
;
}
}
// explicit instantiation
template
class
DatasetImpl
<
std
::
vector
<
MultiSlotType
>
>
;
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/data_set.h
浏览文件 @
3cea00bd
...
@@ -28,8 +28,33 @@ namespace framework {
...
@@ -28,8 +28,33 @@ namespace framework {
class
Dataset
{
class
Dataset
{
public:
public:
Dataset
();
Dataset
()
{};
virtual
~
Dataset
()
{}
virtual
~
Dataset
()
{};
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
=
0
;
virtual
void
SetThreadNum
(
int
thread_num
)
=
0
;
virtual
void
SetTrainerNum
(
int
trainer_num
)
=
0
;
virtual
void
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
=
0
;
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
=
0
;
virtual
int
GetThreadNum
()
=
0
;
virtual
int
GetTrainerNum
()
=
0
;
virtual
const
paddle
::
framework
::
DataFeedDesc
&
GetDataFeedDesc
()
=
0
;
virtual
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
GetReaders
()
=
0
;
virtual
void
LoadIntoMemory
()
=
0
;
virtual
void
LocalShuffle
()
=
0
;
virtual
void
GlobalShuffle
()
=
0
;
virtual
void
CreateReaders
()
=
0
;
virtual
void
DestroyReaders
()
=
0
;
protected:
virtual
int
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
=
0
;
};
template
<
typename
T
>
class
DatasetImpl
:
public
Dataset
{
public:
DatasetImpl
();
virtual
~
DatasetImpl
()
{}
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
);
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
);
virtual
void
SetThreadNum
(
int
thread_num
);
virtual
void
SetThreadNum
(
int
thread_num
);
...
@@ -43,25 +68,34 @@ class Dataset {
...
@@ -43,25 +68,34 @@ class Dataset {
return
data_feed_desc_
;
return
data_feed_desc_
;
}
}
virtual
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
virtual
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
GetReaders
();
GetReaders
();
virtual
void
LoadIntoMemory
();
virtual
void
LoadIntoMemory
();
virtual
void
LocalShuffle
();
virtual
void
LocalShuffle
();
// todo global shuffle
virtual
void
GlobalShuffle
();
virtual
void
GlobalShuffle
();
virtual
void
CreateReaders
();
virtual
void
CreateReaders
();
virtual
void
DestroyReaders
();
protected:
protected:
virtual
int
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
virtual
int
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
);
const
std
::
string
&
msg
);
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers_
;
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers_
;
std
::
vector
<
T
>
memory_data_
;
std
::
mutex
mutex_for_update_memory_data_
;
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>>
shuffled_ins_vec_
;
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>>
shuffled_ins_out_vec_
;
int
thread_num_
;
int
thread_num_
;
std
::
string
fs_name_
;
std
::
string
fs_ugi_
;
paddle
::
framework
::
DataFeedDesc
data_feed_desc_
;
paddle
::
framework
::
DataFeedDesc
data_feed_desc_
;
std
::
vector
<
std
::
string
>
filelist_
;
std
::
vector
<
std
::
string
>
filelist_
;
int
trainer_num_
;
int
trainer_num_
;
};
};
class
MultiSlotDataset
:
public
DatasetImpl
<
std
::
vector
<
MultiSlotType
>>
{
public:
MultiSlotDataset
()
{}
virtual
~
MultiSlotDataset
()
{}
};
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/fleet/fleet_wrapper.cc
浏览文件 @
3cea00bd
...
@@ -27,6 +27,7 @@ See the License for the specific language governing permissions and
...
@@ -27,6 +27,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -35,6 +36,30 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
...
@@ -35,6 +36,30 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std
::
shared_ptr
<
FleetWrapper
>
FleetWrapper
::
s_instance_
=
NULL
;
std
::
shared_ptr
<
FleetWrapper
>
FleetWrapper
::
s_instance_
=
NULL
;
bool
FleetWrapper
::
is_initialized_
=
false
;
bool
FleetWrapper
::
is_initialized_
=
false
;
#ifdef PADDLE_WITH_PSLIB
template
<
class
AR
>
paddle
::
ps
::
Archive
<
AR
>&
operator
<<
(
paddle
::
ps
::
Archive
<
AR
>&
ar
,
const
MultiSlotType
&
ins
)
{
ar
<<
ins
.
GetType
();
ar
<<
ins
.
GetOffset
();
ar
<<
ins
.
GetFloatData
();
ar
<<
ins
.
GetUint64Data
();
return
ar
;
}
template
<
class
AR
>
paddle
::
ps
::
Archive
<
AR
>&
operator
>>
(
paddle
::
ps
::
Archive
<
AR
>&
ar
,
MultiSlotType
&
ins
)
{
ar
>>
ins
.
MutableType
();
ar
>>
ins
.
MutableOffset
();
ar
>>
ins
.
MutableFloatData
();
ar
>>
ins
.
MutableUint64Data
();
return
ar
;
}
#endif
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
FleetWrapper
::
pslib_ptr_
=
NULL
;
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
FleetWrapper
::
pslib_ptr_
=
NULL
;
#endif
#endif
...
@@ -266,5 +291,42 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
...
@@ -266,5 +291,42 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
#endif
}
}
// todo registe_client2client_msg_handler
int
FleetWrapper
::
registe_client2client_msg_handler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
return
0
;
}
// todo send_client2client_msg
int
FleetWrapper
::
send_client2client_msg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
return
0
;
}
template
<
typename
T
>
void
FleetWrapper
::
Serialize
(
const
T
&
t
,
std
::
string
&
str
)
{
#ifdef PADDLE_WITH_PSLIB
paddle
::
ps
::
BinaryArchive
ar
;
ar
<<
t
;
str
=
std
::
string
(
ar
.
buffer
(),
ar
.
length
());
#else
VLOG
(
0
)
<<
"FleetWrapper::Serialize do nothing when no pslib"
;
#endif
}
template
<
typename
T
>
void
FleetWrapper
::
Deserialize
(
T
&
t
,
const
std
::
string
&
str
)
{
#ifdef PADDLE_WITH_PSLIB
paddle
::
ps
::
BinaryArchive
ar
;
ar
.
set_read_buffer
(
const_cast
<
char
*>
(
str
.
c_str
()),
str
.
length
(),
nullptr
);
t
=
ar
.
get
<
T
>
();
#else
VLOG
(
0
)
<<
"FleetWrapper::Deserialize do nothing when no pslib"
;
#endif
}
template
void
FleetWrapper
::
Serialize
<
std
::
vector
<
MultiSlotType
>
>
(
const
std
::
vector
<
MultiSlotType
>&
,
std
::
string
&
);
template
void
FleetWrapper
::
Deserialize
(
std
::
vector
<
MultiSlotType
>
&
,
const
std
::
string
&
);
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/fleet/fleet_wrapper.h
浏览文件 @
3cea00bd
...
@@ -17,7 +17,11 @@ limitations under the License. */
...
@@ -17,7 +17,11 @@ limitations under the License. */
#include <memory>
#include <memory>
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
#include <pslib.h>
#include <pslib.h>
#include <archive.h>
#endif
#endif
#include <random>
#include <atomic>
#include <time.h>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -110,6 +114,16 @@ class FleetWrapper {
...
@@ -110,6 +114,16 @@ class FleetWrapper {
uint64_t
RunServer
();
uint64_t
RunServer
();
void
GatherServers
(
const
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
);
void
GatherServers
(
const
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
);
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
int
registe_client2client_msg_handler
(
int
msg_type
,
MsgHandlerFunc
handler
);
int
send_client2client_msg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
);
std
::
default_random_engine
&
local_random_engine
();
template
<
typename
T
>
void
Serialize
(
const
T
&
t
,
std
::
string
&
str
);
template
<
typename
T
>
void
Deserialize
(
T
&
t
,
const
std
::
string
&
str
);
static
std
::
shared_ptr
<
FleetWrapper
>
GetInstance
()
{
static
std
::
shared_ptr
<
FleetWrapper
>
GetInstance
()
{
if
(
NULL
==
s_instance_
)
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
framework
::
FleetWrapper
());
s_instance_
.
reset
(
new
paddle
::
framework
::
FleetWrapper
());
...
...
paddle/fluid/pybind/data_set_py.cc
浏览文件 @
3cea00bd
...
@@ -41,17 +41,17 @@ namespace paddle {
...
@@ -41,17 +41,17 @@ namespace paddle {
namespace
pybind
{
namespace
pybind
{
void
BindDataset
(
py
::
module
*
m
)
{
void
BindDataset
(
py
::
module
*
m
)
{
py
::
class_
<
framework
::
Dataset
>
(
*
m
,
"
Dataset"
)
py
::
class_
<
framework
::
MultiSlotDataset
>
(
*
m
,
"MultiSlot
Dataset"
)
.
def
(
py
::
init
([]()
{
.
def
(
py
::
init
([]()
{
return
std
::
unique_ptr
<
framework
::
Dataset
>
(
new
framework
::
Dataset
());
return
std
::
unique_ptr
<
framework
::
MultiSlotDataset
>
(
new
framework
::
MultiSlot
Dataset
());
}))
}))
.
def
(
"set_filelist"
,
&
framework
::
Dataset
::
SetFileList
)
.
def
(
"set_filelist"
,
&
framework
::
MultiSlot
Dataset
::
SetFileList
)
.
def
(
"set_thread_num"
,
&
framework
::
Dataset
::
SetThreadNum
)
.
def
(
"set_thread_num"
,
&
framework
::
MultiSlot
Dataset
::
SetThreadNum
)
.
def
(
"set_trainer_num"
,
&
framework
::
Dataset
::
SetTrainerNum
)
.
def
(
"set_trainer_num"
,
&
framework
::
MultiSlot
Dataset
::
SetTrainerNum
)
.
def
(
"set_data_feed_desc"
,
&
framework
::
Dataset
::
SetDataFeedDesc
)
.
def
(
"set_data_feed_desc"
,
&
framework
::
MultiSlot
Dataset
::
SetDataFeedDesc
)
.
def
(
"load_into_memory"
,
&
framework
::
Dataset
::
LoadIntoMemory
)
.
def
(
"load_into_memory"
,
&
framework
::
MultiSlot
Dataset
::
LoadIntoMemory
)
.
def
(
"local_shuffle"
,
&
framework
::
Dataset
::
LocalShuffle
)
.
def
(
"local_shuffle"
,
&
framework
::
MultiSlot
Dataset
::
LocalShuffle
)
.
def
(
"global_shuffle"
,
&
framework
::
Dataset
::
GlobalShuffle
);
.
def
(
"global_shuffle"
,
&
framework
::
MultiSlot
Dataset
::
GlobalShuffle
);
}
}
}
// end namespace pybind
}
// end namespace pybind
...
...
python/paddle/fluid/__init__.py
浏览文件 @
3cea00bd
...
@@ -30,7 +30,7 @@ from .dataset import *
...
@@ -30,7 +30,7 @@ from .dataset import *
from
.
import
async_executor
from
.
import
async_executor
from
.async_executor
import
*
from
.async_executor
import
*
from
.
import
trainer
from
.
import
trainer
_desc
from
.
import
inferencer
from
.
import
inferencer
from
.
import
io
from
.
import
io
...
@@ -67,7 +67,7 @@ from . import install_check
...
@@ -67,7 +67,7 @@ from . import install_check
Tensor
=
LoDTensor
Tensor
=
LoDTensor
__all__
=
framework
.
__all__
+
executor
.
__all__
+
\
__all__
=
framework
.
__all__
+
executor
.
__all__
+
\
trainer
.
__all__
+
inferencer
.
__all__
+
transpiler
.
__all__
+
\
trainer
_desc
.
__all__
+
inferencer
.
__all__
+
transpiler
.
__all__
+
\
parallel_executor
.
__all__
+
lod_tensor
.
__all__
+
\
parallel_executor
.
__all__
+
lod_tensor
.
__all__
+
\
data_feed_desc
.
__all__
+
async_executor
.
__all__
+
compiler
.
__all__
+
[
data_feed_desc
.
__all__
+
async_executor
.
__all__
+
compiler
.
__all__
+
[
'io'
,
'io'
,
...
...
python/paddle/fluid/dataset.py
浏览文件 @
3cea00bd
...
@@ -37,7 +37,7 @@ class DatasetBase(object):
...
@@ -37,7 +37,7 @@ class DatasetBase(object):
# to decide whether we need create in memory instance
# to decide whether we need create in memory instance
self
.
proto_desc
=
data_feed_pb2
.
DataFeedDesc
()
self
.
proto_desc
=
data_feed_pb2
.
DataFeedDesc
()
self
.
proto_desc
.
pipe_command
=
"cat"
self
.
proto_desc
.
pipe_command
=
"cat"
self
.
dataset
=
core
.
Dataset
()
self
.
dataset
=
core
.
MultiSlot
Dataset
()
self
.
thread_num
=
0
self
.
thread_num
=
0
def
set_pipe_command
(
self
,
pipe_command
):
def
set_pipe_command
(
self
,
pipe_command
):
...
@@ -109,7 +109,7 @@ class InMemoryDataset(DatasetBase):
...
@@ -109,7 +109,7 @@ class InMemoryDataset(DatasetBase):
self
.
proto_desc
.
name
=
"MultiSlotInMemoryDataFeed"
self
.
proto_desc
.
name
=
"MultiSlotInMemoryDataFeed"
def
load_into_memory
(
self
):
def
load_into_memory
(
self
):
_prepare_to_run
()
self
.
_prepare_to_run
()
self
.
dataset
.
load_into_memory
()
self
.
dataset
.
load_into_memory
()
def
local_shuffle
(
self
):
def
local_shuffle
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录