Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a5b1a0e1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a5b1a0e1
编写于
3月 20, 2019
作者:
X
xujiaqi01
提交者:
dongdaxiang
3月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support multi dataset && add init model && fix bug
上级
3c65cc1b
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
341 addition
and
113 deletion
+341
-113
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+2
-1
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+106
-49
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+25
-12
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+67
-15
paddle/fluid/framework/data_set.h
paddle/fluid/framework/data_set.h
+7
-1
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+3
-1
paddle/fluid/framework/fleet/fleet_wrapper.cc
paddle/fluid/framework/fleet/fleet_wrapper.cc
+83
-14
paddle/fluid/framework/fleet/fleet_wrapper.h
paddle/fluid/framework/fleet/fleet_wrapper.h
+10
-6
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+4
-2
paddle/fluid/pybind/async_executor_py.cc
paddle/fluid/pybind/async_executor_py.cc
+1
-1
paddle/fluid/pybind/data_set_py.cc
paddle/fluid/pybind/data_set_py.cc
+1
-0
paddle/fluid/pybind/fleet_wrapper_py.cc
paddle/fluid/pybind/fleet_wrapper_py.cc
+1
-0
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+12
-5
python/paddle/fluid/incubate/fleet/base/role_maker.py
python/paddle/fluid/incubate/fleet/base/role_maker.py
+2
-2
python/paddle/fluid/incubate/fleet/parameter_server/__init__.py
.../paddle/fluid/incubate/fleet/parameter_server/__init__.py
+17
-4
未找到文件。
paddle/fluid/framework/async_executor.cc
浏览文件 @
a5b1a0e1
...
...
@@ -155,7 +155,8 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
}
#ifdef PADDLE_WITH_PSLIB
if
(
mode
==
"mpi"
)
{
_pull_dense_thread
->
stop
();
// todo ?
//_pull_dense_thread->stop();
}
#endif
VLOG
(
3
)
<<
"start to run from files in async_executor"
;
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
a5b1a0e1
...
...
@@ -23,15 +23,11 @@ limitations under the License. */
#include "io/shell.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/platform/timer.h"
namespace
paddle
{
namespace
framework
{
std
::
vector
<
std
::
string
>
DataFeed
::
filelist_
;
size_t
DataFeed
::
file_idx_
;
std
::
mutex
DataFeed
::
mutex_for_pick_file_
;
bool
DataFeed
::
finish_set_filelist_
;
void
DataFeed
::
AddFeedVar
(
Variable
*
var
,
const
std
::
string
&
name
)
{
CheckInit
();
for
(
size_t
i
=
0
;
i
<
use_slots_
.
size
();
++
i
)
{
...
...
@@ -42,7 +38,7 @@ void DataFeed::AddFeedVar(Variable* var, const std::string& name) {
}
bool
DataFeed
::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
files
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
*
mutex_for_pick_file_
);
CheckInit
();
// Do not set finish_set_filelist_ flag,
// since a user may set file many times after init reader
...
...
@@ -52,9 +48,8 @@ bool DataFeed::SetFileList(const std::vector<std::string>& files) {
return false;
}
*/
PADDLE_ENFORCE
(
files
.
size
(),
"You have set an empty filelist."
);
//
PADDLE_ENFORCE(files.size(), "You have set an empty filelist.");
filelist_
.
assign
(
files
.
begin
(),
files
.
end
());
file_idx_
=
0
;
finish_set_filelist_
=
true
;
return
true
;
...
...
@@ -66,13 +61,17 @@ void DataFeed::SetBatchSize(int batch_size) {
}
bool
DataFeed
::
PickOneFile
(
std
::
string
*
filename
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_for_pick_file_
);
if
(
file_idx_
==
filelist_
.
size
())
{
PADDLE_ENFORCE
(
mutex_for_pick_file_
!=
nullptr
,
"should call SetFileListMutex before PickOneFile"
);
PADDLE_ENFORCE
(
file_idx_
!=
nullptr
,
"should call SetFileListIndex before PickOneFile"
);
std
::
unique_lock
<
std
::
mutex
>
lock
(
*
mutex_for_pick_file_
);
if
(
*
file_idx_
==
filelist_
.
size
())
{
VLOG
(
3
)
<<
"DataFeed::PickOneFile no more file to pick"
;
return
false
;
}
VLOG
(
3
)
<<
"file_idx_="
<<
file_idx_
;
*
filename
=
filelist_
[
file_idx_
++
];
VLOG
(
3
)
<<
"file_idx_="
<<
*
file_idx_
;
*
filename
=
filelist_
[
(
*
file_idx_
)
++
];
// LOG(ERROR) << "pick file:" << *filename;
return
true
;
}
...
...
@@ -150,7 +149,11 @@ InMemoryDataFeed<T>::InMemoryDataFeed() {
cur_channel_
=
0
;
shuffled_ins_
=
std
::
make_shared
<
paddle
::
framework
::
BlockingQueue
<
T
>>
();
shuffled_ins_out_
=
std
::
make_shared
<
paddle
::
framework
::
BlockingQueue
<
T
>>
();
fleet_send_batch_size_
=
10000
;
fleet_send_batch_size_
=
80000
;
memory_data_
=
nullptr
;
mutex_for_update_memory_data_
=
nullptr
;
this
->
file_idx_
=
nullptr
;
this
->
mutex_for_pick_file_
=
nullptr
;
}
template
<
typename
T
>
...
...
@@ -192,6 +195,8 @@ int InMemoryDataFeed<T>::Next() {
out_channel
->
Push
(
std
::
move
(
instance
));
}
DataFeed
::
batch_size_
=
index
;
VLOG
(
3
)
<<
"batch_size_="
<<
DataFeed
::
batch_size_
<<
", thread_id="
<<
thread_id_
;
if
(
DataFeed
::
batch_size_
!=
0
)
{
PutToFeedVec
(
ins_vec
);
}
else
{
...
...
@@ -227,25 +232,22 @@ void InMemoryDataFeed<T>::SetTrainerNum(int trainer_num) {
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
T
ins
;
std
::
vector
<
T
>
ins
;
DeserializeIns
(
&
ins
,
ins_str
);
shuffled_ins_
->
Push
(
std
::
move
(
ins
));
shuffled_ins_
->
Extend
(
std
::
move
(
ins
));
VLOG
(
3
)
<<
"PutInsToChannel put ins num="
<<
ins
.
size
()
<<
" to channel, channel size="
<<
shuffled_ins_
->
Size
()
<<
" thread_id="
<<
thread_id_
;
}
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
FillMemoryDataToChannel
()
{
VLOG
(
3
)
<<
"FillMemoryDataToChannel, thread_id="
<<
thread_id_
;
int64_t
start
=
0
;
int64_t
end
=
0
;
int64_t
size
=
memory_data_
->
size
();
VLOG
(
3
)
<<
"memory_data size="
<<
size
;
for
(
int64_t
i
=
0
;
i
<=
static_cast
<
int64_t
>
(
thread_id_
);
++
i
)
{
int64_t
len
=
size
/
static_cast
<
int64_t
>
(
thread_num_
)
+
(
i
<
(
size
%
static_cast
<
int64_t
>
(
thread_num_
)));
start
=
end
;
end
+=
len
;
}
for
(
int64_t
i
=
start
;
i
<
end
;
++
i
)
{
auto
interval
=
GetMemoryDataInterval
();
VLOG
(
3
)
<<
"memory data size="
<<
memory_data_
->
size
()
<<
", fill data from ["
<<
interval
.
first
<<
", "
<<
interval
.
second
<<
"), thread_id="
<<
thread_id_
;
for
(
int64_t
i
=
interval
.
first
;
i
<
interval
.
second
;
++
i
)
{
T
&
t
=
(
*
memory_data_
)[
i
];
shuffled_ins_
->
Push
(
std
::
move
(
t
));
}
...
...
@@ -256,14 +258,19 @@ void InMemoryDataFeed<T>::FillChannelToMemoryData() {
VLOG
(
3
)
<<
"FillChannelToMemoryData, thread_id="
<<
thread_id_
;
std
::
vector
<
T
>
local_vec
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
channel
=
nullptr
;
std
::
shared_ptr
<
paddle
::
framework
::
BlockingQueue
<
T
>>
pre_channel
=
nullptr
;
if
(
cur_channel_
==
0
)
{
channel
=
shuffled_ins_
;
pre_channel
=
shuffled_ins_out_
;
}
else
{
channel
=
shuffled_ins_out_
;
pre_channel
=
shuffled_ins_
;
}
CHECK
(
channel
!=
nullptr
);
CHECK
(
pre_channel
!=
nullptr
);
CHECK
(
pre_channel
->
Size
()
==
0
);
local_vec
.
resize
(
channel
->
Size
());
for
(
int64_t
i
=
0
;
i
<
channel
->
S
ize
();
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
local_vec
.
s
ize
();
++
i
)
{
channel
->
Pop
(
local_vec
[
i
]);
}
VLOG
(
3
)
<<
"local_vec size="
<<
local_vec
.
size
()
<<
", thread_id="
<<
thread_id_
;
...
...
@@ -289,20 +296,32 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
int
err_no
=
0
;
PrivateQueueDataFeed
<
T
>::
fp_
=
fs_open_read
(
filename
,
&
err_no
,
PrivateQueueDataFeed
<
T
>::
pipe_command_
);
CHECK
(
PrivateQueueDataFeed
<
T
>::
fp_
!=
nullptr
);
__fsetlocking
(
&*
PrivateQueueDataFeed
<
T
>::
fp_
,
FSETLOCKING_BYCALLER
);
T
instance
;
platform
::
Timer
timeline
;
timeline
.
Start
();
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
local_vec
.
push_back
(
instance
);
}
timeline
.
Pause
();
VLOG
(
3
)
<<
"LoadIntoMemory() read all lines, file="
<<
filename
<<
", thread_id="
<<
thread_id_
;
<<
filename
<<
", cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds, thread_id="
<<
thread_id_
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
*
mutex_for_update_memory_data_
);
timeline
.
Start
();
memory_data_
->
insert
(
memory_data_
->
end
(),
local_vec
.
begin
(),
local_vec
.
end
());
std
::
make_move_iterator
(
local_vec
.
begin
()),
std
::
make_move_iterator
(
local_vec
.
end
()));
timeline
.
Pause
();
VLOG
(
3
)
<<
"LoadIntoMemory() memory_data insert, cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds, thread_id="
<<
thread_id_
;
}
std
::
vector
<
T
>
().
swap
(
local_vec
);
local_vec
.
clear
(
);
}
std
::
vector
<
T
>
().
swap
(
local_vec
);
VLOG
(
3
)
<<
"LoadIntoMemory() end, thread_id="
<<
thread_id_
;
}
...
...
@@ -315,30 +334,66 @@ void InMemoryDataFeed<T>::LocalShuffle() {
template
<
typename
T
>
void
InMemoryDataFeed
<
T
>::
GlobalShuffle
()
{
VLOG
(
3
)
<<
"GlobalShuffle(), thread_id="
<<
thread_id_
;
VLOG
(
3
)
<<
"GlobalShuffle()
begin
, thread_id="
<<
thread_id_
;
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
std
::
vector
<
std
::
string
>
send_str_vec
(
trainer_num_
);
for
(
int64_t
i
=
0
;
i
<
memory_data_
->
size
();
++
i
)
{
// todo get ins id
std
::
vector
<
std
::
vector
<
T
*>>
send_vec
(
trainer_num_
);
for
(
auto
&
vec
:
send_vec
)
{
vec
.
reserve
(
fleet_send_batch_size_
);
}
std
::
vector
<
std
::
future
<
int32_t
>>
total_status
;
auto
interval
=
GetMemoryDataInterval
();
VLOG
(
3
)
<<
"global shuffle data from ["
<<
interval
.
first
<<
", "
<<
interval
.
second
<<
"), thread_id="
<<
thread_id_
;
for
(
int64_t
i
=
interval
.
first
;
i
<
interval
.
second
;
++
i
)
{
// if get ins id, can also use hash
// std::string ins_id = memory_data_[i].ins_id;
// todo hash
int64_t
random_num
=
fleet_ptr
->
LocalRandomEngine
()();
int64_t
node_id
=
random_num
%
trainer_num_
;
std
::
string
str
;
SerializeIns
((
*
memory_data_
)[
i
],
&
str
);
send_str_vec
[
node_id
]
+=
str
;
send_vec
[
node_id
].
push_back
(
&
((
*
memory_data_
)[
i
]));
if
(
i
%
fleet_send_batch_size_
==
0
&&
i
!=
0
)
{
for
(
int
j
=
0
;
j
<
send_str_vec
.
size
();
++
j
)
{
fleet_ptr
->
SendClientToClientMsg
(
0
,
j
,
send_str_vec
[
j
]);
send_str_vec
[
j
]
=
""
;
for
(
int
j
=
0
;
j
<
send_vec
.
size
();
++
j
)
{
std
::
string
send_str
;
SerializeIns
(
send_vec
[
j
],
&
send_str
);
VLOG
(
3
)
<<
"send str_length="
<<
send_str
.
length
()
<<
", ins num="
<<
send_vec
[
j
].
size
()
<<
" to node_id="
<<
j
<<
", thread_id="
<<
thread_id_
;
auto
ret
=
fleet_ptr
->
SendClientToClientMsg
(
0
,
j
,
send_str
);
VLOG
(
3
)
<<
"end send, thread_id="
<<
thread_id_
;
send_vec
[
j
].
clear
();
total_status
.
push_back
(
std
::
move
(
ret
));
}
}
}
for
(
int
j
=
0
;
j
<
send_str_vec
.
size
();
++
j
)
{
if
(
send_str_vec
[
j
].
length
()
!=
0
)
{
fleet_ptr
->
SendClientToClientMsg
(
0
,
j
,
send_str_vec
[
j
]);
for
(
int
j
=
0
;
j
<
send_vec
.
size
();
++
j
)
{
if
(
send_vec
[
j
].
size
()
!=
0
)
{
std
::
string
send_str
;
SerializeIns
(
send_vec
[
j
],
&
send_str
);
VLOG
(
3
)
<<
"send str_length="
<<
send_str
.
length
()
<<
" to node_id="
<<
j
<<
", thread_id="
<<
thread_id_
;
auto
ret
=
fleet_ptr
->
SendClientToClientMsg
(
0
,
j
,
send_str
);
VLOG
(
3
)
<<
"end send, thread_id="
<<
thread_id_
;
total_status
.
push_back
(
std
::
move
(
ret
));
}
std
::
vector
<
T
*>
().
swap
(
send_vec
[
j
]);
}
for
(
auto
&
t
:
total_status
)
{
t
.
wait
();
}
VLOG
(
3
)
<<
"GlobalShuffle() end, thread_id="
<<
thread_id_
;
}
template
<
typename
T
>
std
::
pair
<
int64_t
,
int64_t
>
InMemoryDataFeed
<
T
>::
GetMemoryDataInterval
()
{
int64_t
start
=
0
;
int64_t
end
=
0
;
int64_t
size
=
memory_data_
->
size
();
for
(
int64_t
i
=
0
;
i
<=
static_cast
<
int64_t
>
(
thread_id_
);
++
i
)
{
int64_t
len
=
size
/
static_cast
<
int64_t
>
(
thread_num_
)
+
(
i
<
(
size
%
static_cast
<
int64_t
>
(
thread_num_
)));
start
=
end
;
end
+=
len
;
}
return
std
::
make_pair
(
start
,
end
);
}
// explicit instantiation
...
...
@@ -519,7 +574,7 @@ bool MultiSlotDataFeed::ParseOneInstanceFromPipe(
const
char
*
str
=
reader
.
get
();
std
::
string
line
=
std
::
string
(
str
);
VLOG
(
3
)
<<
line
;
//
VLOG(3) << line;
char
*
endptr
=
const_cast
<
char
*>
(
str
);
int
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
use_slots_index_
.
size
();
++
i
)
{
...
...
@@ -695,7 +750,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(
const
char
*
str
=
reader
.
get
();
std
::
string
line
=
std
::
string
(
str
);
VLOG
(
3
)
<<
line
;
//
VLOG(3) << line;
char
*
endptr
=
const_cast
<
char
*>
(
str
);
int
pos
=
0
;
for
(
size_t
i
=
0
;
i
<
use_slots_index_
.
size
();
++
i
)
{
...
...
@@ -830,13 +885,15 @@ void MultiSlotInMemoryDataFeed::PutToFeedVec(
// todo serialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
SerializeIns
(
const
std
::
vector
<
MultiSlotType
>&
ins
,
std
::
string
*
str
)
{
const
std
::
vector
<
std
::
vector
<
MultiSlotType
>*>&
ins
,
std
::
string
*
str
)
{
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
Serialize
(
ins
,
str
);
}
// todo deserialize ins in global shuffle
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
std
::
vector
<
MultiSlotType
>*
ins
,
const
std
::
string
&
str
)
{
void
MultiSlotInMemoryDataFeed
::
DeserializeIns
(
std
::
vector
<
std
::
vector
<
MultiSlotType
>>*
ins
,
const
std
::
string
&
str
)
{
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
Deserialize
(
ins
,
str
);
}
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
a5b1a0e1
...
...
@@ -21,6 +21,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include <vector>
#include <sstream>
#include <future>
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
...
...
@@ -52,7 +53,10 @@ namespace framework {
// }
class
DataFeed
{
public:
DataFeed
()
{}
DataFeed
()
{
mutex_for_pick_file_
=
nullptr
;
file_idx_
=
nullptr
;
}
virtual
~
DataFeed
()
{}
virtual
void
Init
(
const
paddle
::
framework
::
DataFeedDesc
&
data_feed_desc
)
=
0
;
virtual
bool
CheckFile
(
const
char
*
filename
)
{
...
...
@@ -89,6 +93,12 @@ class DataFeed {
virtual
void
SetThreadNum
(
int
thread_num
)
{
}
// This function will do nothing at default
virtual
void
SetTrainerNum
(
int
trainer_num
)
{
}
virtual
void
SetFileListMutex
(
std
::
mutex
*
mutex
)
{
mutex_for_pick_file_
=
mutex
;
}
virtual
void
SetFileListIndex
(
size_t
*
file_index
)
{
file_idx_
=
file_index
;
}
virtual
void
LoadIntoMemory
()
{
PADDLE_THROW
(
"This function(LoadIntoMemory) is not implemented."
);
}
...
...
@@ -100,7 +110,9 @@ class DataFeed {
}
// This function will do nothing at default
virtual
void
FillMemoryDataToChannel
()
{
}
// This function will do nothing at default
virtual
void
FillChannelToMemoryData
()
{
}
// This function will do nothing at default
virtual
void
PutInsToChannel
(
const
std
::
string
&
ins_str
)
{
}
protected:
...
...
@@ -116,9 +128,9 @@ class DataFeed {
// safe).
virtual
bool
PickOneFile
(
std
::
string
*
filename
);
st
atic
st
d
::
vector
<
std
::
string
>
filelist_
;
s
tatic
size_t
file_idx_
;
st
atic
std
::
mutex
mutex_for_pick_file_
;
std
::
vector
<
std
::
string
>
filelist_
;
s
ize_t
*
file_idx_
;
st
d
::
mutex
*
mutex_for_pick_file_
;
// the alias of used slots, and its order is determined by
// data_feed_desc(proto object)
...
...
@@ -141,7 +153,7 @@ class DataFeed {
int
batch_size_
;
bool
finish_init_
;
static
bool
finish_set_filelist_
;
bool
finish_set_filelist_
;
bool
finish_start_
;
std
::
string
pipe_command_
;
};
...
...
@@ -215,8 +227,9 @@ class InMemoryDataFeed : public PrivateQueueDataFeed<T> {
virtual
bool
ParseOneInstance
(
T
*
instance
)
=
0
;
virtual
bool
ParseOneInstanceFromPipe
(
T
*
instance
)
=
0
;
virtual
void
PutToFeedVec
(
const
T
&
ins_vec
)
=
0
;
virtual
void
SerializeIns
(
const
T
&
ins
,
std
::
string
*
str
)
=
0
;
virtual
void
DeserializeIns
(
T
*
ins
,
const
std
::
string
&
str
)
=
0
;
virtual
void
SerializeIns
(
const
std
::
vector
<
T
*>&
ins
,
std
::
string
*
str
)
=
0
;
virtual
void
DeserializeIns
(
std
::
vector
<
T
>*
ins
,
const
std
::
string
&
str
)
=
0
;
virtual
std
::
pair
<
int64_t
,
int64_t
>
GetMemoryDataInterval
();
int
thread_id_
;
int
thread_num_
;
...
...
@@ -284,13 +297,13 @@ class MultiSlotType {
std
::
string
DebugString
()
{
std
::
stringstream
ss
;
ss
<<
"type: "
<<
type_
<<
"
\n
"
;
ss
<<
"offset:
\n
"
;
ss
<<
"
\n
type: "
<<
type_
<<
"
\n
"
;
ss
<<
"offset:
"
;
ss
<<
"["
;
for
(
const
size_t
&
i
:
offset_
)
{
ss
<<
offset_
[
i
]
<<
","
;
}
ss
<<
"]
\n
data:
\n
["
;
ss
<<
"]
\n
data:
["
;
if
(
type_
[
0
]
==
'f'
)
{
for
(
const
float
&
i
:
float_feasign_
)
{
ss
<<
i
<<
","
;
...
...
@@ -356,9 +369,9 @@ class MultiSlotInMemoryDataFeed
virtual
bool
ParseOneInstance
(
std
::
vector
<
MultiSlotType
>*
instance
);
virtual
bool
ParseOneInstanceFromPipe
(
std
::
vector
<
MultiSlotType
>*
instance
);
virtual
void
PutToFeedVec
(
const
std
::
vector
<
MultiSlotType
>&
ins_vec
);
virtual
void
SerializeIns
(
const
std
::
vector
<
MultiSlotType
>&
ins
,
virtual
void
SerializeIns
(
const
std
::
vector
<
std
::
vector
<
MultiSlotType
>*
>&
ins
,
std
::
string
*
str
);
virtual
void
DeserializeIns
(
std
::
vector
<
MultiSlotType
>*
ins
,
virtual
void
DeserializeIns
(
std
::
vector
<
std
::
vector
<
MultiSlotType
>
>*
ins
,
const
std
::
string
&
str
);
};
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
a5b1a0e1
...
...
@@ -18,6 +18,8 @@
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/framework/io/fs.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -25,12 +27,15 @@ namespace framework {
template
<
typename
T
>
DatasetImpl
<
T
>::
DatasetImpl
()
{
thread_num_
=
1
;
trainer_num_
=
1
;
file_idx_
=
0
;
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
{
VLOG
(
3
)
<<
"filelist size: "
<<
filelist
.
size
();
filelist_
=
filelist
;
file_idx_
=
0
;
/*
int file_cnt = filelist_.size();
if (thread_num_ > file_cnt) {
...
...
@@ -45,19 +50,34 @@ void DatasetImpl<T>::SetFileList(const std::vector<std::string>& filelist) {
// not user friendly
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetThreadNum
(
int
thread_num
)
{
int
file_cnt
=
filelist_
.
size
();
VLOG
(
3
)
<<
"SetThreadNum thread_num="
<<
thread_num
;
//int file_cnt = filelist_.size();
/*
if (file_cnt != 0 && thread_num > file_cnt) {
VLOG(3) << "DataSet thread num = " << thread_num
<< ", file num = " << file_cnt
<< ". Changing DataSet thread num = " << file_cnt;
thread_num = file_cnt;
}
}
*/
thread_num_
=
thread_num
;
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetTrainerNum
(
int
trainer_num
)
{
trainer_num_
=
trainer_num
;
// should inform reader of trainer_num directly
for
(
auto
reader
:
readers_
)
{
reader
->
SetTrainerNum
(
trainer_num
);
}
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
SetHdfsConfig
(
const
std
::
string
&
fs_name
,
const
std
::
string
&
fs_ugi
)
{
std
::
string
cmd
=
std
::
string
(
"hadoop fs"
);
cmd
+=
" -D fs.default.name="
+
fs_name
;
cmd
+=
" -D hadoop.job.ugi="
+
fs_ugi
;
paddle
::
framework
::
hdfs_set_command
(
cmd
);
}
template
<
typename
T
>
...
...
@@ -75,6 +95,8 @@ DatasetImpl<T>::GetReaders() {
template
<
typename
T
>
void
DatasetImpl
<
T
>::
LoadIntoMemory
()
{
VLOG
(
3
)
<<
"DatasetImpl<T>::LoadIntoMemory() begin"
;
platform
::
Timer
timeline
;
timeline
.
Start
();
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
}
...
...
@@ -86,12 +108,17 @@ void DatasetImpl<T>::LoadIntoMemory() {
for
(
std
::
thread
&
t
:
load_threads
)
{
t
.
join
();
}
VLOG
(
3
)
<<
"DatasetImpl<T>::LoadIntoMemory() end"
;
timeline
.
Pause
();
VLOG
(
3
)
<<
"DatasetImpl<T>::LoadIntoMemory() end"
<<
", memory data size="
<<
memory_data_
.
size
()
<<
", cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds"
;
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
LocalShuffle
()
{
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() begin"
;
platform
::
Timer
timeline
;
timeline
.
Start
();
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
}
...
...
@@ -107,23 +134,27 @@ void DatasetImpl<T>::LocalShuffle() {
t
.
join
();
}
std
::
vector
<
T
>
().
swap
(
memory_data_
);
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() end"
;
timeline
.
Pause
();
VLOG
(
3
)
<<
"DatasetImpl<T>::LocalShuffle() end, cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds"
;
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
GlobalShuffle
()
{
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() begin"
;
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
}
// if it is not InMemory, memory_data_ is empty
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
platform
::
Timer
timeline
;
timeline
.
Start
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
VLOG
(
3
)
<<
"RegisterClientToClientMsgHandler"
;
fleet_ptr
->
RegisterClientToClientMsgHandler
(
0
,
[
this
](
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
->
int
{
return
this
->
ReceiveFromClient
(
msg_type
,
client_id
,
msg
);
});
if
(
readers_
.
size
()
==
0
)
{
CreateReaders
();
}
// if it is not InMemory, memory_data_ is empty
std
::
random_shuffle
(
memory_data_
.
begin
(),
memory_data_
.
end
());
VLOG
(
3
)
<<
"start global shuffle threads"
;
std
::
vector
<
std
::
thread
>
global_shuffle_threads
;
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
...
...
@@ -133,15 +164,32 @@ void DatasetImpl<T>::GlobalShuffle() {
for
(
std
::
thread
&
t
:
global_shuffle_threads
)
{
t
.
join
();
}
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() end"
;
std
::
vector
<
T
>
().
swap
(
memory_data_
);
timeline
.
Pause
();
VLOG
(
3
)
<<
"DatasetImpl<T>::GlobalShuffle() end, cost time="
<<
timeline
.
ElapsedSec
()
<<
" seconds"
;
}
template
<
typename
T
>
void
DatasetImpl
<
T
>::
CreateReaders
()
{
VLOG
(
3
)
<<
"Calling CreateReaders()"
;
CHECK
(
thread_num_
>
0
)
<<
"thread_num should > 0"
;
int
file_cnt
=
filelist_
.
size
();
int
memory_data_size
=
memory_data_
.
size
();
if
(
memory_data_size
!=
0
&&
thread_num_
>
memory_data_size
)
{
VLOG
(
3
)
<<
"Dataset thread num = "
<<
thread_num_
<<
", memory data size = "
<<
memory_data_size
<<
". Changing Dataset thread num = "
<<
memory_data_size
;
thread_num_
=
memory_data_size
;
}
else
if
(
file_cnt
!=
0
&&
thread_num_
>
file_cnt
)
{
VLOG
(
3
)
<<
"Dataset thread num = "
<<
thread_num_
<<
", file num = "
<<
file_cnt
<<
". Changing Dataset thread num = "
<<
file_cnt
;
thread_num_
=
file_cnt
;
}
VLOG
(
3
)
<<
"thread_num in Readers: "
<<
thread_num_
;
VLOG
(
3
)
<<
"readers size: "
<<
readers_
.
size
();
VLOG
(
3
)
<<
"Filelist size in readers: "
<<
filelist_
.
size
();
if
(
readers_
.
size
()
!=
0
)
{
return
;
}
...
...
@@ -154,9 +202,10 @@ void DatasetImpl<T>::CreateReaders() {
readers_
.
back
()
->
SetThreadId
(
i
);
readers_
.
back
()
->
SetThreadNum
(
thread_num_
);
readers_
.
back
()
->
SetTrainerNum
(
trainer_num_
);
readers_
.
back
()
->
SetFileListMutex
(
&
mutex_for_pick_file_
);
readers_
.
back
()
->
SetFileListIndex
(
&
file_idx_
);
readers_
.
back
()
->
SetFileList
(
filelist_
);
}
VLOG
(
3
)
<<
"Filelist size in readers: "
<<
filelist_
.
size
();
readers_
[
0
]
->
SetFileList
(
filelist_
);
}
template
<
typename
T
>
...
...
@@ -184,9 +233,12 @@ void DatasetImpl<T>::DestroyReaders() {
template
<
typename
T
>
int
DatasetImpl
<
T
>::
ReceiveFromClient
(
int
msg_type
,
int
client_id
,
const
std
::
string
&
msg
)
{
// todo random
// int64_t index = paddle::ps::local_random_engine()() % thread_num_;
int64_t
index
=
0
;
VLOG
(
3
)
<<
"ReceiveFromClient msg_type="
<<
msg_type
<<
", client_id="
<<
client_id
<<
", msg length="
<<
msg
.
length
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
int64_t
index
=
fleet_ptr
->
LocalRandomEngine
()()
%
thread_num_
;
VLOG
(
3
)
<<
"ramdom index="
<<
index
;
readers_
[
index
]
->
PutInsToChannel
(
msg
);
return
0
;
}
...
...
paddle/fluid/framework/data_set.h
浏览文件 @
a5b1a0e1
...
...
@@ -33,6 +33,8 @@ class Dataset {
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
)
=
0
;
virtual
void
SetThreadNum
(
int
thread_num
)
=
0
;
virtual
void
SetTrainerNum
(
int
trainer_num
)
=
0
;
virtual
void
SetHdfsConfig
(
const
std
::
string
&
fs_name
,
const
std
::
string
&
fs_ugi
)
=
0
;
virtual
void
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
)
=
0
;
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
=
0
;
virtual
int
GetThreadNum
()
=
0
;
...
...
@@ -60,6 +62,8 @@ class DatasetImpl : public Dataset {
virtual
void
SetFileList
(
const
std
::
vector
<
std
::
string
>&
filelist
);
virtual
void
SetThreadNum
(
int
thread_num
);
virtual
void
SetTrainerNum
(
int
trainer_num
);
virtual
void
SetHdfsConfig
(
const
std
::
string
&
fs_name
,
const
std
::
string
&
fs_ugi
);
virtual
void
SetDataFeedDesc
(
const
std
::
string
&
data_feed_desc_str
);
virtual
const
std
::
vector
<
std
::
string
>&
GetFileList
()
{
return
filelist_
;
}
...
...
@@ -85,8 +89,10 @@ class DatasetImpl : public Dataset {
std
::
mutex
mutex_for_update_memory_data_
;
int
thread_num_
;
paddle
::
framework
::
DataFeedDesc
data_feed_desc_
;
std
::
vector
<
std
::
string
>
filelist_
;
int
trainer_num_
;
std
::
vector
<
std
::
string
>
filelist_
;
size_t
file_idx_
;
std
::
mutex
mutex_for_pick_file_
;
};
class
MultiSlotDataset
:
public
DatasetImpl
<
std
::
vector
<
MultiSlotType
>>
{
...
...
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
a5b1a0e1
...
...
@@ -26,12 +26,14 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc,
Dataset
*
dataset
)
{
thread_num_
=
trainer_desc
.
thread_num
();
SetDataset
(
dataset
);
workers_
.
resize
(
thread_num_
);
dataset
->
CreateReaders
();
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
dataset
->
GetReaders
();
thread_num_
=
readers
.
size
();
workers_
.
resize
(
thread_num_
);
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
trainer_desc
.
device_worker_name
());
...
...
paddle/fluid/framework/fleet/fleet_wrapper.cc
浏览文件 @
a5b1a0e1
...
...
@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include <utility>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/scope.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -203,6 +204,60 @@ void FleetWrapper::PullDenseVarsSync(
#endif
}
void
FleetWrapper
::
PushDenseParamSync
(
const
ProgramDesc
&
program
,
const
uint64_t
table_id
,
const
std
::
vector
<
std
::
string
>&
var_names
)
{
#ifdef PADDLE_WITH_PSLIB
paddle
::
framework
::
Scope
scope
;
auto
&
block
=
program
.
Block
(
0
);
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
auto
*
ptr
=
scope
.
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
}
else
{
auto
*
ptr
=
scope
.
Var
(
var
->
Name
());
InitializeVariable
(
ptr
,
var
->
GetType
());
}
}
auto
place
=
platform
::
CPUPlace
();
std
::
vector
<
paddle
::
ps
::
Region
>
regions
;
for
(
auto
&
t
:
var_names
)
{
Variable
*
var
=
scope
.
FindVar
(
t
);
CHECK
(
var
!=
nullptr
)
<<
"var["
<<
t
<<
"] not found"
;
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
std
::
vector
<
int64_t
>
dim
;
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Name
()
==
t
)
{
dim
=
var
->
GetShape
();
break
;
}
}
int
cnt
=
1
;
for
(
auto
&
i
:
dim
)
{
cnt
*=
i
;
}
DDim
d
(
std
::
vector
<
int64_t
>
{
cnt
}.
data
(),
1
);
float
*
g
=
tensor
->
mutable_data
<
float
>
(
d
,
place
);
CHECK
(
g
!=
nullptr
)
<<
"var["
<<
t
<<
"] value not initialized"
;
float
init_range
=
0.2
;
int
rown
=
tensor
->
dims
()[
0
];
init_range
/=
sqrt
(
rown
);
std
::
normal_distribution
<
float
>
ndistr
(
0.0
,
1.0
);
for
(
auto
i
=
0u
;
i
<
tensor
->
numel
();
++
i
)
{
g
[
i
]
=
ndistr
(
LocalRandomEngine
())
*
init_range
;
}
paddle
::
ps
::
Region
reg
(
g
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
));
auto
push_status
=
pslib_ptr_
->
_worker_ptr
->
push_dense_param
(
regions
.
data
(),
regions
.
size
(),
table_id
);
push_status
.
wait
();
auto
status
=
push_status
.
get
();
CHECK
(
status
==
0
)
<<
"push dense param failed, status["
<<
status
<<
"]"
;
}
#endif
}
void
FleetWrapper
::
PushDenseVarsSync
(
Scope
*
scope
,
const
uint64_t
table_id
,
const
std
::
vector
<
std
::
string
>&
var_names
)
{}
...
...
@@ -269,6 +324,8 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
continue
;
}
LOG
(
WARNING
)
<<
"going to memcpy"
;
CHECK
(
fea_idx
<
(
*
push_values
).
size
());
CHECK
(
fea_idx
<
fea_labels
.
size
());
memcpy
((
*
push_values
)[
fea_idx
].
data
()
+
offset
,
g
,
sizeof
(
float
)
*
emb_dim
);
LOG
(
WARNING
)
<<
"show"
;
...
...
@@ -294,13 +351,13 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
#endif
}
int
FleetWrapper
::
RegisterClientToClientMsgHandler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
int
FleetWrapper
::
RegisterClientToClientMsgHandler
(
int
msg_type
,
MsgHandlerFunc
handler
)
{
#ifdef PADDLE_WITH_PSLIB
VLOG
(
3
)
<<
"calling FleetWrapper::RegisterClientToClientMsgHandler"
;
VLOG
(
3
)
<<
"pslib_ptr_="
<<
pslib_ptr_
;
VLOG
(
3
)
<<
"_worker_ptr="
<<
pslib_ptr_
->
_worker_ptr
;
pslib_ptr_
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
handler
);
return
pslib_ptr_
->
_worker_ptr
->
registe_client2client_msg_handler
(
msg_type
,
handler
);
#else
VLOG
(
0
)
<<
"FleetWrapper::RegisterClientToClientMsgHandler"
<<
" does nothing when no pslib"
;
...
...
@@ -308,15 +365,15 @@ int FleetWrapper::RegisterClientToClientMsgHandler(int msg_type,
return
0
;
}
int
FleetWrapper
::
SendClientToClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
std
::
future
<
int32_t
>
FleetWrapper
::
SendClientToClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
)
{
#ifdef PADDLE_WITH_PSLIB
pslib_ptr_
->
_worker_ptr
->
send_client2client_msg
(
msg_type
,
to_client_id
,
msg
);
return
pslib_ptr_
->
_worker_ptr
->
send_client2client_msg
(
msg_type
,
to_client_id
,
msg
);
#else
VLOG
(
0
)
<<
"FleetWrapper::SendClientToClientMsg"
<<
" does nothing when no pslib"
;
#endif
return
0
;
return
std
::
future
<
int32_t
>
()
;
}
std
::
default_random_engine
&
FleetWrapper
::
LocalRandomEngine
()
{
...
...
@@ -336,10 +393,12 @@ std::default_random_engine& FleetWrapper::LocalRandomEngine() {
}
template
<
typename
T
>
void
FleetWrapper
::
Serialize
(
const
T
&
t
,
std
::
string
*
str
)
{
void
FleetWrapper
::
Serialize
(
const
std
::
vector
<
T
*>
&
t
,
std
::
string
*
str
)
{
#ifdef PADDLE_WITH_PSLIB
paddle
::
ps
::
BinaryArchive
ar
;
ar
<<
t
;
for
(
size_t
i
=
0
;
i
<
t
.
size
();
++
i
)
{
ar
<<
*
(
t
[
i
]);
}
*
str
=
std
::
string
(
ar
.
buffer
(),
ar
.
length
());
#else
VLOG
(
0
)
<<
"FleetWrapper::Serialize does nothing when no pslib"
;
...
...
@@ -347,20 +406,30 @@ void FleetWrapper::Serialize(const T& t, std::string* str) {
}
template
<
typename
T
>
void
FleetWrapper
::
Deserialize
(
T
*
t
,
const
std
::
string
&
str
)
{
void
FleetWrapper
::
Deserialize
(
std
::
vector
<
T
>
*
t
,
const
std
::
string
&
str
)
{
#ifdef PADDLE_WITH_PSLIB
if
(
str
.
length
()
==
0
)
{
return
;
}
paddle
::
ps
::
BinaryArchive
ar
;
ar
.
set_read_buffer
(
const_cast
<
char
*>
(
str
.
c_str
()),
str
.
length
(),
nullptr
);
*
t
=
ar
.
get
<
T
>
();
if
(
ar
.
cursor
()
==
ar
.
finish
())
{
return
;
}
while
(
ar
.
cursor
()
<
ar
.
finish
())
{
t
->
push_back
(
ar
.
get
<
T
>
());
}
CHECK
(
ar
.
cursor
()
==
ar
.
finish
());
VLOG
(
3
)
<<
"Deserialize size "
<<
t
->
size
();
#else
VLOG
(
0
)
<<
"FleetWrapper::Deserialize does nothing when no pslib"
;
#endif
}
template
void
FleetWrapper
::
Serialize
<
std
::
vector
<
MultiSlotType
>
>
(
const
std
::
vector
<
MultiSlotType
>&
,
std
::
string
*
);
template
void
FleetWrapper
::
Deserialize
(
std
::
vector
<
MultiSlotType
>
*
,
const
std
::
string
&
);
const
std
::
vector
<
std
::
vector
<
MultiSlotType
>*
>&
,
std
::
string
*
);
template
void
FleetWrapper
::
Deserialize
<
std
::
vector
<
MultiSlotType
>
>
(
std
::
vector
<
std
::
vector
<
MultiSlotType
>>*
,
const
std
::
string
&
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/fleet/fleet_wrapper.h
浏览文件 @
a5b1a0e1
...
...
@@ -27,6 +27,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
#include "paddle/fluid/framework/program_desc.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -71,6 +72,10 @@ class FleetWrapper {
const
std
::
vector
<
std
::
string
>&
var_names
,
std
::
vector
<::
std
::
future
<
int32_t
>>*
pull_dense_status
);
void
PushDenseParamSync
(
const
ProgramDesc
&
program
,
const
uint64_t
table_id
,
const
std
::
vector
<
std
::
string
>&
var_names
);
// Push dense variables to server in async mode
// Param<in>: scope, table_id, var_names,
// Param<out>: push_sparse_status
...
...
@@ -119,16 +124,15 @@ class FleetWrapper {
typedef
std
::
function
<
int32_t
(
int
,
int
,
const
std
::
string
&
)
>
MsgHandlerFunc
;
int
RegisterClientToClientMsgHandler
(
int
msg_type
,
MsgHandlerFunc
handler
);
int
SendClientToClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
);
std
::
future
<
int32_t
>
SendClientToClientMsg
(
int
msg_type
,
int
to_client_id
,
const
std
::
string
&
msg
);
std
::
default_random_engine
&
LocalRandomEngine
();
template
<
typename
T
>
void
Serialize
(
const
T
&
t
,
std
::
string
*
str
);
void
Serialize
(
const
std
::
vector
<
T
*>
&
t
,
std
::
string
*
str
);
template
<
typename
T
>
void
Deserialize
(
T
*
t
,
const
std
::
string
&
str
);
void
Deserialize
(
std
::
vector
<
T
>*
t
,
const
std
::
string
&
str
);
static
std
::
shared_ptr
<
FleetWrapper
>
GetInstance
()
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
framework
::
FleetWrapper
());
...
...
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
a5b1a0e1
...
...
@@ -26,13 +26,15 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_
=
trainer_desc
.
thread_num
();
SetDataset
(
dataset
);
// get filelist from trainer_desc here
workers_
.
resize
(
thread_num_
);
VLOG
(
3
)
<<
"worker thread num: "
<<
thread_num_
;
dataset
->
CreateReaders
();
VLOG
(
3
)
<<
"readers created"
;
const
std
::
vector
<
std
::
shared_ptr
<
paddle
::
framework
::
DataFeed
>>
readers
=
dataset
->
GetReaders
();
VLOG
(
3
)
<<
"readers num: "
<<
readers
.
size
();
// change thread num to readers num
thread_num_
=
readers
.
size
();
VLOG
(
3
)
<<
"worker thread num: "
<<
thread_num_
;
workers_
.
resize
(
thread_num_
);
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
=
DeviceWorkerFactory
::
CreateDeviceWorker
(
trainer_desc
.
device_worker_name
());
...
...
paddle/fluid/pybind/async_executor_py.cc
浏览文件 @
a5b1a0e1
...
...
@@ -49,7 +49,7 @@ void BindAsyncExecutor(py::module* m) {
new
framework
::
AsyncExecutor
(
scope
,
place
));
}))
.
def
(
"run_from_files"
,
&
framework
::
AsyncExecutor
::
RunFromFile
)
.
def
(
"run_from_dataset"
,
&
framework
::
AsyncExecutor
::
RunFromDataset
)
//
.def("run_from_dataset", &framework::AsyncExecutor::RunFromDataset)
.
def
(
"init_server"
,
&
framework
::
AsyncExecutor
::
InitServer
)
.
def
(
"init_worker"
,
&
framework
::
AsyncExecutor
::
InitWorker
)
.
def
(
"start_server"
,
&
framework
::
AsyncExecutor
::
StartServer
)
...
...
paddle/fluid/pybind/data_set_py.cc
浏览文件 @
a5b1a0e1
...
...
@@ -50,6 +50,7 @@ void BindDataset(py::module* m) {
.
def
(
"set_filelist"
,
&
framework
::
Dataset
::
SetFileList
)
.
def
(
"set_thread_num"
,
&
framework
::
Dataset
::
SetThreadNum
)
.
def
(
"set_trainer_num"
,
&
framework
::
Dataset
::
SetTrainerNum
)
.
def
(
"set_hdfs_config"
,
&
framework
::
Dataset
::
SetHdfsConfig
)
.
def
(
"set_data_feed_desc"
,
&
framework
::
Dataset
::
SetDataFeedDesc
)
.
def
(
"load_into_memory"
,
&
framework
::
Dataset
::
LoadIntoMemory
)
.
def
(
"local_shuffle"
,
&
framework
::
Dataset
::
LocalShuffle
)
...
...
paddle/fluid/pybind/fleet_wrapper_py.cc
浏览文件 @
a5b1a0e1
...
...
@@ -47,6 +47,7 @@ void BindFleetWrapper(py::module* m) {
.
def
(
"init_server"
,
&
framework
::
FleetWrapper
::
InitServer
)
.
def
(
"run_server"
,
&
framework
::
FleetWrapper
::
RunServer
)
.
def
(
"init_worker"
,
&
framework
::
FleetWrapper
::
InitWorker
)
.
def
(
"init_model"
,
&
framework
::
FleetWrapper
::
PushDenseParamSync
)
.
def
(
"stop_server"
,
&
framework
::
FleetWrapper
::
StopServer
)
.
def
(
"gather_servers"
,
&
framework
::
FleetWrapper
::
GatherServers
);
}
// end FleetWrapper
...
...
python/paddle/fluid/dataset.py
浏览文件 @
a5b1a0e1
...
...
@@ -86,6 +86,9 @@ class DatasetBase(object):
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
)
def
set_hdfs_config
(
self
,
fs_name
,
fs_ugi
):
self
.
dataset
.
set_hdfs_config
(
fs_name
,
fs_ugi
)
def
_prepare_to_run
(
self
):
self
.
dataset
.
set_data_feed_desc
(
self
.
desc
())
...
...
@@ -115,11 +118,15 @@ class InMemoryDataset(DatasetBase):
def
local_shuffle
(
self
):
self
.
dataset
.
local_shuffle
()
def
global_shuffle
(
self
):
from
.distributed
import
ps_instance
instance
=
ps_instance
.
PaddlePSInstance
(
1
,
2
)
self
.
dataset
.
set_trainer_num
(
instance
.
get_worker_num
())
def
global_shuffle
(
self
,
fleet
=
None
):
trainer_num
=
1
if
fleet
is
not
None
:
fleet
.
fleet_instance
.
role_maker_
.
barrier_worker
()
trainer_num
=
fleet
.
worker_num
()
self
.
dataset
.
set_trainer_num
(
trainer_num
)
self
.
dataset
.
global_shuffle
()
if
fleet
is
not
None
:
fleet
.
fleet_instance
.
role_maker_
.
barrier_worker
()
class
QueueDataset
(
DatasetBase
):
...
...
@@ -130,5 +137,5 @@ class QueueDataset(DatasetBase):
def
local_shuffle
(
self
):
pass
def
global_shuffle
(
self
):
def
global_shuffle
(
self
,
fleet
=
None
):
pass
python/paddle/fluid/incubate/fleet/base/role_maker.py
浏览文件 @
a5b1a0e1
...
...
@@ -170,7 +170,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
"""
if
self
.
_check_role_generation
():
if
self
.
is_worker
():
return
self
.
get_size
()
return
self
.
get_size
()
/
2
;
return
0
def
server_num
(
self
):
...
...
@@ -179,7 +179,7 @@ class MPISymetricRoleMaker(MPIRoleMaker):
"""
if
self
.
_check_role_generation
():
if
self
.
is_server
():
return
self
.
get_size
()
return
self
.
get_size
()
/
2
;
return
0
def
worker_index
(
self
):
...
...
python/paddle/fluid/incubate/fleet/parameter_server/__init__.py
浏览文件 @
a5b1a0e1
...
...
@@ -43,7 +43,7 @@ class Fleet(object):
save_pserver_model(): save model parameters in pserver, called from a server node
Example:
.. code-block:: python
import paddle.fluid.incubate.fleet.parameter_server as fleet
from my_model import bow_net
...
...
@@ -58,7 +58,7 @@ class Fleet(object):
fleet.init_worker() # init worker should be called before training
# do other things like training
elif fleet.is_server():
fleet.init_pserver()
fleet.init_pserver()
fleet.stop()
"""
...
...
@@ -75,7 +75,7 @@ class Fleet(object):
"""
init(): which should be called only once in user's python scripts. init() will initialize
FleetWrapper in CPP, it will also initialize a RoleMaker which is used for identifying
current node's role, e.g. worker, server, etc.
current node's role, e.g. worker, server, etc.
"""
if
not
self
.
is_initialized_
:
self
.
role_maker_
=
MPISymetricRoleMaker
()
...
...
@@ -122,7 +122,7 @@ class Fleet(object):
print
(
"You should run DistributedOptimizer.minimize() first"
)
sys
.
exit
(
-
1
)
def
init_worker
(
self
):
def
init_worker
(
self
,
program
):
"""
init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect
...
...
@@ -143,6 +143,19 @@ class Fleet(object):
self
.
role_maker_
.
get_rank
())
self
.
role_maker_
.
barrier_all
()
self
.
role_maker_
.
barrier_worker
()
if
self
.
role_maker_
.
is_first_worker
():
tables
=
self
.
_dist_desc
.
trainer_param
.
dense_table
.
_values
for
i
in
range
(
0
,
len
(
tables
)):
table
=
tables
[
i
];
var_name_list
=
[]
for
i
in
range
(
0
,
len
(
table
.
dense_variable_name
)):
var_name_list
.
append
(
table
.
dense_variable_name
[
i
])
#print "table id ", table.table_id
#print "var_name_list ", var_name_list
self
.
_fleet_ptr
.
init_model
(
program
.
desc
,
int
(
table
.
table_id
),
var_name_list
)
self
.
role_maker_
.
barrier_worker
()
else
:
print
(
"You should run DistributedOptimizer.minimize() first"
)
sys
.
exit
(
-
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录