Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
91fc8f35
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
91fc8f35
编写于
11月 19, 2018
作者:
W
wangguibao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Interface rework
上级
274ec6a1
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
238 addition
and
599 deletion
+238
-599
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+12
-23
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+68
-341
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+18
-103
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+23
-34
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+8
-8
paddle/fluid/framework/data_feed.proto
paddle/fluid/framework/data_feed.proto
+2
-1
paddle/fluid/framework/data_feed_factory.cc
paddle/fluid/framework/data_feed_factory.cc
+14
-9
paddle/fluid/framework/data_feed_factory.h
paddle/fluid/framework/data_feed_factory.h
+3
-2
paddle/fluid/framework/executor_thread_worker.cc
paddle/fluid/framework/executor_thread_worker.cc
+31
-3
paddle/fluid/framework/executor_thread_worker.h
paddle/fluid/framework/executor_thread_worker.h
+8
-1
paddle/fluid/pybind/async_executor_py.cc
paddle/fluid/pybind/async_executor_py.cc
+16
-20
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+35
-54
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
91fc8f35
...
...
@@ -36,7 +36,7 @@ add_subdirectory(details)
endif
(
NOT WIN32
)
# ddim lib
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
async_executor_p
aram SRCS async_executor_param
.proto
)
proto_library
(
async_executor_p
roto SRCS data_feed
.proto
)
cc_library
(
ddim SRCS ddim.cc DEPS eigen3 boost
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
...
...
@@ -138,31 +138,23 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version
)
cc_library
(
ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto
)
if
(
NOT WIN32
)
cc_library
(
ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler
)
endif
(
NOT WIN32
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc
)
nv_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
if
(
NOT WIN32
)
py_proto_compile
(
framework_py_proto SRCS framework.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_dependencies
(
framework_py_proto framework_py_proto_init
)
if
(
NOT WIN32
)
add_custom_command
(
TARGET framework_py_proto POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto
COMMAND cp *.py
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto/
COMMENT
"Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
else
(
NOT WIN32
)
string
(
REPLACE
"/"
"
\\
"
proto_dstpath
"
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto/"
)
add_custom_command
(
TARGET framework_py_proto POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto
COMMAND copy /Y *.py
${
proto_dstpath
}
COMMENT
"Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
add_custom_command
(
TARGET framework_py_proto POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto
COMMAND cp *.py
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto/
COMMENT
"Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
endif
(
NOT WIN32
)
cc_library
(
lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor
)
...
...
@@ -176,11 +168,7 @@ if(WITH_DISTRIBUTE)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
if
(
NOT WIN32
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator
)
else
(
NOT WIN32
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass
)
endif
(
NOT WIN32
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator
)
cc_test
(
test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op
)
endif
()
...
...
@@ -192,10 +180,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
endif
()
# NOT WIN32
cc_library
(
async_executor
SRCS async_executor.cc data_feed.cc datafeed_creator.cc
SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc
DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method graph_to_program_pass
async_executor_p
aram
)
async_executor_p
roto
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/async_executor.cc
浏览文件 @
91fc8f35
...
...
@@ -36,43 +36,13 @@ limitations under the License. */
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/pybind/pybind.h"
namespace
paddle
{
namespace
framework
{
bool
AsyncExecutor
::
workers_initialized_
=
false
;
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
)
{
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR
)
{
var
->
GetMutable
<
LoDTensor
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
SELECTED_ROWS
)
{
var
->
GetMutable
<
SelectedRows
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
FEED_MINIBATCH
)
{
var
->
GetMutable
<
FeedFetchList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
FETCH_LIST
)
{
var
->
GetMutable
<
FeedFetchList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
STEP_SCOPES
)
{
var
->
GetMutable
<
std
::
vector
<
Scope
>>
();
}
else
if
(
var_type
==
proto
::
VarType
::
LOD_RANK_TABLE
)
{
var
->
GetMutable
<
LoDRankTable
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
var
->
GetMutable
<
LoDTensorArray
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
PLACE_LIST
)
{
var
->
GetMutable
<
platform
::
PlaceList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
var
->
GetMutable
<
ReaderHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
RAW
)
{
// GetMutable will be called in operator
}
else
{
PADDLE_THROW
(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]"
,
var_type
);
}
}
static
void
ReadBinaryFile
(
const
std
::
string
&
filename
,
std
::
string
*
content
)
{
std
::
string
&
contents
=
*
content
;
...
...
@@ -139,343 +109,100 @@ static void SaveModel(
}
}
// end SaveModel
void
ExecutorThreadWorker
::
Reset
()
{
inspect_values_
.
clear
();
}
void
ExecutorThreadWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
op_names_
.
clear
();
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
std
::
unique_ptr
<
OperatorBase
>
local_op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
op_names_
.
push_back
(
op_desc
->
Type
());
OperatorBase
*
local_op_ptr
=
local_op
.
release
();
ops_
.
push_back
(
local_op_ptr
);
continue
;
}
}
void
ExecutorThreadWorker
::
CreateThreadScope
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
thread_scope_
=
&
root_scope_
->
NewScope
();
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
}
else
{
auto
*
ptr
=
thread_scope_
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
}
}
}
void
ExecutorThreadWorker
::
SetDataFeed
(
DataFeed
&
datafeed
)
{
if
(
typeid
(
datafeed
)
==
typeid
(
TextClassDataFeed
))
{
local_reader_
.
reset
(
new
TextClassDataFeed
(
dynamic_cast
<
TextClassDataFeed
&>
(
datafeed
)));
local_reader_
->
SetThreadId
(
thread_id_
);
}
}
void
ExecutorThreadWorker
::
BindingDataFeedMemory
()
{
const
std
::
vector
<
std
::
string
>&
input_feed
=
local_reader_
->
GetUseSlotAlias
();
for
(
auto
name
:
input_feed
)
{
local_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
}
}
AsyncExecutor
::
AsyncExecutor
(
Scope
&
scope
,
const
platform
::
Place
&
place
)
:
root_scope_
(
scope
),
place_
(
place
)
{}
void
ExecutorThreadWorker
::
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
inspect_var_names_
.
clear
();
inspect_var_names_
.
insert
(
inspect_var_names_
.
end
(),
inspect_var_names
.
begin
(),
inspect_var_names
.
end
());
void
AsyncExecutor
::
CreateThreads
(
ExecutorThreadWorker
*
worker
,
const
ProgramDesc
&
main_program
,
const
std
::
shared_ptr
<
DataFeed
>&
reader
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
Scope
&
root_scope
,
const
int
thread_index
)
{
worker
->
SetThreadId
(
thread_index
);
worker
->
SetRootScope
(
&
root_scope
);
worker
->
CreateThreadResource
(
main_program
,
place_
);
worker
->
SetDataFeed
(
reader
);
worker
->
SetFetchVarNames
(
fetch_var_names
);
worker
->
BindingDataFeedMemory
();
}
void
ExecutorThreadWorker
::
SetModelParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
model_param_names_
=
param_names
;
}
void
ExecutorThreadWorker
::
SetDevice
()
{
static
unsigned
priority
[]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
};
unsigned
int
i
=
this
->
thread_id_
;
if
(
i
<
sizeof
(
priority
)
/
sizeof
(
unsigned
))
{
unsigned
proc
=
priority
[
i
];
cpu_set_t
mask
;
CPU_ZERO
(
&
mask
);
CPU_SET
(
proc
,
&
mask
);
if
(
-
1
==
sched_setaffinity
(
0
,
sizeof
(
mask
),
&
mask
))
{
LOG
(
ERROR
)
<<
"WARNING: Failed to set thread affinity for thread "
<<
i
;
}
else
{
CPU_ZERO
(
&
mask
);
if
((
0
==
sched_getaffinity
(
0
,
sizeof
(
mask
),
&
mask
))
&&
CPU_ISSET
(
proc
,
&
mask
))
{
LOG
(
ERROR
)
<<
"TRACE: Thread "
<<
i
<<
" is running on processor "
<<
proc
<<
"..."
;
}
}
}
}
void
ExecutorThreadWorker
::
Train
()
{
LOG
(
ERROR
)
<<
"begin to train"
;
SetDevice
();
int
inspect_var_num
=
inspect_var_names_
.
size
();
inspect_values_
.
clear
();
inspect_values_
.
resize
(
inspect_var_num
,
0
);
local_reader_
->
WaitNextEpoch
();
int
epoch
=
local_reader_
->
GetCurrentEpoch
();
LOG
(
ERROR
)
<<
"epoch: "
<<
epoch
;
int
batch_num
=
1
;
while
(
true
)
{
const
char
*
file
=
local_reader_
->
PickOneFile
();
if
(
file
==
NULL
)
{
break
;
}
if
(
!
local_reader_
->
SetFile
(
file
))
{
break
;
}
while
(
true
)
{
bool
flag
=
local_reader_
->
ReadBatch
();
if
(
!
flag
)
{
break
;
}
for
(
unsigned
int
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
ops_
[
i
]
->
Run
(
*
thread_scope_
,
place_
);
}
batch_num
++
;
float
avg_inspect
=
0.0
;
for
(
int
i
=
0
;
i
<
inspect_var_num
;
++
i
)
{
avg_inspect
=
thread_scope_
->
FindVar
(
inspect_var_names_
[
i
])
->
GetMutable
<
LoDTensor
>
()
->
data
<
float
>
()[
0
];
inspect_values_
[
i
]
+=
avg_inspect
;
}
thread_scope_
->
DropKids
();
}
local_reader_
->
UpdateEpochNum
();
LOG
(
ERROR
)
<<
"memory used after epoch "
<<
epoch
+
1
<<
" called: "
<<
memory
::
memory_usage
(
place_
);
}
for
(
int
i
=
0
;
i
<
inspect_var_num
;
++
i
)
{
inspect_values_
[
i
]
/=
batch_num
;
std
::
string
var
=
inspect_var_names_
[
i
].
substr
(
0
,
inspect_var_names_
[
i
].
find_first_of
(
"_"
));
LOG
(
ERROR
)
<<
"mean "
<<
var
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
inspect_values_
[
i
];
}
if
(
thread_id_
==
0
)
{
char
modelfile
[
1024
];
snprintf
(
&
modelfile
[
0
],
sizeof
(
modelfile
),
"%s_epoch%d.model"
,
model_prefix_
.
c_str
(),
epoch
);
std
::
string
model_filename
=
std
::
string
(
modelfile
);
// this save_inference_model can only save imdbtask, should make this
// general
//
// currently comment it
LOG
(
ERROR
)
<<
"Going to save model "
<<
modelfile
;
SaveModel
(
main_program_
,
thread_scope_
,
model_param_names_
,
model_filename
,
true
);
}
}
void
ExecutorThreadWorker
::
SetThreadId
(
int
tid
)
{
thread_id_
=
tid
;
}
void
ExecutorThreadWorker
::
SetPlace
(
const
platform
::
Place
&
place
)
{
place_
=
place
;
}
void
ExecutorThreadWorker
::
SetMainProgram
(
const
ProgramDesc
&
main_program_desc
)
{
main_program_
.
reset
(
new
ProgramDesc
(
main_program_desc
));
}
void
ExecutorThreadWorker
::
SetRootScope
(
Scope
*
g_scope
)
{
root_scope_
=
g_scope
;
}
void
ExecutorThreadWorker
::
SetMaxTrainingEpoch
(
int
max_epoch
)
{
max_epoch_
=
max_epoch
;
}
AsyncExecutor
::
AsyncExecutor
(
ProgramDesc
&
main_program
,
const
std
::
vector
<
std
::
string
>&
param_names
,
TextClassDataFeed
&
data_feed
,
unsigned
int
thread_num
,
const
platform
::
Place
&
place
)
:
thread_num_
(
thread_num
),
place_
(
place
),
main_program_
(
main_program
),
data_feed_
(
data_feed
)
{
model_param_names_
.
clear
();
model_param_names_
.
insert
(
model_param_names_
.
end
(),
param_names
.
begin
(),
param_names
.
end
());
}
void
AsyncExecutor
::
InitRootScope
(
Scope
*
scope
)
{
root_scope_
=
scope
;
}
void
AsyncExecutor
::
SetMaxTrainingEpoch
(
int
max_epoch
)
{
max_epoch_
=
max_epoch
;
void
AsyncExecutor
::
CheckFiles
(
const
std
::
vector
<
std
::
string
>&
files
)
{
// function for user to check file formats
// should be exposed to users
}
void
AsyncExecutor
::
SetModelPrefix
(
const
std
::
string
&
model_prefix
)
{
model_prefix_
=
model_prefix
;
}
void
AsyncExecutor
::
RunStartupProgram
(
const
ProgramDesc
&
program
,
Scope
*
scope
)
{
auto
&
block
=
program
.
Block
(
0
);
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
auto
*
ptr
=
scope
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
// LOGERR("Persistable Var Name:%s", var->Name().c_str());
}
}
std
::
vector
<
float
>
AsyncExecutor
::
RunFromFile
(
const
ProgramDesc
&
main_program
,
const
DataFeedDesc
&
data_feed_desc
,
const
std
::
vector
<
std
::
string
>&
filelist
,
const
int
thread_num
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
)
{
std
::
vector
<
std
::
thread
>
threads
;
std
::
map
<
std
::
string
,
int
>
param_dict
;
std
::
vector
<
OperatorBase
*>
ops
;
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
std
::
vector
<
std
::
string
>
param_name_vec
=
op_desc
->
OutputArgumentNames
();
bool
need_to_run
=
false
;
for
(
auto
&
name
:
param_name_vec
)
{
if
(
param_dict
.
find
(
name
)
==
param_dict
.
end
())
{
param_dict
[
name
]
=
1
;
need_to_run
=
true
;
}
}
if
(
need_to_run
)
{
std
::
unique_ptr
<
OperatorBase
>
local_op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
OperatorBase
*
local_op_ptr
=
local_op
.
release
();
ops
.
push_back
(
local_op_ptr
);
}
/*
readerDesc: protobuf description for reader initlization
argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index
reader:
1) each thread has a reader, reader will read input data and
put it into input queue
2) each reader has a Next() iterface, that can fetch an instance
from the input queue
*/
// todo: should be factory method for creating datafeed
std
::
vector
<
std
::
shared_ptr
<
DataFeed
>
>
readers
;
readers
.
resize
(
thread_num
);
for
(
unsigned
int
i
=
0
;
i
<
readers
.
size
();
++
i
)
{
readers
[
i
]
=
DataFeedFactory
::
CreateDataFeed
(
data_feed_desc
.
name
());
}
// LOGERR("There are %d parameters in startup program, %d op needs to run",
// param_dict.size(), ops.size());
for
(
auto
&
op
:
ops
)
{
op
->
Run
(
*
scope
,
place_
);
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
workers
;
workers
.
resize
(
thread_num
);
for
(
auto
&
worker
:
workers
)
{
worker
.
reset
(
new
ExecutorThreadWorker
);
}
// LOGERR("total time for startup program: %fs", timeline.elapsed_sec());
for
(
auto
&
op
:
ops
)
{
delete
op
;
}
// LOGERR("run startup program done.");
}
std
::
unique_ptr
<
ProgramDesc
>
AsyncExecutor
::
LoadDescFromFile
(
const
std
::
string
&
f
)
{
std
::
string
program_desc_str
;
ReadBinaryFile
(
f
,
&
program_desc_str
);
std
::
unique_ptr
<
ProgramDesc
>
program
(
new
ProgramDesc
(
program_desc_str
));
return
program
;
}
void
AsyncExecutor
::
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
inspect_var_names_
.
clear
();
inspect_var_names_
.
insert
(
inspect_var_names_
.
end
(),
inspect_var_names
.
begin
(),
inspect_var_names
.
end
());
}
void
AsyncExecutor
::
PrepareThreads
(
const
ProgramDesc
&
host_program
)
{
workers_
.
resize
(
thread_num_
);
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
].
reset
(
new
ExecutorThreadWorker
);
workers_
[
i
]
->
SetThreadId
(
i
);
workers_
[
i
]
->
CreateThreadOperators
(
host_program
);
workers_
[
i
]
->
SetRootScope
(
root_scope_
);
workers_
[
i
]
->
SetPlace
(
place_
);
workers_
[
i
]
->
SetMaxTrainingEpoch
(
max_epoch_
);
workers_
[
i
]
->
CreateThreadScope
(
host_program
);
workers_
[
i
]
->
SetInspectVarNames
(
inspect_var_names_
);
workers_
[
i
]
->
SetModelParamNames
(
model_param_names_
);
workers_
[
i
]
->
SetMainProgram
(
host_program
);
workers_
[
i
]
->
SetModelPrefix
(
model_prefix_
);
//
// new a datafeed here
workers_
[
i
]
->
SetDataFeed
(
data_feed_
);
workers_
[
i
]
->
BindingDataFeedMemory
();
// prepare thread resource here
for
(
int
thidx
=
0
;
thidx
<
thread_num
;
++
thidx
)
{
CreateThreads
(
workers
[
thidx
].
get
(),
main_program
,
readers
[
thidx
],
fetch_var_names
,
root_scope_
,
thidx
);
}
}
std
::
vector
<
float
>&
AsyncExecutor
::
Run
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
SetInspectVarNames
(
inspect_var_names
);
threads_
.
clear
();
// thread binding here?
if
(
workers_initialized_
==
false
)
{
PrepareThreads
(
main_program_
);
workers_initialized_
=
true
;
}
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
->
Reset
();
workers_
[
i
]
->
SetInspectVarNames
(
inspect_var_names
);
threads_
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
Train
,
workers_
[
i
].
get
()));
// start executing ops in multiple threads
for
(
int
thidx
=
0
;
thidx
<
thread_num
;
++
thidx
)
{
threads
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
TrainFiles
,
workers
[
thidx
].
get
()));
}
for
(
auto
&
th
:
threads
_
)
{
for
(
auto
&
th
:
threads
)
{
th
.
join
();
}
inspect_values_
.
clear
();
inspect_values_
.
resize
(
inspect_var_names_
.
size
(),
0
);
std
::
vector
<
float
>
fetch_values
;
fetch_values
.
resize
(
fetch_var_names
.
size
(),
0
);
std
::
vector
<
std
::
vector
<
float
>*>
inspect
_value_vectors
;
inspect_value_vectors
.
resize
(
thread_num_
);
for
(
int
i
=
0
;
i
<
thread_num
_
;
++
i
)
{
inspect_value_vectors
[
i
]
=
&
workers_
[
i
]
->
GetInspect
Values
();
std
::
vector
<
std
::
vector
<
float
>*>
fetch
_value_vectors
;
fetch_value_vectors
.
resize
(
thread_num
);
for
(
int
i
=
0
;
i
<
thread_num
;
++
i
)
{
fetch_value_vectors
[
i
]
=
&
workers
[
i
]
->
GetFetch
Values
();
}
for
(
unsigned
int
i
=
0
;
i
<
inspect_var_names_
.
size
();
++
i
)
{
for
(
unsigned
int
i
=
0
;
i
<
fetch_var_names
.
size
();
++
i
)
{
float
value
=
0.0
;
for
(
int
j
=
0
;
j
<
thread_num
_
;
++
j
)
{
value
+=
inspect
_value_vectors
[
j
]
->
at
(
i
);
for
(
int
j
=
0
;
j
<
thread_num
;
++
j
)
{
value
+=
fetch
_value_vectors
[
j
]
->
at
(
i
);
}
value
/=
thread_num
_
;
inspect_values_
[
i
]
=
value
;
value
/=
thread_num
;
fetch_values
[
i
]
=
value
;
}
return
inspect_values_
;
return
fetch_values
;
}
void
AsyncExecutor
::
LoadInitModel
()
{
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
91fc8f35
...
...
@@ -23,7 +23,8 @@ limitations under the License. */
#include <thread> // NOLINT
#include <vector>
#include <typeinfo>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
...
...
@@ -31,93 +32,13 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
);
class
ExecutorThreadWorker
{
public:
ExecutorThreadWorker
()
{}
~
ExecutorThreadWorker
()
{}
void
CreateThreadScope
(
const
ProgramDesc
&
program
);
void
SetThreadId
(
int
tid
);
void
CreateThreadOperators
(
const
ProgramDesc
&
program
);
void
SetRootScope
(
Scope
*
g_scope
);
void
SetDevice
();
void
AddFidSet
();
void
SetCommBatch
(
int
comm_batch
)
{
comm_batch_
=
comm_batch
;
}
void
AddTrainFile
(
const
std
::
string
&
filename
);
void
SetMainProgram
(
const
ProgramDesc
&
main_program_desc
);
void
SetPlace
(
const
paddle
::
platform
::
Place
&
place
);
void
SetMaxTrainingEpoch
(
const
int
max_epoch
);
void
BindingDataFeedMemory
();
void
SetModelPrefix
(
const
std
::
string
&
prefix
)
{
model_prefix_
=
prefix
;
}
void
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
void
SetModelParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
SetDataFeed
(
DataFeed
&
datafeed
);
// NOLINT
void
Train
();
const
char
*
PickOneFile
();
void
UpdateEpochNum
();
void
Reset
();
void
Initialize
()
{}
std
::
vector
<
float
>&
GetInspectValues
()
{
return
inspect_values_
;}
protected:
// thread index
int
thread_id_
;
// max epoch for each thread
unsigned
int
max_epoch_
;
// instances learned currently
int
comm_batch_
;
std
::
string
model_prefix_
;
std
::
vector
<
std
::
string
>
op_names_
;
// local ops for forward and backward
std
::
vector
<
OperatorBase
*>
ops_
;
// main program for training
std
::
unique_ptr
<
ProgramDesc
>
main_program_
;
// binary data reader
std
::
unique_ptr
<
DataFeed
>
local_reader_
;
std
::
vector
<
std
::
string
>
inspect_var_names_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
// execution place
platform
::
Place
place_
;
// root scope for model parameters
Scope
*
root_scope_
;
// a thread scope, father scope is global score which is shared
Scope
*
thread_scope_
;
private:
std
::
vector
<
float
>
inspect_values_
;
};
class
AsyncExecutor
{
public:
explicit
AsyncExecutor
(
ProgramDesc
&
main_program
,
// NOLINT
const
std
::
vector
<
std
::
string
>&
param_names
,
TextClassDataFeed
&
data_feed
,
// NOLINT
unsigned
int
thread_num
,
const
platform
::
Place
&
place
);
explicit
AsyncExecutor
(
Scope
&
scope
,
const
platform
::
Place
&
place
);
// NOLINT
virtual
~
AsyncExecutor
()
{}
static
std
::
unique_ptr
<
ProgramDesc
>
LoadDescFromFile
(
const
std
::
string
&
filename
);
void
InitRootScope
(
Scope
*
scope
);
void
SetMaxTrainingEpoch
(
const
int
max_epoch
);
Scope
*
GetRootScope
()
{
return
root_scope_
;
}
void
SetBatchSize
(
const
int
batch_size
)
{
batch_size_
=
batch_size
;
}
void
SetCommBatch
(
int
comm_batch
)
{
comm_batch_
=
comm_batch
;
}
Scope
*
GetRootScope
()
{
return
&
root_scope_
;
}
void
SetModelPath
(
const
std
::
string
&
model_path
)
{
model_path_
=
model_path
;
...
...
@@ -132,38 +53,32 @@ class AsyncExecutor {
}
void
SetModelPrefix
(
const
std
::
string
&
model_prefix
);
virtual
void
PrepareThreads
(
const
ProgramDesc
&
host_program
);
void
RunStartupProgram
(
const
ProgramDesc
&
program
,
Scope
*
scope
);
std
::
vector
<
float
>&
Run
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
std
::
vector
<
float
>
RunFromFile
(
const
ProgramDesc
&
main_program
,
const
DataFeedDesc
&
data_feed_desc
,
const
std
::
vector
<
std
::
string
>&
filelist
,
const
int
thread_num
,
const
std
::
vector
<
std
::
string
>&
fetch_names
);
void
CheckFiles
(
const
std
::
vector
<
std
::
string
>&
files
);
void
LoadInitModel
();
private:
void
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
void
CreateThreads
(
ExecutorThreadWorker
*
worker
,
const
ProgramDesc
&
main_program
,
const
std
::
shared_ptr
<
DataFeed
>&
reader
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
Scope
&
root_scope
,
// NOLINT
const
int
thread_index
);
public:
int
thread_num_
;
int
max_epoch_
;
int
batch_size_
;
int
comm_batch_
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
workers_
;
std
::
vector
<
std
::
thread
>
threads_
;
std
::
vector
<
std
::
string
>
inspect_var_names_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
string
model_prefix_
;
std
::
string
model_path_
;
std
::
string
init_prog_file_
;
std
::
string
init_model_file_
;
Scope
*
root_scope_
;
Scope
&
root_scope_
;
platform
::
Place
place_
;
private:
ProgramDesc
&
main_program_
;
TextClassDataFeed
&
data_feed_
;
std
::
vector
<
float
>
inspect_values_
;
private:
static
bool
workers_initialized_
;
};
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
91fc8f35
...
...
@@ -38,17 +38,17 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace
paddle
{
namespace
framework
{
std
::
vector
<
std
::
string
>
TextClass
DataFeed
::
s_filelist_
;
std
::
mutex
TextClass
DataFeed
::
s_locker_for_pick_file_
;
unsigned
int
TextClass
DataFeed
::
s_current_file_idx_
=
0
;
size_t
TextClass
DataFeed
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
TextClass
DataFeed
::
s_current_epoch_
=
0
;
int
TextClass
DataFeed
::
s_current_save_epoch_
=
0
;
std
::
mutex
TextClass
DataFeed
::
s_locker_epoch_start_
;
std
::
condition_variable
TextClass
DataFeed
::
s_condition_epoch_start_
;
bool
TextClass
DataFeed
::
s_epoch_start_flag_
=
false
;
void
TextClass
DataFeed
::
Init
()
{
std
::
vector
<
std
::
string
>
MultiSlot
DataFeed
::
s_filelist_
;
std
::
mutex
MultiSlot
DataFeed
::
s_locker_for_pick_file_
;
unsigned
int
MultiSlot
DataFeed
::
s_current_file_idx_
=
0
;
size_t
MultiSlot
DataFeed
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
MultiSlot
DataFeed
::
s_current_epoch_
=
0
;
int
MultiSlot
DataFeed
::
s_current_save_epoch_
=
0
;
std
::
mutex
MultiSlot
DataFeed
::
s_locker_epoch_start_
;
std
::
condition_variable
MultiSlot
DataFeed
::
s_condition_epoch_start_
;
bool
MultiSlot
DataFeed
::
s_epoch_start_flag_
=
false
;
void
MultiSlot
DataFeed
::
Init
()
{
// hard coding for a specific datafeed
feed_vec_
.
resize
(
2
);
// feed_vec_[0].reset(new LoDTensor);
...
...
@@ -73,12 +73,12 @@ void TextClassDataFeed::Init() {
field_names_
.
clear
();
}
TextClassDataFeed
::
TextClass
DataFeed
()
{
MultiSlotDataFeed
::
MultiSlot
DataFeed
()
{
Init
();
}
// todo: use elegant implemention for this function
bool
TextClass
DataFeed
::
ReadBatch
()
{
bool
MultiSlot
DataFeed
::
ReadBatch
()
{
paddle
::
framework
::
Vector
<
size_t
>
offset
;
int
tlen
=
0
;
int
llen
=
0
;
...
...
@@ -142,13 +142,13 @@ bool TextClassDataFeed::ReadBatch() {
return
true
;
}
TextClassDataFeed
::
TextClassDataFeed
(
const
TextClass
DataFeed
&
data_feed
)
{
MultiSlotDataFeed
::
MultiSlotDataFeed
(
const
MultiSlot
DataFeed
&
data_feed
)
{
Init
();
SetBatchSize
(
data_feed
.
batch_size_
);
SetFieldNames
(
data_feed
.
field_names_
);
}
void
TextClass
DataFeed
::
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
void
MultiSlot
DataFeed
::
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
for
(
unsigned
int
i
=
0
;
i
<
use_slot_alias_
.
size
();
++
i
)
{
if
(
name
==
use_slot_alias_
[
i
])
{
feed_vec_
[
i
]
=
feed
->
GetMutable
<
LoDTensor
>
();
...
...
@@ -156,7 +156,7 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
}
}
void
TextClass
DataFeed
::
SetFileList
(
const
char
*
filelist
)
{
void
MultiSlot
DataFeed
::
SetFileList
(
const
char
*
filelist
)
{
s_filelist_
.
clear
();
std
::
ifstream
fin
(
filelist
);
PADDLE_ENFORCE
(
fin
.
good
(),
...
...
@@ -170,14 +170,14 @@ void TextClassDataFeed::SetFileList(const char* filelist) {
fin
.
close
();
}
void
TextClass
DataFeed
::
SetFieldNames
(
void
MultiSlot
DataFeed
::
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
)
{
field_names_
.
clear
();
field_names_
.
insert
(
field_names_
.
end
(),
field_names
.
begin
(),
field_names
.
end
());
}
bool
TextClass
DataFeed
::
SetFile
(
const
char
*
filename
)
{
bool
MultiSlot
DataFeed
::
SetFile
(
const
char
*
filename
)
{
// termnum termid termid ... termid label
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
binary
);
if
(
ifs
.
fail
())
{
...
...
@@ -198,7 +198,7 @@ bool TextClassDataFeed::SetFile(const char* filename) {
return
true
;
}
void
TextClass
DataFeed
::
UpdateEpochNum
()
{
void
MultiSlot
DataFeed
::
UpdateEpochNum
()
{
s_current_finished_file_cnt_
++
;
if
(
s_current_finished_file_cnt_
>=
s_filelist_
.
size
())
{
...
...
@@ -214,25 +214,14 @@ void TextClassDataFeed::UpdateEpochNum() {
}
}
void
TextClassDataFeed
::
StartOneEpoch
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
std
::
random_shuffle
(
s_filelist_
.
begin
(),
s_filelist_
.
end
());
s_current_file_idx_
=
0
;
LOG
(
INFO
)
<<
"Beginning epoch "
<<
s_current_epoch_
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
true
;
}
s_condition_epoch_start_
.
notify_all
();
void
MultiSlotDataFeed
::
Start
()
{
}
void
TextClassDataFeed
::
WaitNextEpoch
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_condition_epoch_start_
.
wait
(
lock
,
[]{
return
s_epoch_start_flag_
;});
int
MultiSlotDataFeed
::
Next
()
{
return
0
;
}
const
char
*
TextClass
DataFeed
::
PickOneFile
()
{
const
char
*
MultiSlot
DataFeed
::
PickOneFile
()
{
std
::
string
file_to_be_processed
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
91fc8f35
...
...
@@ -81,8 +81,8 @@ class DataFeed {
virtual
unsigned
int
GetCurrentEpoch
()
=
0
;
virtual
const
char
*
PickOneFile
()
=
0
;
virtual
void
UpdateEpochNum
()
=
0
;
virtual
void
Start
OneEpoch
()
=
0
;
virtual
void
WaitNextEpoch
()
=
0
;
virtual
void
Start
()
=
0
;
virtual
int
Next
()
=
0
;
std
::
vector
<
LoDTensor
*>&
GetFeedVec
()
{
return
feed_vec_
;
...
...
@@ -106,13 +106,13 @@ class DataFeed {
int
thread_id_
;
};
class
TextClass
DataFeed
:
public
DataFeed
{
class
MultiSlot
DataFeed
:
public
DataFeed
{
public:
TextClass
DataFeed
();
TextClassDataFeed
(
const
TextClass
DataFeed
&
data_feed
);
MultiSlot
DataFeed
();
MultiSlotDataFeed
(
const
MultiSlot
DataFeed
&
data_feed
);
public:
virtual
~
TextClass
DataFeed
()
{}
virtual
~
MultiSlot
DataFeed
()
{}
virtual
void
Init
();
virtual
bool
ReadBatch
();
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
);
...
...
@@ -125,8 +125,8 @@ class TextClassDataFeed : public DataFeed {
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
unsigned
int
GetCurrentEpoch
()
{
return
s_current_epoch_
;}
void
UpdateEpochNum
();
void
Start
OneEpoch
();
void
WaitNextEpoch
();
void
Start
();
int
Next
();
public:
void
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
);
...
...
paddle/fluid/framework/data_feed.proto
浏览文件 @
91fc8f35
...
...
@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax
=
"proto2"
;
package
paddle
;
package
paddle
.
framework
;
message
DataFeedDesc
{
optional
string
name
=
1
;
optional
int32
batch
=
2
[
default
=
32
];
repeated
string
field_names
=
3
;
}
paddle/fluid/framework/data_feed_factory.cc
浏览文件 @
91fc8f35
...
...
@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
namespace
framework
{
typedef
shared_ptr
<
DataFeed
>
(
*
Createdata_feedFunction
)();
typedef
s
td
::
s
hared_ptr
<
DataFeed
>
(
*
Createdata_feedFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
Createdata_feedFunction
>
data_feedMap
;
data_feedMap
g_data_feed_map
;
#define REGISTER_DATAFEED_CLASS(data_feed_class) \
namespace { \
shared_ptr<DataFeed> Creator_##data_feed_class() { \
return shared_ptr<DataFeed>(new data_feed_class); \
s
td::s
hared_ptr<DataFeed> Creator_##data_feed_class() { \
return s
td::s
hared_ptr<DataFeed>(new data_feed_class); \
} \
class __Registerer_##data_feed_class { \
public: \
...
...
@@ -35,8 +40,8 @@ data_feedMap g_data_feed_map;
} // namespace
string
DataFeedFactory
::
DataFeedTypeList
()
{
string
data_feed_types
;
st
d
::
st
ring
DataFeedFactory
::
DataFeedTypeList
()
{
st
d
::
st
ring
data_feed_types
;
for
(
auto
iter
=
g_data_feed_map
.
begin
();
iter
!=
g_data_feed_map
.
end
();
++
iter
)
{
if
(
iter
!=
g_data_feed_map
.
begin
())
{
...
...
@@ -47,9 +52,9 @@ string DataFeedFactory::DataFeedTypeList() {
return
data_feed_types
;
}
shared_ptr
<
DataFeed
>
DataFeedFactory
::
CreateDataFeed
(
const
char
*
data_feed_class
)
{
if
(
g_data_feed_map
.
count
(
string
(
data_feed_class
)
)
<
1
)
{
s
td
::
s
hared_ptr
<
DataFeed
>
DataFeedFactory
::
CreateDataFeed
(
std
::
string
data_feed_class
)
{
if
(
g_data_feed_map
.
count
(
data_feed_class
)
<
1
)
{
exit
(
-
1
);
}
return
g_data_feed_map
[
data_feed_class
]();
...
...
paddle/fluid/framework/data_feed_factory.h
浏览文件 @
91fc8f35
...
...
@@ -16,14 +16,15 @@ limitations under the License. */
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_
#include <string>
#include "paddle/framework/data_feed.h"
#include <memory>
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
namespace
framework
{
class
DataFeedFactory
{
public:
static
std
::
string
DataFeedTypeList
();
static
s
hared_ptr
<
DataFeed
>
CreateDataFeed
(
const
char
*
data_feed_class
);
static
s
td
::
shared_ptr
<
DataFeed
>
CreateDataFeed
(
std
::
string
data_feed_class
);
};
}
// namespace framework
}
// namespace paddle
...
...
paddle/fluid/framework/executor_thread_worker.cc
浏览文件 @
91fc8f35
...
...
@@ -93,6 +93,11 @@ void ExecutorThreadWorker::CreateThreadResource(
void
ExecutorThreadWorker
::
CreateThreadScope
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
"root_scope should be set before creating thread scope"
);
thread_scope_
=
&
root_scope_
->
NewScope
();
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
...
...
@@ -107,17 +112,24 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
void
ExecutorThreadWorker
::
SetDataFeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
)
{
local
_reader_
=
datafeed
;
thread
_reader_
=
datafeed
;
}
void
ExecutorThreadWorker
::
BindingDataFeedMemory
()
{
const
std
::
vector
<
std
::
string
>&
input_feed
=
thread_reader_
->
GetUseSlotAlias
();
for
(
auto
name
:
input_feed
)
{
local
_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
thread
_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
}
}
void
ExecutorThreadWorker
::
SetFetchVarNames
(
const
std
::
vector
<
std
::
string
>&
fetch_var_names
)
{
fetch_var_names_
.
clear
();
fetch_var_names_
.
insert
(
fetch_var_names_
.
end
(),
fetch_var_names
.
begin
(),
fetch_var_names
.
end
());
}
void
ExecutorThreadWorker
::
SetDevice
()
{
// at most 48 threads binding currently
static
unsigned
priority
[]
=
{
...
...
@@ -156,12 +168,28 @@ void ExecutorThreadWorker::SetDevice() {
void
ExecutorThreadWorker
::
TrainFiles
()
{
// todo: configurable
SetDevice
();
int
fetch_var_num
=
fetch_var_names_
.
size
();
fetch_values_
.
clear
();
fetch_values_
.
resize
(
fetch_var_num
,
0
);
thread_reader_
->
Start
();
while
(
int
cur_batch
=
thread_reader_
->
Next
())
{
int
cur_batch
;
while
((
cur_batch
=
thread_reader_
->
Next
())
>
0
)
{
// executor run here
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
*
thread_scope_
,
place_
);
}
float
avg_inspect
=
0.0
;
for
(
int
i
=
0
;
i
<
fetch_var_num
;
++
i
)
{
avg_inspect
=
thread_scope_
->
FindVar
(
fetch_var_names_
[
i
])
->
GetMutable
<
LoDTensor
>
()
->
data
<
float
>
()[
0
];
fetch_values_
[
i
]
+=
avg_inspect
;
}
thread_scope_
->
DropKids
();
}
}
...
...
paddle/fluid/framework/executor_thread_worker.h
浏览文件 @
91fc8f35
...
...
@@ -43,6 +43,9 @@ class ExecutorThreadWorker {
void
SetDevice
();
void
BindingDataFeedMemory
();
void
SetDataFeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
);
void
TrainFiles
();
void
SetFetchVarNames
(
const
std
::
vector
<
std
::
string
>&
fetch_var_names
);
std
::
vector
<
float
>&
GetFetchValues
()
{
return
fetch_values_
;}
private:
void
CreateThreadScope
(
const
framework
::
ProgramDesc
&
program
);
...
...
@@ -66,9 +69,13 @@ class ExecutorThreadWorker {
Scope
*
root_scope_
;
// a thread scope, father scope is global score which is shared
Scope
*
thread_scope_
;
private:
std
::
vector
<
std
::
string
>
fetch_var_names_
;
std
::
vector
<
float
>
fetch_values_
;
};
}
// namespace framework
}
// namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_
ASYNC_EXECUTO
R_H_
#endif // PADDLE_FLUID_FRAMEWORK_
EXECUTOR_THREAD_WORKE
R_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
paddle/fluid/pybind/async_executor_py.cc
浏览文件 @
91fc8f35
...
...
@@ -30,36 +30,32 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/
async_executor_param
.pb.h"
#include "paddle/fluid/framework/
data_feed
.pb.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h"
namespace
py
=
pybind11
;
namespace
pd
=
paddle
::
framework
;
namespace
paddle
{
namespace
pybind
{
using
set_name_func
=
void
(
pd
::
DataFeedDesc
::*
)(
const
std
::
string
&
);
void
BindAsyncExecutor
(
py
::
module
*
m
)
{
py
::
class_
<
framework
::
DataFeed
>
(
*
m
,
"DataFeed"
);
py
::
class_
<
framework
::
TextClassDataFeed
,
framework
::
DataFeed
>
(
*
m
,
"TextDataFeed"
)
.
def
(
py
::
init
())
.
def
(
"set_filelist"
,
[]
(
framework
::
TextClassDataFeed
&
self
,
const
char
*
data_list_file
)
{
self
.
SetFileList
(
data_list_file
);
})
.
def
(
"set_batch_size"
,
&
framework
::
TextClassDataFeed
::
SetBatchSize
)
.
def
(
"set_field_names"
,
&
framework
::
TextClassDataFeed
::
SetFieldNames
)
.
def
(
"start_one_epoch"
,
&
framework
::
TextClassDataFeed
::
StartOneEpoch
);
py
::
class_
<
pd
::
DataFeedDesc
>
(
*
m
,
"DataFeedDesc"
)
.
def
(
pybind11
::
init
<>
())
.
def
(
"set_name"
,
(
set_name_func
)
&
pd
::
DataFeedDesc
::
set_name
)
.
def
(
"set_batch"
,
&
pd
::
DataFeedDesc
::
set_batch
)
.
def
(
"set_field_names"
,
[]
(
pd
::
DataFeedDesc
&
self
,
const
std
::
vector
<
std
::
string
>
&
fields
)
{
for
(
auto
field
:
fields
)
{
self
.
add_field_names
(
field
);
}
});
py
::
class_
<
framework
::
AsyncExecutor
>
(
*
m
,
"AsyncExecutor"
)
.
def
(
py
::
init
<
framework
::
ProgramDesc
&
,
std
::
vector
<
std
::
string
>&
,
framework
::
TextClassDataFeed
&
,
unsigned
int
,
const
platform
::
Place
&>
())
.
def
(
"init_root_scope"
,
&
framework
::
AsyncExecutor
::
InitRootScope
)
.
def
(
"run_startup_program"
,
&
framework
::
AsyncExecutor
::
RunStartupProgram
)
.
def
(
"run"
,
&
framework
::
AsyncExecutor
::
Run
);
.
def
(
py
::
init
<
pd
::
Scope
&
,
const
platform
::
Place
&>
())
.
def
(
"run_from_files"
,
&
framework
::
AsyncExecutor
::
RunFromFile
)
.
def
(
"check_file"
,
&
framework
::
AsyncExecutor
::
CheckFiles
);
}
// end BindAsyncExecutor
}
// end namespace pybind
}
// end namespace paddle
...
...
python/paddle/fluid/async_executor.py
浏览文件 @
91fc8f35
...
...
@@ -19,30 +19,26 @@ import contextlib
import
six
from
.framework
import
Program
,
default_main_program
,
Variable
from
.
import
core
from
.
import
Executor
from
.
executor
import
global_scope
__all__
=
[
'
Tex
tDataFeed'
,
'AsyncExecutor'
]
__all__
=
[
'
MultiSlo
tDataFeed'
,
'AsyncExecutor'
]
g_scope
=
core
.
Scope
()
class
TextDataFeed
(
):
class
DataFeedDesc
(
object
):
def
__init__
(
self
):
self
.
feed
=
core
.
TextDataFeed
()
def
set_filelist
(
self
,
filelist
):
self
.
feed
.
set_filelist
(
filelist
)
self
.
desc
=
core
.
DataFeedDesc
()
def
set_batch_size
(
self
,
batch_size
):
self
.
feed
.
set_batch_size
(
batch_size
)
def
set_field_names
(
self
,
field_names
):
if
isinstance
(
field_names
,
Variable
):
self
.
desc
.
set_batch
(
batch_size
)
def
set_field_name
(
self
,
field_names
):
if
isinstance
(
field_names
,
str
):
field_names
=
[
field_names
]
self
.
desc
.
set_field_names
(
field_names
)
self
.
feed
.
set_field_names
(
field_names
)
def
start_an_epoch
(
self
):
self
.
feed
.
start_one_epoch
(
)
class
MultiSlotDataFeed
(
DataFeedDesc
):
def
__init__
(
self
):
super
(
MultiSlotDataFeed
,
self
).
__init__
()
self
.
desc
.
set_name
(
"MultiSlotDataFeed"
)
class
AsyncExecutor
(
object
):
"""
...
...
@@ -55,45 +51,19 @@ class AsyncExecutor(object):
They has the exactly same arguments, and expected the same results.
"""
def
__init__
(
self
,
program
,
param_names
,
data_feed
,
thread_num
,
place
=
None
,
scope
=
None
):
if
program
is
None
:
program
=
default_main_program
()
program_desc
=
program
.
desc
if
not
isinstance
(
data_feed
,
TextDataFeed
):
raise
ValueError
(
"data_feed for AsyncExecutor.run() type error"
)
def
__init__
(
self
,
place
=
None
):
if
place
is
None
:
place
=
core
.
CPUPlace
()
if
not
isinstance
(
place
,
core
.
CPUPlace
):
raise
ValueError
(
"AsyncExecutor only supports CPU device"
)
if
isinstance
(
param_names
,
Variable
):
param_names
=
[
param_names
]
p
=
core
.
Place
()
p
.
set_place
(
place
)
self
.
executor
=
core
.
AsyncExecutor
(
program_desc
,
param_names
,
data_feed
.
feed
,
thread_num
,
p
)
def
run_startup_program
(
self
,
program
=
None
,
scope
=
None
):
if
program
is
None
:
program
=
default_startup_program
()
program_desc
=
program
.
_get_desc
()
if
scope
is
None
:
scope
=
g_scope
scope
=
global_scope
()
self
.
executor
=
core
.
AsyncExecutor
(
scope
,
p
)
self
.
executor
.
run_startup_program
(
program_desc
,
scope
)
def
run
(
self
,
inspect_vars
,
scope
=
None
):
def
run
(
self
,
program
,
data_feed
,
filelist
,
thread_num
,
fetch
):
"""
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
...
...
@@ -136,16 +106,27 @@ class AsyncExecutor(object):
>>> feed={'X': x},
>>> fetch_list=[loss.name])
"""
if
inspect_vars
is
not
None
:
if
isinstance
(
inspect_vars
,
Variable
):
inspect_vars
=
[
inspect_vars
]
inspect_var_names
=
[
var
.
name
for
var
in
inspect_vars
]
if
program
is
None
:
program
=
default_main_program
()
program_desc
=
program
.
desc
if
data_feed
is
None
:
raise
ValueError
(
'ValueError: data_feed should be provided'
)
if
filelist
is
None
:
raise
ValueError
(
'ValueError: filelist should be provided'
)
if
isinstance
(
filelist
,
str
):
filelist
=
[
filelist
]
if
scope
is
None
:
scope
=
g_scope
if
not
isinstance
(
thread_num
,
int
)
:
raise
TypeError
(
'TypeError: thread_num should be a positive number'
)
self
.
executor
.
init_root_scope
(
scope
)
if
fetch
is
not
None
:
if
isinstance
(
fetch
,
Variable
):
fetch
=
[
fetch
]
fetch_var_names
=
[
var
.
name
for
var
in
fetch
]
evaluation
=
self
.
executor
.
run
(
inspect
_var_names
)
evaluation
=
self
.
executor
.
run
_from_files
(
program_desc
,
data_feed
.
desc
,
filelist
,
thread_num
,
fetch
_var_names
)
return
evaluation
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录