Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
eb6a941f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eb6a941f
编写于
11月 14, 2018
作者:
W
wangguibao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix async_executor interfaces: 1) Remove all protobufs; 2) Stop after each epoch
上级
1d239cc8
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
396 addition
and
403 deletion
+396
-403
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+144
-227
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+35
-42
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+104
-12
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+38
-20
paddle/fluid/pybind/async_executor_py.cc
paddle/fluid/pybind/async_executor_py.cc
+24
-43
paddle/fluid/pybind/async_executor_py.h
paddle/fluid/pybind/async_executor_py.h
+1
-0
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+50
-59
未找到文件。
paddle/fluid/framework/async_executor.cc
浏览文件 @
eb6a941f
...
@@ -40,13 +40,8 @@ limitations under the License. */
...
@@ -40,13 +40,8 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
mutex
ExecutorThreadWorker
::
s_locker_for_pick_file_
;
unsigned
int
ExecutorThreadWorker
::
s_current_file_idx_
=
0
;
bool
AsyncExecutor
::
workers_initialized_
=
false
;
size_t
ExecutorThreadWorker
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
ExecutorThreadWorker
::
s_current_epoch_
=
0
;
int
ExecutorThreadWorker
::
s_current_save_epoch_
=
0
;
bool
ExecutorThreadWorker
::
s_is_first_worker_
=
false
;
std
::
vector
<
std
::
string
>
ExecutorThreadWorker
::
s_thread_filelist_
;
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
)
{
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
)
{
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR
)
{
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR
)
{
...
@@ -124,7 +119,6 @@ static void SaveModel(
...
@@ -124,7 +119,6 @@ static void SaveModel(
{{
"X"
,
{
var
->
Name
()}}},
{{
"X"
,
{
var
->
Name
()}}},
{},
{},
attrs
);
attrs
);
save_op
->
Run
(
*
scope
,
place
);
save_op
->
Run
(
*
scope
,
place
);
}
else
{
}
else
{
paralist
.
push_back
(
var
->
Name
());
paralist
.
push_back
(
var
->
Name
());
...
@@ -140,15 +134,14 @@ static void SaveModel(
...
@@ -140,15 +134,14 @@ static void SaveModel(
{{
"X"
,
paralist
}},
{{
"X"
,
paralist
}},
{},
{},
attrs
);
attrs
);
save_op
->
Run
(
*
scope
,
place
);
save_op
->
Run
(
*
scope
,
place
);
}
}
}
// end SaveModel
}
// end SaveModel
void
ExecutorThreadWorker
::
Reset
()
{
void
ExecutorThreadWorker
::
AddTrainFile
(
const
std
::
string
&
file
)
{
inspect_values_
.
clear
();
s_thread_filelist_
.
push_back
(
file
);
}
}
void
ExecutorThreadWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
void
ExecutorThreadWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
auto
&
block
=
program
.
Block
(
0
);
op_names_
.
clear
();
op_names_
.
clear
();
...
@@ -175,8 +168,12 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
...
@@ -175,8 +168,12 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
}
}
}
}
void
ExecutorThreadWorker
::
SetDataFeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
)
{
void
ExecutorThreadWorker
::
SetDataFeed
(
DataFeed
&
datafeed
)
{
local_reader_
=
datafeed
;
if
(
typeid
(
datafeed
)
==
typeid
(
TextClassDataFeed
))
{
local_reader_
.
reset
(
new
TextClassDataFeed
(
dynamic_cast
<
TextClassDataFeed
&>
(
datafeed
)));
local_reader_
->
SetThreadId
(
thread_id_
);
}
}
}
void
ExecutorThreadWorker
::
BindingDataFeedMemory
()
{
void
ExecutorThreadWorker
::
BindingDataFeedMemory
()
{
...
@@ -186,9 +183,11 @@ void ExecutorThreadWorker::BindingDataFeedMemory() {
...
@@ -186,9 +183,11 @@ void ExecutorThreadWorker::BindingDataFeedMemory() {
}
}
}
}
void
ExecutorThreadWorker
::
SetInspectVarName
(
void
ExecutorThreadWorker
::
SetInspectVarNames
(
const
std
::
string
&
inspect_var_name
)
{
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
inspect_var_name_
=
inspect_var_name
;
inspect_var_names_
.
clear
();
inspect_var_names_
.
insert
(
inspect_var_names_
.
end
(),
inspect_var_names
.
begin
(),
inspect_var_names
.
end
());
}
}
void
ExecutorThreadWorker
::
SetModelParamNames
(
void
ExecutorThreadWorker
::
SetModelParamNames
(
...
@@ -196,11 +195,6 @@ void ExecutorThreadWorker::SetModelParamNames(
...
@@ -196,11 +195,6 @@ void ExecutorThreadWorker::SetModelParamNames(
model_param_names_
=
param_names
;
model_param_names_
=
param_names
;
}
}
void
ExecutorThreadWorker
::
SetSparseCommData
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
)
{
sparse_comm_data_
=
param_names
;
}
void
ExecutorThreadWorker
::
SetDevice
()
{
void
ExecutorThreadWorker
::
SetDevice
()
{
static
unsigned
priority
[]
=
{
static
unsigned
priority
[]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
0
,
1
,
2
,
3
,
4
,
5
,
...
@@ -228,138 +222,79 @@ void ExecutorThreadWorker::SetDevice() {
...
@@ -228,138 +222,79 @@ void ExecutorThreadWorker::SetDevice() {
CPU_ZERO
(
&
mask
);
CPU_ZERO
(
&
mask
);
if
((
0
==
sched_getaffinity
(
0
,
sizeof
(
mask
),
&
mask
))
if
((
0
==
sched_getaffinity
(
0
,
sizeof
(
mask
),
&
mask
))
&&
CPU_ISSET
(
proc
,
&
mask
))
{
&&
CPU_ISSET
(
proc
,
&
mask
))
{
LOG
(
ERROR
)
<<
"TRACE: Thread "
<<
i
<<
" is running on processor "
<<
proc
<<
"..."
;
LOG
(
ERROR
)
<<
"TRACE: Thread "
<<
i
<<
" is running on processor "
<<
proc
<<
"..."
;
}
}
}
}
}
}
}
}
void
ExecutorThreadWorker
::
UpdateEpochNum
()
{
s_current_finished_file_cnt_
++
;
if
(
s_current_finished_file_cnt_
>=
s_thread_filelist_
.
size
())
{
void
ExecutorThreadWorker
::
Train
()
{
s_current_finished_file_cnt_
=
0
;
LOG
(
ERROR
)
<<
"begin to train"
;
s_current_epoch_
++
;
SetDevice
();
}
}
const
char
*
ExecutorThreadWorker
::
PickOneFile
()
{
int
inspect_var_num
=
inspect_var_names_
.
size
();
std
::
string
file_to_be_preocessed
;
inspect_values_
.
clear
()
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
inspect_values_
.
resize
(
inspect_var_num
,
0
);
if
(
s_current_file_idx_
>=
s_thread_filelist_
.
size
())
{
local_reader_
->
WaitNextEpoch
();
std
::
random_shuffle
(
s_thread_filelist_
.
begin
(),
int
epoch
=
local_reader_
->
GetCurrentEpoch
();
s_thread_filelist_
.
end
());
s_current_file_idx_
=
0
;
// s_current_epoch_++; //example: when one file, one thread, it's bug
LOG
(
ERROR
)
<<
"thread "
<<
thread_id_
<<
": finish traing for epoch "
<<
s_current_epoch_
+
1
;
}
file_to_be_preocessed
=
s_thread_filelist_
[
s_current_file_idx_
];
s_current_file_idx_
++
;
LOG
(
ERROR
)
<<
"epoch: "
<<
epoch
;
return
file_to_be_preocessed
.
c_str
();
}
void
ExecutorThreadWorker
::
Train
()
{
LOG
(
ERROR
)
<<
"begin to train"
;
SetDevice
();
#ifdef LOCAL_PROF
std
::
vector
<
double
>
op_total_time
;
std
::
vector
<
std
::
string
>
op_name
;
// int total_batch = 0;
for
(
auto
&
op
:
ops_
)
{
op_name
.
push_back
(
op
->
Type
());
}
op_total_time
.
resize
(
ops_
.
size
());
for
(
int
i
=
0
;
i
<
op_total_time
.
size
();
++
i
)
{
op_total_time
[
i
]
=
0.0
;
}
#endif
std
::
string
inspect_key
=
"inspect"
;
if
(
!
inspect_var_name_
.
empty
())
{
inspect_key
=
inspect_var_name_
.
substr
(
0
,
inspect_var_name_
.
find_first_of
(
'_'
));
}
for
(
unsigned
i
=
0
;
i
<
max_epoch_
;
++
i
)
{
LOG
(
ERROR
)
<<
"epoch: "
<<
i
;
#ifdef LOCAL_PROF
Timer
timeline
;
double
total_time
=
0.0
;
double
read_time
=
0.0
;
#endif
float
total_inspect
=
0
;
int
batch_num
=
1
;
int
batch_num
=
1
;
while
(
i
==
s_current_epoch_
)
{
const
char
*
filename
=
PickOneFile
();
local_reader_
->
SetFile
(
filename
);
while
(
true
)
{
while
(
true
)
{
#ifdef LOCAL_PROF
const
char
*
file
=
local_reader_
->
PickOneFile
();
timeline
.
start
();
if
(
file
==
NULL
)
{
#endif
bool
flag
=
local_reader_
->
ReadBatch
();
if
(
!
flag
)
{
break
;
break
;
}
}
#ifdef LOCAL_PROF
timeline
.
pause
();
if
(
!
local_reader_
->
SetFile
(
file
))
{
read_time
+=
timeline
.
elapsed_sec
();
break
;
total_time
+=
timeline
.
elapsed_sec
();
}
#endif
while
(
true
)
{
bool
flag
=
local_reader_
->
ReadBatch
();
if
(
!
flag
)
{
if
(
!
flag
)
{
break
;
break
;
}
}
for
(
unsigned
int
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
for
(
unsigned
int
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
#ifdef LOCAL_PROF
timeline
.
start
();
#endif
ops_
[
i
]
->
Run
(
*
thread_scope_
,
place_
);
ops_
[
i
]
->
Run
(
*
thread_scope_
,
place_
);
#ifdef LOCAL_PROF
timeline
.
pause
();
op_total_time
[
i
]
+=
timeline
.
elapsed_sec
();
total_time
+=
timeline
.
elapsed_sec
();
#endif
}
}
batch_num
++
;
batch_num
++
;
float
avg_inspect
=
0.0
;
float
avg_inspect
=
0.0
;
if
(
!
inspect_var_name_
.
empty
()
)
{
for
(
int
i
=
0
;
i
<
inspect_var_num
;
++
i
)
{
avg_inspect
=
thread_scope_
->
FindVar
(
inspect_var_name_
)
avg_inspect
=
thread_scope_
->
FindVar
(
inspect_var_names_
[
i
]
)
->
GetMutable
<
LoDTensor
>
()
->
GetMutable
<
LoDTensor
>
()
->
data
<
float
>
()[
0
];
->
data
<
float
>
()[
0
];
inspect_values_
[
i
]
+=
avg_inspect
;
}
}
total_inspect
+=
avg_inspect
;
thread_scope_
->
DropKids
();
thread_scope_
->
DropKids
();
}
}
UpdateEpochNum
();
LOG
(
ERROR
)
<<
"memory used after epoch "
<<
i
+
1
local_reader_
->
UpdateEpochNum
();
LOG
(
ERROR
)
<<
"memory used after epoch "
<<
epoch
+
1
<<
" called: "
<<
memory
::
memory_usage
(
place_
);
<<
" called: "
<<
memory
::
memory_usage
(
place_
);
}
for
(
int
i
=
0
;
i
<
inspect_var_num
;
++
i
)
{
inspect_values_
[
i
]
/=
batch_num
;
std
::
string
var
=
inspect_var_names_
[
i
].
substr
(
0
,
inspect_var_names_
[
i
].
find_first_of
(
"_"
));
LOG
(
ERROR
)
<<
"mean "
<<
var
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
inspect_values_
[
i
];
}
#ifdef LOCAL_PROF
for
(
int
i
=
0
;
i
<
op_total_time
.
size
();
++
i
)
{
std
::
cerr
<<
"op_name:["
<<
i
<<
"]["
<<
op_name
[
i
]
<<
"]"
<<
" op_mean_time:["
<<
op_total_time
[
i
]
<<
"s]"
<<
std
::
endl
;
}
std
::
cerr
<<
"read time: "
<<
read_time
<<
"s"
<<
std
::
endl
;
#endif
}
#ifdef LOCAL_PROF
LOG
(
ERROR
)
<<
"mean "
<<
inspect_key
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
total_inspect
/
batch_num
<<
", total_time: "
<<
total_time
;
#else
LOG
(
ERROR
)
<<
"mean "
<<
inspect_key
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
total_inspect
/
batch_num
;
#endif
if
(
thread_id_
==
0
)
{
if
(
thread_id_
==
0
)
{
char
modelfile
[
1024
];
char
modelfile
[
1024
];
snprintf
(
&
modelfile
[
0
],
snprintf
(
&
modelfile
[
0
],
sizeof
(
modelfile
),
"%s_epoch%d.model"
,
sizeof
(
modelfile
),
model_prefix_
.
c_str
(),
epoch
);
"%s_epoch%d.model"
,
model_prefix_
.
c_str
(),
i
);
std
::
string
model_filename
=
std
::
string
(
modelfile
);
std
::
string
model_filename
=
std
::
string
(
modelfile
);
// this save_inference_model can only save imdbtask, should make this
// this save_inference_model can only save imdbtask, should make this
// general
// general
...
@@ -372,7 +307,6 @@ void ExecutorThreadWorker::Train() {
...
@@ -372,7 +307,6 @@ void ExecutorThreadWorker::Train() {
model_filename
,
model_filename
,
true
);
true
);
}
}
}
}
}
void
ExecutorThreadWorker
::
SetThreadId
(
int
tid
)
{
void
ExecutorThreadWorker
::
SetThreadId
(
int
tid
)
{
...
@@ -396,7 +330,20 @@ void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
...
@@ -396,7 +330,20 @@ void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_
=
max_epoch
;
max_epoch_
=
max_epoch
;
}
}
AsyncExecutor
::
AsyncExecutor
(
const
platform
::
Place
&
place
)
:
place_
(
place
)
{}
AsyncExecutor
::
AsyncExecutor
(
ProgramDesc
&
main_program
,
const
std
::
vector
<
std
::
string
>&
param_names
,
TextClassDataFeed
&
data_feed
,
unsigned
int
thread_num
,
const
platform
::
Place
&
place
)
:
thread_num_
(
thread_num
),
place_
(
place
),
main_program_
(
main_program
),
data_feed_
(
data_feed
)
{
model_param_names_
.
clear
();
model_param_names_
.
insert
(
model_param_names_
.
end
(),
param_names
.
begin
(),
param_names
.
end
());
}
void
AsyncExecutor
::
InitRootScope
(
Scope
*
scope
)
{
void
AsyncExecutor
::
InitRootScope
(
Scope
*
scope
)
{
root_scope_
=
scope
;
root_scope_
=
scope
;
...
@@ -406,10 +353,6 @@ void AsyncExecutor::SetMaxTrainingEpoch(int max_epoch) {
...
@@ -406,10 +353,6 @@ void AsyncExecutor::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_
=
max_epoch
;
max_epoch_
=
max_epoch
;
}
}
void
AsyncExecutor
::
SetDataFeedName
(
const
char
*
feedname
)
{
feed_name_
=
std
::
string
(
feedname
);
}
void
AsyncExecutor
::
SetModelPrefix
(
const
std
::
string
&
model_prefix
)
{
void
AsyncExecutor
::
SetModelPrefix
(
const
std
::
string
&
model_prefix
)
{
model_prefix_
=
model_prefix
;
model_prefix_
=
model_prefix
;
}
}
...
@@ -463,60 +406,16 @@ std::unique_ptr<ProgramDesc> AsyncExecutor::LoadDescFromFile(
...
@@ -463,60 +406,16 @@ std::unique_ptr<ProgramDesc> AsyncExecutor::LoadDescFromFile(
return
program
;
return
program
;
}
}
void
AsyncExecutor
::
SetDenseCommTensor
(
void
AsyncExecutor
::
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
)
{
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
dense_comm_tensor_
.
resize
(
dense_comm_tensor
.
size
());
inspect_var_names_
.
clear
();
for
(
unsigned
int
i
=
0
;
i
<
dense_comm_tensor
.
size
();
++
i
)
{
inspect_var_names_
.
insert
(
inspect_var_names_
.
end
(),
dense_comm_tensor_
[
i
]
=
dense_comm_tensor
[
i
];
inspect_var_names
.
begin
(),
inspect_var_names
.
end
());
}
}
void
AsyncExecutor
::
SetSparseCommTensor
(
const
std
::
vector
<
std
::
string
>&
sparse_comm_tensor
)
{
sparse_comm_tensor_
.
resize
(
sparse_comm_tensor
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
sparse_comm_tensor
.
size
();
++
i
)
{
sparse_comm_tensor_
[
i
]
=
sparse_comm_tensor
[
i
];
}
}
void
AsyncExecutor
::
SetSparseCommData
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
)
{
sparse_comm_data_
=
sparse_comm_data
;
LOG
(
INFO
)
<<
"Sparse comm data: "
<<
sparse_comm_data_
.
size
();
}
void
AsyncExecutor
::
SetFileList
(
const
char
*
filelist
)
{
filelist_
.
clear
();
std
::
ifstream
fin
(
filelist
);
std
::
string
filename
;
while
(
fin
>>
filename
)
{
LOG
(
ERROR
)
<<
"add "
<<
filename
.
c_str
()
<<
" to filelist"
;
filelist_
.
push_back
(
filename
);
}
fin
.
close
();
}
void
AsyncExecutor
::
SetFileList
(
std
::
vector
<
std
::
string
>
tfiles
)
{
filelist_
.
clear
();
filelist_
.
insert
(
filelist_
.
end
(),
tfiles
.
begin
(),
tfiles
.
end
());
return
;
}
void
AsyncExecutor
::
SetInspectVarName
(
const
std
::
string
&
inspect_var_name
)
{
inspect_var_name_
=
inspect_var_name
;
}
void
AsyncExecutor
::
SetParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
model_param_names_
=
param_names
;
}
void
AsyncExecutor
::
SetThreadNum
(
const
int
thread_num
)
{
thread_num_
=
thread_num
;
}
}
void
AsyncExecutor
::
PrepareThreads
(
const
ProgramDesc
&
host_program
)
{
void
AsyncExecutor
::
PrepareThreads
(
const
ProgramDesc
&
host_program
)
{
workers_
.
resize
(
thread_num_
);
workers_
.
resize
(
thread_num_
);
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
].
reset
(
new
ExecutorThreadWorker
);
workers_
[
i
].
reset
(
new
ExecutorThreadWorker
);
workers_
[
i
]
->
SetThreadId
(
i
);
workers_
[
i
]
->
SetThreadId
(
i
);
workers_
[
i
]
->
CreateThreadOperators
(
host_program
);
workers_
[
i
]
->
CreateThreadOperators
(
host_program
);
...
@@ -524,34 +423,31 @@ void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) {
...
@@ -524,34 +423,31 @@ void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) {
workers_
[
i
]
->
SetPlace
(
place_
);
workers_
[
i
]
->
SetPlace
(
place_
);
workers_
[
i
]
->
SetMaxTrainingEpoch
(
max_epoch_
);
workers_
[
i
]
->
SetMaxTrainingEpoch
(
max_epoch_
);
workers_
[
i
]
->
CreateThreadScope
(
host_program
);
workers_
[
i
]
->
CreateThreadScope
(
host_program
);
workers_
[
i
]
->
SetInspectVarName
(
inspect_var_name
_
);
workers_
[
i
]
->
SetInspectVarName
s
(
inspect_var_names
_
);
workers_
[
i
]
->
SetModelParamNames
(
model_param_names_
);
workers_
[
i
]
->
SetModelParamNames
(
model_param_names_
);
workers_
[
i
]
->
SetSparseCommData
(
sparse_comm_data_
);
workers_
[
i
]
->
SetMainProgram
(
host_program
);
workers_
[
i
]
->
SetMainProgram
(
host_program
);
workers_
[
i
]
->
SetModelPrefix
(
model_prefix_
);
workers_
[
i
]
->
SetModelPrefix
(
model_prefix_
);
}
//
for
(
unsigned
i
=
0
;
i
<
filelist_
.
size
();
++
i
)
{
// suppose at least one trainer thread here, and
// filelist is static so that we only add filelist once
workers_
[
0
]
->
AddTrainFile
(
filelist_
[
i
]);
}
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
// new a datafeed here
// new a datafeed here
std
::
shared_ptr
<
DataFeed
>
local_feed
=
CreateDataFeed
(
feed_name_
.
c_str
());
workers_
[
i
]
->
SetDataFeed
(
data_feed_
);
local_feed
->
Init
();
local_feed
->
SetBatchSize
(
batch_size_
);
workers_
[
i
]
->
SetDataFeed
(
local_feed
);
workers_
[
i
]
->
BindingDataFeedMemory
();
workers_
[
i
]
->
BindingDataFeedMemory
();
workers_
[
i
]
->
SetThreadId
(
i
);
}
}
}
}
void
AsyncExecutor
::
RunAsyncExecutor
(
const
ProgramDesc
&
host_program
)
{
std
::
vector
<
float
>&
AsyncExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
SetInspectVarNames
(
inspect_var_names
);
threads_
.
clear
();
// thread binding here?
// thread binding here?
PrepareThreads
(
host_program
);
if
(
workers_initialized_
==
false
)
{
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
PrepareThreads
(
main_program_
);
workers_initialized_
=
true
;
}
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
->
Reset
();
workers_
[
i
]
->
SetInspectVarNames
(
inspect_var_names
);
threads_
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
Train
,
threads_
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
Train
,
workers_
[
i
].
get
()));
workers_
[
i
].
get
()));
}
}
...
@@ -559,6 +455,27 @@ void AsyncExecutor::RunAsyncExecutor(const ProgramDesc& host_program) {
...
@@ -559,6 +455,27 @@ void AsyncExecutor::RunAsyncExecutor(const ProgramDesc& host_program) {
for
(
auto
&
th
:
threads_
)
{
for
(
auto
&
th
:
threads_
)
{
th
.
join
();
th
.
join
();
}
}
inspect_values_
.
clear
();
inspect_values_
.
resize
(
inspect_var_names_
.
size
(),
0
);
std
::
vector
<
std
::
vector
<
float
>*>
inspect_value_vectors
;
inspect_value_vectors
.
resize
(
thread_num_
);
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
inspect_value_vectors
[
i
]
=
&
workers_
[
i
]
->
GetInspectValues
();
}
for
(
unsigned
int
i
=
0
;
i
<
inspect_var_names_
.
size
();
++
i
)
{
float
value
=
0.0
;
for
(
int
j
=
0
;
j
<
thread_num_
;
++
j
)
{
value
+=
inspect_value_vectors
[
j
]
->
at
(
i
);
}
value
/=
thread_num_
;
inspect_values_
[
i
]
=
value
;
}
return
inspect_values_
;
}
}
void
AsyncExecutor
::
LoadInitModel
()
{
void
AsyncExecutor
::
LoadInitModel
()
{
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
eb6a941f
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include <string>
#include <string>
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include <vector>
#include <typeinfo>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
...
@@ -36,10 +37,9 @@ class ExecutorThreadWorker {
...
@@ -36,10 +37,9 @@ class ExecutorThreadWorker {
public:
public:
ExecutorThreadWorker
()
{}
ExecutorThreadWorker
()
{}
~
ExecutorThreadWorker
()
{}
~
ExecutorThreadWorker
()
{}
void
CreateThreadScope
(
const
framework
::
ProgramDesc
&
program
);
void
CreateThreadScope
(
const
ProgramDesc
&
program
);
void
SetDataFeed
(
const
DataFeed
&
datafeed
);
void
SetThreadId
(
int
tid
);
void
SetThreadId
(
int
tid
);
void
CreateThreadOperators
(
const
framework
::
ProgramDesc
&
program
);
void
CreateThreadOperators
(
const
ProgramDesc
&
program
);
void
SetRootScope
(
Scope
*
g_scope
);
void
SetRootScope
(
Scope
*
g_scope
);
void
SetDevice
();
void
SetDevice
();
void
AddFidSet
();
void
AddFidSet
();
...
@@ -52,25 +52,16 @@ class ExecutorThreadWorker {
...
@@ -52,25 +52,16 @@ class ExecutorThreadWorker {
void
SetModelPrefix
(
const
std
::
string
&
prefix
)
{
model_prefix_
=
prefix
;
}
void
SetModelPrefix
(
const
std
::
string
&
prefix
)
{
model_prefix_
=
prefix
;
}
void
SetInspectVarName
(
const
std
::
string
&
inspect_var_name
);
void
SetInspectVarName
s
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
void
SetModelParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
SetModelParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
SetSparseCommData
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
);
void
SetDataFeed
(
DataFeed
&
datafeed
);
// NOLINT
void
SetDataFeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
);
void
Train
();
void
Train
();
const
char
*
PickOneFile
();
const
char
*
PickOneFile
();
void
UpdateEpochNum
();
void
UpdateEpochNum
();
void
Reset
();
void
SetDenseCommTensor
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{}
void
Initialize
()
{}
void
Initialize
()
{}
std
::
vector
<
float
>&
GetInspectValues
()
{
return
inspect_values_
;}
public:
static
std
::
mutex
s_locker_for_pick_file_
;
static
unsigned
int
s_current_file_idx_
;
static
size_t
s_current_finished_file_cnt_
;
static
unsigned
int
s_current_epoch_
;
static
int
s_current_save_epoch_
;
static
std
::
vector
<
std
::
string
>
s_thread_filelist_
;
// filelist
static
bool
s_is_first_worker_
;
protected:
protected:
// thread index
// thread index
...
@@ -88,14 +79,13 @@ class ExecutorThreadWorker {
...
@@ -88,14 +79,13 @@ class ExecutorThreadWorker {
std
::
vector
<
OperatorBase
*>
ops_
;
std
::
vector
<
OperatorBase
*>
ops_
;
// main program for training
// main program for training
std
::
unique_ptr
<
framework
::
ProgramDesc
>
main_program_
;
std
::
unique_ptr
<
ProgramDesc
>
main_program_
;
// binary data reader
// binary data reader
std
::
shared
_ptr
<
DataFeed
>
local_reader_
;
std
::
unique
_ptr
<
DataFeed
>
local_reader_
;
std
::
string
inspect_var_name
_
;
std
::
vector
<
std
::
string
>
inspect_var_names
_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
map
<
std
::
string
,
int
>
sparse_comm_data_
;
// execution place
// execution place
platform
::
Place
place_
;
platform
::
Place
place_
;
...
@@ -105,24 +95,26 @@ class ExecutorThreadWorker {
...
@@ -105,24 +95,26 @@ class ExecutorThreadWorker {
// a thread scope, father scope is global score which is shared
// a thread scope, father scope is global score which is shared
Scope
*
thread_scope_
;
Scope
*
thread_scope_
;
private:
std
::
vector
<
float
>
inspect_values_
;
};
};
class
AsyncExecutor
{
class
AsyncExecutor
{
public:
public:
explicit
AsyncExecutor
(
const
platform
::
Place
&
place
);
explicit
AsyncExecutor
(
ProgramDesc
&
main_program
,
// NOLINT
const
std
::
vector
<
std
::
string
>&
param_names
,
TextClassDataFeed
&
data_feed
,
// NOLINT
unsigned
int
thread_num
,
const
platform
::
Place
&
place
);
virtual
~
AsyncExecutor
()
{}
virtual
~
AsyncExecutor
()
{}
static
std
::
unique_ptr
<
ProgramDesc
>
LoadDescFromFile
(
static
std
::
unique_ptr
<
ProgramDesc
>
LoadDescFromFile
(
const
std
::
string
&
filename
);
const
std
::
string
&
filename
);
void
InitRootScope
(
Scope
*
scope
);
void
InitRootScope
(
Scope
*
scope
);
void
SetInspectVarName
(
const
std
::
string
&
inspect_var_name
);
void
SetParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
SetMaxTrainingEpoch
(
const
int
max_epoch
);
void
SetMaxTrainingEpoch
(
const
int
max_epoch
);
Scope
*
GetRootScope
()
{
return
root_scope_
;
}
Scope
*
GetRootScope
()
{
return
root_scope_
;
}
void
SetThreadNum
(
const
int
thread_num
);
void
SetBatchSize
(
const
int
batch_size
)
{
batch_size_
=
batch_size
;
}
void
SetBatchSize
(
const
int
batch_size
)
{
batch_size_
=
batch_size
;
}
void
SetFileList
(
const
char
*
filelist
);
void
SetFileList
(
const
std
::
vector
<
std
::
string
>
filelist
);
void
SetDataFeedName
(
const
char
*
feedname
);
void
SetCommBatch
(
int
comm_batch
)
{
void
SetCommBatch
(
int
comm_batch
)
{
comm_batch_
=
comm_batch
;
comm_batch_
=
comm_batch
;
}
}
...
@@ -140,37 +132,38 @@ class AsyncExecutor {
...
@@ -140,37 +132,38 @@ class AsyncExecutor {
}
}
void
SetModelPrefix
(
const
std
::
string
&
model_prefix
);
void
SetModelPrefix
(
const
std
::
string
&
model_prefix
);
void
SetDenseCommTensor
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
);
virtual
void
PrepareThreads
(
const
ProgramDesc
&
host_program
);
void
SetSparseCommTensor
(
void
RunStartupProgram
(
const
ProgramDesc
&
program
,
Scope
*
scope
);
const
std
::
vector
<
std
::
string
>&
sparse_comm_tensor
);
std
::
vector
<
float
>&
Run
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
void
SetSparseCommData
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
);
virtual
void
PrepareThreads
(
const
framework
::
ProgramDesc
&
host_program
);
void
RunStartupProgram
(
const
framework
::
ProgramDesc
&
program
,
framework
::
Scope
*
scope
);
void
RunAsyncExecutor
(
const
ProgramDesc
&
host_program
);
void
LoadInitModel
();
void
LoadInitModel
();
private:
void
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
public:
public:
unsigned
int
thread_num_
;
int
thread_num_
;
int
max_epoch_
;
int
max_epoch_
;
int
batch_size_
;
int
batch_size_
;
int
comm_batch_
;
int
comm_batch_
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
workers_
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
workers_
;
std
::
vector
<
std
::
thread
>
threads_
;
std
::
vector
<
std
::
thread
>
threads_
;
std
::
vector
<
std
::
string
>
filelist_
;
std
::
vector
<
std
::
string
>
inspect_var_names_
;
std
::
string
inspect_var_name_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
vector
<
std
::
string
>
dense_comm_tensor_
;
std
::
vector
<
std
::
string
>
sparse_comm_tensor_
;
std
::
map
<
std
::
string
,
int
>
sparse_comm_data_
;
std
::
string
model_prefix_
;
std
::
string
model_prefix_
;
std
::
string
model_path_
;
std
::
string
model_path_
;
std
::
string
init_prog_file_
;
std
::
string
init_prog_file_
;
std
::
string
init_model_file_
;
std
::
string
init_model_file_
;
std
::
string
feed_name_
;
Scope
*
root_scope_
;
Scope
*
root_scope_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
private:
ProgramDesc
&
main_program_
;
TextClassDataFeed
&
data_feed_
;
std
::
vector
<
float
>
inspect_values_
;
private:
static
bool
workers_initialized_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
eb6a941f
...
@@ -38,6 +38,16 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
...
@@ -38,6 +38,16 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
vector
<
std
::
string
>
TextClassDataFeed
::
s_filelist_
;
std
::
mutex
TextClassDataFeed
::
s_locker_for_pick_file_
;
unsigned
int
TextClassDataFeed
::
s_current_file_idx_
=
0
;
size_t
TextClassDataFeed
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
TextClassDataFeed
::
s_current_epoch_
=
0
;
int
TextClassDataFeed
::
s_current_save_epoch_
=
0
;
std
::
mutex
TextClassDataFeed
::
s_locker_epoch_start_
;
std
::
condition_variable
TextClassDataFeed
::
s_condition_epoch_start_
;
bool
TextClassDataFeed
::
s_epoch_start_flag_
=
false
;
void
TextClassDataFeed
::
Init
()
{
void
TextClassDataFeed
::
Init
()
{
// hard coding for a specific datafeed
// hard coding for a specific datafeed
feed_vec_
.
resize
(
2
);
feed_vec_
.
resize
(
2
);
...
@@ -59,6 +69,12 @@ void TextClassDataFeed::Init() {
...
@@ -59,6 +69,12 @@ void TextClassDataFeed::Init() {
label_host_
.
reset
(
new
int
[
10240
],
label_host_
.
reset
(
new
int
[
10240
],
[](
int
*
p
)
{
delete
[]
p
;});
// max label in a batch
[](
int
*
p
)
{
delete
[]
p
;});
// max label in a batch
label_ptr_
=
label_host_
.
get
();
label_ptr_
=
label_host_
.
get
();
field_names_
.
clear
();
}
TextClassDataFeed
::
TextClassDataFeed
()
{
Init
();
}
}
// todo: use elegant implemention for this function
// todo: use elegant implemention for this function
...
@@ -69,6 +85,7 @@ bool TextClassDataFeed::ReadBatch() {
...
@@ -69,6 +85,7 @@ bool TextClassDataFeed::ReadBatch() {
int
inst_idx
=
0
;
int
inst_idx
=
0
;
offset
.
resize
(
batch_size_
+
1
);
offset
.
resize
(
batch_size_
+
1
);
offset
[
0
]
=
0
;
offset
[
0
]
=
0
;
while
(
inst_idx
<
batch_size_
)
{
while
(
inst_idx
<
batch_size_
)
{
int
ptr_offset
=
0
;
int
ptr_offset
=
0
;
if
(
file_content_buffer_ptr_
-
file_content_buffer_
>=
file_size_
)
{
if
(
file_content_buffer_ptr_
-
file_content_buffer_
>=
file_size_
)
{
...
@@ -125,6 +142,12 @@ bool TextClassDataFeed::ReadBatch() {
...
@@ -125,6 +142,12 @@ bool TextClassDataFeed::ReadBatch() {
return
true
;
return
true
;
}
}
TextClassDataFeed
::
TextClassDataFeed
(
const
TextClassDataFeed
&
data_feed
)
{
Init
();
SetBatchSize
(
data_feed
.
batch_size_
);
SetFieldNames
(
data_feed
.
field_names_
);
}
void
TextClassDataFeed
::
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
void
TextClassDataFeed
::
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
for
(
unsigned
int
i
=
0
;
i
<
use_slot_alias_
.
size
();
++
i
)
{
for
(
unsigned
int
i
=
0
;
i
<
use_slot_alias_
.
size
();
++
i
)
{
if
(
name
==
use_slot_alias_
[
i
])
{
if
(
name
==
use_slot_alias_
[
i
])
{
...
@@ -133,30 +156,99 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
...
@@ -133,30 +156,99 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
}
}
}
}
void
TextClassDataFeed
::
SetFileList
(
const
char
*
filelist
)
{
s_filelist_
.
clear
();
std
::
ifstream
fin
(
filelist
);
PADDLE_ENFORCE
(
fin
.
good
(),
"Opening file %s fail"
,
filelist
);
std
::
string
filename
;
while
(
fin
>>
filename
)
{
LOG
(
ERROR
)
<<
"add "
<<
filename
.
c_str
()
<<
" to filelist"
;
s_filelist_
.
push_back
(
filename
);
}
fin
.
close
();
}
void
TextClassDataFeed
::
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
)
{
field_names_
.
clear
();
field_names_
.
insert
(
field_names_
.
end
(),
field_names
.
begin
(),
field_names
.
end
());
}
bool
TextClassDataFeed
::
SetFile
(
const
char
*
filename
)
{
bool
TextClassDataFeed
::
SetFile
(
const
char
*
filename
)
{
// termnum termid termid ... termid label
// termnum termid termid ... termid label
int
filesize
=
ReadWholeFile
(
filename
,
file_content_buffer_
);
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
binary
);
// todo , remove magic number
if
(
ifs
.
fail
())
{
return
false
;
}
ifs
.
seekg
(
0
,
std
::
ios
::
end
);
int
filesize
=
ifs
.
tellg
();
ifs
.
seekg
(
0
,
std
::
ios
::
beg
);
ifs
.
read
(
file_content_buffer_
,
filesize
);
if
(
filesize
<
0
||
filesize
>=
1024
*
1024
*
1024
)
{
if
(
filesize
<
0
||
filesize
>=
1024
*
1024
*
1024
)
{
return
false
;
return
false
;
}
}
file_content_buffer_ptr_
=
file_content_buffer_
;
file_content_buffer_ptr_
=
file_content_buffer_
;
file_size_
=
filesize
;
file_size_
=
filesize
;
// todo , remove magic number
return
true
;
return
true
;
}
}
int
TextClassDataFeed
::
ReadWholeFile
(
const
std
::
string
&
filename
,
void
TextClassDataFeed
::
UpdateEpochNum
()
{
char
*
buffer
)
{
s_current_finished_file_cnt_
++
;
std
::
ifstream
ifs
(
filename
.
c_str
(),
std
::
ios
::
binary
);
if
(
ifs
.
fail
())
{
if
(
s_current_finished_file_cnt_
>=
s_filelist_
.
size
())
{
return
-
1
;
s_current_finished_file_cnt_
=
0
;
s_current_epoch_
++
;
#if 1
LOG
(
WARNING
)
<<
"UpdateEpochNum: epoch = "
<<
s_current_epoch_
;
#endif
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
false
;
}
}
}
}
ifs
.
seekg
(
0
,
std
::
ios
::
end
);
void
TextClassDataFeed
::
StartOneEpoch
()
{
int
file_size
=
ifs
.
tellg
();
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
ifs
.
seekg
(
0
,
std
::
ios
::
beg
);
std
::
random_shuffle
(
s_filelist_
.
begin
(),
s_filelist_
.
end
());
ifs
.
read
(
buffer
,
file_size
);
s_current_file_idx_
=
0
;
return
file_size
;
LOG
(
INFO
)
<<
"Beginning epoch "
<<
s_current_epoch_
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
true
;
}
s_condition_epoch_start_
.
notify_all
();
}
void
TextClassDataFeed
::
WaitNextEpoch
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_condition_epoch_start_
.
wait
(
lock
,
[]{
return
s_epoch_start_flag_
;});
}
const
char
*
TextClassDataFeed
::
PickOneFile
()
{
std
::
string
file_to_be_processed
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
// One epoch has run over
// Wait for next epoch
if
(
s_current_file_idx_
>=
s_filelist_
.
size
())
{
LOG
(
ERROR
)
<<
"thread "
<<
thread_id_
<<
": finish traing for epoch "
<<
s_current_epoch_
+
1
;
return
NULL
;
}
file_to_be_processed
=
s_filelist_
[
s_current_file_idx_
];
s_current_file_idx_
++
;
return
file_to_be_processed
.
c_str
();
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
eb6a941f
...
@@ -47,24 +47,9 @@ struct Instance {
...
@@ -47,24 +47,9 @@ struct Instance {
std
::
vector
<
Gauc
>
gauc_vec
;
std
::
vector
<
Gauc
>
gauc_vec
;
};
};
class
DataFeed
{
DataFeed
()
{}
virtual
~
DataFeed
()
{}
};
class
BlockingQueueDataFeed
:
DataFeed
{
BlockingQueueDataFeed
()
{}
virtual
~
BlockingQueueDataFeed
()
{}
};
class
ThreadedDataFeed
:
DataFeed
{
ThreadedDataFeed
()
{}
virtual
~
ThreadedDataFeed
()
{}
};
class
DataFeed
{
class
DataFeed
{
public:
public:
DataFeed
()
{}
DataFeed
()
:
default_batch_size_
(
1
),
batch_size_
(
0
),
thread_id_
(
0
)
{}
virtual
~
DataFeed
()
{}
virtual
~
DataFeed
()
{}
virtual
void
Init
()
=
0
;
virtual
void
Init
()
=
0
;
/*
/*
...
@@ -93,6 +78,11 @@ class DataFeed {
...
@@ -93,6 +78,11 @@ class DataFeed {
virtual
void
SetBatchSize
(
int
batch
)
{
default_batch_size_
=
batch
;
}
virtual
void
SetBatchSize
(
int
batch
)
{
default_batch_size_
=
batch
;
}
virtual
int
GetBatchSize
()
{
return
batch_size_
;
}
virtual
int
GetBatchSize
()
{
return
batch_size_
;
}
virtual
void
SetBufferSize
(
int
buffer_size
)
{}
virtual
void
SetBufferSize
(
int
buffer_size
)
{}
virtual
unsigned
int
GetCurrentEpoch
()
=
0
;
virtual
const
char
*
PickOneFile
()
=
0
;
virtual
void
UpdateEpochNum
()
=
0
;
virtual
void
StartOneEpoch
()
=
0
;
virtual
void
WaitNextEpoch
()
=
0
;
std
::
vector
<
LoDTensor
*>&
GetFeedVec
()
{
std
::
vector
<
LoDTensor
*>&
GetFeedVec
()
{
return
feed_vec_
;
return
feed_vec_
;
...
@@ -103,6 +93,9 @@ class DataFeed {
...
@@ -103,6 +93,9 @@ class DataFeed {
return
feed_vec_
;
return
feed_vec_
;
}
}
int
GetThreadId
()
{
return
thread_id_
;}
void
SetThreadId
(
int
thread_id
)
{
thread_id_
=
thread_id
;}
protected:
protected:
std
::
vector
<
uint16_t
>
all_slot_ids_
;
std
::
vector
<
uint16_t
>
all_slot_ids_
;
std
::
vector
<
uint16_t
>
use_slot_ids_
;
std
::
vector
<
uint16_t
>
use_slot_ids_
;
...
@@ -110,9 +103,14 @@ class DataFeed {
...
@@ -110,9 +103,14 @@ class DataFeed {
std
::
vector
<
LoDTensor
*>
feed_vec_
;
std
::
vector
<
LoDTensor
*>
feed_vec_
;
int
default_batch_size_
;
int
default_batch_size_
;
int
batch_size_
;
int
batch_size_
;
int
thread_id_
;
};
};
class
TextClassDataFeed
:
public
DataFeed
{
class
TextClassDataFeed
:
public
DataFeed
{
public:
TextClassDataFeed
();
TextClassDataFeed
(
const
TextClassDataFeed
&
data_feed
);
public:
public:
virtual
~
TextClassDataFeed
()
{}
virtual
~
TextClassDataFeed
()
{}
virtual
void
Init
();
virtual
void
Init
();
...
@@ -120,25 +118,45 @@ class TextClassDataFeed : public DataFeed {
...
@@ -120,25 +118,45 @@ class TextClassDataFeed : public DataFeed {
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
);
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
);
virtual
void
BindScope
(
Scope
*
scope
)
{}
virtual
void
BindScope
(
Scope
*
scope
)
{}
virtual
bool
SetFile
(
const
char
*
filename
);
virtual
bool
SetFile
(
const
char
*
filename
);
virtual
bool
CheckFile
(
const
char
*
filename
)
{
virtual
bool
CheckFile
(
const
char
*
filename
)
{
// TODO(xxx)
// TODO(xxx)
return
false
;
return
false
;
}
}
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
unsigned
int
GetCurrentEpoch
()
{
return
s_current_epoch_
;}
void
UpdateEpochNum
();
void
StartOneEpoch
();
void
WaitNextEpoch
();
public:
void
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
);
public:
static
void
SetFileList
(
const
char
*
filelist
);
private:
const
char
*
PickOneFile
();
private:
private:
int
ReadWholeFile
(
const
std
::
string
&
filename
,
char
*
buffer
);
char
*
file_content_buffer_
;
char
*
file_content_buffer_
;
char
*
file_content_buffer_ptr_
;
char
*
file_content_buffer_ptr_
;
int
*
batch_id_buffer_
;
int
*
batch_id_buffer_
;
int
*
label_ptr_
;
int
*
label_ptr_
;
int
file_size_
;
int
file_size_
;
std
::
vector
<
std
::
string
>
names_
;
std
::
vector
<
std
::
string
>
field_
names_
;
std
::
shared_ptr
<
char
>
file_content_buffer_host_
;
std
::
shared_ptr
<
char
>
file_content_buffer_host_
;
std
::
shared_ptr
<
int
>
batch_id_host_
;
std
::
shared_ptr
<
int
>
batch_id_host_
;
std
::
shared_ptr
<
int
>
label_host_
;
std
::
shared_ptr
<
int
>
label_host_
;
static
std
::
vector
<
std
::
string
>
s_filelist_
;
static
std
::
mutex
s_locker_for_pick_file_
;
static
unsigned
int
s_current_file_idx_
;
static
size_t
s_current_finished_file_cnt_
;
static
unsigned
int
s_current_epoch_
;
static
int
s_current_save_epoch_
;
static
std
::
mutex
s_locker_epoch_start_
;
static
std
::
condition_variable
s_condition_epoch_start_
;
static
bool
s_epoch_start_flag_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/pybind/async_executor_py.cc
浏览文件 @
eb6a941f
...
@@ -21,7 +21,10 @@ limitations under the License. */
...
@@ -21,7 +21,10 @@ limitations under the License. */
#ifdef _XOPEN_SOURCE
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#endif
#include <vector>
#include <string>
#include "paddle/fluid/pybind/async_executor_py.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
...
@@ -29,58 +32,36 @@ limitations under the License. */
...
@@ -29,58 +32,36 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/async_executor_param.pb.h"
#include "paddle/fluid/framework/async_executor_param.pb.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/
pybind/async_executor_py
.h"
#include "paddle/fluid/
framework/data_feed
.h"
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
namespace
paddle
{
namespace
paddle
{
namespace
pybind
{
namespace
pybind
{
void
BindAsyncExecutor
(
py
::
module
*
m
)
{
void
BindAsyncExecutor
(
py
::
module
*
m
)
{
py
::
class_
<
paddle
::
AsyncExecutorParameter
>
(
*
m
,
"AsyncExecutorParameter"
)
py
::
class_
<
framework
::
DataFeed
>
(
*
m
,
"DataFeed"
);
.
def
(
py
::
init
<>
())
py
::
class_
<
framework
::
TextClassDataFeed
,
.
def
(
"parse"
,
framework
::
DataFeed
>
(
*
m
,
"TextDataFeed"
)
[](
paddle
::
AsyncExecutorParameter
&
self
,
const
std
::
string
&
conf_file
)
{
.
def
(
py
::
init
())
int
file_descriptor
=
open
(
conf_file
.
c_str
(),
O_RDONLY
);
.
def
(
"set_filelist"
,
google
::
protobuf
::
io
::
FileInputStream
file_input
(
file_descriptor
);
[]
(
framework
::
TextClassDataFeed
&
self
,
const
char
*
data_list_file
)
{
google
::
protobuf
::
TextFormat
::
Parse
(
&
file_input
,
&
self
);
self
.
SetFileList
(
data_list_file
);
close
(
file_descriptor
);
})
}
.
def
(
"set_batch_size"
,
&
framework
::
TextClassDataFeed
::
SetBatchSize
)
);
.
def
(
"set_field_names"
,
&
framework
::
TextClassDataFeed
::
SetFieldNames
)
py
::
class_
<
framework
::
AsyncExecutor
>
(
*
m
,
"AsyncExecutor"
)
.
def
(
"start_one_epoch"
,
&
framework
::
TextClassDataFeed
::
StartOneEpoch
);
.
def
(
py
::
init
<
const
platform
::
Place
&>
())
.
def
(
"init"
,
[](
framework
::
AsyncExecutor
&
self
,
paddle
::
AsyncExecutorParameter
&
parameter
,
framework
::
Scope
*
scope
)
{
paddle
::
BaseParameter
base_param
=
parameter
.
base_param
();
// TODO Extract parameter list from python side, instead of
py
::
class_
<
framework
::
AsyncExecutor
>
(
*
m
,
"AsyncExecutor"
)
// providing them in confgurations manually
.
def
(
py
::
init
<
framework
::
ProgramDesc
&
,
std
::
vector
<
std
::
string
>
param_names
;
std
::
vector
<
std
::
string
>&
,
for
(
int
i
=
0
;
i
<
base_param
.
model_param_names_size
();
++
i
)
{
framework
::
TextClassDataFeed
&
,
param_names
.
push_back
(
base_param
.
model_param_names
(
i
));
unsigned
int
,
}
const
platform
::
Place
&>
())
paddle
::
framework
::
InitDevices
(
false
);
.
def
(
"init_root_scope"
,
&
framework
::
AsyncExecutor
::
InitRootScope
)
self
.
InitRootScope
(
scope
);
self
.
SetThreadNum
(
base_param
.
thread_num
());
self
.
SetMaxTrainingEpoch
(
base_param
.
max_epoch
());
self
.
SetFileList
(
base_param
.
filelist
().
c_str
());
self
.
SetBatchSize
(
base_param
.
batch_size
());
self
.
SetDataFeedName
(
base_param
.
datafeed_class
().
c_str
());
self
.
SetInspectVarName
(
base_param
.
inspect_var_name
());
self
.
SetParamNames
(
param_names
);
self
.
SetModelPath
(
base_param
.
model_path
());
self
.
SetModelPrefix
(
base_param
.
model_prefix
());
self
.
SetInitProgFile
(
base_param
.
init_prog_file
());
self
.
SetInitModelFile
(
base_param
.
init_model_file
());
return
;
}
)
.
def
(
"run_startup_program"
,
&
framework
::
AsyncExecutor
::
RunStartupProgram
)
.
def
(
"run_startup_program"
,
&
framework
::
AsyncExecutor
::
RunStartupProgram
)
.
def
(
"load_init_model"
,
&
framework
::
AsyncExecutor
::
LoadInitModel
)
.
def
(
"run"
,
&
framework
::
AsyncExecutor
::
Run
);
.
def
(
"run"
,
&
framework
::
AsyncExecutor
::
RunAsyncExecutor
);
}
// end BindAsyncExecutor
}
// end BindAsyncExecutor
}
// end namespace
framework
}
// end namespace
pybind
}
// end namespace paddle
}
// end namespace paddle
/* vim: set expandtab ts=2 sw=2 sts=2 tw=80: */
/* vim: set expandtab ts=2 sw=2 sts=2 tw=80: */
paddle/fluid/pybind/async_executor_py.h
浏览文件 @
eb6a941f
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#ifndef PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#ifndef PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#define PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#define PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#include "pybind11/pybind11.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
...
...
python/paddle/fluid/async_executor.py
浏览文件 @
eb6a941f
...
@@ -21,22 +21,28 @@ from .framework import Program, default_main_program, Variable
...
@@ -21,22 +21,28 @@ from .framework import Program, default_main_program, Variable
from
.
import
core
from
.
import
core
from
.
import
Executor
from
.
import
Executor
__all__
=
[
'
AsyncExecutorParameter
'
,
'AsyncExecutor'
]
__all__
=
[
'
TextDataFeed
'
,
'AsyncExecutor'
]
g_scope
=
core
.
Scope
()
g_scope
=
core
.
Scope
()
class
AsyncExecutorParameter
(
object
):
class
TextDataFeed
():
"""
AsyncExecutor configure parameter
Args:
None
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
parameter
=
core
.
AsyncExecutorParameter
()
self
.
feed
=
core
.
TextDataFeed
()
def
set_filelist
(
self
,
filelist
):
self
.
feed
.
set_filelist
(
filelist
)
def
set_batch_size
(
self
,
batch_size
):
self
.
feed
.
set_batch_size
(
batch_size
)
def
set_field_names
(
self
,
field_names
):
if
isinstance
(
field_names
,
Variable
):
field_names
=
[
field_names
]
self
.
feed
.
set_field_names
(
field_names
)
def
parse
(
self
,
conf_file
):
def
start_an_epoch
(
self
):
self
.
parameter
.
parse
(
conf_file
)
self
.
feed
.
start_one_epoch
(
)
class
AsyncExecutor
(
object
):
class
AsyncExecutor
(
object
):
"""
"""
...
@@ -50,39 +56,31 @@ class AsyncExecutor(object):
...
@@ -50,39 +56,31 @@ class AsyncExecutor(object):
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
async_executor_parameter
,
program
,
place
,
param_names
,
scope
):
data_feed
,
if
not
isinstance
(
async_executor_parameter
,
AsyncExecutorParameter
):
thread_num
,
raise
TypeError
(
place
=
None
,
"AsyncExecutor requires AsyncExecutorParameter as its parameter. "
scope
=
None
):
"But you passed in %s"
%
s
(
type
(
async_executor_parameter
))
if
program
is
None
:
)
program
=
default_main_program
()
program_desc
=
program
.
desc
self
.
place
=
place
p
=
core
.
Place
()
p
.
set_place
(
place
)
self
.
executor
=
core
.
AsyncExecutor
(
p
)
self
.
executor
.
init
(
async_executor_parameter
.
parameter
,
scope
)
self
.
_closed
=
False
self
.
parameter
=
async_executor_parameter
.
parameter
def
close
(
self
):
if
not
isinstance
(
data_feed
,
TextDataFeed
):
"""
raise
ValueError
(
"data_feed for AsyncExecutor.run() type error"
)
Close this executor.
You can no long use this executor after calling this method.
if
place
is
None
:
For the distributed training, this method would free the resource on PServers related to
place
=
core
.
CPUPlace
()
the current Trainer.
if
not
isinstance
(
place
,
core
.
CPUPlace
):
raise
ValueError
(
"AsyncExecutor only supports CPU device"
)
if
isinstance
(
param_names
,
Variable
):
param_names
=
[
param_names
]
p
=
core
.
Place
()
p
.
set_place
(
place
)
self
.
executor
=
core
.
AsyncExecutor
(
program_desc
,
param_names
,
data_feed
.
feed
,
thread_num
,
p
)
Example:
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> ...
>>> exe.close()
"""
if
not
self
.
_closed
:
self
.
_closed
=
True
def
run_startup_program
(
self
,
def
run_startup_program
(
self
,
program
=
None
,
program
=
None
,
scope
=
None
):
scope
=
None
):
...
@@ -95,7 +93,7 @@ class AsyncExecutor(object):
...
@@ -95,7 +93,7 @@ class AsyncExecutor(object):
self
.
executor
.
run_startup_program
(
program_desc
,
scope
)
self
.
executor
.
run_startup_program
(
program_desc
,
scope
)
def
run
(
self
,
program
=
None
,
scope
=
None
):
def
run
(
self
,
inspect_vars
,
scope
=
None
):
"""
"""
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
Python executor takes a program, add feed operators and fetch operators to this program according
...
@@ -138,23 +136,16 @@ class AsyncExecutor(object):
...
@@ -138,23 +136,16 @@ class AsyncExecutor(object):
>>> feed={'X': x},
>>> feed={'X': x},
>>> fetch_list=[loss.name])
>>> fetch_list=[loss.name])
"""
"""
if
inspect_vars
is
not
None
:
if
self
.
_closed
:
if
isinstance
(
inspect_vars
,
Variable
):
raise
RuntimeError
(
"Attempted to use a closed Executor"
)
inspect_vars
=
[
inspect_vars
]
inspect_var_names
=
[
var
.
name
for
var
in
inspect_vars
]
if
program
is
None
:
program
=
default_main_program
()
program_desc
=
program
.
desc
if
not
isinstance
(
program
,
Program
):
raise
TypeError
(
"Executor requires Program as its Parameter. But you passed in %s"
%
(
type
(
program
)))
if
scope
is
None
:
if
scope
is
None
:
scope
=
g_scope
scope
=
g_scope
self
.
executor
.
run
(
program
.
desc
)
self
.
executor
.
init_root_scope
(
scope
)
evaluation
=
self
.
executor
.
run
(
inspect_var_names
)
return
evaluation
def
load_init_model
(
self
):
return
self
.
executor
.
load_init_model
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录