Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a5b1a0e1
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a5b1a0e1
编写于
3月 20, 2019
作者:
X
xujiaqi01
提交者:
dongdaxiang
3月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support multi dataset && add init model && fix bug
上级
3c65cc1b
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
341 addition
and
113 deletion
+341
-113
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+2
-1
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+106
-49
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+25
-12
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+67
-15
paddle/fluid/framework/data_set.h
paddle/fluid/framework/data_set.h
+7
-1
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+3
-1
paddle/fluid/framework/fleet/fleet_wrapper.cc
paddle/fluid/framework/fleet/fleet_wrapper.cc
+83
-14
paddle/fluid/framework/fleet/fleet_wrapper.h
paddle/fluid/framework/fleet/fleet_wrapper.h
+10
-6
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+4
-2
paddle/fluid/pybind/async_executor_py.cc
paddle/fluid/pybind/async_executor_py.cc
+1
-1
paddle/fluid/pybind/data_set_py.cc
paddle/fluid/pybind/data_set_py.cc
+1
-0
paddle/fluid/pybind/fleet_wrapper_py.cc
paddle/fluid/pybind/fleet_wrapper_py.cc
+1
-0
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+12
-5
python/paddle/fluid/incubate/fleet/base/role_maker.py
python/paddle/fluid/incubate/fleet/base/role_maker.py
+2
-2
python/paddle/fluid/incubate/fleet/parameter_server/__init__.py
.../paddle/fluid/incubate/fleet/parameter_server/__init__.py
+17
-4
未找到文件。
paddle/fluid/framework/async_executor.cc
浏览文件 @
a5b1a0e1
...
@@ -155,7 +155,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
...
@@ -155,7 +155,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
}
}
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
if
(
mode
==
"mpi"
)
{
if
(
mode
==
"mpi"
)
{
_pull_dense_thread
->
stop
();
// todo ?
//_pull_dense_thread->stop();
}
}
#endif
#endif
VLOG
(
3
)
<<
"start to run from files in async_executor"
;
VLOG
(
3
)
<<
"start to run from files in async_executor"
;
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
a5b1a0e1
...
@@ -23,15 +23,11 @@ limitations under the License. */
...
@@ -23,15 +23,11 @@ limitations under the License. */
#include "io/shell.h"
#include "io/shell.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/platform/timer.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
vector
<
std
::
string
>
DataFeed
::
filelist_
;
size_t
DataFeed
::
file_idx_
;
std
::
mutex
DataFeed
::
mutex_for_pick_file_
;
bool
DataFeed
::
finish_set_filelist_
;
void
DataFeed
::
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
)
{
void
DataFeed
::
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
)
{
CheckInit
();
CheckInit
();
for
(
size_t
i
=
0
;
i
<
use_slots_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
use_slots_
.
size
();
++
i
)
{
...
@@ -42,7 +38,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
...
@@ -42,7 +38,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
}
}
bool
DataFeed
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
files
)
{
bool
DataFeed
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
files
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
*
mutex_for_pick_file_
);
CheckInit
();
CheckInit
();
// Do not set finish_set_filelist_ flag,
// Do not set finish_set_filelist_ flag,
// since a user may set file many times after init reader
// since a user may set file many times after init reader
...
@@ -52,9 +48,8 @@ bool DataFeed::SetFileList(const std::vector<std::string>& files) {
...
@@ -52,9 +48,8 @@ bool DataFeed::SetFileList(const std::vector<std::string>& files) {
return false;
return false;
}
}
*/
*/
PADDLE_ENFORCE
(
files
.
size
(),
"You have set an empty filelist."
);
//
PADDLE_ENFORCE(files.size(), "You have set an empty filelist.");
filelist_
.
assign
(
files
.
begin
(),
files
.
end
());
filelist_
.
assign
(
files
.
begin
(),
files
.
end
());
file_idx_
=
0
;
finish_set_filelist_
=
true
;
finish_set_filelist_
=
true
;
return
true
;
return
true
;
...
@@ -66,13 +61,17 @@ void DataFeed::SetBatchSize(int batch_size) {
...
@@ -66,13 +61,17 @@ void DataFeed::SetBatchSize(int batch_size) {
}
}
bool
DataFeed
::
PickOneFile
(
std
::
string
*
filename
)
{
bool
DataFeed
::
PickOneFile
(
std
::
string
*
filename
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
PADDLE_ENFORCE
(
mutex_for_pick_file_
!=
nullptr
,
if
(
file_idx_
==
filelist_
.
size
())
{
"should call SetFileListMutex before PickOneFile"
);
PADDLE_ENFORCE
(
file_idx_
!=
nullptr
,
"should call SetFileListIndex before PickOneFile"
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
*
mutex_for_pick_file_
);
if
(
*
file_idx_
==
filelist_
.
size
())
{
VLOG
(
3
)
<<
"DataFeed::PickOneFile no more file to pick"
;
VLOG
(
3
)
<<
"DataFeed::PickOneFile no more file to pick"
;
return
false
;
return
false
;
}
}
VLOG
(
3
)
<<
"file_idx_="
<<
file_idx_
;
VLOG
(
3
)
<<
"file_idx_="
<<
*
file_idx_
;
*
filename
=
filelist_
[
file_idx_
++
];
*
filename
=
filelist_
[
(
*
file_idx_
)
++
];
// LOG(ERROR) << "pick file:" << *filename;
// LOG(ERROR) << "pick file:" << *filename;
return
true
;
return
true
;
}
}
...
@@ -150,7 +149,11 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
...
@@ -150,7 +149,11 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
cur_channel_
=
0
;
cur_channel_
=
0
;
shuffled_ins_
=
std
::
make_shared
<
paddle
::
framework
::
BlockingQueue
<
T
>>
();
shuffled_ins_
=
std
::
make_shared
<
paddle
::
framework
::
BlockingQueue
<
T
>>
();
shuffled_ins_out_
=
std
::
make_shared
<
paddle
::
framework
::
BlockingQueue
<
T
>>
();
shuffled_ins_out_
=
std
::
make_shared
<
paddle
::
framework
::
BlockingQueue
<
T
>>
();
fleet_send_batch_size_
=
10000
;
fleet_send_batch_size_
=
80000
;
memory_data_
=
nullptr
;
mutex_for_update_memory_data_
=
nullptr
;
this
->
file_idx_
=
nullptr
;
this
->
mutex_for_pick_file_
=
nullptr
;
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -192,6 +195,8 @@ int InMemoryDataFeed<T>::Next() {
...
@@ -192,6 +195,8 @@ int InMemoryDataFeed<T>::Next() {
out_channel
->
Push
(
std
::
move
(
instance
));
out_channel
->
Push
(
std
::
move
(
instance
));
}
}
DataFeed
::
batch_size_
=
index
;
DataFeed
::
batch_size_
=
index
;
VLOG
(
3
)
<<
"batch_size_="
<<
DataFeed
::
batch_size_
<<
", thread_id="
<<
thread_id_
;
if
(
DataFeed
::
batch_size_
!=
0
)
{
if
(
DataFeed
::
batch_size_
!=
0
)
{
PutToFeedVec
(
ins_vec
);
PutToFeedVec
(
ins_vec
);
}
else
{
}
else
{
...
@@ -227,25 +232,22 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
...
@@ -227,25 +232,22 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
template
<
typename
T
>
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
void
InMemoryDataFeed
<
T
>::
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
T
ins
;
std
::
vector
<
T
>
ins
;
DeserializeIns
(
&
ins
,
ins_str
);
DeserializeIns
(
&
ins
,
ins_str
);
shuffled_ins_
->
Push
(
std
::
move
(
ins
));
shuffled_ins_
->
Extend
(
std
::
move
(
ins
));
VLOG
(
3
)
<<
"PutInsToChannel put ins num="
<<
ins
.
size
()
<<
" to channel, channel size="
<<
shuffled_ins_
->
Size
()
<<
" thread_id="
<<
thread_id_
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
FillMemoryDataToChannel
()
{
void
InMemoryDataFeed
<
T
>::
FillMemoryDataToChannel
()
{
VLOG
(
3
)
<<
"FillMemoryDataToChannel, thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"FillMemoryDataToChannel, thread_id="
<<
thread_id_
;
int64_t
start
=
0
;
auto
interval
=
GetMemoryDataInterval
();
int64_t
end
=
0
;
VLOG
(
3
)
<<
"memory data size="
<<
memory_data_
->
size
()
int64_t
size
=
memory_data_
->
size
();
<<
", fill data from ["
<<
interval
.
first
<<
", "
VLOG
(
3
)
<<
"memory_data size="
<<
size
;
<<
interval
.
second
<<
"), thread_id="
<<
thread_id_
;
for
(
int64_t
i
=
0
;
i
<=
static_cast
<
int64_t
>
(
thread_id_
);
++
i
)
{
for
(
int64_t
i
=
interval
.
first
;
i
<
interval
.
second
;
++
i
)
{
int64_t
len
=
size
/
static_cast
<
int64_t
>
(
thread_num_
)
+
(
i
<
(
size
%
static_cast
<
int64_t
>
(
thread_num_
)));
start
=
end
;
end
+=
len
;
}
for
(
int64_t
i
=
start
;
i
<
end
;
++
i
)
{
T
&
t
=
(
*
memory_data_
)[
i
];
T
&
t
=
(
*
memory_data_
)[
i
];
shuffled_ins_
->
Push
(
std
::
move
(
t
));
shuffled_ins_
->
Push
(
std
::
move
(
t
));
}
}
...
@@ -256,14 +258,19 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() {
...
@@ -256,14 +258,19 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() {
VLOG
(
3
)
<<
"FillChannelToMemoryData, thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"FillChannelToMemoryData, thread_id="
<<
thread_id_
;
std
::
vector
<
T
>
local_vec
;
std
::
vector
<
T
>
local_vec
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
channel
=
nullptr
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
channel
=
nullptr
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
pre_channel
=
nullptr
;
if
(
cur_channel_
==
0
)
{
if
(
cur_channel_
==
0
)
{
channel
=
shuffled_ins_
;
channel
=
shuffled_ins_
;
pre_channel
=
shuffled_ins_out_
;
}
else
{
}
else
{
channel
=
shuffled_ins_out_
;
channel
=
shuffled_ins_out_
;
pre_channel
=
shuffled_ins_
;
}
}
CHECK
(
channel
!=
nullptr
);
CHECK
(
channel
!=
nullptr
);
CHECK
(
pre_channel
!=
nullptr
);
CHECK
(
pre_channel
->
Size
()
==
0
);
local_vec
.
resize
(
channel
->
Size
());
local_vec
.
resize
(
channel
->
Size
());
for
(
int64_t
i
=
0
;
i
<
channel
->
S
ize
();
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
local_vec
.
s
ize
();
++
i
)
{
channel
->
Pop
(
local_vec
[
i
]);
channel
->
Pop
(
local_vec
[
i
]);
}
}
VLOG
(
3
)
<<
"local_vec size="
<<
local_vec
.
size
()
<<
", thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"local_vec size="
<<
local_vec
.
size
()
<<
", thread_id="
<<
thread_id_
;
...
@@ -289,20 +296,32 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
...
@@ -289,20 +296,32 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
int
err_no
=
0
;
int
err_no
=
0
;
PrivateQueueDataFeed
<
T
>::
fp_
=
PrivateQueueDataFeed
<
T
>::
fp_
=
fs_open_read
(
filename
,
&
err_no
,
PrivateQueueDataFeed
<
T
>::
pipe_command_
);
fs_open_read
(
filename
,
&
err_no
,
PrivateQueueDataFeed
<
T
>::
pipe_command_
);
CHECK
(
PrivateQueueDataFeed
<
T
>::
fp_
!=
nullptr
);
__fsetlocking
(
&*
PrivateQueueDataFeed
<
T
>::
fp_
,
FSETLOCKING_BYCALLER
);
__fsetlocking
(
&*
PrivateQueueDataFeed
<
T
>::
fp_
,
FSETLOCKING_BYCALLER
);
T
instance
;
T
instance
;
platform
::
Timer
timeline
;
timeline
.
Start
();
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
local_vec
.
push_back
(
instance
);
local_vec
.
push_back
(
instance
);
}
}
timeline
.
Pause
();
VLOG
(
3
)
<<
"LoadIntoMemory() read all lines, file="
VLOG
(
3
)
<<
"LoadIntoMemory() read all lines, file="
<<
filename
<<
", thread_id="
<<
thread_id_
;
<<
filename
<<
", cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds, thread_id="
<<
thread_id_
;
{
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
mutex_for_update_memory_data_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
mutex_for_update_memory_data_
);
timeline
.
Start
();
memory_data_
->
insert
(
memory_data_
->
end
(),
memory_data_
->
insert
(
memory_data_
->
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
std
::
make_move_iterator
(
local_vec
.
begin
()),
std
::
make_move_iterator
(
local_vec
.
end
()));
timeline
.
Pause
();
VLOG
(
3
)
<<
"LoadIntoMemory() memory_data insert, cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds, thread_id="
<<
thread_id_
;
}
}
std
::
vector
<
T
>
().
swap
(
local_vec
);
local_vec
.
clear
(
);
}
}
std
::
vector
<
T
>
().
swap
(
local_vec
);
VLOG
(
3
)
<<
"LoadIntoMemory() end, thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"LoadIntoMemory() end, thread_id="
<<
thread_id_
;
}
}
...
@@ -315,30 +334,66 @@ void InMemoryDataFeed<T>::LocalShuffle() {
...
@@ -315,30 +334,66 @@ void InMemoryDataFeed<T>::LocalShuffle() {
template
<
typename
T
>
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
GlobalShuffle
()
{
void
InMemoryDataFeed
<
T
>::
GlobalShuffle
()
{
VLOG
(
3
)
<<
"GlobalShuffle(), thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"GlobalShuffle()
begin
, thread_id="
<<
thread_id_
;
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
std
::
vector
<
std
::
string
>
send_str_vec
(
trainer_num_
);
std
::
vector
<
std
::
vector
<
T
*>>
send_vec
(
trainer_num_
);
for
(
int64_t
i
=
0
;
i
<
memory_data_
->
size
();
++
i
)
{
for
(
auto
&
vec
:
send_vec
)
{
// todo get ins id
vec
.
reserve
(
fleet_send_batch_size_
);
}
std
::
vector
<
std
::
future
<
int32_t
>>
total_status
;
auto
interval
=
GetMemoryDataInterval
();
VLOG
(
3
)
<<
"global shuffle data from ["
<<
interval
.
first
<<
", "
<<
interval
.
second
<<
"), thread_id="
<<
thread_id_
;
for
(
int64_t
i
=
interval
.
first
;
i
<
interval
.
second
;
++
i
)
{
// if get ins id, can also use hash
// std::string ins_id = memory_data_[i].ins_id;
// std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t
random_num
=
fleet_ptr
->
LocalRandomEngine
()();
int64_t
random_num
=
fleet_ptr
->
LocalRandomEngine
()();
int64_t
node_id
=
random_num
%
trainer_num_
;
int64_t
node_id
=
random_num
%
trainer_num_
;
std
::
string
str
;
send_vec
[
node_id
].
push_back
(
&
((
*
memory_data_
)[
i
]));
SerializeIns
((
*
memory_data_
)[
i
],
&
str
);
send_str_vec
[
node_id
]
+=
str
;
if
(
i
%
fleet_send_batch_size_
==
0
&&
i
!=
0
)
{
if
(
i
%
fleet_send_batch_size_
==
0
&&
i
!=
0
)
{
for
(
int
j
=
0
;
j
<
send_str_vec
.
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
send_vec
.
size
();
++
j
)
{
fleet_ptr
->
SendClientToClientMsg
(
0
,
j
,
send_str_vec
[
j
]);
std
::
string
send_str
;
send_str_vec
[
j
]
=
""
;
SerializeIns
(
send_vec
[
j
],
&
send_str
);
VLOG
(
3
)
<<
"send str_length="
<<
send_str
.
length
()
<<
", ins num="
<<
send_vec
[
j
].
size
()
<<
" to node_id="
<<
j
<<
", thread_id="
<<
thread_id_
;
auto
ret
=
fleet_ptr
->
SendClientToClientMsg
(
0
,
j
,
send_str
);
VLOG
(
3
)
<<
"end send, thread_id="
<<
thread_id_
;
send_vec
[
j
].
clear
();
total_status
.
push_back
(
std
::
move
(
ret
));
}
}
}
}
}
}
for
(
int
j
=
0
;
j
<
send_str_vec
.
size
();
++
j
)
{
for
(
int
j
=
0
;
j
<
send_vec
.
size
();
++
j
)
{
if
(
send_str_vec
[
j
].
length
()
!=
0
)
{
if
(
send_vec
[
j
].
size
()
!=
0
)
{
fleet_ptr
->
SendClientToClientMsg
(
0
,
j
,
send_str_vec
[
j
]);
std
::
string
send_str
;
SerializeIns
(
send_vec
[
j
],
&
send_str
);
VLOG
(
3
)
<<
"send str_length="
<<
send_str
.
length
()
<<
" to node_id="
<<
j
<<
", thread_id="
<<
thread_id_
;
auto
ret
=
fleet_ptr
->
SendClientToClientMsg
(
0
,
j
,
send_str
);
VLOG
(
3
)
<<
"end send, thread_id="
<<
thread_id_
;
total_status
.
push_back
(
std
::
move
(
ret
));
}
}
std
::
vector
<
T
*>
().
swap
(
send_vec
[
j
]);
}
for
(
auto
&
t
:
total_status
)
{
t
.
wait
();
}
}
VLOG
(
3
)
<<
"GlobalShuffle() end, thread_id="
<<
thread_id_
;
}
template
<
typename
T
>
std
::
pair
<
int64_t
,
int64_t
>
InMemoryDataFeed
<
T
>::
GetMemoryDataInterval
()
{
int64_t
start
=
0
;
int64_t
end
=
0
;
int64_t
size
=
memory_data_
->
size
();
for
(
int64_t
i
=
0
;
i
<=
static_cast
<
int64_t
>
(
thread_id_
);
++
i
)
{
int64_t
len
=
size
/
static_cast
<
int64_t
>
(
thread_num_
)
+
(
i
<
(
size
%
static_cast
<
int64_t
>
(
thread_num_
)));
start
=
end
;
end
+=
len
;
}
return
std
::
make_pair
(
start
,
end
);
}
}
// explicit instantiation
// explicit instantiation
...
@@ -519,7 +574,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
...
@@ -519,7 +574,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
const
char
*
str
=
reader
.
get
();
const
char
*
str
=
reader
.
get
();
std
::
string
line
=
std
::
string
(
str
);
std
::
string
line
=
std
::
string
(
str
);
VLOG
(
3
)
<<
line
;
//
VLOG(3) << line;
char
*
endptr
=
const_cast
<
char
*>
(
str
);
char
*
endptr
=
const_cast
<
char
*>
(
str
);
int
pos
=
0
;
int
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
use_slots_index_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
use_slots_index_
.
size
();
++
i
)
{
...
@@ -695,7 +750,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
...
@@ -695,7 +750,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
const
char
*
str
=
reader
.
get
();
const
char
*
str
=
reader
.
get
();
std
::
string
line
=
std
::
string
(
str
);
std
::
string
line
=
std
::
string
(
str
);
VLOG
(
3
)
<<
line
;
//
VLOG(3) << line;
char
*
endptr
=
const_cast
<
char
*>
(
str
);
char
*
endptr
=
const_cast
<
char
*>
(
str
);
int
pos
=
0
;
int
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
use_slots_index_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
use_slots_index_
.
size
();
++
i
)
{
...
@@ -830,13 +885,15 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
...
@@ -830,13 +885,15 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle
// todo serialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
SerializeIns
(
void
MultiSlotInMemoryDataFeed
::
SerializeIns
(
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
*
str
)
{
const
std
::
vector
<
std
::
vector
<
MultiSlotType
>*>&
ins
,
std
::
string
*
str
)
{
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
Serialize
(
ins
,
str
);
fleet_ptr
->
Serialize
(
ins
,
str
);
}
}
// todo deserialize ins in global shuffle
// todo deserialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
std
::
vector
<
MultiSlotType
>*
ins
,
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
const
std
::
string
&
str
)
{
std
::
vector
<
std
::
vector
<
MultiSlotType
>>*
ins
,
const
std
::
string
&
str
)
{
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
Deserialize
(
ins
,
str
);
fleet_ptr
->
Deserialize
(
ins
,
str
);
}
}
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
a5b1a0e1
...
@@ -21,6 +21,7 @@ limitations under the License. */
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include <vector>
#include <sstream>
#include <sstream>
#include <future>
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
@@ -52,7 +53,10 @@ namespace framework {
...
@@ -52,7 +53,10 @@ namespace framework {
// }
// }
class
DataFeed
{
class
DataFeed
{
public:
public:
DataFeed
()
{}
DataFeed
()
{
mutex_for_pick_file_
=
nullptr
;
file_idx_
=
nullptr
;
}
virtual
~
DataFeed
()
{}
virtual
~
DataFeed
()
{}
virtual
void
Init
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
)
=
0
;
virtual
void
Init
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
)
=
0
;
virtual
bool
CheckFile
(
const
char
*
filename
)
{
virtual
bool
CheckFile
(
const
char
*
filename
)
{
...
@@ -89,6 +93,12 @@ class DataFeed {
...
@@ -89,6 +93,12 @@ class DataFeed {
virtual
void
SetThreadNum
(
int
thread_num
)
{
}
virtual
void
SetThreadNum
(
int
thread_num
)
{
}
// This function will do nothing at default
// This function will do nothing at default
virtual
void
SetTrainerNum
(
int
trainer_num
)
{
}
virtual
void
SetTrainerNum
(
int
trainer_num
)
{
}
virtual
void
SetFileListMutex
(
std
::
mutex
*
mutex
)
{
mutex_for_pick_file_
=
mutex
;
}
virtual
void
SetFileListIndex
(
size_t
*
file_index
)
{
file_idx_
=
file_index
;
}
virtual
void
LoadIntoMemory
()
{
virtual
void
LoadIntoMemory
()
{
PADDLE_THROW
(
"This function(LoadIntoMemory) is not implemented."
);
PADDLE_THROW
(
"This function(LoadIntoMemory) is not implemented."
);
}
}
...
@@ -100,7 +110,9 @@ class DataFeed {
...
@@ -100,7 +110,9 @@ class DataFeed {
}
}
// This function will do nothing at default
// This function will do nothing at default
virtual
void
FillMemoryDataToChannel
()
{
}
virtual
void
FillMemoryDataToChannel
()
{
}
// This function will do nothing at default
virtual
void
FillChannelToMemoryData
()
{
}
virtual
void
FillChannelToMemoryData
()
{
}
// This function will do nothing at default
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
}
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
}
protected:
protected:
...
@@ -116,9 +128,9 @@ class DataFeed {
...
@@ -116,9 +128,9 @@ class DataFeed {
// safe).
// safe).
virtual
bool
PickOneFile
(
std
::
string
*
filename
);
virtual
bool
PickOneFile
(
std
::
string
*
filename
);
st
atic
st
d
::
vector
<
std
::
string
>
filelist_
;
std
::
vector
<
std
::
string
>
filelist_
;
s
tatic
size_t
file_idx_
;
s
ize_t
*
file_idx_
;
st
atic
std
::
mutex
mutex_for_pick_file_
;
st
d
::
mutex
*
mutex_for_pick_file_
;
// the alias of used slots, and its order is determined by
// the alias of used slots, and its order is determined by
// data_feed_desc(proto object)
// data_feed_desc(proto object)
...
@@ -141,7 +153,7 @@ class DataFeed {
...
@@ -141,7 +153,7 @@ class DataFeed {
int
batch_size_
;
int
batch_size_
;
bool
finish_init_
;
bool
finish_init_
;
static
bool
finish_set_filelist_
;
bool
finish_set_filelist_
;
bool
finish_start_
;
bool
finish_start_
;
std
::
string
pipe_command_
;
std
::
string
pipe_command_
;
};
};
...
@@ -215,8 +227,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
...
@@ -215,8 +227,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual
bool
ParseOneInstance
(
T
*
instance
)
=
0
;
virtual
bool
ParseOneInstance
(
T
*
instance
)
=
0
;
virtual
bool
ParseOneInstanceFromPipe
(
T
*
instance
)
=
0
;
virtual
bool
ParseOneInstanceFromPipe
(
T
*
instance
)
=
0
;
virtual
void
PutToFeedVec
(
const
T
&
ins_vec
)
=
0
;
virtual
void
PutToFeedVec
(
const
T
&
ins_vec
)
=
0
;
virtual
void
SerializeIns
(
const
T
&
ins
,
std
::
string
*
str
)
=
0
;
virtual
void
SerializeIns
(
const
std
::
vector
<
T
*>&
ins
,
std
::
string
*
str
)
=
0
;
virtual
void
DeserializeIns
(
T
*
ins
,
const
std
::
string
&
str
)
=
0
;
virtual
void
DeserializeIns
(
std
::
vector
<
T
>*
ins
,
const
std
::
string
&
str
)
=
0
;
virtual
std
::
pair
<
int64_t
,
int64_t
>
GetMemoryDataInterval
();
int
thread_id_
;
int
thread_id_
;
int
thread_num_
;
int
thread_num_
;
...
@@ -284,13 +297,13 @@ class MultiSlotType {
...
@@ -284,13 +297,13 @@ class MultiSlotType {
std
::
string
DebugString
()
{
std
::
string
DebugString
()
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
"type: "
<<
type_
<<
"
\n
"
;
ss
<<
"
\n
type: "
<<
type_
<<
"
\n
"
;
ss
<<
"offset:
\n
"
;
ss
<<
"offset:
"
;
ss
<<
"["
;
ss
<<
"["
;
for
(
const
size_t
&
i
:
offset_
)
{
for
(
const
size_t
&
i
:
offset_
)
{
ss
<<
offset_
[
i
]
<<
","
;
ss
<<
offset_
[
i
]
<<
","
;
}
}
ss
<<
"]
\n
data:
\n
["
;
ss
<<
"]
\n
data:
["
;
if
(
type_
[
0
]
==
'f'
)
{
if
(
type_
[
0
]
==
'f'
)
{
for
(
const
float
&
i
:
float_feasign_
)
{
for
(
const
float
&
i
:
float_feasign_
)
{
ss
<<
i
<<
","
;
ss
<<
i
<<
","
;
...
@@ -356,9 +369,9 @@ class MultiSlotInMemoryDataFeed
...
@@ -356,9 +369,9 @@ class MultiSlotInMemoryDataFeed
virtual
bool
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>*
instance
);
virtual
bool
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>*
instance
);
virtual
bool
ParseOneInstanceFromPipe
(
std
::
vector
<
MultiSlotType
>*
instance
);
virtual
bool
ParseOneInstanceFromPipe
(
std
::
vector
<
MultiSlotType
>*
instance
);
virtual
void
PutToFeedVec
(
const
std
::
vector
<
MultiSlotType
>&
ins_vec
);
virtual
void
PutToFeedVec
(
const
std
::
vector
<
MultiSlotType
>&
ins_vec
);
virtual
void
SerializeIns
(
const
std
::
vector
<
MultiSlotType
>&
ins
,
virtual
void
SerializeIns
(
const
std
::
vector
<
std
::
vector
<
MultiSlotType
>*
>&
ins
,
std
::
string
*
str
);
std
::
string
*
str
);
virtual
void
DeserializeIns
(
std
::
vector
<
MultiSlotType
>*
ins
,
virtual
void
DeserializeIns
(
std
::
vector
<
std
::
vector
<
MultiSlotType
>
>*
ins
,
const
std
::
string
&
str
);
const
std
::
string
&
str
);
};
};
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
a5b1a0e1
...
@@ -18,6 +18,8 @@
...
@@ -18,6 +18,8 @@
#include "google/protobuf/message.h"
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/framework/io/fs.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -25,12 +27,15 @@ namespace framework {
...
@@ -25,12 +27,15 @@ namespace framework {
template
<
typename
T
>
template
<
typename
T
>
DatasetImpl
<
T
>::
DatasetImpl
()
{
DatasetImpl
<
T
>::
DatasetImpl
()
{
thread_num_
=
1
;
thread_num_
=
1
;
trainer_num_
=
1
;
file_idx_
=
0
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
{
void
DatasetImpl
<
T
>::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
{
VLOG
(
3
)
<<
"filelist size: "
<<
filelist
.
size
();
VLOG
(
3
)
<<
"filelist size: "
<<
filelist
.
size
();
filelist_
=
filelist
;
filelist_
=
filelist
;
file_idx_
=
0
;
/*
/*
int file_cnt = filelist_.size();
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
if (thread_num_ > file_cnt) {
...
@@ -45,19 +50,34 @@ void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
...
@@ -45,19 +50,34 @@ void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
// not user friendly
// not user friendly
template
<
typename
T
>
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetThreadNum
(
int
thread_num
)
{
void
DatasetImpl
<
T
>::
SetThreadNum
(
int
thread_num
)
{
int
file_cnt
=
filelist_
.
size
();
VLOG
(
3
)
<<
"SetThreadNum thread_num="
<<
thread_num
;
//int file_cnt = filelist_.size();
/*
if (file_cnt != 0 && thread_num > file_cnt) {
if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(3) << "DataSet thread num = " << thread_num
VLOG(3) << "DataSet thread num = " << thread_num
<< ", file num = " << file_cnt
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
<< ". Changing DataSet thread num = " << file_cnt;
thread_num = file_cnt;
thread_num = file_cnt;
}
}
*/
thread_num_
=
thread_num
;
thread_num_
=
thread_num
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetTrainerNum
(
int
trainer_num
)
{
void
DatasetImpl
<
T
>::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
trainer_num_
=
trainer_num
;
// should inform reader of trainer_num directly
for
(
auto
reader
:
readers_
)
{
reader
->
SetTrainerNum
(
trainer_num
);
}
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetHdfsConfig
(
const
std
::
string
&
fs_name
,
const
std
::
string
&
fs_ugi
)
{
std
::
string
cmd
=
std
::
string
(
"hadoop fs"
);
cmd
+=
" -D fs.default.name="
+
fs_name
;
cmd
+=
" -D hadoop.job.ugi="
+
fs_ugi
;
paddle
::
framework
::
hdfs_set_command
(
cmd
);
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -75,6 +95,8 @@ DatasetImpl<T>::GetReaders() {
...
@@ -75,6 +95,8 @@ DatasetImpl<T>::GetReaders() {
template
<
typename
T
>
template
<
typename
T
>
void
DatasetImpl
<
T
>::
LoadIntoMemory
()
{
void
DatasetImpl
<
T
>::
LoadIntoMemory
()
{
VLOG
(
3
)
<<
"DatasetImpl<T>::LoadIntoMemory() begin"
;
VLOG
(
3
)
<<
"DatasetImpl<T>::LoadIntoMemory() begin"
;
platform
::
Timer
timeline
;
timeline
.
Start
();
if
(
readers_
.
size
()
==
0
)
{
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
CreateReaders
();
}
}
...
@@ -86,12 +108,17 @@ void DatasetImpl<T>::LoadIntoMemory() {
...
@@ -86,12 +108,17 @@ void DatasetImpl<T>::LoadIntoMemory() {
for
(
std
::
thread
&
t
:
load_threads
)
{
for
(
std
::
thread
&
t
:
load_threads
)
{
t
.
join
();
t
.
join
();
}
}
VLOG
(
3
)
<<
"DatasetImpl<T>::LoadIntoMemory() end"
;
timeline
.
Pause
();
VLOG
(
3
)
<<
"DatasetImpl<T>::LoadIntoMemory() end"
<<
", memory data size="
<<
memory_data_
.
size
()
<<
", cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds"
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
DatasetImpl
<
T
>::
LocalShuffle
()
{
void
DatasetImpl
<
T
>::
LocalShuffle
()
{
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() begin"
;
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() begin"
;
platform
::
Timer
timeline
;
timeline
.
Start
();
if
(
readers_
.
size
()
==
0
)
{
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
CreateReaders
();
}
}
...
@@ -107,23 +134,27 @@ void DatasetImpl<T>::LocalShuffle() {
...
@@ -107,23 +134,27 @@ void DatasetImpl<T>::LocalShuffle() {
t
.
join
();
t
.
join
();
}
}
std
::
vector
<
T
>
().
swap
(
memory_data_
);
std
::
vector
<
T
>
().
swap
(
memory_data_
);
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() end"
;
timeline
.
Pause
();
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() end, cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds"
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
DatasetImpl
<
T
>::
GlobalShuffle
()
{
void
DatasetImpl
<
T
>::
GlobalShuffle
()
{
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() begin"
;
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() begin"
;
if
(
readers_
.
size
()
==
0
)
{
platform
::
Timer
timeline
;
CreateReaders
();
timeline
.
Start
();
}
// if it is not InMemory, memory_data_ is empty
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
VLOG
(
3
)
<<
"RegisterClientToClientMsgHandler"
;
VLOG
(
3
)
<<
"RegisterClientToClientMsgHandler"
;
fleet_ptr
->
RegisterClientToClientMsgHandler
(
fleet_ptr
->
RegisterClientToClientMsgHandler
(
0
,
[
this
](
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
->
int
{
0
,
[
this
](
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
->
int
{
return
this
->
ReceiveFromClient
(
msg_type
,
client_id
,
msg
);
return
this
->
ReceiveFromClient
(
msg_type
,
client_id
,
msg
);
});
});
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
}
// if it is not InMemory, memory_data_ is empty
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
VLOG
(
3
)
<<
"start global shuffle threads"
;
VLOG
(
3
)
<<
"start global shuffle threads"
;
std
::
vector
<
std
::
thread
>
global_shuffle_threads
;
std
::
vector
<
std
::
thread
>
global_shuffle_threads
;
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
...
@@ -133,15 +164,32 @@ void DatasetImpl<T>::GlobalShuffle() {
...
@@ -133,15 +164,32 @@ void DatasetImpl<T>::GlobalShuffle() {
for
(
std
::
thread
&
t
:
global_shuffle_threads
)
{
for
(
std
::
thread
&
t
:
global_shuffle_threads
)
{
t
.
join
();
t
.
join
();
}
}
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() end"
;
std
::
vector
<
T
>
().
swap
(
memory_data_
);
timeline
.
Pause
();
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() end, cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds"
;
}
}
template
<
typename
T
>
template
<
typename
T
>
void
DatasetImpl
<
T
>::
CreateReaders
()
{
void
DatasetImpl
<
T
>::
CreateReaders
()
{
VLOG
(
3
)
<<
"Calling CreateReaders()"
;
VLOG
(
3
)
<<
"Calling CreateReaders()"
;
CHECK
(
thread_num_
>
0
)
<<
"thread_num should > 0"
;
CHECK
(
thread_num_
>
0
)
<<
"thread_num should > 0"
;
int
file_cnt
=
filelist_
.
size
();
int
memory_data_size
=
memory_data_
.
size
();
if
(
memory_data_size
!=
0
&&
thread_num_
>
memory_data_size
)
{
VLOG
(
3
)
<<
"Dataset thread num = "
<<
thread_num_
<<
", memory data size = "
<<
memory_data_size
<<
". Changing Dataset thread num = "
<<
memory_data_size
;
thread_num_
=
memory_data_size
;
}
else
if
(
file_cnt
!=
0
&&
thread_num_
>
file_cnt
)
{
VLOG
(
3
)
<<
"Dataset thread num = "
<<
thread_num_
<<
", file num = "
<<
file_cnt
<<
". Changing Dataset thread num = "
<<
file_cnt
;
thread_num_
=
file_cnt
;
}
VLOG
(
3
)
<<
"thread_num in Readers: "
<<
thread_num_
;
VLOG
(
3
)
<<
"thread_num in Readers: "
<<
thread_num_
;
VLOG
(
3
)
<<
"readers size: "
<<
readers_
.
size
();
VLOG
(
3
)
<<
"readers size: "
<<
readers_
.
size
();
VLOG
(
3
)
<<
"Filelist size in readers: "
<<
filelist_
.
size
();
if
(
readers_
.
size
()
!=
0
)
{
if
(
readers_
.
size
()
!=
0
)
{
return
;
return
;
}
}
...
@@ -154,9 +202,10 @@ void DatasetImpl<T>::CreateReaders() {
...
@@ -154,9 +202,10 @@ void DatasetImpl<T>::CreateReaders() {
readers_
.
back
()
->
SetThreadId
(
i
);
readers_
.
back
()
->
SetThreadId
(
i
);
readers_
.
back
()
->
SetThreadNum
(
thread_num_
);
readers_
.
back
()
->
SetThreadNum
(
thread_num_
);
readers_
.
back
()
->
SetTrainerNum
(
trainer_num_
);
readers_
.
back
()
->
SetTrainerNum
(
trainer_num_
);
readers_
.
back
()
->
SetFileListMutex
(
&
mutex_for_pick_file_
);
readers_
.
back
()
->
SetFileListIndex
(
&
file_idx_
);
readers_
.
back
()
->
SetFileList
(
filelist_
);
}
}
VLOG
(
3
)
<<
"Filelist size in readers: "
<<
filelist_
.
size
();
readers_
[
0
]
->
SetFileList
(
filelist_
);
}
}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -184,9 +233,12 @@ void DatasetImpl<T>::DestroyReaders() {
...
@@ -184,9 +233,12 @@ void DatasetImpl<T>::DestroyReaders() {
template
<
typename
T
>
template
<
typename
T
>
int
DatasetImpl
<
T
>::
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
int
DatasetImpl
<
T
>::
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
{
const
std
::
string
&
msg
)
{
// todo random
VLOG
(
3
)
<<
"ReceiveFromClient msg_type="
<<
msg_type
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
<<
", client_id="
<<
client_id
<<
", msg length="
int64_t
index
=
0
;
<<
msg
.
length
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
int64_t
index
=
fleet_ptr
->
LocalRandomEngine
()()
%
thread_num_
;
VLOG
(
3
)
<<
"ramdom index="
<<
index
;
readers_
[
index
]
->
PutInsToChannel
(
msg
);
readers_
[
index
]
->
PutInsToChannel
(
msg
);
return
0
;
return
0
;
}
}
...
...
paddle/fluid/framework/data_set.h
浏览文件 @
a5b1a0e1
...
@@ -33,6 +33,8 @@ class Dataset {
...
@@ -33,6 +33,8 @@ class Dataset {
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
=
0
;
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
=
0
;
virtual
void
SetThreadNum
(
int
thread_num
)
=
0
;
virtual
void
SetThreadNum
(
int
thread_num
)
=
0
;
virtual
void
SetTrainerNum
(
int
trainer_num
)
=
0
;
virtual
void
SetTrainerNum
(
int
trainer_num
)
=
0
;
virtual
void
SetHdfsConfig
(
const
std
::
string
&
fs_name
,
const
std
::
string
&
fs_ugi
)
=
0
;
virtual
void
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
=
0
;
virtual
void
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
=
0
;
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
=
0
;
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
=
0
;
virtual
int
GetThreadNum
()
=
0
;
virtual
int
GetThreadNum
()
=
0
;
...
@@ -60,6 +62,8 @@ class DatasetImpl : public Dataset {
...
@@ -60,6 +62,8 @@ class DatasetImpl : public Dataset {
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
);
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
);
virtual
void
SetThreadNum
(
int
thread_num
);
virtual
void
SetThreadNum
(
int
thread_num
);
virtual
void
SetTrainerNum
(
int
trainer_num
);
virtual
void
SetTrainerNum
(
int
trainer_num
);
virtual
void
SetHdfsConfig
(
const
std
::
string
&
fs_name
,
const
std
::
string
&
fs_ugi
);
virtual
void
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
);
virtual
void
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
);
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
{
return
filelist_
;
}
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
{
return
filelist_
;
}
...
@@ -85,8 +89,10 @@ class DatasetImpl : public Dataset {
...
@@ -85,8 +89,10 @@ class DatasetImpl : public Dataset {
std
::
mutex
mutex_for_update_memory_data_
;
std
::
mutex
mutex_for_update_memory_data_
;
int
thread_num_
;
int
thread_num_
;
paddle
::
framework
::
DataFeedDesc
data_feed_desc_
;
paddle
::
framework
::
DataFeedDesc
data_feed_desc_
;
std
::
vector
<
std
::
string
>
filelist_
;
int
trainer_num_
;
int
trainer_num_
;
std
::
vector
<
std
::
string
>
filelist_
;
size_t
file_idx_
;
std
::
mutex
mutex_for_pick_file_
;
};
};
class
MultiSlotDataset
:
public
DatasetImpl
<
std
::
vector
<
MultiSlotType
>>
{
class
MultiSlotDataset
:
public
DatasetImpl
<
std
::
vector
<
MultiSlotType
>>
{
...
...
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
a5b1a0e1
...
@@ -26,12 +26,14 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
...
@@ -26,12 +26,14 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset
*
dataset
)
{
Dataset
*
dataset
)
{
thread_num_
=
trainer_desc
.
thread_num
();
thread_num_
=
trainer_desc
.
thread_num
();
SetDataset
(
dataset
);
SetDataset
(
dataset
);
workers_
.
resize
(
thread_num_
);
dataset
->
CreateReaders
();
dataset
->
CreateReaders
();
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
dataset
->
GetReaders
();
dataset
->
GetReaders
();
thread_num_
=
readers
.
size
();
workers_
.
resize
(
thread_num_
);
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
trainer_desc
.
device_worker_name
());
trainer_desc
.
device_worker_name
());
...
...
paddle/fluid/framework/fleet/fleet_wrapper.cc
浏览文件 @
a5b1a0e1
...
@@ -29,6 +29,7 @@ limitations under the License. */
...
@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <utility>
#include <utility>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -203,6 +204,60 @@ void FleetWrapper::PullDenseVarsSync(
...
@@ -203,6 +204,60 @@ void FleetWrapper::PullDenseVarsSync(
#endif
#endif
}
}
void
FleetWrapper
::
PushDenseParamSync
(
const
ProgramDesc
&
program
,
const
uint64_t
table_id
,
const
std
::
vector
<
std
::
string
>&
var_names
)
{
#ifdef PADDLE_WITH_PSLIB
paddle
::
framework
::
Scope
scope
;
auto
&
block
=
program
.
Block
(
0
);
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
auto
*
ptr
=
scope
.
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
}
else
{
auto
*
ptr
=
scope
.
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
}
}
auto
place
=
platform
::
CPUPlace
();
std
::
vector
<
paddle
::
ps
::
Region
>
regions
;
for
(
auto
&
t
:
var_names
)
{
Variable
*
var
=
scope
.
FindVar
(
t
);
CHECK
(
var
!=
nullptr
)
<<
"var["
<<
t
<<
"] not found"
;
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
std
::
vector
<
int64_t
>
dim
;
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Name
()
==
t
)
{
dim
=
var
->
GetShape
();
break
;
}
}
int
cnt
=
1
;
for
(
auto
&
i
:
dim
)
{
cnt
*=
i
;
}
DDim
d
(
std
::
vector
<
int64_t
>
{
cnt
}.
data
(),
1
);
float
*
g
=
tensor
->
mutable_data
<
float
>
(
d
,
place
);
CHECK
(
g
!=
nullptr
)
<<
"var["
<<
t
<<
"] value not initialized"
;
float
init_range
=
0.2
;
int
rown
=
tensor
->
dims
()[
0
];
init_range
/=
sqrt
(
rown
);
std
::
normal_distribution
<
float
>
ndistr
(
0.0
,
1.0
);
for
(
auto
i
=
0u
;
i
<
tensor
->
numel
();
++
i
)
{
g
[
i
]
=
ndistr
(
LocalRandomEngine
())
*
init_range
;
}
paddle
::
ps
::
Region
reg
(
g
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
));
auto
push_status
=
pslib_ptr_
->
_worker_ptr
->
push_dense_param
(
regions
.
data
(),
regions
.
size
(),
table_id
);
push_status
.
wait
();
auto
status
=
push_status
.
get
();
CHECK
(
status
==
0
)
<<
"push dense param failed, status["
<<
status
<<
"]"
;
}
#endif
}
void
FleetWrapper
::
PushDenseVarsSync
(
void
FleetWrapper
::
PushDenseVarsSync
(
Scope
*
scope
,
const
uint64_t
table_id
,
Scope
*
scope
,
const
uint64_t
table_id
,
const
std
::
vector
<
std
::
string
>&
var_names
)
{}
const
std
::
vector
<
std
::
string
>&
var_names
)
{}
...
@@ -269,6 +324,8 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
...
@@ -269,6 +324,8 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
continue
;
continue
;
}
}
LOG
(
WARNING
)
<<
"going to memcpy"
;
LOG
(
WARNING
)
<<
"going to memcpy"
;
CHECK
(
fea_idx
<
(
*
push_values
).
size
());
CHECK
(
fea_idx
<
fea_labels
.
size
());
memcpy
((
*
push_values
)[
fea_idx
].
data
()
+
offset
,
g
,
memcpy
((
*
push_values
)[
fea_idx
].
data
()
+
offset
,
g
,
sizeof
(
float
)
*
emb_dim
);
sizeof
(
float
)
*
emb_dim
);
LOG
(
WARNING
)
<<
"show"
;
LOG
(
WARNING
)
<<
"show"
;
...
@@ -294,13 +351,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
...
@@ -294,13 +351,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
#endif
}
}
int
FleetWrapper
::
RegisterClientToClientMsgHandler
(
int
msg_type
,
int
FleetWrapper
::
RegisterClientToClientMsgHandler
(
MsgHandlerFunc
handler
)
{
int
msg_type
,
MsgHandlerFunc
handler
)
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
VLOG
(
3
)
<<
"calling FleetWrapper::RegisterClientToClientMsgHandler"
;
VLOG
(
3
)
<<
"calling FleetWrapper::RegisterClientToClientMsgHandler"
;
VLOG
(
3
)
<<
"pslib_ptr_="
<<
pslib_ptr_
;
VLOG
(
3
)
<<
"pslib_ptr_="
<<
pslib_ptr_
;
VLOG
(
3
)
<<
"_worker_ptr="
<<
pslib_ptr_
->
_worker_ptr
;
VLOG
(
3
)
<<
"_worker_ptr="
<<
pslib_ptr_
->
_worker_ptr
;
pslib_ptr_
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
handler
);
return
pslib_ptr_
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
handler
);
#else
#else
VLOG
(
0
)
<<
"FleetWrapper::RegisterClientToClientMsgHandler"
VLOG
(
0
)
<<
"FleetWrapper::RegisterClientToClientMsgHandler"
<<
" does nothing when no pslib"
;
<<
" does nothing when no pslib"
;
...
@@ -308,15 +365,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
...
@@ -308,15 +365,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
return
0
;
return
0
;
}
}
int
FleetWrapper
::
SendClientToClientMsg
(
int
msg_type
,
int
to_client_id
,
std
::
future
<
int32_t
>
FleetWrapper
::
SendClientToClientMsg
(
const
std
::
string
&
msg
)
{
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_
->
_worker_ptr
->
send_client2client_msg
(
msg_type
,
to_client_id
,
msg
);
return
pslib_ptr_
->
_worker_ptr
->
send_client2client_msg
(
msg_type
,
to_client_id
,
msg
);
#else
#else
VLOG
(
0
)
<<
"FleetWrapper::SendClientToClientMsg"
VLOG
(
0
)
<<
"FleetWrapper::SendClientToClientMsg"
<<
" does nothing when no pslib"
;
<<
" does nothing when no pslib"
;
#endif
#endif
return
0
;
return
std
::
future
<
int32_t
>
()
;
}
}
std
::
default_random_engine
&
FleetWrapper
::
LocalRandomEngine
()
{
std
::
default_random_engine
&
FleetWrapper
::
LocalRandomEngine
()
{
...
@@ -336,10 +393,12 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() {
...
@@ -336,10 +393,12 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() {
}
}
template
<
typename
T
>
template
<
typename
T
>
void
FleetWrapper
::
Serialize
(
const
T
&
t
,
std
::
string
*
str
)
{
void
FleetWrapper
::
Serialize
(
const
std
::
vector
<
T
*>
&
t
,
std
::
string
*
str
)
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
paddle
::
ps
::
BinaryArchive
ar
;
paddle
::
ps
::
BinaryArchive
ar
;
ar
<<
t
;
for
(
size_t
i
=
0
;
i
<
t
.
size
();
++
i
)
{
ar
<<
*
(
t
[
i
]);
}
*
str
=
std
::
string
(
ar
.
buffer
(),
ar
.
length
());
*
str
=
std
::
string
(
ar
.
buffer
(),
ar
.
length
());
#else
#else
VLOG
(
0
)
<<
"FleetWrapper::Serialize does nothing when no pslib"
;
VLOG
(
0
)
<<
"FleetWrapper::Serialize does nothing when no pslib"
;
...
@@ -347,20 +406,30 @@ void FleetWrapper::Serialize(const T& t, std::string* str) {
...
@@ -347,20 +406,30 @@ void FleetWrapper::Serialize(const T& t, std::string* str) {
}
}
template
<
typename
T
>
template
<
typename
T
>
void
FleetWrapper
::
Deserialize
(
T
*
t
,
const
std
::
string
&
str
)
{
void
FleetWrapper
::
Deserialize
(
std
::
vector
<
T
>
*
t
,
const
std
::
string
&
str
)
{
#ifdef PADDLE_WITH_PSLIB
#ifdef PADDLE_WITH_PSLIB
if
(
str
.
length
()
==
0
)
{
return
;
}
paddle
::
ps
::
BinaryArchive
ar
;
paddle
::
ps
::
BinaryArchive
ar
;
ar
.
set_read_buffer
(
const_cast
<
char
*>
(
str
.
c_str
()),
str
.
length
(),
nullptr
);
ar
.
set_read_buffer
(
const_cast
<
char
*>
(
str
.
c_str
()),
str
.
length
(),
nullptr
);
*
t
=
ar
.
get
<
T
>
();
if
(
ar
.
cursor
()
==
ar
.
finish
())
{
return
;
}
while
(
ar
.
cursor
()
<
ar
.
finish
())
{
t
->
push_back
(
ar
.
get
<
T
>
());
}
CHECK
(
ar
.
cursor
()
==
ar
.
finish
());
VLOG
(
3
)
<<
"Deserialize size "
<<
t
->
size
();
#else
#else
VLOG
(
0
)
<<
"FleetWrapper::Deserialize does nothing when no pslib"
;
VLOG
(
0
)
<<
"FleetWrapper::Deserialize does nothing when no pslib"
;
#endif
#endif
}
}
template
void
FleetWrapper
::
Serialize
<
std
::
vector
<
MultiSlotType
>
>
(
template
void
FleetWrapper
::
Serialize
<
std
::
vector
<
MultiSlotType
>
>
(
const
std
::
vector
<
MultiSlotType
>&
,
std
::
string
*
);
const
std
::
vector
<
std
::
vector
<
MultiSlotType
>*
>&
,
std
::
string
*
);
template
void
FleetWrapper
::
Deserialize
(
std
::
vector
<
MultiSlotType
>
*
,
template
void
FleetWrapper
::
Deserialize
<
std
::
vector
<
MultiSlotType
>
>
(
const
std
::
string
&
);
std
::
vector
<
std
::
vector
<
MultiSlotType
>>*
,
const
std
::
string
&
);
}
// end namespace framework
}
// end namespace framework
}
// end namespace paddle
}
// end namespace paddle
paddle/fluid/framework/fleet/fleet_wrapper.h
浏览文件 @
a5b1a0e1
...
@@ -27,6 +27,7 @@ limitations under the License. */
...
@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -71,6 +72,10 @@ class FleetWrapper {
...
@@ -71,6 +72,10 @@ class FleetWrapper {
const
std
::
vector
<
std
::
string
>&
var_names
,
const
std
::
vector
<
std
::
string
>&
var_names
,
std
::
vector
<::
std
::
future
<
int32_t
>>*
pull_dense_status
);
std
::
vector
<::
std
::
future
<
int32_t
>>*
pull_dense_status
);
void
PushDenseParamSync
(
const
ProgramDesc
&
program
,
const
uint64_t
table_id
,
const
std
::
vector
<
std
::
string
>&
var_names
);
// Push dense variables to server in async mode
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names,
// Param<in>: scope, table_id, var_names,
// Param<out>: push_sparse_status
// Param<out>: push_sparse_status
...
@@ -119,16 +124,15 @@ class FleetWrapper {
...
@@ -119,16 +124,15 @@ class FleetWrapper {
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
int
RegisterClientToClientMsgHandler
(
int
msg_type
,
MsgHandlerFunc
handler
);
int
RegisterClientToClientMsgHandler
(
int
msg_type
,
MsgHandlerFunc
handler
);
int
SendClientToClientMsg
(
int
msg_type
,
std
::
future
<
int32_t
>
SendClientToClientMsg
(
int
msg_type
,
int
to_client_id
,
int
to_client_id
,
const
std
::
string
&
msg
);
const
std
::
string
&
msg
);
std
::
default_random_engine
&
LocalRandomEngine
();
std
::
default_random_engine
&
LocalRandomEngine
();
template
<
typename
T
>
template
<
typename
T
>
void
Serialize
(
const
T
&
t
,
std
::
string
*
str
);
void
Serialize
(
const
std
::
vector
<
T
*>
&
t
,
std
::
string
*
str
);
template
<
typename
T
>
template
<
typename
T
>
void
Deserialize
(
T
*
t
,
const
std
::
string
&
str
);
void
Deserialize
(
std
::
vector
<
T
>*
t
,
const
std
::
string
&
str
);
static
std
::
shared_ptr
<
FleetWrapper
>
GetInstance
()
{
static
std
::
shared_ptr
<
FleetWrapper
>
GetInstance
()
{
if
(
NULL
==
s_instance_
)
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
framework
::
FleetWrapper
());
s_instance_
.
reset
(
new
paddle
::
framework
::
FleetWrapper
());
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
a5b1a0e1
...
@@ -26,13 +26,15 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
...
@@ -26,13 +26,15 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_
=
trainer_desc
.
thread_num
();
thread_num_
=
trainer_desc
.
thread_num
();
SetDataset
(
dataset
);
SetDataset
(
dataset
);
// get filelist from trainer_desc here
// get filelist from trainer_desc here
workers_
.
resize
(
thread_num_
);
VLOG
(
3
)
<<
"worker thread num: "
<<
thread_num_
;
dataset
->
CreateReaders
();
dataset
->
CreateReaders
();
VLOG
(
3
)
<<
"readers created"
;
VLOG
(
3
)
<<
"readers created"
;
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
dataset
->
GetReaders
();
dataset
->
GetReaders
();
VLOG
(
3
)
<<
"readers num: "
<<
readers
.
size
();
VLOG
(
3
)
<<
"readers num: "
<<
readers
.
size
();
// change thread num to readers num
thread_num_
=
readers
.
size
();
VLOG
(
3
)
<<
"worker thread num: "
<<
thread_num_
;
workers_
.
resize
(
thread_num_
);
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
trainer_desc
.
device_worker_name
());
trainer_desc
.
device_worker_name
());
...
...
paddle/fluid/pybind/async_executor_py.cc
浏览文件 @
a5b1a0e1
...
@@ -49,7 +49,7 @@ void BindAsyncExecutor(py::module* m) {
...
@@ -49,7 +49,7 @@ void BindAsyncExecutor(py::module* m) {
new
framework
::
AsyncExecutor
(
scope
,
place
));
new
framework
::
AsyncExecutor
(
scope
,
place
));
}))
}))
.
def
(
"run_from_files"
,
&
framework
::
AsyncExecutor
::
RunFromFile
)
.
def
(
"run_from_files"
,
&
framework
::
AsyncExecutor
::
RunFromFile
)
.
def
(
"run_from_dataset"
,
&
framework
::
AsyncExecutor
::
RunFromDataset
)
//
.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset)
.
def
(
"init_server"
,
&
framework
::
AsyncExecutor
::
InitServer
)
.
def
(
"init_server"
,
&
framework
::
AsyncExecutor
::
InitServer
)
.
def
(
"init_worker"
,
&
framework
::
AsyncExecutor
::
InitWorker
)
.
def
(
"init_worker"
,
&
framework
::
AsyncExecutor
::
InitWorker
)
.
def
(
"start_server"
,
&
framework
::
AsyncExecutor
::
StartServer
)
.
def
(
"start_server"
,
&
framework
::
AsyncExecutor
::
StartServer
)
...
...
paddle/fluid/pybind/data_set_py.cc
浏览文件 @
a5b1a0e1
...
@@ -50,6 +50,7 @@ void BindDataset(py::module* m) {
...
@@ -50,6 +50,7 @@ void BindDataset(py::module* m) {
.
def
(
"set_filelist"
,
&
framework
::
Dataset
::
SetFileList
)
.
def
(
"set_filelist"
,
&
framework
::
Dataset
::
SetFileList
)
.
def
(
"set_thread_num"
,
&
framework
::
Dataset
::
SetThreadNum
)
.
def
(
"set_thread_num"
,
&
framework
::
Dataset
::
SetThreadNum
)
.
def
(
"set_trainer_num"
,
&
framework
::
Dataset
::
SetTrainerNum
)
.
def
(
"set_trainer_num"
,
&
framework
::
Dataset
::
SetTrainerNum
)
.
def
(
"set_hdfs_config"
,
&
framework
::
Dataset
::
SetHdfsConfig
)
.
def
(
"set_data_feed_desc"
,
&
framework
::
Dataset
::
SetDataFeedDesc
)
.
def
(
"set_data_feed_desc"
,
&
framework
::
Dataset
::
SetDataFeedDesc
)
.
def
(
"load_into_memory"
,
&
framework
::
Dataset
::
LoadIntoMemory
)
.
def
(
"load_into_memory"
,
&
framework
::
Dataset
::
LoadIntoMemory
)
.
def
(
"local_shuffle"
,
&
framework
::
Dataset
::
LocalShuffle
)
.
def
(
"local_shuffle"
,
&
framework
::
Dataset
::
LocalShuffle
)
...
...
paddle/fluid/pybind/fleet_wrapper_py.cc
浏览文件 @
a5b1a0e1
...
@@ -47,6 +47,7 @@ void BindFleetWrapper(py::module* m) {
...
@@ -47,6 +47,7 @@ void BindFleetWrapper(py::module* m) {
.
def
(
"init_server"
,
&
framework
::
FleetWrapper
::
InitServer
)
.
def
(
"init_server"
,
&
framework
::
FleetWrapper
::
InitServer
)
.
def
(
"run_server"
,
&
framework
::
FleetWrapper
::
RunServer
)
.
def
(
"run_server"
,
&
framework
::
FleetWrapper
::
RunServer
)
.
def
(
"init_worker"
,
&
framework
::
FleetWrapper
::
InitWorker
)
.
def
(
"init_worker"
,
&
framework
::
FleetWrapper
::
InitWorker
)
.
def
(
"init_model"
,
&
framework
::
FleetWrapper
::
PushDenseParamSync
)
.
def
(
"stop_server"
,
&
framework
::
FleetWrapper
::
StopServer
)
.
def
(
"stop_server"
,
&
framework
::
FleetWrapper
::
StopServer
)
.
def
(
"gather_servers"
,
&
framework
::
FleetWrapper
::
GatherServers
);
.
def
(
"gather_servers"
,
&
framework
::
FleetWrapper
::
GatherServers
);
}
// end FleetWrapper
}
// end FleetWrapper
...
...
python/paddle/fluid/dataset.py
浏览文件 @
a5b1a0e1
...
@@ -86,6 +86,9 @@ class DatasetBase(object):
...
@@ -86,6 +86,9 @@ class DatasetBase(object):
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
)
)
def
set_hdfs_config
(
self
,
fs_name
,
fs_ugi
):
self
.
dataset
.
set_hdfs_config
(
fs_name
,
fs_ugi
)
def
_prepare_to_run
(
self
):
def
_prepare_to_run
(
self
):
self
.
dataset
.
set_data_feed_desc
(
self
.
desc
())
self
.
dataset
.
set_data_feed_desc
(
self
.
desc
())
...
@@ -115,11 +118,15 @@ class InMemoryDataset(DatasetBase):
...
@@ -115,11 +118,15 @@ class InMemoryDataset(DatasetBase):
def
local_shuffle
(
self
):
def
local_shuffle
(
self
):
self
.
dataset
.
local_shuffle
()
self
.
dataset
.
local_shuffle
()
def
global_shuffle
(
self
):
def
global_shuffle
(
self
,
fleet
=
None
):
from
.distributed
import
ps_instance
trainer_num
=
1
instance
=
ps_instance
.
PaddlePSInstance
(
1
,
2
)
if
fleet
is
not
None
:
self
.
dataset
.
set_trainer_num
(
instance
.
get_worker_num
())
fleet
.
fleet_instance
.
role_maker_
.
barrier_worker
()
trainer_num
=
fleet
.
worker_num
()
self
.
dataset
.
set_trainer_num
(
trainer_num
)
self
.
dataset
.
global_shuffle
()
self
.
dataset
.
global_shuffle
()
if
fleet
is
not
None
:
fleet
.
fleet_instance
.
role_maker_
.
barrier_worker
()
class
QueueDataset
(
DatasetBase
):
class
QueueDataset
(
DatasetBase
):
...
@@ -130,5 +137,5 @@ class QueueDataset(DatasetBase):
...
@@ -130,5 +137,5 @@ class QueueDataset(DatasetBase):
def
local_shuffle
(
self
):
def
local_shuffle
(
self
):
pass
pass
def
global_shuffle
(
self
):
def
global_shuffle
(
self
,
fleet
=
None
):
pass
pass
python/paddle/fluid/incubate/fleet/base/role_maker.py
浏览文件 @
a5b1a0e1
...
@@ -170,7 +170,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
...
@@ -170,7 +170,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
"""
"""
if
self
.
_check_role_generation
():
if
self
.
_check_role_generation
():
if
self
.
is_worker
():
if
self
.
is_worker
():
return
self
.
get_size
()
return
self
.
get_size
()
/
2
;
return
0
return
0
def
server_num
(
self
):
def
server_num
(
self
):
...
@@ -179,7 +179,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
...
@@ -179,7 +179,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
"""
"""
if
self
.
_check_role_generation
():
if
self
.
_check_role_generation
():
if
self
.
is_server
():
if
self
.
is_server
():
return
self
.
get_size
()
return
self
.
get_size
()
/
2
;
return
0
return
0
def
worker_index
(
self
):
def
worker_index
(
self
):
...
...
python/paddle/fluid/incubate/fleet/parameter_server/__init__.py
浏览文件 @
a5b1a0e1
...
@@ -43,7 +43,7 @@ class Fleet(object):
...
@@ -43,7 +43,7 @@ class Fleet(object):
save_pserver_model(): save model parameters in pserver, called from a server node
save_pserver_model(): save model parameters in pserver, called from a server node
Example:
Example:
.. code-block:: python
.. code-block:: python
import paddle.fluid.incubate.fleet.parameter_server as fleet
import paddle.fluid.incubate.fleet.parameter_server as fleet
from my_model import bow_net
from my_model import bow_net
...
@@ -58,7 +58,7 @@ class Fleet(object):
...
@@ -58,7 +58,7 @@ class Fleet(object):
fleet.init_worker() # init worker should be called before training
fleet.init_worker() # init worker should be called before training
# do other things like training
# do other things like training
elif fleet.is_server():
elif fleet.is_server():
fleet.init_pserver()
fleet.init_pserver()
fleet.stop()
fleet.stop()
"""
"""
...
@@ -75,7 +75,7 @@ class Fleet(object):
...
@@ -75,7 +75,7 @@ class Fleet(object):
"""
"""
init(): which should be called only once in user's python scripts. init() will initialize
init(): which should be called only once in user's python scripts. init() will initialize
FleetWrapper in CPP, it will also initialize a RoleMaker which is used for identifying
FleetWrapper in CPP, it will also initialize a RoleMaker which is used for identifying
current node's role, e.g. worker, server, etc.
current node's role, e.g. worker, server, etc.
"""
"""
if
not
self
.
is_initialized_
:
if
not
self
.
is_initialized_
:
self
.
role_maker_
=
MPISymetricRoleMaker
()
self
.
role_maker_
=
MPISymetricRoleMaker
()
...
@@ -122,7 +122,7 @@ class Fleet(object):
...
@@ -122,7 +122,7 @@ class Fleet(object):
print
(
"You should run DistributedOptimizer.minimize() first"
)
print
(
"You should run DistributedOptimizer.minimize() first"
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
def
init_worker
(
self
):
def
init_worker
(
self
,
program
):
"""
"""
init_worker(): will be called by user. When a user knows current process is_server(), he/she
init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect
should call init_worker() to initialize global information about worker and connect
...
@@ -143,6 +143,19 @@ class Fleet(object):
...
@@ -143,6 +143,19 @@ class Fleet(object):
self
.
role_maker_
.
get_rank
())
self
.
role_maker_
.
get_rank
())
self
.
role_maker_
.
barrier_all
()
self
.
role_maker_
.
barrier_all
()
self
.
role_maker_
.
barrier_worker
()
self
.
role_maker_
.
barrier_worker
()
if
self
.
role_maker_
.
is_first_worker
():
tables
=
self
.
_dist_desc
.
trainer_param
.
dense_table
.
_values
for
i
in
range
(
0
,
len
(
tables
)):
table
=
tables
[
i
];
var_name_list
=
[]
for
i
in
range
(
0
,
len
(
table
.
dense_variable_name
)):
var_name_list
.
append
(
table
.
dense_variable_name
[
i
])
#print "table id ", table.table_id
#print "var_name_list ", var_name_list
self
.
_fleet_ptr
.
init_model
(
program
.
desc
,
int
(
table
.
table_id
),
var_name_list
)
self
.
role_maker_
.
barrier_worker
()
else
:
else
:
print
(
"You should run DistributedOptimizer.minimize() first"
)
print
(
"You should run DistributedOptimizer.minimize() first"
)
sys
.
exit
(
-
1
)
sys
.
exit
(
-
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录