Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
eb6a941f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eb6a941f
编写于
11月 14, 2018
作者:
W
wangguibao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix async_executor interfaces: 1) Remove all protobufs; 2) Stop after each epoch
上级
1d239cc8
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
396 addition
and
403 deletion
+396
-403
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+144
-227
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+35
-42
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+104
-12
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+38
-20
paddle/fluid/pybind/async_executor_py.cc
paddle/fluid/pybind/async_executor_py.cc
+24
-43
paddle/fluid/pybind/async_executor_py.h
paddle/fluid/pybind/async_executor_py.h
+1
-0
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+50
-59
未找到文件。
paddle/fluid/framework/async_executor.cc
浏览文件 @
eb6a941f
...
...
@@ -40,13 +40,8 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
std
::
mutex
ExecutorThreadWorker
::
s_locker_for_pick_file_
;
unsigned
int
ExecutorThreadWorker
::
s_current_file_idx_
=
0
;
size_t
ExecutorThreadWorker
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
ExecutorThreadWorker
::
s_current_epoch_
=
0
;
int
ExecutorThreadWorker
::
s_current_save_epoch_
=
0
;
bool
ExecutorThreadWorker
::
s_is_first_worker_
=
false
;
std
::
vector
<
std
::
string
>
ExecutorThreadWorker
::
s_thread_filelist_
;
bool
AsyncExecutor
::
workers_initialized_
=
false
;
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
)
{
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR
)
{
...
...
@@ -124,7 +119,6 @@ static void SaveModel(
{{
"X"
,
{
var
->
Name
()}}},
{},
attrs
);
save_op
->
Run
(
*
scope
,
place
);
}
else
{
paralist
.
push_back
(
var
->
Name
());
...
...
@@ -140,15 +134,14 @@ static void SaveModel(
{{
"X"
,
paralist
}},
{},
attrs
);
save_op
->
Run
(
*
scope
,
place
);
}
}
// end SaveModel
void
ExecutorThreadWorker
::
AddTrainFile
(
const
std
::
string
&
file
)
{
s_thread_filelist_
.
push_back
(
file
);
void
ExecutorThreadWorker
::
Reset
()
{
inspect_values_
.
clear
();
}
void
ExecutorThreadWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
op_names_
.
clear
();
...
...
@@ -175,8 +168,12 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
}
}
void
ExecutorThreadWorker
::
SetDataFeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
)
{
local_reader_
=
datafeed
;
void
ExecutorThreadWorker
::
SetDataFeed
(
DataFeed
&
datafeed
)
{
if
(
typeid
(
datafeed
)
==
typeid
(
TextClassDataFeed
))
{
local_reader_
.
reset
(
new
TextClassDataFeed
(
dynamic_cast
<
TextClassDataFeed
&>
(
datafeed
)));
local_reader_
->
SetThreadId
(
thread_id_
);
}
}
void
ExecutorThreadWorker
::
BindingDataFeedMemory
()
{
...
...
@@ -186,9 +183,11 @@ void ExecutorThreadWorker::BindingDataFeedMemory() {
}
}
void
ExecutorThreadWorker
::
SetInspectVarName
(
const
std
::
string
&
inspect_var_name
)
{
inspect_var_name_
=
inspect_var_name
;
void
ExecutorThreadWorker
::
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
inspect_var_names_
.
clear
();
inspect_var_names_
.
insert
(
inspect_var_names_
.
end
(),
inspect_var_names
.
begin
(),
inspect_var_names
.
end
());
}
void
ExecutorThreadWorker
::
SetModelParamNames
(
...
...
@@ -196,11 +195,6 @@ void ExecutorThreadWorker::SetModelParamNames(
model_param_names_
=
param_names
;
}
void
ExecutorThreadWorker
::
SetSparseCommData
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
)
{
sparse_comm_data_
=
param_names
;
}
void
ExecutorThreadWorker
::
SetDevice
()
{
static
unsigned
priority
[]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
...
...
@@ -228,138 +222,79 @@ void ExecutorThreadWorker::SetDevice() {
CPU_ZERO
(
&
mask
);
if
((
0
==
sched_getaffinity
(
0
,
sizeof
(
mask
),
&
mask
))
&&
CPU_ISSET
(
proc
,
&
mask
))
{
LOG
(
ERROR
)
<<
"TRACE: Thread "
<<
i
<<
" is running on processor "
<<
proc
<<
"..."
;
LOG
(
ERROR
)
<<
"TRACE: Thread "
<<
i
<<
" is running on processor "
<<
proc
<<
"..."
;
}
}
}
}
void
ExecutorThreadWorker
::
UpdateEpochNum
()
{
s_current_finished_file_cnt_
++
;
if
(
s_current_finished_file_cnt_
>=
s_thread_filelist_
.
size
())
{
s_current_finished_file_cnt_
=
0
;
s_current_epoch_
++
;
}
}
void
ExecutorThreadWorker
::
Train
()
{
LOG
(
ERROR
)
<<
"begin to train"
;
SetDevice
();
const
char
*
ExecutorThreadWorker
::
PickOneFile
()
{
std
::
string
file_to_be_preocessed
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
int
inspect_var_num
=
inspect_var_names_
.
size
();
inspect_values_
.
clear
()
;
inspect_values_
.
resize
(
inspect_var_num
,
0
);
if
(
s_current_file_idx_
>=
s_thread_filelist_
.
size
())
{
std
::
random_shuffle
(
s_thread_filelist_
.
begin
(),
s_thread_filelist_
.
end
());
s_current_file_idx_
=
0
;
// s_current_epoch_++; //example: when one file, one thread, it's bug
LOG
(
ERROR
)
<<
"thread "
<<
thread_id_
<<
": finish traing for epoch "
<<
s_current_epoch_
+
1
;
}
file_to_be_preocessed
=
s_thread_filelist_
[
s_current_file_idx_
];
local_reader_
->
WaitNextEpoch
();
int
epoch
=
local_reader_
->
GetCurrentEpoch
();
s_current_file_idx_
++
;
return
file_to_be_preocessed
.
c_str
();
}
LOG
(
ERROR
)
<<
"epoch: "
<<
epoch
;
void
ExecutorThreadWorker
::
Train
()
{
LOG
(
ERROR
)
<<
"begin to train"
;
SetDevice
();
#ifdef LOCAL_PROF
std
::
vector
<
double
>
op_total_time
;
std
::
vector
<
std
::
string
>
op_name
;
// int total_batch = 0;
for
(
auto
&
op
:
ops_
)
{
op_name
.
push_back
(
op
->
Type
());
}
op_total_time
.
resize
(
ops_
.
size
());
for
(
int
i
=
0
;
i
<
op_total_time
.
size
();
++
i
)
{
op_total_time
[
i
]
=
0.0
;
}
#endif
std
::
string
inspect_key
=
"inspect"
;
if
(
!
inspect_var_name_
.
empty
())
{
inspect_key
=
inspect_var_name_
.
substr
(
0
,
inspect_var_name_
.
find_first_of
(
'_'
));
}
for
(
unsigned
i
=
0
;
i
<
max_epoch_
;
++
i
)
{
LOG
(
ERROR
)
<<
"epoch: "
<<
i
;
#ifdef LOCAL_PROF
Timer
timeline
;
double
total_time
=
0.0
;
double
read_time
=
0.0
;
#endif
float
total_inspect
=
0
;
int
batch_num
=
1
;
while
(
i
==
s_current_epoch_
)
{
const
char
*
filename
=
PickOneFile
();
local_reader_
->
SetFile
(
filename
);
while
(
true
)
{
#ifdef LOCAL_PROF
timeline
.
start
();
#endif
bool
flag
=
local_reader_
->
ReadBatch
();
if
(
!
flag
)
{
const
char
*
file
=
local_reader_
->
PickOneFile
();
if
(
file
==
NULL
)
{
break
;
}
#ifdef LOCAL_PROF
timeline
.
pause
();
read_time
+=
timeline
.
elapsed_sec
();
total_time
+=
timeline
.
elapsed_sec
();
#endif
if
(
!
local_reader_
->
SetFile
(
file
))
{
break
;
}
while
(
true
)
{
bool
flag
=
local_reader_
->
ReadBatch
();
if
(
!
flag
)
{
break
;
}
for
(
unsigned
int
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
#ifdef LOCAL_PROF
timeline
.
start
();
#endif
ops_
[
i
]
->
Run
(
*
thread_scope_
,
place_
);
#ifdef LOCAL_PROF
timeline
.
pause
();
op_total_time
[
i
]
+=
timeline
.
elapsed_sec
();
total_time
+=
timeline
.
elapsed_sec
();
#endif
}
batch_num
++
;
float
avg_inspect
=
0.0
;
if
(
!
inspect_var_name_
.
empty
()
)
{
avg_inspect
=
thread_scope_
->
FindVar
(
inspect_var_name_
)
for
(
int
i
=
0
;
i
<
inspect_var_num
;
++
i
)
{
avg_inspect
=
thread_scope_
->
FindVar
(
inspect_var_names_
[
i
]
)
->
GetMutable
<
LoDTensor
>
()
->
data
<
float
>
()[
0
];
inspect_values_
[
i
]
+=
avg_inspect
;
}
total_inspect
+=
avg_inspect
;
thread_scope_
->
DropKids
();
}
UpdateEpochNum
();
LOG
(
ERROR
)
<<
"memory used after epoch "
<<
i
+
1
local_reader_
->
UpdateEpochNum
();
LOG
(
ERROR
)
<<
"memory used after epoch "
<<
epoch
+
1
<<
" called: "
<<
memory
::
memory_usage
(
place_
);
}
for
(
int
i
=
0
;
i
<
inspect_var_num
;
++
i
)
{
inspect_values_
[
i
]
/=
batch_num
;
std
::
string
var
=
inspect_var_names_
[
i
].
substr
(
0
,
inspect_var_names_
[
i
].
find_first_of
(
"_"
));
LOG
(
ERROR
)
<<
"mean "
<<
var
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
inspect_values_
[
i
];
}
#ifdef LOCAL_PROF
for
(
int
i
=
0
;
i
<
op_total_time
.
size
();
++
i
)
{
std
::
cerr
<<
"op_name:["
<<
i
<<
"]["
<<
op_name
[
i
]
<<
"]"
<<
" op_mean_time:["
<<
op_total_time
[
i
]
<<
"s]"
<<
std
::
endl
;
}
std
::
cerr
<<
"read time: "
<<
read_time
<<
"s"
<<
std
::
endl
;
#endif
}
#ifdef LOCAL_PROF
LOG
(
ERROR
)
<<
"mean "
<<
inspect_key
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
total_inspect
/
batch_num
<<
", total_time: "
<<
total_time
;
#else
LOG
(
ERROR
)
<<
"mean "
<<
inspect_key
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
total_inspect
/
batch_num
;
#endif
if
(
thread_id_
==
0
)
{
char
modelfile
[
1024
];
snprintf
(
&
modelfile
[
0
],
sizeof
(
modelfile
),
"%s_epoch%d.model"
,
model_prefix_
.
c_str
(),
i
);
snprintf
(
&
modelfile
[
0
],
sizeof
(
modelfile
),
"%s_epoch%d.model"
,
model_prefix_
.
c_str
(),
epoch
);
std
::
string
model_filename
=
std
::
string
(
modelfile
);
// this save_inference_model can only save imdbtask, should make this
// general
...
...
@@ -372,7 +307,6 @@ void ExecutorThreadWorker::Train() {
model_filename
,
true
);
}
}
}
void
ExecutorThreadWorker
::
SetThreadId
(
int
tid
)
{
...
...
@@ -396,7 +330,20 @@ void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_
=
max_epoch
;
}
AsyncExecutor
::
AsyncExecutor
(
const
platform
::
Place
&
place
)
:
place_
(
place
)
{}
AsyncExecutor
::
AsyncExecutor
(
ProgramDesc
&
main_program
,
const
std
::
vector
<
std
::
string
>&
param_names
,
TextClassDataFeed
&
data_feed
,
unsigned
int
thread_num
,
const
platform
::
Place
&
place
)
:
thread_num_
(
thread_num
),
place_
(
place
),
main_program_
(
main_program
),
data_feed_
(
data_feed
)
{
model_param_names_
.
clear
();
model_param_names_
.
insert
(
model_param_names_
.
end
(),
param_names
.
begin
(),
param_names
.
end
());
}
void
AsyncExecutor
::
InitRootScope
(
Scope
*
scope
)
{
root_scope_
=
scope
;
...
...
@@ -406,10 +353,6 @@ void AsyncExecutor::SetMaxTrainingEpoch(int max_epoch) {
max_epoch_
=
max_epoch
;
}
void
AsyncExecutor
::
SetDataFeedName
(
const
char
*
feedname
)
{
feed_name_
=
std
::
string
(
feedname
);
}
void
AsyncExecutor
::
SetModelPrefix
(
const
std
::
string
&
model_prefix
)
{
model_prefix_
=
model_prefix
;
}
...
...
@@ -463,60 +406,16 @@ std::unique_ptr<ProgramDesc> AsyncExecutor::LoadDescFromFile(
return
program
;
}
void
AsyncExecutor
::
SetDenseCommTensor
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
)
{
dense_comm_tensor_
.
resize
(
dense_comm_tensor
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
dense_comm_tensor
.
size
();
++
i
)
{
dense_comm_tensor_
[
i
]
=
dense_comm_tensor
[
i
];
}
}
void
AsyncExecutor
::
SetSparseCommTensor
(
const
std
::
vector
<
std
::
string
>&
sparse_comm_tensor
)
{
sparse_comm_tensor_
.
resize
(
sparse_comm_tensor
.
size
());
for
(
unsigned
int
i
=
0
;
i
<
sparse_comm_tensor
.
size
();
++
i
)
{
sparse_comm_tensor_
[
i
]
=
sparse_comm_tensor
[
i
];
}
}
void
AsyncExecutor
::
SetSparseCommData
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
)
{
sparse_comm_data_
=
sparse_comm_data
;
LOG
(
INFO
)
<<
"Sparse comm data: "
<<
sparse_comm_data_
.
size
();
}
void
AsyncExecutor
::
SetFileList
(
const
char
*
filelist
)
{
filelist_
.
clear
();
std
::
ifstream
fin
(
filelist
);
std
::
string
filename
;
while
(
fin
>>
filename
)
{
LOG
(
ERROR
)
<<
"add "
<<
filename
.
c_str
()
<<
" to filelist"
;
filelist_
.
push_back
(
filename
);
}
fin
.
close
();
}
void
AsyncExecutor
::
SetFileList
(
std
::
vector
<
std
::
string
>
tfiles
)
{
filelist_
.
clear
();
filelist_
.
insert
(
filelist_
.
end
(),
tfiles
.
begin
(),
tfiles
.
end
());
return
;
}
void
AsyncExecutor
::
SetInspectVarName
(
const
std
::
string
&
inspect_var_name
)
{
inspect_var_name_
=
inspect_var_name
;
}
void
AsyncExecutor
::
SetParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
model_param_names_
=
param_names
;
}
void
AsyncExecutor
::
SetThreadNum
(
const
int
thread_num
)
{
thread_num_
=
thread_num
;
void
AsyncExecutor
::
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
inspect_var_names_
.
clear
();
inspect_var_names_
.
insert
(
inspect_var_names_
.
end
(),
inspect_var_names
.
begin
(),
inspect_var_names
.
end
());
}
void
AsyncExecutor
::
PrepareThreads
(
const
ProgramDesc
&
host_program
)
{
workers_
.
resize
(
thread_num_
);
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
].
reset
(
new
ExecutorThreadWorker
);
workers_
[
i
]
->
SetThreadId
(
i
);
workers_
[
i
]
->
CreateThreadOperators
(
host_program
);
...
...
@@ -524,34 +423,31 @@ void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) {
workers_
[
i
]
->
SetPlace
(
place_
);
workers_
[
i
]
->
SetMaxTrainingEpoch
(
max_epoch_
);
workers_
[
i
]
->
CreateThreadScope
(
host_program
);
workers_
[
i
]
->
SetInspectVarName
(
inspect_var_name
_
);
workers_
[
i
]
->
SetInspectVarName
s
(
inspect_var_names
_
);
workers_
[
i
]
->
SetModelParamNames
(
model_param_names_
);
workers_
[
i
]
->
SetSparseCommData
(
sparse_comm_data_
);
workers_
[
i
]
->
SetMainProgram
(
host_program
);
workers_
[
i
]
->
SetModelPrefix
(
model_prefix_
);
}
for
(
unsigned
i
=
0
;
i
<
filelist_
.
size
();
++
i
)
{
// suppose at least one trainer thread here, and
// filelist is static so that we only add filelist once
workers_
[
0
]
->
AddTrainFile
(
filelist_
[
i
]);
}
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
//
// new a datafeed here
std
::
shared_ptr
<
DataFeed
>
local_feed
=
CreateDataFeed
(
feed_name_
.
c_str
());
local_feed
->
Init
();
local_feed
->
SetBatchSize
(
batch_size_
);
workers_
[
i
]
->
SetDataFeed
(
local_feed
);
workers_
[
i
]
->
SetDataFeed
(
data_feed_
);
workers_
[
i
]
->
BindingDataFeedMemory
();
workers_
[
i
]
->
SetThreadId
(
i
);
}
}
void
AsyncExecutor
::
RunAsyncExecutor
(
const
ProgramDesc
&
host_program
)
{
std
::
vector
<
float
>&
AsyncExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
SetInspectVarNames
(
inspect_var_names
);
threads_
.
clear
();
// thread binding here?
PrepareThreads
(
host_program
);
for
(
unsigned
i
=
0
;
i
<
thread_num_
;
++
i
)
{
if
(
workers_initialized_
==
false
)
{
PrepareThreads
(
main_program_
);
workers_initialized_
=
true
;
}
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
->
Reset
();
workers_
[
i
]
->
SetInspectVarNames
(
inspect_var_names
);
threads_
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
Train
,
workers_
[
i
].
get
()));
}
...
...
@@ -559,6 +455,27 @@ void AsyncExecutor::RunAsyncExecutor(const ProgramDesc& host_program) {
for
(
auto
&
th
:
threads_
)
{
th
.
join
();
}
inspect_values_
.
clear
();
inspect_values_
.
resize
(
inspect_var_names_
.
size
(),
0
);
std
::
vector
<
std
::
vector
<
float
>*>
inspect_value_vectors
;
inspect_value_vectors
.
resize
(
thread_num_
);
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
inspect_value_vectors
[
i
]
=
&
workers_
[
i
]
->
GetInspectValues
();
}
for
(
unsigned
int
i
=
0
;
i
<
inspect_var_names_
.
size
();
++
i
)
{
float
value
=
0.0
;
for
(
int
j
=
0
;
j
<
thread_num_
;
++
j
)
{
value
+=
inspect_value_vectors
[
j
]
->
at
(
i
);
}
value
/=
thread_num_
;
inspect_values_
[
i
]
=
value
;
}
return
inspect_values_
;
}
void
AsyncExecutor
::
LoadInitModel
()
{
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
eb6a941f
...
...
@@ -22,6 +22,7 @@ limitations under the License. */
#include <string>
#include <thread> // NOLINT
#include <vector>
#include <typeinfo>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h"
...
...
@@ -36,10 +37,9 @@ class ExecutorThreadWorker {
public:
ExecutorThreadWorker
()
{}
~
ExecutorThreadWorker
()
{}
void
CreateThreadScope
(
const
framework
::
ProgramDesc
&
program
);
void
SetDataFeed
(
const
DataFeed
&
datafeed
);
void
CreateThreadScope
(
const
ProgramDesc
&
program
);
void
SetThreadId
(
int
tid
);
void
CreateThreadOperators
(
const
framework
::
ProgramDesc
&
program
);
void
CreateThreadOperators
(
const
ProgramDesc
&
program
);
void
SetRootScope
(
Scope
*
g_scope
);
void
SetDevice
();
void
AddFidSet
();
...
...
@@ -52,25 +52,16 @@ class ExecutorThreadWorker {
void
SetModelPrefix
(
const
std
::
string
&
prefix
)
{
model_prefix_
=
prefix
;
}
void
SetInspectVarName
(
const
std
::
string
&
inspect_var_name
);
void
SetInspectVarName
s
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
void
SetModelParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
SetSparseCommData
(
const
std
::
map
<
std
::
string
,
int
>&
param_names
);
void
SetDataFeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
);
void
SetDataFeed
(
DataFeed
&
datafeed
);
// NOLINT
void
Train
();
const
char
*
PickOneFile
();
void
UpdateEpochNum
();
void
Reset
();
void
SetDenseCommTensor
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{}
void
Initialize
()
{}
public:
static
std
::
mutex
s_locker_for_pick_file_
;
static
unsigned
int
s_current_file_idx_
;
static
size_t
s_current_finished_file_cnt_
;
static
unsigned
int
s_current_epoch_
;
static
int
s_current_save_epoch_
;
static
std
::
vector
<
std
::
string
>
s_thread_filelist_
;
// filelist
static
bool
s_is_first_worker_
;
std
::
vector
<
float
>&
GetInspectValues
()
{
return
inspect_values_
;}
protected:
// thread index
...
...
@@ -88,14 +79,13 @@ class ExecutorThreadWorker {
std
::
vector
<
OperatorBase
*>
ops_
;
// main program for training
std
::
unique_ptr
<
framework
::
ProgramDesc
>
main_program_
;
std
::
unique_ptr
<
ProgramDesc
>
main_program_
;
// binary data reader
std
::
shared
_ptr
<
DataFeed
>
local_reader_
;
std
::
unique
_ptr
<
DataFeed
>
local_reader_
;
std
::
string
inspect_var_name
_
;
std
::
vector
<
std
::
string
>
inspect_var_names
_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
map
<
std
::
string
,
int
>
sparse_comm_data_
;
// execution place
platform
::
Place
place_
;
...
...
@@ -105,24 +95,26 @@ class ExecutorThreadWorker {
// a thread scope, father scope is global score which is shared
Scope
*
thread_scope_
;
private:
std
::
vector
<
float
>
inspect_values_
;
};
class
AsyncExecutor
{
public:
explicit
AsyncExecutor
(
const
platform
::
Place
&
place
);
explicit
AsyncExecutor
(
ProgramDesc
&
main_program
,
// NOLINT
const
std
::
vector
<
std
::
string
>&
param_names
,
TextClassDataFeed
&
data_feed
,
// NOLINT
unsigned
int
thread_num
,
const
platform
::
Place
&
place
);
virtual
~
AsyncExecutor
()
{}
static
std
::
unique_ptr
<
ProgramDesc
>
LoadDescFromFile
(
const
std
::
string
&
filename
);
void
InitRootScope
(
Scope
*
scope
);
void
SetInspectVarName
(
const
std
::
string
&
inspect_var_name
);
void
SetParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
SetMaxTrainingEpoch
(
const
int
max_epoch
);
Scope
*
GetRootScope
()
{
return
root_scope_
;
}
void
SetThreadNum
(
const
int
thread_num
);
void
SetBatchSize
(
const
int
batch_size
)
{
batch_size_
=
batch_size
;
}
void
SetFileList
(
const
char
*
filelist
);
void
SetFileList
(
const
std
::
vector
<
std
::
string
>
filelist
);
void
SetDataFeedName
(
const
char
*
feedname
);
void
SetCommBatch
(
int
comm_batch
)
{
comm_batch_
=
comm_batch
;
}
...
...
@@ -140,37 +132,38 @@ class AsyncExecutor {
}
void
SetModelPrefix
(
const
std
::
string
&
model_prefix
);
void
SetDenseCommTensor
(
const
std
::
vector
<
std
::
string
>&
dense_comm_tensor
);
void
SetSparseCommTensor
(
const
std
::
vector
<
std
::
string
>&
sparse_comm_tensor
);
void
SetSparseCommData
(
const
std
::
map
<
std
::
string
,
int
>&
sparse_comm_data
);
virtual
void
PrepareThreads
(
const
framework
::
ProgramDesc
&
host_program
);
void
RunStartupProgram
(
const
framework
::
ProgramDesc
&
program
,
framework
::
Scope
*
scope
);
void
RunAsyncExecutor
(
const
ProgramDesc
&
host_program
);
virtual
void
PrepareThreads
(
const
ProgramDesc
&
host_program
);
void
RunStartupProgram
(
const
ProgramDesc
&
program
,
Scope
*
scope
);
std
::
vector
<
float
>&
Run
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
void
LoadInitModel
();
private:
void
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
public:
unsigned
int
thread_num_
;
int
thread_num_
;
int
max_epoch_
;
int
batch_size_
;
int
comm_batch_
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
workers_
;
std
::
vector
<
std
::
thread
>
threads_
;
std
::
vector
<
std
::
string
>
filelist_
;
std
::
string
inspect_var_name_
;
std
::
vector
<
std
::
string
>
inspect_var_names_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
vector
<
std
::
string
>
dense_comm_tensor_
;
std
::
vector
<
std
::
string
>
sparse_comm_tensor_
;
std
::
map
<
std
::
string
,
int
>
sparse_comm_data_
;
std
::
string
model_prefix_
;
std
::
string
model_path_
;
std
::
string
init_prog_file_
;
std
::
string
init_model_file_
;
std
::
string
feed_name_
;
Scope
*
root_scope_
;
platform
::
Place
place_
;
private:
ProgramDesc
&
main_program_
;
TextClassDataFeed
&
data_feed_
;
std
::
vector
<
float
>
inspect_values_
;
private:
static
bool
workers_initialized_
;
};
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
eb6a941f
...
...
@@ -38,6 +38,16 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace
paddle
{
namespace
framework
{
std
::
vector
<
std
::
string
>
TextClassDataFeed
::
s_filelist_
;
std
::
mutex
TextClassDataFeed
::
s_locker_for_pick_file_
;
unsigned
int
TextClassDataFeed
::
s_current_file_idx_
=
0
;
size_t
TextClassDataFeed
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
TextClassDataFeed
::
s_current_epoch_
=
0
;
int
TextClassDataFeed
::
s_current_save_epoch_
=
0
;
std
::
mutex
TextClassDataFeed
::
s_locker_epoch_start_
;
std
::
condition_variable
TextClassDataFeed
::
s_condition_epoch_start_
;
bool
TextClassDataFeed
::
s_epoch_start_flag_
=
false
;
void
TextClassDataFeed
::
Init
()
{
// hard coding for a specific datafeed
feed_vec_
.
resize
(
2
);
...
...
@@ -59,6 +69,12 @@ void TextClassDataFeed::Init() {
label_host_
.
reset
(
new
int
[
10240
],
[](
int
*
p
)
{
delete
[]
p
;});
// max label in a batch
label_ptr_
=
label_host_
.
get
();
field_names_
.
clear
();
}
TextClassDataFeed
::
TextClassDataFeed
()
{
Init
();
}
// todo: use elegant implemention for this function
...
...
@@ -69,6 +85,7 @@ bool TextClassDataFeed::ReadBatch() {
int
inst_idx
=
0
;
offset
.
resize
(
batch_size_
+
1
);
offset
[
0
]
=
0
;
while
(
inst_idx
<
batch_size_
)
{
int
ptr_offset
=
0
;
if
(
file_content_buffer_ptr_
-
file_content_buffer_
>=
file_size_
)
{
...
...
@@ -125,6 +142,12 @@ bool TextClassDataFeed::ReadBatch() {
return
true
;
}
TextClassDataFeed
::
TextClassDataFeed
(
const
TextClassDataFeed
&
data_feed
)
{
Init
();
SetBatchSize
(
data_feed
.
batch_size_
);
SetFieldNames
(
data_feed
.
field_names_
);
}
void
TextClassDataFeed
::
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
for
(
unsigned
int
i
=
0
;
i
<
use_slot_alias_
.
size
();
++
i
)
{
if
(
name
==
use_slot_alias_
[
i
])
{
...
...
@@ -133,30 +156,99 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
}
}
void
TextClassDataFeed
::
SetFileList
(
const
char
*
filelist
)
{
s_filelist_
.
clear
();
std
::
ifstream
fin
(
filelist
);
PADDLE_ENFORCE
(
fin
.
good
(),
"Opening file %s fail"
,
filelist
);
std
::
string
filename
;
while
(
fin
>>
filename
)
{
LOG
(
ERROR
)
<<
"add "
<<
filename
.
c_str
()
<<
" to filelist"
;
s_filelist_
.
push_back
(
filename
);
}
fin
.
close
();
}
void
TextClassDataFeed
::
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
)
{
field_names_
.
clear
();
field_names_
.
insert
(
field_names_
.
end
(),
field_names
.
begin
(),
field_names
.
end
());
}
bool
TextClassDataFeed
::
SetFile
(
const
char
*
filename
)
{
// termnum termid termid ... termid label
int
filesize
=
ReadWholeFile
(
filename
,
file_content_buffer_
);
// todo , remove magic number
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
binary
);
if
(
ifs
.
fail
())
{
return
false
;
}
ifs
.
seekg
(
0
,
std
::
ios
::
end
);
int
filesize
=
ifs
.
tellg
();
ifs
.
seekg
(
0
,
std
::
ios
::
beg
);
ifs
.
read
(
file_content_buffer_
,
filesize
);
if
(
filesize
<
0
||
filesize
>=
1024
*
1024
*
1024
)
{
return
false
;
}
file_content_buffer_ptr_
=
file_content_buffer_
;
file_size_
=
filesize
;
// todo , remove magic number
return
true
;
}
int
TextClassDataFeed
::
ReadWholeFile
(
const
std
::
string
&
filename
,
char
*
buffer
)
{
std
::
ifstream
ifs
(
filename
.
c_str
(),
std
::
ios
::
binary
);
if
(
ifs
.
fail
())
{
return
-
1
;
void
TextClassDataFeed
::
UpdateEpochNum
()
{
s_current_finished_file_cnt_
++
;
if
(
s_current_finished_file_cnt_
>=
s_filelist_
.
size
())
{
s_current_finished_file_cnt_
=
0
;
s_current_epoch_
++
;
#if 1
LOG
(
WARNING
)
<<
"UpdateEpochNum: epoch = "
<<
s_current_epoch_
;
#endif
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
false
;
}
}
}
ifs
.
seekg
(
0
,
std
::
ios
::
end
);
int
file_size
=
ifs
.
tellg
();
ifs
.
seekg
(
0
,
std
::
ios
::
beg
);
ifs
.
read
(
buffer
,
file_size
);
return
file_size
;
void
TextClassDataFeed
::
StartOneEpoch
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
std
::
random_shuffle
(
s_filelist_
.
begin
(),
s_filelist_
.
end
());
s_current_file_idx_
=
0
;
LOG
(
INFO
)
<<
"Beginning epoch "
<<
s_current_epoch_
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
true
;
}
s_condition_epoch_start_
.
notify_all
();
}
void
TextClassDataFeed
::
WaitNextEpoch
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_condition_epoch_start_
.
wait
(
lock
,
[]{
return
s_epoch_start_flag_
;});
}
const
char
*
TextClassDataFeed
::
PickOneFile
()
{
std
::
string
file_to_be_processed
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
// One epoch has run over
// Wait for next epoch
if
(
s_current_file_idx_
>=
s_filelist_
.
size
())
{
LOG
(
ERROR
)
<<
"thread "
<<
thread_id_
<<
": finish traing for epoch "
<<
s_current_epoch_
+
1
;
return
NULL
;
}
file_to_be_processed
=
s_filelist_
[
s_current_file_idx_
];
s_current_file_idx_
++
;
return
file_to_be_processed
.
c_str
();
}
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
eb6a941f
...
...
@@ -47,24 +47,9 @@ struct Instance {
std
::
vector
<
Gauc
>
gauc_vec
;
};
class
DataFeed
{
DataFeed
()
{}
virtual
~
DataFeed
()
{}
};
class
BlockingQueueDataFeed
:
DataFeed
{
BlockingQueueDataFeed
()
{}
virtual
~
BlockingQueueDataFeed
()
{}
};
class
ThreadedDataFeed
:
DataFeed
{
ThreadedDataFeed
()
{}
virtual
~
ThreadedDataFeed
()
{}
};
class
DataFeed
{
public:
DataFeed
()
{}
DataFeed
()
:
default_batch_size_
(
1
),
batch_size_
(
0
),
thread_id_
(
0
)
{}
virtual
~
DataFeed
()
{}
virtual
void
Init
()
=
0
;
/*
...
...
@@ -93,6 +78,11 @@ class DataFeed {
virtual
void
SetBatchSize
(
int
batch
)
{
default_batch_size_
=
batch
;
}
virtual
int
GetBatchSize
()
{
return
batch_size_
;
}
virtual
void
SetBufferSize
(
int
buffer_size
)
{}
virtual
unsigned
int
GetCurrentEpoch
()
=
0
;
virtual
const
char
*
PickOneFile
()
=
0
;
virtual
void
UpdateEpochNum
()
=
0
;
virtual
void
StartOneEpoch
()
=
0
;
virtual
void
WaitNextEpoch
()
=
0
;
std
::
vector
<
LoDTensor
*>&
GetFeedVec
()
{
return
feed_vec_
;
...
...
@@ -103,6 +93,9 @@ class DataFeed {
return
feed_vec_
;
}
int
GetThreadId
()
{
return
thread_id_
;}
void
SetThreadId
(
int
thread_id
)
{
thread_id_
=
thread_id
;}
protected:
std
::
vector
<
uint16_t
>
all_slot_ids_
;
std
::
vector
<
uint16_t
>
use_slot_ids_
;
...
...
@@ -110,9 +103,14 @@ class DataFeed {
std
::
vector
<
LoDTensor
*>
feed_vec_
;
int
default_batch_size_
;
int
batch_size_
;
int
thread_id_
;
};
class
TextClassDataFeed
:
public
DataFeed
{
public:
TextClassDataFeed
();
TextClassDataFeed
(
const
TextClassDataFeed
&
data_feed
);
public:
virtual
~
TextClassDataFeed
()
{}
virtual
void
Init
();
...
...
@@ -120,25 +118,45 @@ class TextClassDataFeed : public DataFeed {
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
);
virtual
void
BindScope
(
Scope
*
scope
)
{}
virtual
bool
SetFile
(
const
char
*
filename
);
virtual
bool
CheckFile
(
const
char
*
filename
)
{
// TODO(xxx)
return
false
;
}
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
unsigned
int
GetCurrentEpoch
()
{
return
s_current_epoch_
;}
void
UpdateEpochNum
();
void
StartOneEpoch
();
void
WaitNextEpoch
();
public:
void
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
);
public:
static
void
SetFileList
(
const
char
*
filelist
);
private:
const
char
*
PickOneFile
();
private:
int
ReadWholeFile
(
const
std
::
string
&
filename
,
char
*
buffer
);
char
*
file_content_buffer_
;
char
*
file_content_buffer_ptr_
;
int
*
batch_id_buffer_
;
int
*
label_ptr_
;
int
file_size_
;
std
::
vector
<
std
::
string
>
names_
;
std
::
vector
<
std
::
string
>
field_
names_
;
std
::
shared_ptr
<
char
>
file_content_buffer_host_
;
std
::
shared_ptr
<
int
>
batch_id_host_
;
std
::
shared_ptr
<
int
>
label_host_
;
static
std
::
vector
<
std
::
string
>
s_filelist_
;
static
std
::
mutex
s_locker_for_pick_file_
;
static
unsigned
int
s_current_file_idx_
;
static
size_t
s_current_finished_file_cnt_
;
static
unsigned
int
s_current_epoch_
;
static
int
s_current_save_epoch_
;
static
std
::
mutex
s_locker_epoch_start_
;
static
std
::
condition_variable
s_condition_epoch_start_
;
static
bool
s_epoch_start_flag_
;
};
}
// namespace framework
...
...
paddle/fluid/pybind/async_executor_py.cc
浏览文件 @
eb6a941f
...
...
@@ -21,7 +21,10 @@ limitations under the License. */
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include <vector>
#include <string>
#include "paddle/fluid/pybind/async_executor_py.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "paddle/fluid/inference/io.h"
...
...
@@ -29,58 +32,36 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/async_executor_param.pb.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/
pybind/async_executor_py
.h"
#include "paddle/fluid/
framework/data_feed
.h"
namespace
py
=
pybind11
;
namespace
paddle
{
namespace
pybind
{
void
BindAsyncExecutor
(
py
::
module
*
m
)
{
py
::
class_
<
paddle
::
AsyncExecutorParameter
>
(
*
m
,
"AsyncExecutorParameter"
)
.
def
(
py
::
init
<>
())
.
def
(
"parse"
,
[](
paddle
::
AsyncExecutorParameter
&
self
,
const
std
::
string
&
conf_file
)
{
int
file_descriptor
=
open
(
conf_file
.
c_str
(),
O_RDONLY
);
google
::
protobuf
::
io
::
FileInputStream
file_input
(
file_descriptor
);
google
::
protobuf
::
TextFormat
::
Parse
(
&
file_input
,
&
self
);
close
(
file_descriptor
);
}
);
py
::
class_
<
framework
::
AsyncExecutor
>
(
*
m
,
"AsyncExecutor"
)
.
def
(
py
::
init
<
const
platform
::
Place
&>
())
.
def
(
"init"
,
[](
framework
::
AsyncExecutor
&
self
,
paddle
::
AsyncExecutorParameter
&
parameter
,
framework
::
Scope
*
scope
)
{
paddle
::
BaseParameter
base_param
=
parameter
.
base_param
();
py
::
class_
<
framework
::
DataFeed
>
(
*
m
,
"DataFeed"
);
py
::
class_
<
framework
::
TextClassDataFeed
,
framework
::
DataFeed
>
(
*
m
,
"TextDataFeed"
)
.
def
(
py
::
init
())
.
def
(
"set_filelist"
,
[]
(
framework
::
TextClassDataFeed
&
self
,
const
char
*
data_list_file
)
{
self
.
SetFileList
(
data_list_file
);
})
.
def
(
"set_batch_size"
,
&
framework
::
TextClassDataFeed
::
SetBatchSize
)
.
def
(
"set_field_names"
,
&
framework
::
TextClassDataFeed
::
SetFieldNames
)
.
def
(
"start_one_epoch"
,
&
framework
::
TextClassDataFeed
::
StartOneEpoch
);
// TODO Extract parameter list from python side, instead of
// providing them in confgurations manually
std
::
vector
<
std
::
string
>
param_names
;
for
(
int
i
=
0
;
i
<
base_param
.
model_param_names_size
();
++
i
)
{
param_names
.
push_back
(
base_param
.
model_param_names
(
i
));
}
paddle
::
framework
::
InitDevices
(
false
);
self
.
InitRootScope
(
scope
);
self
.
SetThreadNum
(
base_param
.
thread_num
());
self
.
SetMaxTrainingEpoch
(
base_param
.
max_epoch
());
self
.
SetFileList
(
base_param
.
filelist
().
c_str
());
self
.
SetBatchSize
(
base_param
.
batch_size
());
self
.
SetDataFeedName
(
base_param
.
datafeed_class
().
c_str
());
self
.
SetInspectVarName
(
base_param
.
inspect_var_name
());
self
.
SetParamNames
(
param_names
);
self
.
SetModelPath
(
base_param
.
model_path
());
self
.
SetModelPrefix
(
base_param
.
model_prefix
());
self
.
SetInitProgFile
(
base_param
.
init_prog_file
());
self
.
SetInitModelFile
(
base_param
.
init_model_file
());
return
;
}
)
py
::
class_
<
framework
::
AsyncExecutor
>
(
*
m
,
"AsyncExecutor"
)
.
def
(
py
::
init
<
framework
::
ProgramDesc
&
,
std
::
vector
<
std
::
string
>&
,
framework
::
TextClassDataFeed
&
,
unsigned
int
,
const
platform
::
Place
&>
())
.
def
(
"init_root_scope"
,
&
framework
::
AsyncExecutor
::
InitRootScope
)
.
def
(
"run_startup_program"
,
&
framework
::
AsyncExecutor
::
RunStartupProgram
)
.
def
(
"load_init_model"
,
&
framework
::
AsyncExecutor
::
LoadInitModel
)
.
def
(
"run"
,
&
framework
::
AsyncExecutor
::
RunAsyncExecutor
);
.
def
(
"run"
,
&
framework
::
AsyncExecutor
::
Run
);
}
// end BindAsyncExecutor
}
// end namespace
framework
}
// end namespace
pybind
}
// end namespace paddle
/* vim: set expandtab ts=2 sw=2 sts=2 tw=80: */
paddle/fluid/pybind/async_executor_py.h
浏览文件 @
eb6a941f
...
...
@@ -15,6 +15,7 @@
#ifndef PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#define PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
namespace
py
=
pybind11
;
...
...
python/paddle/fluid/async_executor.py
浏览文件 @
eb6a941f
...
...
@@ -21,22 +21,28 @@ from .framework import Program, default_main_program, Variable
from
.
import
core
from
.
import
Executor
__all__
=
[
'
AsyncExecutorParameter
'
,
'AsyncExecutor'
]
__all__
=
[
'
TextDataFeed
'
,
'AsyncExecutor'
]
g_scope
=
core
.
Scope
()
class
AsyncExecutorParameter
(
object
):
"""
AsyncExecutor configure parameter
Args:
None
"""
class
TextDataFeed
():
def
__init__
(
self
):
self
.
parameter
=
core
.
AsyncExecutorParameter
()
self
.
feed
=
core
.
TextDataFeed
()
def
set_filelist
(
self
,
filelist
):
self
.
feed
.
set_filelist
(
filelist
)
def
set_batch_size
(
self
,
batch_size
):
self
.
feed
.
set_batch_size
(
batch_size
)
def
set_field_names
(
self
,
field_names
):
if
isinstance
(
field_names
,
Variable
):
field_names
=
[
field_names
]
self
.
feed
.
set_field_names
(
field_names
)
def
parse
(
self
,
conf_file
):
self
.
parameter
.
parse
(
conf_file
)
def
start_an_epoch
(
self
):
self
.
feed
.
start_one_epoch
(
)
class
AsyncExecutor
(
object
):
"""
...
...
@@ -50,39 +56,31 @@ class AsyncExecutor(object):
"""
def
__init__
(
self
,
async_executor_parameter
,
place
,
scope
):
if
not
isinstance
(
async_executor_parameter
,
AsyncExecutorParameter
):
raise
TypeError
(
"AsyncExecutor requires AsyncExecutorParameter as its parameter. "
"But you passed in %s"
%
s
(
type
(
async_executor_parameter
))
)
self
.
place
=
place
p
=
core
.
Place
()
p
.
set_place
(
place
)
self
.
executor
=
core
.
AsyncExecutor
(
p
)
self
.
executor
.
init
(
async_executor_parameter
.
parameter
,
scope
)
self
.
_closed
=
False
self
.
parameter
=
async_executor_parameter
.
parameter
program
,
param_names
,
data_feed
,
thread_num
,
place
=
None
,
scope
=
None
):
if
program
is
None
:
program
=
default_main_program
()
program_desc
=
program
.
desc
def
close
(
self
):
"""
Close this executor.
if
not
isinstance
(
data_feed
,
TextDataFeed
):
raise
ValueError
(
"data_feed for AsyncExecutor.run() type error"
)
You can no long use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to
the current Trainer.
if
place
is
None
:
place
=
core
.
CPUPlace
()
if
not
isinstance
(
place
,
core
.
CPUPlace
):
raise
ValueError
(
"AsyncExecutor only supports CPU device"
)
if
isinstance
(
param_names
,
Variable
):
param_names
=
[
param_names
]
p
=
core
.
Place
()
p
.
set_place
(
place
)
self
.
executor
=
core
.
AsyncExecutor
(
program_desc
,
param_names
,
data_feed
.
feed
,
thread_num
,
p
)
Example:
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> ...
>>> exe.close()
"""
if
not
self
.
_closed
:
self
.
_closed
=
True
def
run_startup_program
(
self
,
program
=
None
,
scope
=
None
):
...
...
@@ -95,7 +93,7 @@ class AsyncExecutor(object):
self
.
executor
.
run_startup_program
(
program_desc
,
scope
)
def
run
(
self
,
program
=
None
,
scope
=
None
):
def
run
(
self
,
inspect_vars
,
scope
=
None
):
"""
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
...
...
@@ -138,23 +136,16 @@ class AsyncExecutor(object):
>>> feed={'X': x},
>>> fetch_list=[loss.name])
"""
if
self
.
_closed
:
raise
RuntimeError
(
"Attempted to use a closed Executor"
)
if
program
is
None
:
program
=
default_main_program
()
program_desc
=
program
.
desc
if
not
isinstance
(
program
,
Program
):
raise
TypeError
(
"Executor requires Program as its Parameter. But you passed in %s"
%
(
type
(
program
)))
if
inspect_vars
is
not
None
:
if
isinstance
(
inspect_vars
,
Variable
):
inspect_vars
=
[
inspect_vars
]
inspect_var_names
=
[
var
.
name
for
var
in
inspect_vars
]
if
scope
is
None
:
scope
=
g_scope
self
.
executor
.
run
(
program
.
desc
)
self
.
executor
.
init_root_scope
(
scope
)
evaluation
=
self
.
executor
.
run
(
inspect_var_names
)
return
evaluation
def
load_init_model
(
self
):
return
self
.
executor
.
load_init_model
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录