Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
3cea00bd
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3cea00bd
编写于
3月 12, 2019
作者:
X
xujiaqi01
提交者:
dongdaxiang
3月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
store memory data in Dataset && fix bug
上级
ff87698a
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
356 addition
and
70 deletion
+356
-70
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+110
-21
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+38
-5
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+78
-24
paddle/fluid/framework/data_set.h
paddle/fluid/framework/data_set.h
+41
-7
paddle/fluid/framework/fleet/fleet_wrapper.cc
paddle/fluid/framework/fleet/fleet_wrapper.cc
+62
-0
paddle/fluid/framework/fleet/fleet_wrapper.h
paddle/fluid/framework/fleet/fleet_wrapper.h
+14
-0
paddle/fluid/pybind/data_set_py.cc
paddle/fluid/pybind/data_set_py.cc
+9
-9
python/paddle/fluid/__init__.py
python/paddle/fluid/__init__.py
+2
-2
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+2
-2
未找到文件。
paddle/fluid/framework/data_feed.cc
浏览文件 @
3cea00bd
...
@@ -68,8 +68,10 @@ void DataFeed::SetBatchSize(int batch_size) {
...
@@ -68,8 +68,10 @@ void DataFeed::SetBatchSize(int batch_size) {
bool
DataFeed
::
PickOneFile
(
std
::
string
*
filename
)
{
bool
DataFeed
::
PickOneFile
(
std
::
string
*
filename
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
if
(
file_idx_
==
filelist_
.
size
())
{
if
(
file_idx_
==
filelist_
.
size
())
{
VLOG
(
3
)
<<
"DataFeed::PickOneFile no more file to pick"
;
return
false
;
return
false
;
}
}
VLOG
(
3
)
<<
"file_idx_="
<<
file_idx_
;
*
filename
=
filelist_
[
file_idx_
++
];
*
filename
=
filelist_
[
file_idx_
++
];
// LOG(ERROR) << "pick file:" << *filename;
// LOG(ERROR) << "pick file:" << *filename;
return
true
;
return
true
;
...
@@ -146,17 +148,18 @@ template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
...
@@ -146,17 +148,18 @@ template class PrivateQueueDataFeed<std::vector<MultiSlotType>>;
template
<
typename
T
>
template
<
typename
T
>
InMemoryDataFeed
<
T
>::
InMemoryDataFeed
()
{
InMemoryDataFeed
<
T
>::
InMemoryDataFeed
()
{
cur_channel_
=
0
;
cur_channel_
=
0
;
shuffled_ins_
=
nullptr
;
shuffled_ins_
=
std
::
make_shared
<
paddle
::
framework
::
BlockingQueue
<
T
>>
();
shuffled_ins_out_
=
nullptr
;
shuffled_ins_out_
=
std
::
make_shared
<
paddle
::
framework
::
BlockingQueue
<
T
>>
();
fleet_send_batch_size_
=
10000
;
}
}
template
<
typename
T
>
template
<
typename
T
>
bool
InMemoryDataFeed
<
T
>::
Start
()
{
bool
InMemoryDataFeed
<
T
>::
Start
()
{
DataFeed
::
CheckSetFileList
();
DataFeed
::
CheckSetFileList
();
if
(
memory_data_
.
size
()
!
=
0
)
{
if
(
shuffled_ins_
->
Size
()
==
0
&&
shuffled_ins_out_
->
Size
()
=
=
0
)
{
CHECK_EQ
(
cur_channel_
,
0
);
FillMemoryDataToChannel
(
);
shuffled_ins_
->
Extend
(
std
::
move
(
memory_data_
)
);
//std::unique_lock<std::mutex> lock(*mutex_for_update_memory_data_
);
std
::
vector
<
T
>
().
swap
(
memory_data_
);
//
std::vector<T>().swap(memory_data_);
}
}
DataFeed
::
finish_start_
=
true
;
DataFeed
::
finish_start_
=
true
;
return
true
;
return
true
;
...
@@ -196,6 +199,31 @@ int InMemoryDataFeed<T>::Next() {
...
@@ -196,6 +199,31 @@ int InMemoryDataFeed<T>::Next() {
return
DataFeed
::
batch_size_
;
return
DataFeed
::
batch_size_
;
}
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetMemoryData
(
void
*
memory_data
)
{
memory_data_
=
static_cast
<
std
::
vector
<
T
>*>
(
memory_data
);
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetMemoryDataMutex
(
std
::
mutex
*
mutex
)
{
mutex_for_update_memory_data_
=
mutex
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetThreadId
(
int
thread_id
)
{
thread_id_
=
thread_id
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetThreadNum
(
int
thread_num
)
{
thread_num_
=
thread_num
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
template
<
typename
T
>
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
void
InMemoryDataFeed
<
T
>::
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
T
ins
;
T
ins
;
...
@@ -203,11 +231,54 @@ void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
...
@@ -203,11 +231,54 @@ void InMemoryDataFeed<T>::PutInsToChannel(const std::string& ins_str) {
shuffled_ins_
->
Push
(
std
::
move
(
ins
));
shuffled_ins_
->
Push
(
std
::
move
(
ins
));
}
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
FillMemoryDataToChannel
()
{
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::FillMemoryDataToChannel, thread_id="
<<
thread_id_
;
int64_t
start
=
0
;
int64_t
end
=
0
;
int64_t
size
=
memory_data_
->
size
();
VLOG
(
3
)
<<
"memory_data size="
<<
size
;
for
(
int64_t
i
=
0
;
i
<=
static_cast
<
int64_t
>
(
thread_id_
);
++
i
)
{
int64_t
len
=
size
/
static_cast
<
int64_t
>
(
thread_num_
)
+
(
i
<
(
size
%
static_cast
<
int64_t
>
(
thread_num_
)));
start
=
end
;
end
+=
len
;
}
for
(
int64_t
i
=
start
;
i
<
end
;
++
i
)
{
T
&
t
=
(
*
memory_data_
)[
i
];
shuffled_ins_
->
Push
(
std
::
move
(
t
));
}
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
FillChannelToMemoryData
()
{
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::FillChannelToMemoryData, thread_id="
<<
thread_id_
;
std
::
vector
<
T
>
local_vec
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
channel
=
nullptr
;
if
(
cur_channel_
==
0
)
{
channel
=
shuffled_ins_
;
}
else
{
channel
=
shuffled_ins_out_
;
}
CHECK
(
channel
!=
nullptr
);
local_vec
.
reserve
(
channel
->
Size
());
for
(
int64_t
i
=
0
;
i
<
channel
->
Size
();
++
i
)
{
channel
->
Pop
(
local_vec
[
i
]);
}
std
::
unique_lock
<
std
::
mutex
>
lock
(
*
mutex_for_update_memory_data_
);
lock
.
lock
();
memory_data_
->
insert
(
memory_data_
->
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
lock
.
unlock
();
std
::
vector
<
T
>
().
swap
(
local_vec
);
}
template
<
typename
T
>
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
LoadIntoMemory
()
{
void
InMemoryDataFeed
<
T
>::
LoadIntoMemory
()
{
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::LoadIntoMemory() begin, thread_id="
<<
thread_id_
;
std
::
vector
<
T
>
local_vec
;
std
::
vector
<
T
>
local_vec
;
std
::
string
filename
;
std
::
string
filename
;
while
(
DataFeed
::
PickOneFile
(
&
filename
))
{
while
(
DataFeed
::
PickOneFile
(
&
filename
))
{
VLOG
(
3
)
<<
"PickOneFile, filename="
<<
filename
<<
", thread_id="
<<
thread_id_
;
int
err_no
=
0
;
int
err_no
=
0
;
PrivateQueueDataFeed
<
T
>::
fp_
=
PrivateQueueDataFeed
<
T
>::
fp_
=
fs_open_read
(
filename
,
&
err_no
,
PrivateQueueDataFeed
<
T
>::
pipe_command_
);
fs_open_read
(
filename
,
&
err_no
,
PrivateQueueDataFeed
<
T
>::
pipe_command_
);
...
@@ -216,35 +287,50 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
...
@@ -216,35 +287,50 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
local_vec
.
push_back
(
instance
);
local_vec
.
push_back
(
instance
);
}
}
memory_data_
.
insert
(
memory_data_
.
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::LoadIntoMemory() read all lines, thread_id="
<<
thread_id_
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
mutex_for_update_memory_data_
);
memory_data_
->
insert
(
memory_data_
->
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
}
std
::
vector
<
T
>
().
swap
(
local_vec
);
std
::
vector
<
T
>
().
swap
(
local_vec
);
}
}
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::LoadIntoMemory() end, thread_id="
<<
thread_id_
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
LocalShuffle
()
{
void
InMemoryDataFeed
<
T
>::
LocalShuffle
()
{
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::LocalShuffle() begin, thread_id="
<<
thread_id_
;
FillMemoryDataToChannel
();
VLOG
(
3
)
<<
"InMemoryDataFeed<T>::LocalShuffle() end, thread_id="
<<
thread_id_
;
}
}
// todo global shuffle
/*
template
<
typename
T
>
template
<
typename
T
>
void InMemoryDataFeed<T>::GlobalShuffle(int trainer_num) {
void
InMemoryDataFeed
<
T
>::
GlobalShuffle
()
{
std::random_shuffle(memory_data_.begin(), memory_data_.end());
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
for (int64_t i = 0; i < memory_data_.size(); ++i) {
std
::
vector
<
std
::
string
>
send_str_vec
(
trainer_num_
);
for
(
int64_t
i
=
0
;
i
<
memory_data_
->
size
();
++
i
)
{
// todo get ins id
// todo get ins id
//std::string ins_id = memory_data_[i].ins_id;
//std::string ins_id = memory_data_[i].ins_id;
// todo hash
// todo hash
int64_t hash_id = paddle::ps::local_random_engine()();
//
int64_t hash_id = paddle::ps::local_random_engine()();
//int64_t hash_id = hash(ins_id)
;
int64_t
hash_id
=
0
;
int64_t
node_id
=
hash_id
%
trainer_num_
;
int64_t
node_id
=
hash_id
%
trainer_num_
;
std
::
string
str
;
std
::
string
str
;
SerializeIns(memory_data_[i], str);
SerializeIns
((
*
memory_data_
)[
i
],
str
);
auto fleet_ptr = FleetWrapper::GetInstance();
send_str_vec
[
node_id
]
+=
str
;
auto ret = fleet_ptr->send_client2client_msg(0, node_id, str);
if
(
i
%
fleet_send_batch_size_
==
0
&&
i
!=
0
)
{
for
(
int
j
=
0
;
j
<
send_str_vec
.
size
();
++
j
)
{
fleet_ptr
->
send_client2client_msg
(
0
,
j
,
send_str_vec
[
j
]);
send_str_vec
[
j
]
=
""
;
}
}
}
for
(
int
j
=
0
;
j
<
send_str_vec
.
size
();
++
j
)
{
if
(
send_str_vec
[
j
].
length
()
!=
0
)
{
fleet_ptr
->
send_client2client_msg
(
0
,
j
,
send_str_vec
[
j
]);
}
}
}
}
}
*/
// explicit instantiation
// explicit instantiation
template
class
InMemoryDataFeed
<
std
::
vector
<
MultiSlotType
>
>
;
template
class
InMemoryDataFeed
<
std
::
vector
<
MultiSlotType
>
>
;
...
@@ -646,6 +732,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(
...
@@ -646,6 +732,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstance(
if
(
getline
(
file_
,
line
))
{
if
(
getline
(
file_
,
line
))
{
int
use_slots_num
=
use_slots_
.
size
();
int
use_slots_num
=
use_slots_
.
size
();
instance
->
resize
(
use_slots_num
);
instance
->
resize
(
use_slots_num
);
VLOG
(
3
)
<<
line
;
// parse line
// parse line
const
char
*
str
=
line
.
c_str
();
const
char
*
str
=
line
.
c_str
();
char
*
endptr
=
const_cast
<
char
*>
(
str
);
char
*
endptr
=
const_cast
<
char
*>
(
str
);
...
@@ -735,12 +822,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
...
@@ -735,12 +822,14 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle
// todo serialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
SerializeIns
(
void
MultiSlotInMemoryDataFeed
::
SerializeIns
(
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
&
str
)
{
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
&
str
)
{
return
;
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
Serialize
(
ins
,
str
);
}
}
// todo deserialize ins in global shuffle
// todo deserialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
std
::
vector
<
MultiSlotType
>&
ins
,
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
std
::
vector
<
MultiSlotType
>&
ins
,
const
std
::
string
&
str
)
{
const
std
::
string
&
str
)
{
return
;
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
Deserialize
(
ins
,
str
);
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
3cea00bd
...
@@ -20,6 +20,7 @@ limitations under the License. */
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include <string>
#include <string>
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include <vector>
#include <sstream>
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
@@ -78,17 +79,33 @@ class DataFeed {
...
@@ -78,17 +79,33 @@ class DataFeed {
// This function is used for binding feed_vec memory
// This function is used for binding feed_vec memory
virtual
void
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
);
virtual
void
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
);
// This function will do nothing at default
virtual
void
SetMemoryData
(
void
*
memory_data
)
{
}
// This function will do nothing at default
virtual
void
SetMemoryDataMutex
(
std
::
mutex
*
mutex
)
{
}
// This function will do nothing at default
virtual
void
SetThreadId
(
int
thread_id
)
{
}
// This function will do nothing at default
virtual
void
SetThreadNum
(
int
thread_num
)
{
}
// This function will do nothing at default
virtual
void
SetTrainerNum
(
int
trainer_num
)
{
}
virtual
void
LoadIntoMemory
()
{
virtual
void
LoadIntoMemory
()
{
PADDLE_THROW
(
"This function(LoadIntoMemory) is not implemented."
);
PADDLE_THROW
(
"This function(LoadIntoMemory) is not implemented."
);
}
}
virtual
void
LocalShuffle
()
{
virtual
void
LocalShuffle
()
{
PADDLE_THROW
(
"This function(LocalShuffle) is not implemented."
);
PADDLE_THROW
(
"This function(LocalShuffle) is not implemented."
);
}
}
virtual
void
GlobalShuffle
(
int
trainer_num
)
{
virtual
void
GlobalShuffle
()
{
PADDLE_THROW
(
"This function(GlobalShuffle) is not implemented."
);
PADDLE_THROW
(
"This function(GlobalShuffle) is not implemented."
);
}
}
virtual
void
FillMemoryDataToChannel
()
{
PADDLE_THROW
(
"This function(FillMemoryDataToChannel) is not implemented."
);
}
virtual
void
FillChannelToMemoryData
()
{
PADDLE_THROW
(
"This function(FillChannelToMemoryData) is not implemented."
);
}
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
PADDLE_THROW
(
"This function(PutToChannel) is not implemented."
);
PADDLE_THROW
(
"This function(Put
Ins
ToChannel) is not implemented."
);
}
}
protected:
protected:
...
@@ -181,13 +198,20 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
...
@@ -181,13 +198,20 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
public:
public:
InMemoryDataFeed
();
InMemoryDataFeed
();
virtual
~
InMemoryDataFeed
()
{}
virtual
~
InMemoryDataFeed
()
{}
virtual
void
Init
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
)
=
0
;
virtual
bool
Start
();
virtual
bool
Start
();
virtual
int
Next
();
virtual
int
Next
();
virtual
void
SetMemoryData
(
void
*
memory_data
);
virtual
void
SetMemoryDataMutex
(
std
::
mutex
*
mutex
);
virtual
void
SetThreadId
(
int
thread_id
);
virtual
void
SetThreadNum
(
int
thread_num
);
virtual
void
SetTrainerNum
(
int
trainer_num
);
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
);
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
);
virtual
void
FillMemoryDataToChannel
();
virtual
void
FillChannelToMemoryData
();
virtual
void
LoadIntoMemory
();
virtual
void
LoadIntoMemory
();
virtual
void
LocalShuffle
();
virtual
void
LocalShuffle
();
// todo global shuffle
virtual
void
GlobalShuffle
();
//virtual void GlobalShuffle(int trainer_num);
protected:
protected:
virtual
void
AddInstanceToInsVec
(
T
*
vec_ins
,
const
T
&
instance
,
int
index
)
=
0
;
virtual
void
AddInstanceToInsVec
(
T
*
vec_ins
,
const
T
&
instance
,
int
index
)
=
0
;
virtual
bool
ParseOneInstance
(
T
*
instance
)
=
0
;
virtual
bool
ParseOneInstance
(
T
*
instance
)
=
0
;
...
@@ -196,13 +220,18 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
...
@@ -196,13 +220,18 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual
void
SerializeIns
(
const
T
&
ins
,
std
::
string
&
str
)
=
0
;
virtual
void
SerializeIns
(
const
T
&
ins
,
std
::
string
&
str
)
=
0
;
virtual
void
DeserializeIns
(
T
&
ins
,
const
std
::
string
&
str
)
=
0
;
virtual
void
DeserializeIns
(
T
&
ins
,
const
std
::
string
&
str
)
=
0
;
std
::
vector
<
T
>
memory_data_
;
int
thread_id_
;
int
thread_num_
;
int
trainer_num_
;
std
::
vector
<
T
>*
memory_data_
;
std
::
mutex
*
mutex_for_update_memory_data_
;
// when read ins, we put ins from one channel to the other,
// when read ins, we put ins from one channel to the other,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// and when finish reading, we set cur_channel = 1 - cur_channel,
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
// so if cur_channel=0, all data are in shuffled_ins_, else shuffled_ins_out_
int
cur_channel_
;
int
cur_channel_
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
shuffled_ins_
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
shuffled_ins_
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
shuffled_ins_out_
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
shuffled_ins_out_
;
int64_t
fleet_send_batch_size_
;
};
};
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
// This class define the data type of instance(ins_vec) in MultiSlotDataFeed
...
@@ -226,6 +255,7 @@ class MultiSlotType {
...
@@ -226,6 +255,7 @@ class MultiSlotType {
offset_
[
0
]
=
0
;
offset_
[
0
]
=
0
;
}
}
const
std
::
vector
<
size_t
>&
GetOffset
()
const
{
return
offset_
;
}
const
std
::
vector
<
size_t
>&
GetOffset
()
const
{
return
offset_
;
}
std
::
vector
<
size_t
>&
MutableOffset
()
{
return
offset_
;
}
void
AddValue
(
const
float
v
)
{
void
AddValue
(
const
float
v
)
{
CheckFloat
();
CheckFloat
();
float_feasign_
.
push_back
(
v
);
float_feasign_
.
push_back
(
v
);
...
@@ -248,8 +278,11 @@ class MultiSlotType {
...
@@ -248,8 +278,11 @@ class MultiSlotType {
}
}
}
}
const
std
::
vector
<
float
>&
GetFloatData
()
const
{
return
float_feasign_
;
}
const
std
::
vector
<
float
>&
GetFloatData
()
const
{
return
float_feasign_
;
}
std
::
vector
<
float
>&
MutableFloatData
()
{
return
float_feasign_
;
}
const
std
::
vector
<
uint64_t
>&
GetUint64Data
()
const
{
return
uint64_feasign_
;
}
const
std
::
vector
<
uint64_t
>&
GetUint64Data
()
const
{
return
uint64_feasign_
;
}
std
::
vector
<
uint64_t
>&
MutableUint64Data
()
{
return
uint64_feasign_
;
}
const
std
::
string
&
GetType
()
const
{
return
type_
;
}
const
std
::
string
&
GetType
()
const
{
return
type_
;
}
std
::
string
&
MutableType
()
{
return
type_
;
}
private:
private:
void
CheckType
(
const
std
::
string
&
type
)
const
{
void
CheckType
(
const
std
::
string
&
type
)
const
{
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
3cea00bd
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
* See the License for the specific language governing permissions and
* See the License for the specific language governing permissions and
* limitations under the License. */
* limitations under the License. */
#include <random>
#include "paddle/fluid/framework/data_set.h"
#include "paddle/fluid/framework/data_set.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/message.h"
#include "google/protobuf/message.h"
...
@@ -21,23 +22,27 @@
...
@@ -21,23 +22,27 @@
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
Dataset
::
Dataset
()
{
thread_num_
=
1
;
}
template
<
typename
T
>
DatasetImpl
<
T
>::
DatasetImpl
()
{
thread_num_
=
1
;
}
void
Dataset
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
{
VLOG
(
3
)
<<
"filelist size: "
<<
filelist
.
size
();
VLOG
(
3
)
<<
"filelist size: "
<<
filelist
.
size
();
filelist_
=
filelist
;
filelist_
=
filelist
;
/*
int file_cnt = filelist_.size();
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
if (thread_num_ > file_cnt) {
VLOG(1) << "DataSet thread num = " << thread_num_
VLOG(1) << "DataSet thread num = " << thread_num_
<< ", file num = " << file_cnt
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
<< ". Changing DataSet thread num = " << file_cnt;
thread_num_ = file_cnt;
thread_num_ = file_cnt;
}
}
*/
}
}
// buggy here, a user should set filelist first before this function
// buggy here, a user should set filelist first before this function
// not user friendly
// not user friendly
void
Dataset
::
SetThreadNum
(
int
thread_num
)
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetThreadNum
(
int
thread_num
)
{
int
file_cnt
=
filelist_
.
size
();
int
file_cnt
=
filelist_
.
size
();
if
(
file_cnt
!=
0
&&
thread_num
>
file_cnt
)
{
if
(
file_cnt
!=
0
&&
thread_num
>
file_cnt
)
{
VLOG
(
1
)
<<
"DataSet thread num = "
<<
thread_num
VLOG
(
1
)
<<
"DataSet thread num = "
<<
thread_num
...
@@ -48,19 +53,24 @@ void Dataset::SetThreadNum(int thread_num) {
...
@@ -48,19 +53,24 @@ void Dataset::SetThreadNum(int thread_num) {
thread_num_
=
thread_num
;
thread_num_
=
thread_num
;
}
}
void
Dataset
::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
}
void
Dataset
::
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
{
google
::
protobuf
::
TextFormat
::
ParseFromString
(
data_feed_desc_str
,
google
::
protobuf
::
TextFormat
::
ParseFromString
(
data_feed_desc_str
,
&
data_feed_desc_
);
&
data_feed_desc_
);
}
}
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
template
<
typename
T
>
Dataset
::
GetReaders
()
{
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
DatasetImpl
<
T
>::
GetReaders
()
{
return
readers_
;
return
readers_
;
}
}
void
Dataset
::
LoadIntoMemory
()
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
LoadIntoMemory
()
{
VLOG
(
3
)
<<
"DatasetImpl<T>::LoadIntoMemory() begin"
;
if
(
readers_
.
size
()
==
0
)
{
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
CreateReaders
();
}
}
...
@@ -72,12 +82,18 @@ void Dataset::LoadIntoMemory() {
...
@@ -72,12 +82,18 @@ void Dataset::LoadIntoMemory() {
for
(
std
::
thread
&
t
:
load_threads
)
{
for
(
std
::
thread
&
t
:
load_threads
)
{
t
.
join
();
t
.
join
();
}
}
VLOG
(
3
)
<<
"DatasetImpl<T>::LoadIntoMemory() end"
;
}
}
void
Dataset
::
LocalShuffle
()
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
LocalShuffle
()
{
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() begin"
;
if
(
readers_
.
size
()
==
0
)
{
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
CreateReaders
();
}
}
// if it is not InMemory, memory_data_ is empty
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
std
::
vector
<
std
::
thread
>
local_shuffle_threads
;
std
::
vector
<
std
::
thread
>
local_shuffle_threads
;
for
(
int64_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
local_shuffle_threads
.
push_back
(
std
::
thread
(
local_shuffle_threads
.
push_back
(
std
::
thread
(
...
@@ -86,30 +102,37 @@ void Dataset::LocalShuffle() {
...
@@ -86,30 +102,37 @@ void Dataset::LocalShuffle() {
for
(
std
::
thread
&
t
:
local_shuffle_threads
)
{
for
(
std
::
thread
&
t
:
local_shuffle_threads
)
{
t
.
join
();
t
.
join
();
}
}
std
::
vector
<
T
>
().
swap
(
memory_data_
);
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() end"
;
}
}
// todo global shuffle
template
<
typename
T
>
void
Dataset
::
GlobalShuffle
()
{
void
DatasetImpl
<
T
>::
GlobalShuffle
()
{
/*
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() begin"
;
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
}
// if it is not InMemory, memory_data_ is empty
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
registe_client2client_msg_handler
(
0
,
fleet_ptr
->
registe_client2client_msg_handler
(
0
,
[
this
](
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
->
int
{
[
this
](
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
->
int
{
return
this
->
ReceiveFromClient
(
msg_type
,
client_id
,
msg
);
return
this
->
ReceiveFromClient
(
msg_type
,
client_id
,
msg
);
});
});
if (readers_.size() == 0) {
CreateReaders();
}
std
::
vector
<
std
::
thread
>
global_shuffle_threads
;
std
::
vector
<
std
::
thread
>
global_shuffle_threads
;
for (int64_t i = 0; i < thread_num_; ++i) {
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
global_shuffle_threads.push_back(std::thread(&paddle::framework::DataFeed::GlobalShuffle,
global_shuffle_threads
.
push_back
(
readers_[i].get(), trainer_num_));
std
::
thread
(
&
paddle
::
framework
::
DataFeed
::
GlobalShuffle
,
readers_
[
i
].
get
()));
}
}
for
(
std
::
thread
&
t
:
global_shuffle_threads
)
{
for
(
std
::
thread
&
t
:
global_shuffle_threads
)
{
t
.
join
();
t
.
join
();
}*/
}
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() end"
;
}
}
void
Dataset
::
CreateReaders
()
{
template
<
typename
T
>
void
DatasetImpl
<
T
>::
CreateReaders
()
{
VLOG
(
3
)
<<
"Calling CreateReaders()"
;
VLOG
(
3
)
<<
"Calling CreateReaders()"
;
CHECK
(
thread_num_
>
0
)
<<
"thread_num should > 0"
;
CHECK
(
thread_num_
>
0
)
<<
"thread_num should > 0"
;
VLOG
(
3
)
<<
"thread_num in Readers: "
<<
thread_num_
;
VLOG
(
3
)
<<
"thread_num in Readers: "
<<
thread_num_
;
...
@@ -118,22 +141,53 @@ void Dataset::CreateReaders() {
...
@@ -118,22 +141,53 @@ void Dataset::CreateReaders() {
return
;
return
;
}
}
VLOG
(
3
)
<<
"data feed class name: "
<<
data_feed_desc_
.
name
();
VLOG
(
3
)
<<
"data feed class name: "
<<
data_feed_desc_
.
name
();
for
(
int
64_t
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
readers_
.
push_back
(
DataFeedFactory
::
CreateDataFeed
(
data_feed_desc_
.
name
()));
readers_
.
push_back
(
DataFeedFactory
::
CreateDataFeed
(
data_feed_desc_
.
name
()));
readers_
.
back
()
->
Init
(
data_feed_desc_
);
readers_
.
back
()
->
Init
(
data_feed_desc_
);
readers_
.
back
()
->
SetMemoryData
(
&
memory_data_
);
readers_
.
back
()
->
SetMemoryDataMutex
(
&
mutex_for_update_memory_data_
);
readers_
.
back
()
->
SetThreadId
(
i
);
readers_
.
back
()
->
SetThreadNum
(
thread_num_
);
readers_
.
back
()
->
SetTrainerNum
(
trainer_num_
);
}
}
VLOG
(
3
)
<<
"Filelist size in readers: "
<<
filelist_
.
size
();
VLOG
(
3
)
<<
"Filelist size in readers: "
<<
filelist_
.
size
();
readers_
[
0
]
->
SetFileList
(
filelist_
);
readers_
[
0
]
->
SetFileList
(
filelist_
);
}
}
int
Dataset
::
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
template
<
typename
T
>
void
DatasetImpl
<
T
>::
DestroyReaders
()
{
VLOG
(
3
)
<<
"Calling DestroyReaders()"
;
// clear memory_data_ before fill it
// because if LoadIntoMemory but no Shuffle,
// memory_data_ has empty data which has been std::move to channel
if
(
memory_data_
.
size
()
!=
0
)
{
std
::
vector
<
T
>
().
swap
(
memory_data_
);
}
std
::
vector
<
std
::
thread
>
fill_threads
;
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
fill_threads
.
push_back
(
std
::
thread
(
&
paddle
::
framework
::
DataFeed
::
FillChannelToMemoryData
,
readers_
[
i
].
get
()));
}
for
(
std
::
thread
&
t
:
fill_threads
)
{
t
.
join
();
}
std
::
vector
<
std
::
string
>
().
swap
(
filelist_
);
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
().
swap
(
readers_
);
}
template
<
typename
T
>
int
DatasetImpl
<
T
>::
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
{
const
std
::
string
&
msg
)
{
//
can also use hash
//
todo random
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
int64_t
index
=
0
;
int64_t
index
=
0
;
readers_
[
index
]
->
PutInsToChannel
(
msg
);
readers_
[
index
]
->
PutInsToChannel
(
msg
);
return
0
;
return
0
;
}
}
// explicit instantiation
template
class
DatasetImpl
<
std
::
vector
<
MultiSlotType
>
>
;
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/data_set.h
浏览文件 @
3cea00bd
...
@@ -28,8 +28,33 @@ namespace framework {
...
@@ -28,8 +28,33 @@ namespace framework {
class
Dataset
{
class
Dataset
{
public:
public:
Dataset
();
Dataset
()
{};
virtual
~
Dataset
()
{}
virtual
~
Dataset
()
{};
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
=
0
;
virtual
void
SetThreadNum
(
int
thread_num
)
=
0
;
virtual
void
SetTrainerNum
(
int
trainer_num
)
=
0
;
virtual
void
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
=
0
;
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
=
0
;
virtual
int
GetThreadNum
()
=
0
;
virtual
int
GetTrainerNum
()
=
0
;
virtual
const
paddle
::
framework
::
DataFeedDesc
&
GetDataFeedDesc
()
=
0
;
virtual
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
GetReaders
()
=
0
;
virtual
void
LoadIntoMemory
()
=
0
;
virtual
void
LocalShuffle
()
=
0
;
virtual
void
GlobalShuffle
()
=
0
;
virtual
void
CreateReaders
()
=
0
;
virtual
void
DestroyReaders
()
=
0
;
protected:
virtual
int
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
=
0
;
};
template
<
typename
T
>
class
DatasetImpl
:
public
Dataset
{
public:
DatasetImpl
();
virtual
~
DatasetImpl
()
{}
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
);
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
);
virtual
void
SetThreadNum
(
int
thread_num
);
virtual
void
SetThreadNum
(
int
thread_num
);
...
@@ -43,25 +68,34 @@ class Dataset {
...
@@ -43,25 +68,34 @@ class Dataset {
return
data_feed_desc_
;
return
data_feed_desc_
;
}
}
virtual
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
virtual
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>&
GetReaders
();
GetReaders
();
virtual
void
LoadIntoMemory
();
virtual
void
LoadIntoMemory
();
virtual
void
LocalShuffle
();
virtual
void
LocalShuffle
();
// todo global shuffle
virtual
void
GlobalShuffle
();
virtual
void
GlobalShuffle
();
virtual
void
CreateReaders
();
virtual
void
CreateReaders
();
virtual
void
DestroyReaders
();
protected:
protected:
virtual
int
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
virtual
int
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
);
const
std
::
string
&
msg
);
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers_
;
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers_
;
std
::
vector
<
T
>
memory_data_
;
std
::
mutex
mutex_for_update_memory_data_
;
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>>
shuffled_ins_vec_
;
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>>
shuffled_ins_out_vec_
;
int
thread_num_
;
int
thread_num_
;
std
::
string
fs_name_
;
std
::
string
fs_ugi_
;
paddle
::
framework
::
DataFeedDesc
data_feed_desc_
;
paddle
::
framework
::
DataFeedDesc
data_feed_desc_
;
std
::
vector
<
std
::
string
>
filelist_
;
std
::
vector
<
std
::
string
>
filelist_
;
int
trainer_num_
;
int
trainer_num_
;
};
};
class
MultiSlotDataset
:
public
DatasetImpl
<
std
::
vector
<
MultiSlotType
>>
{
public:
MultiSlotDataset
()
{}
virtual
~
MultiSlotDataset
()
{}
};
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/fleet/fleet_wrapper.cc
浏览文件 @
3cea00bd
...
@@ -27,6 +27,7 @@ See the License for the specific language governing permissions and
...
@@ -27,6 +27,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -35,6 +36,30 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
...
@@ -35,6 +36,30 @@ const uint32_t MAX_FEASIGN_NUM = 1024 * 100 * 100;
std
::
shared_ptr
<
FleetWrapper
>
FleetWrapper
::
s_instance_
=
NULL
;
std
::
shared_ptr
<
FleetWrapper
>
FleetWrapper
::
s_instance_
=
NULL
;
bool
FleetWrapper
::
is_initialized_
=
false
;
bool
FleetWrapper
::
is_initialized_
=
false
;
#ifdef PADDLE_WITH_PSLIB
template
<
class
AR
>
paddle
::
ps
::
Archive
<
AR
>&
operator
<<
(
paddle
::
ps
::
Archive
<
AR
>&
ar
,
const
MultiSlotType
&
ins
)
{
ar
<<
ins
.
GetType
();
ar
<<
ins
.
GetOffset
();
ar
<<
ins
.
GetFloatData
();
ar
<<
ins
.
GetUint64Data
();
return
ar
;
}
template
<
class
AR
>
paddle
::
ps
::
Archive
<
AR
>&
operator
>>
(
paddle
::
ps
::
Archive
<
AR
>&
ar
,
MultiSlotType
&
ins
)
{
ar
>>
ins
.
MutableType
();
ar
>>
ins
.
MutableOffset
();
ar
>>
ins
.
MutableFloatData
();
ar
>>
ins
.
MutableUint64Data
();
return
ar
;
}
#endif
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
FleetWrapper
::
pslib_ptr_
=
NULL
;
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
FleetWrapper
::
pslib_ptr_
=
NULL
;
#endif
#endif
...
@@ -266,5 +291,42 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
...
@@ -266,5 +291,42 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
#endif
}
}
// todo registe_client2client_msg_handler
int
FleetWrapper
::
registe_client2client_msg_handler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
return
0
;
}
// todo send_client2client_msg
int
FleetWrapper
::
send_client2client_msg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
return
0
;
}
template
<
typename
T
>
void
FleetWrapper
::
Serialize
(
const
T
&
t
,
std
::
string
&
str
)
{
#ifdef PADDLE_WITH_PSLIB
paddle
::
ps
::
BinaryArchive
ar
;
ar
<<
t
;
str
=
std
::
string
(
ar
.
buffer
(),
ar
.
length
());
#else
VLOG
(
0
)
<<
"FleetWrapper::Serialize do nothing when no pslib"
;
#endif
}
template
<
typename
T
>
void
FleetWrapper
::
Deserialize
(
T
&
t
,
const
std
::
string
&
str
)
{
#ifdef PADDLE_WITH_PSLIB
paddle
::
ps
::
BinaryArchive
ar
;
ar
.
set_read_buffer
(
const_cast
<
char
*>
(
str
.
c_str
()),
str
.
length
(),
nullptr
);
t
=
ar
.
get
<
T
>
();
#else
VLOG
(
0
)
<<
"FleetWrapper::Deserialize do nothing when no pslib"
;
#endif
}
template
void
FleetWrapper
::
Serialize
<
std
::
vector
<
MultiSlotType
>
>
(
const
std
::
vector
<
MultiSlotType
>&
,
std
::
string
&
);
template
void
FleetWrapper
::
Deserialize
(
std
::
vector
<
MultiSlotType
>
&
,
const
std
::
string
&
);
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/fleet/fleet_wrapper.h
浏览文件 @
3cea00bd
...
@@ -17,7 +17,11 @@ limitations under the License. */
...
@@ -17,7 +17,11 @@ limitations under the License. */
#include <memory>
#include <memory>
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
#include <pslib.h>
#include <pslib.h>
#include <archive.h>
#endif
#endif
#include <random>
#include <atomic>
#include <time.h>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
...
@@ -110,6 +114,16 @@ class FleetWrapper {
...
@@ -110,6 +114,16 @@ class FleetWrapper {
uint64_t
RunServer
();
uint64_t
RunServer
();
void
GatherServers
(
const
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
);
void
GatherServers
(
const
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
);
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
int
registe_client2client_msg_handler
(
int
msg_type
,
MsgHandlerFunc
handler
);
int
send_client2client_msg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
);
std
::
default_random_engine
&
local_random_engine
();
template
<
typename
T
>
void
Serialize
(
const
T
&
t
,
std
::
string
&
str
);
template
<
typename
T
>
void
Deserialize
(
T
&
t
,
const
std
::
string
&
str
);
static
std
::
shared_ptr
<
FleetWrapper
>
GetInstance
()
{
static
std
::
shared_ptr
<
FleetWrapper
>
GetInstance
()
{
if
(
NULL
==
s_instance_
)
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
framework
::
FleetWrapper
());
s_instance_
.
reset
(
new
paddle
::
framework
::
FleetWrapper
());
...
...
paddle/fluid/pybind/data_set_py.cc
浏览文件 @
3cea00bd
...
@@ -41,17 +41,17 @@ namespace paddle {
...
@@ -41,17 +41,17 @@ namespace paddle {
namespace
pybind
{
namespace
pybind
{
void
BindDataset
(
py
::
module
*
m
)
{
void
BindDataset
(
py
::
module
*
m
)
{
py
::
class_
<
framework
::
Dataset
>
(
*
m
,
"
Dataset"
)
py
::
class_
<
framework
::
MultiSlotDataset
>
(
*
m
,
"MultiSlot
Dataset"
)
.
def
(
py
::
init
([]()
{
.
def
(
py
::
init
([]()
{
return
std
::
unique_ptr
<
framework
::
Dataset
>
(
new
framework
::
Dataset
());
return
std
::
unique_ptr
<
framework
::
MultiSlotDataset
>
(
new
framework
::
MultiSlot
Dataset
());
}))
}))
.
def
(
"set_filelist"
,
&
framework
::
Dataset
::
SetFileList
)
.
def
(
"set_filelist"
,
&
framework
::
MultiSlot
Dataset
::
SetFileList
)
.
def
(
"set_thread_num"
,
&
framework
::
Dataset
::
SetThreadNum
)
.
def
(
"set_thread_num"
,
&
framework
::
MultiSlot
Dataset
::
SetThreadNum
)
.
def
(
"set_trainer_num"
,
&
framework
::
Dataset
::
SetTrainerNum
)
.
def
(
"set_trainer_num"
,
&
framework
::
MultiSlot
Dataset
::
SetTrainerNum
)
.
def
(
"set_data_feed_desc"
,
&
framework
::
Dataset
::
SetDataFeedDesc
)
.
def
(
"set_data_feed_desc"
,
&
framework
::
MultiSlot
Dataset
::
SetDataFeedDesc
)
.
def
(
"load_into_memory"
,
&
framework
::
Dataset
::
LoadIntoMemory
)
.
def
(
"load_into_memory"
,
&
framework
::
MultiSlot
Dataset
::
LoadIntoMemory
)
.
def
(
"local_shuffle"
,
&
framework
::
Dataset
::
LocalShuffle
)
.
def
(
"local_shuffle"
,
&
framework
::
MultiSlot
Dataset
::
LocalShuffle
)
.
def
(
"global_shuffle"
,
&
framework
::
Dataset
::
GlobalShuffle
);
.
def
(
"global_shuffle"
,
&
framework
::
MultiSlot
Dataset
::
GlobalShuffle
);
}
}
}
// end namespace pybind
}
// end namespace pybind
...
...
python/paddle/fluid/__init__.py
浏览文件 @
3cea00bd
...
@@ -30,7 +30,7 @@ from .dataset import *
...
@@ -30,7 +30,7 @@ from .dataset import *
from
.
import
async_executor
from
.
import
async_executor
from
.async_executor
import
*
from
.async_executor
import
*
from
.
import
trainer
from
.
import
trainer
_desc
from
.
import
inferencer
from
.
import
inferencer
from
.
import
io
from
.
import
io
...
@@ -67,7 +67,7 @@ from . import install_check
...
@@ -67,7 +67,7 @@ from . import install_check
Tensor
=
LoDTensor
Tensor
=
LoDTensor
__all__
=
framework
.
__all__
+
executor
.
__all__
+
\
__all__
=
framework
.
__all__
+
executor
.
__all__
+
\
trainer
.
__all__
+
inferencer
.
__all__
+
transpiler
.
__all__
+
\
trainer
_desc
.
__all__
+
inferencer
.
__all__
+
transpiler
.
__all__
+
\
parallel_executor
.
__all__
+
lod_tensor
.
__all__
+
\
parallel_executor
.
__all__
+
lod_tensor
.
__all__
+
\
data_feed_desc
.
__all__
+
async_executor
.
__all__
+
compiler
.
__all__
+
[
data_feed_desc
.
__all__
+
async_executor
.
__all__
+
compiler
.
__all__
+
[
'io'
,
'io'
,
...
...
python/paddle/fluid/dataset.py
浏览文件 @
3cea00bd
...
@@ -37,7 +37,7 @@ class DatasetBase(object):
...
@@ -37,7 +37,7 @@ class DatasetBase(object):
# to decide whether we need create in memory instance
# to decide whether we need create in memory instance
self
.
proto_desc
=
data_feed_pb2
.
DataFeedDesc
()
self
.
proto_desc
=
data_feed_pb2
.
DataFeedDesc
()
self
.
proto_desc
.
pipe_command
=
"cat"
self
.
proto_desc
.
pipe_command
=
"cat"
self
.
dataset
=
core
.
Dataset
()
self
.
dataset
=
core
.
MultiSlot
Dataset
()
self
.
thread_num
=
0
self
.
thread_num
=
0
def
set_pipe_command
(
self
,
pipe_command
):
def
set_pipe_command
(
self
,
pipe_command
):
...
@@ -109,7 +109,7 @@ class InMemoryDataset(DatasetBase):
...
@@ -109,7 +109,7 @@ class InMemoryDataset(DatasetBase):
self
.
proto_desc
.
name
=
"MultiSlotInMemoryDataFeed"
self
.
proto_desc
.
name
=
"MultiSlotInMemoryDataFeed"
def
load_into_memory
(
self
):
def
load_into_memory
(
self
):
_prepare_to_run
()
self
.
_prepare_to_run
()
self
.
dataset
.
load_into_memory
()
self
.
dataset
.
load_into_memory
()
def
local_shuffle
(
self
):
def
local_shuffle
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录