Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
91fc8f35
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
91fc8f35
编写于
11月 19, 2018
作者:
W
wangguibao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Interface rework
上级
274ec6a1
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
238 addition
and
599 deletion
+238
-599
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+12
-23
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+68
-341
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+18
-103
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+23
-34
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+8
-8
paddle/fluid/framework/data_feed.proto
paddle/fluid/framework/data_feed.proto
+2
-1
paddle/fluid/framework/data_feed_factory.cc
paddle/fluid/framework/data_feed_factory.cc
+14
-9
paddle/fluid/framework/data_feed_factory.h
paddle/fluid/framework/data_feed_factory.h
+3
-2
paddle/fluid/framework/executor_thread_worker.cc
paddle/fluid/framework/executor_thread_worker.cc
+31
-3
paddle/fluid/framework/executor_thread_worker.h
paddle/fluid/framework/executor_thread_worker.h
+8
-1
paddle/fluid/pybind/async_executor_py.cc
paddle/fluid/pybind/async_executor_py.cc
+16
-20
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+35
-54
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
91fc8f35
...
@@ -36,7 +36,7 @@ add_subdirectory(details)
...
@@ -36,7 +36,7 @@ add_subdirectory(details)
endif
(
NOT WIN32
)
endif
(
NOT WIN32
)
# ddim lib
# ddim lib
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
async_executor_p
aram SRCS async_executor_param
.proto
)
proto_library
(
async_executor_p
roto SRCS data_feed
.proto
)
cc_library
(
ddim SRCS ddim.cc DEPS eigen3 boost
)
cc_library
(
ddim SRCS ddim.cc DEPS eigen3 boost
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
...
@@ -138,31 +138,23 @@ cc_test(version_test SRCS version_test.cc DEPS version)
...
@@ -138,31 +138,23 @@ cc_test(version_test SRCS version_test.cc DEPS version)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog version
)
cc_library
(
ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto
)
cc_library
(
ngraph_bridge SRCS ngraph_bridge.cc DEPS operator framework_proto
)
if
(
NOT WIN32
)
cc_library
(
ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
cc_library
(
ngraph_operator SRCS ngraph_operator.cc DEPS ngraph_bridge operator op_info device_context tensor scope glog
shape_inference data_transform lod_tensor profiler
)
shape_inference data_transform lod_tensor profiler
)
endif
(
NOT WIN32
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc
)
nv_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
nv_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
if
(
NOT WIN32
)
py_proto_compile
(
framework_py_proto SRCS framework.proto
)
py_proto_compile
(
framework_py_proto SRCS framework.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_dependencies
(
framework_py_proto framework_py_proto_init
)
add_dependencies
(
framework_py_proto framework_py_proto_init
)
if
(
NOT WIN32
)
add_custom_command
(
TARGET framework_py_proto POST_BUILD
add_custom_command
(
TARGET framework_py_proto POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto
COMMAND cp *.py
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto/
COMMAND cp *.py
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto/
COMMENT
"Copy generated python proto into directory paddle/fluid/proto."
COMMENT
"Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
else
(
NOT WIN32
)
string
(
REPLACE
"/"
"
\\
"
proto_dstpath
"
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto/"
)
add_custom_command
(
TARGET framework_py_proto POST_BUILD
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
PADDLE_BINARY_DIR
}
/python/paddle/fluid/proto
COMMAND copy /Y *.py
${
proto_dstpath
}
COMMENT
"Copy generated python proto into directory paddle/fluid/proto."
WORKING_DIRECTORY
${
CMAKE_CURRENT_BINARY_DIR
}
)
endif
(
NOT WIN32
)
endif
(
NOT WIN32
)
cc_library
(
lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor
)
cc_library
(
lod_rank_table SRCS lod_rank_table.cc DEPS lod_tensor
)
...
@@ -176,11 +168,7 @@ if(WITH_DISTRIBUTE)
...
@@ -176,11 +168,7 @@ if(WITH_DISTRIBUTE)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set
(
DISTRIBUTE_COMPILE_FLAGS
"-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor"
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
set_source_files_properties
(
executor.cc PROPERTIES COMPILE_FLAGS
${
DISTRIBUTE_COMPILE_FLAGS
}
)
else
()
else
()
if
(
NOT WIN32
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass ngraph_operator
)
else
(
NOT WIN32
)
cc_library
(
executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass
)
endif
(
NOT WIN32
)
cc_test
(
test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op
)
cc_test
(
test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op
)
endif
()
endif
()
...
@@ -192,10 +180,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
...
@@ -192,10 +180,11 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
endif
()
# NOT WIN32
endif
()
# NOT WIN32
cc_library
(
async_executor
cc_library
(
async_executor
SRCS async_executor.cc data_feed.cc datafeed_creator.cc
SRCS async_executor.cc data_feed.cc data_feed_factory.cc
executor_thread_worker.cc
DEPS op_registry device_context scope framework_proto glog
DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method graph_to_program_pass
lod_rank_table feed_fetch_method graph_to_program_pass
async_executor_p
aram
)
async_executor_p
roto
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_library
(
prune SRCS prune.cc DEPS framework_proto
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
cc_test
(
prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context
)
...
...
paddle/fluid/framework/async_executor.cc
浏览文件 @
91fc8f35
...
@@ -36,43 +36,13 @@ limitations under the License. */
...
@@ -36,43 +36,13 @@ limitations under the License. */
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/pybind.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
bool
AsyncExecutor
::
workers_initialized_
=
false
;
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
)
{
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR
)
{
var
->
GetMutable
<
LoDTensor
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
SELECTED_ROWS
)
{
var
->
GetMutable
<
SelectedRows
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
FEED_MINIBATCH
)
{
var
->
GetMutable
<
FeedFetchList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
FETCH_LIST
)
{
var
->
GetMutable
<
FeedFetchList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
STEP_SCOPES
)
{
var
->
GetMutable
<
std
::
vector
<
Scope
>>
();
}
else
if
(
var_type
==
proto
::
VarType
::
LOD_RANK_TABLE
)
{
var
->
GetMutable
<
LoDRankTable
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
var
->
GetMutable
<
LoDTensorArray
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
PLACE_LIST
)
{
var
->
GetMutable
<
platform
::
PlaceList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
var
->
GetMutable
<
ReaderHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
RAW
)
{
// GetMutable will be called in operator
}
else
{
PADDLE_THROW
(
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]"
,
var_type
);
}
}
static
void
ReadBinaryFile
(
const
std
::
string
&
filename
,
static
void
ReadBinaryFile
(
const
std
::
string
&
filename
,
std
::
string
*
content
)
{
std
::
string
*
content
)
{
std
::
string
&
contents
=
*
content
;
std
::
string
&
contents
=
*
content
;
...
@@ -139,343 +109,100 @@ static void SaveModel(
...
@@ -139,343 +109,100 @@ static void SaveModel(
}
}
}
// end SaveModel
}
// end SaveModel
void
ExecutorThreadWorker
::
Reset
()
{
AsyncExecutor
::
AsyncExecutor
(
Scope
&
scope
,
const
platform
::
Place
&
place
)
inspect_values_
.
clear
();
:
root_scope_
(
scope
),
place_
(
place
)
{}
}
void
ExecutorThreadWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
op_names_
.
clear
();
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
std
::
unique_ptr
<
OperatorBase
>
local_op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
op_names_
.
push_back
(
op_desc
->
Type
());
OperatorBase
*
local_op_ptr
=
local_op
.
release
();
ops_
.
push_back
(
local_op_ptr
);
continue
;
}
}
void
ExecutorThreadWorker
::
CreateThreadScope
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
thread_scope_
=
&
root_scope_
->
NewScope
();
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
auto
*
ptr
=
root_scope_
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
}
else
{
auto
*
ptr
=
thread_scope_
->
Var
(
var
->
Name
());
CreateTensor
(
ptr
,
var
->
GetType
());
}
}
}
void
ExecutorThreadWorker
::
SetDataFeed
(
DataFeed
&
datafeed
)
{
if
(
typeid
(
datafeed
)
==
typeid
(
TextClassDataFeed
))
{
local_reader_
.
reset
(
new
TextClassDataFeed
(
dynamic_cast
<
TextClassDataFeed
&>
(
datafeed
)));
local_reader_
->
SetThreadId
(
thread_id_
);
}
}
void
ExecutorThreadWorker
::
BindingDataFeedMemory
()
{
const
std
::
vector
<
std
::
string
>&
input_feed
=
local_reader_
->
GetUseSlotAlias
();
for
(
auto
name
:
input_feed
)
{
local_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
}
}
void
ExecutorThreadWorker
::
SetInspectVarNames
(
void
AsyncExecutor
::
CreateThreads
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
ExecutorThreadWorker
*
worker
,
inspect_var_names_
.
clear
();
const
ProgramDesc
&
main_program
,
inspect_var_names_
.
insert
(
inspect_var_names_
.
end
(),
const
std
::
shared_ptr
<
DataFeed
>&
reader
,
inspect_var_names
.
begin
(),
inspect_var_names
.
end
());
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
Scope
&
root_scope
,
const
int
thread_index
)
{
worker
->
SetThreadId
(
thread_index
);
worker
->
SetRootScope
(
&
root_scope
);
worker
->
CreateThreadResource
(
main_program
,
place_
);
worker
->
SetDataFeed
(
reader
);
worker
->
SetFetchVarNames
(
fetch_var_names
);
worker
->
BindingDataFeedMemory
();
}
}
void
ExecutorThreadWorker
::
SetModelParamNames
(
void
AsyncExecutor
::
CheckFiles
(
const
std
::
vector
<
std
::
string
>&
param_names
)
{
const
std
::
vector
<
std
::
string
>&
files
)
{
model_param_names_
=
param_names
;
// function for user to check file formats
}
// should be exposed to users
void
ExecutorThreadWorker
::
SetDevice
()
{
static
unsigned
priority
[]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
36
,
37
,
38
,
39
,
40
,
41
,
42
,
43
,
44
,
45
,
46
,
47
};
unsigned
int
i
=
this
->
thread_id_
;
if
(
i
<
sizeof
(
priority
)
/
sizeof
(
unsigned
))
{
unsigned
proc
=
priority
[
i
];
cpu_set_t
mask
;
CPU_ZERO
(
&
mask
);
CPU_SET
(
proc
,
&
mask
);
if
(
-
1
==
sched_setaffinity
(
0
,
sizeof
(
mask
),
&
mask
))
{
LOG
(
ERROR
)
<<
"WARNING: Failed to set thread affinity for thread "
<<
i
;
}
else
{
CPU_ZERO
(
&
mask
);
if
((
0
==
sched_getaffinity
(
0
,
sizeof
(
mask
),
&
mask
))
&&
CPU_ISSET
(
proc
,
&
mask
))
{
LOG
(
ERROR
)
<<
"TRACE: Thread "
<<
i
<<
" is running on processor "
<<
proc
<<
"..."
;
}
}
}
}
void
ExecutorThreadWorker
::
Train
()
{
LOG
(
ERROR
)
<<
"begin to train"
;
SetDevice
();
int
inspect_var_num
=
inspect_var_names_
.
size
();
inspect_values_
.
clear
();
inspect_values_
.
resize
(
inspect_var_num
,
0
);
local_reader_
->
WaitNextEpoch
();
int
epoch
=
local_reader_
->
GetCurrentEpoch
();
LOG
(
ERROR
)
<<
"epoch: "
<<
epoch
;
int
batch_num
=
1
;
while
(
true
)
{
const
char
*
file
=
local_reader_
->
PickOneFile
();
if
(
file
==
NULL
)
{
break
;
}
if
(
!
local_reader_
->
SetFile
(
file
))
{
break
;
}
while
(
true
)
{
bool
flag
=
local_reader_
->
ReadBatch
();
if
(
!
flag
)
{
break
;
}
for
(
unsigned
int
i
=
0
;
i
<
ops_
.
size
();
++
i
)
{
ops_
[
i
]
->
Run
(
*
thread_scope_
,
place_
);
}
batch_num
++
;
float
avg_inspect
=
0.0
;
for
(
int
i
=
0
;
i
<
inspect_var_num
;
++
i
)
{
avg_inspect
=
thread_scope_
->
FindVar
(
inspect_var_names_
[
i
])
->
GetMutable
<
LoDTensor
>
()
->
data
<
float
>
()[
0
];
inspect_values_
[
i
]
+=
avg_inspect
;
}
thread_scope_
->
DropKids
();
}
local_reader_
->
UpdateEpochNum
();
LOG
(
ERROR
)
<<
"memory used after epoch "
<<
epoch
+
1
<<
" called: "
<<
memory
::
memory_usage
(
place_
);
}
for
(
int
i
=
0
;
i
<
inspect_var_num
;
++
i
)
{
inspect_values_
[
i
]
/=
batch_num
;
std
::
string
var
=
inspect_var_names_
[
i
].
substr
(
0
,
inspect_var_names_
[
i
].
find_first_of
(
"_"
));
LOG
(
ERROR
)
<<
"mean "
<<
var
.
c_str
()
<<
" of epoch "
<<
i
+
1
<<
": "
<<
inspect_values_
[
i
];
}
if
(
thread_id_
==
0
)
{
char
modelfile
[
1024
];
snprintf
(
&
modelfile
[
0
],
sizeof
(
modelfile
),
"%s_epoch%d.model"
,
model_prefix_
.
c_str
(),
epoch
);
std
::
string
model_filename
=
std
::
string
(
modelfile
);
// this save_inference_model can only save imdbtask, should make this
// general
//
// currently comment it
LOG
(
ERROR
)
<<
"Going to save model "
<<
modelfile
;
SaveModel
(
main_program_
,
thread_scope_
,
model_param_names_
,
model_filename
,
true
);
}
}
void
ExecutorThreadWorker
::
SetThreadId
(
int
tid
)
{
thread_id_
=
tid
;
}
void
ExecutorThreadWorker
::
SetPlace
(
const
platform
::
Place
&
place
)
{
place_
=
place
;
}
void
ExecutorThreadWorker
::
SetMainProgram
(
const
ProgramDesc
&
main_program_desc
)
{
main_program_
.
reset
(
new
ProgramDesc
(
main_program_desc
));
}
void
ExecutorThreadWorker
::
SetRootScope
(
Scope
*
g_scope
)
{
root_scope_
=
g_scope
;
}
void
ExecutorThreadWorker
::
SetMaxTrainingEpoch
(
int
max_epoch
)
{
max_epoch_
=
max_epoch
;
}
AsyncExecutor
::
AsyncExecutor
(
ProgramDesc
&
main_program
,
const
std
::
vector
<
std
::
string
>&
param_names
,
TextClassDataFeed
&
data_feed
,
unsigned
int
thread_num
,
const
platform
::
Place
&
place
)
:
thread_num_
(
thread_num
),
place_
(
place
),
main_program_
(
main_program
),
data_feed_
(
data_feed
)
{
model_param_names_
.
clear
();
model_param_names_
.
insert
(
model_param_names_
.
end
(),
param_names
.
begin
(),
param_names
.
end
());
}
void
AsyncExecutor
::
InitRootScope
(
Scope
*
scope
)
{
root_scope_
=
scope
;
}
void
AsyncExecutor
::
SetMaxTrainingEpoch
(
int
max_epoch
)
{
max_epoch_
=
max_epoch
;
}
}
void
AsyncExecutor
::
SetModelPrefix
(
const
std
::
string
&
model_prefix
)
{
void
AsyncExecutor
::
SetModelPrefix
(
const
std
::
string
&
model_prefix
)
{
model_prefix_
=
model_prefix
;
model_prefix_
=
model_prefix
;
}
}
void
AsyncExecutor
::
RunStartupProgram
(
const
ProgramDesc
&
program
,
std
::
vector
<
float
>
AsyncExecutor
::
RunFromFile
(
Scope
*
scope
)
{
const
ProgramDesc
&
main_program
,
auto
&
block
=
program
.
Block
(
0
);
const
DataFeedDesc
&
data_feed_desc
,
for
(
auto
&
var
:
block
.
AllVars
())
{
const
std
::
vector
<
std
::
string
>&
filelist
,
if
(
var
->
Persistable
())
{
const
int
thread_num
,
auto
*
ptr
=
scope
->
Var
(
var
->
Name
());
const
std
::
vector
<
std
::
string
>&
fetch_var_names
)
{
CreateTensor
(
ptr
,
var
->
GetType
());
std
::
vector
<
std
::
thread
>
threads
;
// LOGERR("Persistable Var Name:%s", var->Name().c_str());
}
}
std
::
map
<
std
::
string
,
int
>
param_dict
;
/*
std
::
vector
<
OperatorBase
*>
ops
;
readerDesc: protobuf description for reader initlization
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
argument: class_name, batch_size, use_slot, queue_size, buffer_size, padding_index
std
::
vector
<
std
::
string
>
param_name_vec
=
op_desc
->
OutputArgumentNames
();
bool
need_to_run
=
false
;
reader:
for
(
auto
&
name
:
param_name_vec
)
{
1) each thread has a reader, reader will read input data and
if
(
param_dict
.
find
(
name
)
==
param_dict
.
end
())
{
put it into input queue
param_dict
[
name
]
=
1
;
2) each reader has a Next() iterface, that can fetch an instance
need_to_run
=
true
;
from the input queue
}
*/
}
// todo: should be factory method for creating datafeed
if
(
need_to_run
)
{
std
::
vector
<
std
::
shared_ptr
<
DataFeed
>
>
readers
;
std
::
unique_ptr
<
OperatorBase
>
local_op
=
OpRegistry
::
CreateOp
(
*
op_desc
);
readers
.
resize
(
thread_num
);
OperatorBase
*
local_op_ptr
=
local_op
.
release
();
for
(
unsigned
int
i
=
0
;
i
<
readers
.
size
();
++
i
)
{
ops
.
push_back
(
local_op_ptr
);
readers
[
i
]
=
DataFeedFactory
::
CreateDataFeed
(
data_feed_desc
.
name
());
}
}
}
// LOGERR("There are %d parameters in startup program, %d op needs to run",
// param_dict.size(), ops.size());
for
(
auto
&
op
:
ops
)
{
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
workers
;
op
->
Run
(
*
scope
,
place_
);
workers
.
resize
(
thread_num
);
for
(
auto
&
worker
:
workers
)
{
worker
.
reset
(
new
ExecutorThreadWorker
);
}
}
// LOGERR("total time for startup program: %fs", timeline.elapsed_sec());
for
(
auto
&
op
:
ops
)
{
delete
op
;
}
// LOGERR("run startup program done.");
}
std
::
unique_ptr
<
ProgramDesc
>
AsyncExecutor
::
LoadDescFromFile
(
// prepare thread resource here
const
std
::
string
&
f
)
{
for
(
int
thidx
=
0
;
thidx
<
thread_num
;
++
thidx
)
{
std
::
string
program_desc_str
;
CreateThreads
(
workers
[
thidx
].
get
(),
main_program
,
ReadBinaryFile
(
f
,
&
program_desc_str
);
readers
[
thidx
],
fetch_var_names
,
root_scope_
,
thidx
);
std
::
unique_ptr
<
ProgramDesc
>
program
(
new
ProgramDesc
(
program_desc_str
));
return
program
;
}
void
AsyncExecutor
::
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
inspect_var_names_
.
clear
();
inspect_var_names_
.
insert
(
inspect_var_names_
.
end
(),
inspect_var_names
.
begin
(),
inspect_var_names
.
end
());
}
void
AsyncExecutor
::
PrepareThreads
(
const
ProgramDesc
&
host_program
)
{
workers_
.
resize
(
thread_num_
);
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
].
reset
(
new
ExecutorThreadWorker
);
workers_
[
i
]
->
SetThreadId
(
i
);
workers_
[
i
]
->
CreateThreadOperators
(
host_program
);
workers_
[
i
]
->
SetRootScope
(
root_scope_
);
workers_
[
i
]
->
SetPlace
(
place_
);
workers_
[
i
]
->
SetMaxTrainingEpoch
(
max_epoch_
);
workers_
[
i
]
->
CreateThreadScope
(
host_program
);
workers_
[
i
]
->
SetInspectVarNames
(
inspect_var_names_
);
workers_
[
i
]
->
SetModelParamNames
(
model_param_names_
);
workers_
[
i
]
->
SetMainProgram
(
host_program
);
workers_
[
i
]
->
SetModelPrefix
(
model_prefix_
);
//
// new a datafeed here
workers_
[
i
]
->
SetDataFeed
(
data_feed_
);
workers_
[
i
]
->
BindingDataFeedMemory
();
}
}
}
std
::
vector
<
float
>&
AsyncExecutor
::
Run
(
// start executing ops in multiple threads
const
std
::
vector
<
std
::
string
>&
inspect_var_names
)
{
for
(
int
thidx
=
0
;
thidx
<
thread_num
;
++
thidx
)
{
SetInspectVarNames
(
inspect_var_names
);
threads
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
TrainFiles
,
threads_
.
clear
();
workers
[
thidx
].
get
()));
// thread binding here?
if
(
workers_initialized_
==
false
)
{
PrepareThreads
(
main_program_
);
workers_initialized_
=
true
;
}
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
->
Reset
();
workers_
[
i
]
->
SetInspectVarNames
(
inspect_var_names
);
threads_
.
push_back
(
std
::
thread
(
&
ExecutorThreadWorker
::
Train
,
workers_
[
i
].
get
()));
}
}
for
(
auto
&
th
:
threads
_
)
{
for
(
auto
&
th
:
threads
)
{
th
.
join
();
th
.
join
();
}
}
inspect_values_
.
clear
();
std
::
vector
<
float
>
fetch_values
;
inspect_values_
.
resize
(
inspect_var_names_
.
size
(),
0
);
fetch_values
.
resize
(
fetch_var_names
.
size
(),
0
);
std
::
vector
<
std
::
vector
<
float
>*>
inspect
_value_vectors
;
std
::
vector
<
std
::
vector
<
float
>*>
fetch
_value_vectors
;
inspect_value_vectors
.
resize
(
thread_num_
);
fetch_value_vectors
.
resize
(
thread_num
);
for
(
int
i
=
0
;
i
<
thread_num
_
;
++
i
)
{
for
(
int
i
=
0
;
i
<
thread_num
;
++
i
)
{
inspect_value_vectors
[
i
]
=
&
workers_
[
i
]
->
GetInspect
Values
();
fetch_value_vectors
[
i
]
=
&
workers
[
i
]
->
GetFetch
Values
();
}
}
for
(
unsigned
int
i
=
0
;
i
<
inspect_var_names_
.
size
();
++
i
)
{
for
(
unsigned
int
i
=
0
;
i
<
fetch_var_names
.
size
();
++
i
)
{
float
value
=
0.0
;
float
value
=
0.0
;
for
(
int
j
=
0
;
j
<
thread_num
_
;
++
j
)
{
for
(
int
j
=
0
;
j
<
thread_num
;
++
j
)
{
value
+=
inspect
_value_vectors
[
j
]
->
at
(
i
);
value
+=
fetch
_value_vectors
[
j
]
->
at
(
i
);
}
}
value
/=
thread_num
_
;
value
/=
thread_num
;
inspect_values_
[
i
]
=
value
;
fetch_values
[
i
]
=
value
;
}
}
return
inspect_values_
;
return
fetch_values
;
}
}
void
AsyncExecutor
::
LoadInitModel
()
{
void
AsyncExecutor
::
LoadInitModel
()
{
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
91fc8f35
...
@@ -23,7 +23,8 @@ limitations under the License. */
...
@@ -23,7 +23,8 @@ limitations under the License. */
#include <thread> // NOLINT
#include <thread> // NOLINT
#include <vector>
#include <vector>
#include <typeinfo>
#include <typeinfo>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
...
@@ -31,93 +32,13 @@ limitations under the License. */
...
@@ -31,93 +32,13 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
void
CreateTensor
(
Variable
*
var
,
proto
::
VarType
::
Type
var_type
);
class
ExecutorThreadWorker
{
public:
ExecutorThreadWorker
()
{}
~
ExecutorThreadWorker
()
{}
void
CreateThreadScope
(
const
ProgramDesc
&
program
);
void
SetThreadId
(
int
tid
);
void
CreateThreadOperators
(
const
ProgramDesc
&
program
);
void
SetRootScope
(
Scope
*
g_scope
);
void
SetDevice
();
void
AddFidSet
();
void
SetCommBatch
(
int
comm_batch
)
{
comm_batch_
=
comm_batch
;
}
void
AddTrainFile
(
const
std
::
string
&
filename
);
void
SetMainProgram
(
const
ProgramDesc
&
main_program_desc
);
void
SetPlace
(
const
paddle
::
platform
::
Place
&
place
);
void
SetMaxTrainingEpoch
(
const
int
max_epoch
);
void
BindingDataFeedMemory
();
void
SetModelPrefix
(
const
std
::
string
&
prefix
)
{
model_prefix_
=
prefix
;
}
void
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
void
SetModelParamNames
(
const
std
::
vector
<
std
::
string
>&
param_names
);
void
SetDataFeed
(
DataFeed
&
datafeed
);
// NOLINT
void
Train
();
const
char
*
PickOneFile
();
void
UpdateEpochNum
();
void
Reset
();
void
Initialize
()
{}
std
::
vector
<
float
>&
GetInspectValues
()
{
return
inspect_values_
;}
protected:
// thread index
int
thread_id_
;
// max epoch for each thread
unsigned
int
max_epoch_
;
// instances learned currently
int
comm_batch_
;
std
::
string
model_prefix_
;
std
::
vector
<
std
::
string
>
op_names_
;
// local ops for forward and backward
std
::
vector
<
OperatorBase
*>
ops_
;
// main program for training
std
::
unique_ptr
<
ProgramDesc
>
main_program_
;
// binary data reader
std
::
unique_ptr
<
DataFeed
>
local_reader_
;
std
::
vector
<
std
::
string
>
inspect_var_names_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
// execution place
platform
::
Place
place_
;
// root scope for model parameters
Scope
*
root_scope_
;
// a thread scope, father scope is global score which is shared
Scope
*
thread_scope_
;
private:
std
::
vector
<
float
>
inspect_values_
;
};
class
AsyncExecutor
{
class
AsyncExecutor
{
public:
public:
explicit
AsyncExecutor
(
ProgramDesc
&
main_program
,
// NOLINT
explicit
AsyncExecutor
(
Scope
&
scope
,
const
platform
::
Place
&
place
);
// NOLINT
const
std
::
vector
<
std
::
string
>&
param_names
,
TextClassDataFeed
&
data_feed
,
// NOLINT
unsigned
int
thread_num
,
const
platform
::
Place
&
place
);
virtual
~
AsyncExecutor
()
{}
virtual
~
AsyncExecutor
()
{}
static
std
::
unique_ptr
<
ProgramDesc
>
LoadDescFromFile
(
static
std
::
unique_ptr
<
ProgramDesc
>
LoadDescFromFile
(
const
std
::
string
&
filename
);
const
std
::
string
&
filename
);
void
InitRootScope
(
Scope
*
scope
);
Scope
*
GetRootScope
()
{
return
&
root_scope_
;
}
void
SetMaxTrainingEpoch
(
const
int
max_epoch
);
Scope
*
GetRootScope
()
{
return
root_scope_
;
}
void
SetBatchSize
(
const
int
batch_size
)
{
batch_size_
=
batch_size
;
}
void
SetCommBatch
(
int
comm_batch
)
{
comm_batch_
=
comm_batch
;
}
void
SetModelPath
(
const
std
::
string
&
model_path
)
{
void
SetModelPath
(
const
std
::
string
&
model_path
)
{
model_path_
=
model_path
;
model_path_
=
model_path
;
...
@@ -132,38 +53,32 @@ class AsyncExecutor {
...
@@ -132,38 +53,32 @@ class AsyncExecutor {
}
}
void
SetModelPrefix
(
const
std
::
string
&
model_prefix
);
void
SetModelPrefix
(
const
std
::
string
&
model_prefix
);
virtual
void
PrepareThreads
(
const
ProgramDesc
&
host_program
);
void
RunStartupProgram
(
const
ProgramDesc
&
program
,
Scope
*
scope
);
void
RunStartupProgram
(
const
ProgramDesc
&
program
,
Scope
*
scope
);
std
::
vector
<
float
>&
Run
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
std
::
vector
<
float
>
RunFromFile
(
const
ProgramDesc
&
main_program
,
const
DataFeedDesc
&
data_feed_desc
,
const
std
::
vector
<
std
::
string
>&
filelist
,
const
int
thread_num
,
const
std
::
vector
<
std
::
string
>&
fetch_names
);
void
CheckFiles
(
const
std
::
vector
<
std
::
string
>&
files
);
void
LoadInitModel
();
void
LoadInitModel
();
private:
private:
void
SetInspectVarNames
(
const
std
::
vector
<
std
::
string
>&
inspect_var_names
);
void
CreateThreads
(
ExecutorThreadWorker
*
worker
,
const
ProgramDesc
&
main_program
,
const
std
::
shared_ptr
<
DataFeed
>&
reader
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
Scope
&
root_scope
,
// NOLINT
const
int
thread_index
);
public:
public:
int
thread_num_
;
int
max_epoch_
;
int
batch_size_
;
int
comm_batch_
;
std
::
vector
<
std
::
shared_ptr
<
ExecutorThreadWorker
>
>
workers_
;
std
::
vector
<
std
::
thread
>
threads_
;
std
::
vector
<
std
::
string
>
inspect_var_names_
;
std
::
vector
<
std
::
string
>
model_param_names_
;
std
::
string
model_prefix_
;
std
::
string
model_prefix_
;
std
::
string
model_path_
;
std
::
string
model_path_
;
std
::
string
init_prog_file_
;
std
::
string
init_prog_file_
;
std
::
string
init_model_file_
;
std
::
string
init_model_file_
;
Scope
*
root_scope_
;
Scope
&
root_scope_
;
platform
::
Place
place_
;
platform
::
Place
place_
;
private:
ProgramDesc
&
main_program_
;
TextClassDataFeed
&
data_feed_
;
std
::
vector
<
float
>
inspect_values_
;
private:
static
bool
workers_initialized_
;
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
91fc8f35
...
@@ -38,17 +38,17 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
...
@@ -38,17 +38,17 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
std
::
vector
<
std
::
string
>
TextClass
DataFeed
::
s_filelist_
;
std
::
vector
<
std
::
string
>
MultiSlot
DataFeed
::
s_filelist_
;
std
::
mutex
TextClass
DataFeed
::
s_locker_for_pick_file_
;
std
::
mutex
MultiSlot
DataFeed
::
s_locker_for_pick_file_
;
unsigned
int
TextClass
DataFeed
::
s_current_file_idx_
=
0
;
unsigned
int
MultiSlot
DataFeed
::
s_current_file_idx_
=
0
;
size_t
TextClass
DataFeed
::
s_current_finished_file_cnt_
=
0
;
size_t
MultiSlot
DataFeed
::
s_current_finished_file_cnt_
=
0
;
unsigned
int
TextClass
DataFeed
::
s_current_epoch_
=
0
;
unsigned
int
MultiSlot
DataFeed
::
s_current_epoch_
=
0
;
int
TextClass
DataFeed
::
s_current_save_epoch_
=
0
;
int
MultiSlot
DataFeed
::
s_current_save_epoch_
=
0
;
std
::
mutex
TextClass
DataFeed
::
s_locker_epoch_start_
;
std
::
mutex
MultiSlot
DataFeed
::
s_locker_epoch_start_
;
std
::
condition_variable
TextClass
DataFeed
::
s_condition_epoch_start_
;
std
::
condition_variable
MultiSlot
DataFeed
::
s_condition_epoch_start_
;
bool
TextClass
DataFeed
::
s_epoch_start_flag_
=
false
;
bool
MultiSlot
DataFeed
::
s_epoch_start_flag_
=
false
;
void
TextClass
DataFeed
::
Init
()
{
void
MultiSlot
DataFeed
::
Init
()
{
// hard coding for a specific datafeed
// hard coding for a specific datafeed
feed_vec_
.
resize
(
2
);
feed_vec_
.
resize
(
2
);
// feed_vec_[0].reset(new LoDTensor);
// feed_vec_[0].reset(new LoDTensor);
...
@@ -73,12 +73,12 @@ void TextClassDataFeed::Init() {
...
@@ -73,12 +73,12 @@ void TextClassDataFeed::Init() {
field_names_
.
clear
();
field_names_
.
clear
();
}
}
TextClassDataFeed
::
TextClass
DataFeed
()
{
MultiSlotDataFeed
::
MultiSlot
DataFeed
()
{
Init
();
Init
();
}
}
// todo: use elegant implemention for this function
// todo: use elegant implemention for this function
bool
TextClass
DataFeed
::
ReadBatch
()
{
bool
MultiSlot
DataFeed
::
ReadBatch
()
{
paddle
::
framework
::
Vector
<
size_t
>
offset
;
paddle
::
framework
::
Vector
<
size_t
>
offset
;
int
tlen
=
0
;
int
tlen
=
0
;
int
llen
=
0
;
int
llen
=
0
;
...
@@ -142,13 +142,13 @@ bool TextClassDataFeed::ReadBatch() {
...
@@ -142,13 +142,13 @@ bool TextClassDataFeed::ReadBatch() {
return
true
;
return
true
;
}
}
TextClassDataFeed
::
TextClassDataFeed
(
const
TextClass
DataFeed
&
data_feed
)
{
MultiSlotDataFeed
::
MultiSlotDataFeed
(
const
MultiSlot
DataFeed
&
data_feed
)
{
Init
();
Init
();
SetBatchSize
(
data_feed
.
batch_size_
);
SetBatchSize
(
data_feed
.
batch_size_
);
SetFieldNames
(
data_feed
.
field_names_
);
SetFieldNames
(
data_feed
.
field_names_
);
}
}
void
TextClass
DataFeed
::
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
void
MultiSlot
DataFeed
::
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
)
{
for
(
unsigned
int
i
=
0
;
i
<
use_slot_alias_
.
size
();
++
i
)
{
for
(
unsigned
int
i
=
0
;
i
<
use_slot_alias_
.
size
();
++
i
)
{
if
(
name
==
use_slot_alias_
[
i
])
{
if
(
name
==
use_slot_alias_
[
i
])
{
feed_vec_
[
i
]
=
feed
->
GetMutable
<
LoDTensor
>
();
feed_vec_
[
i
]
=
feed
->
GetMutable
<
LoDTensor
>
();
...
@@ -156,7 +156,7 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
...
@@ -156,7 +156,7 @@ void TextClassDataFeed::AddFeedVar(Variable* feed, const std::string& name) {
}
}
}
}
void
TextClass
DataFeed
::
SetFileList
(
const
char
*
filelist
)
{
void
MultiSlot
DataFeed
::
SetFileList
(
const
char
*
filelist
)
{
s_filelist_
.
clear
();
s_filelist_
.
clear
();
std
::
ifstream
fin
(
filelist
);
std
::
ifstream
fin
(
filelist
);
PADDLE_ENFORCE
(
fin
.
good
(),
PADDLE_ENFORCE
(
fin
.
good
(),
...
@@ -170,14 +170,14 @@ void TextClassDataFeed::SetFileList(const char* filelist) {
...
@@ -170,14 +170,14 @@ void TextClassDataFeed::SetFileList(const char* filelist) {
fin
.
close
();
fin
.
close
();
}
}
void
TextClass
DataFeed
::
SetFieldNames
(
void
MultiSlot
DataFeed
::
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
)
{
const
std
::
vector
<
std
::
string
>&
field_names
)
{
field_names_
.
clear
();
field_names_
.
clear
();
field_names_
.
insert
(
field_names_
.
end
(),
field_names
.
begin
(),
field_names_
.
insert
(
field_names_
.
end
(),
field_names
.
begin
(),
field_names
.
end
());
field_names
.
end
());
}
}
bool
TextClass
DataFeed
::
SetFile
(
const
char
*
filename
)
{
bool
MultiSlot
DataFeed
::
SetFile
(
const
char
*
filename
)
{
// termnum termid termid ... termid label
// termnum termid termid ... termid label
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
binary
);
std
::
ifstream
ifs
(
filename
,
std
::
ios
::
binary
);
if
(
ifs
.
fail
())
{
if
(
ifs
.
fail
())
{
...
@@ -198,7 +198,7 @@ bool TextClassDataFeed::SetFile(const char* filename) {
...
@@ -198,7 +198,7 @@ bool TextClassDataFeed::SetFile(const char* filename) {
return
true
;
return
true
;
}
}
void
TextClass
DataFeed
::
UpdateEpochNum
()
{
void
MultiSlot
DataFeed
::
UpdateEpochNum
()
{
s_current_finished_file_cnt_
++
;
s_current_finished_file_cnt_
++
;
if
(
s_current_finished_file_cnt_
>=
s_filelist_
.
size
())
{
if
(
s_current_finished_file_cnt_
>=
s_filelist_
.
size
())
{
...
@@ -214,25 +214,14 @@ void TextClassDataFeed::UpdateEpochNum() {
...
@@ -214,25 +214,14 @@ void TextClassDataFeed::UpdateEpochNum() {
}
}
}
}
void
TextClassDataFeed
::
StartOneEpoch
()
{
void
MultiSlotDataFeed
::
Start
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
std
::
random_shuffle
(
s_filelist_
.
begin
(),
s_filelist_
.
end
());
s_current_file_idx_
=
0
;
LOG
(
INFO
)
<<
"Beginning epoch "
<<
s_current_epoch_
;
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
s_epoch_start_flag_
=
true
;
}
s_condition_epoch_start_
.
notify_all
();
}
}
void
TextClassDataFeed
::
WaitNextEpoch
()
{
int
MultiSlotDataFeed
::
Next
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
s_locker_epoch_start_
);
return
0
;
s_condition_epoch_start_
.
wait
(
lock
,
[]{
return
s_epoch_start_flag_
;});
}
}
const
char
*
TextClass
DataFeed
::
PickOneFile
()
{
const
char
*
MultiSlot
DataFeed
::
PickOneFile
()
{
std
::
string
file_to_be_processed
;
std
::
string
file_to_be_processed
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
std
::
lock_guard
<
std
::
mutex
>
lock
(
s_locker_for_pick_file_
);
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
91fc8f35
...
@@ -81,8 +81,8 @@ class DataFeed {
...
@@ -81,8 +81,8 @@ class DataFeed {
virtual
unsigned
int
GetCurrentEpoch
()
=
0
;
virtual
unsigned
int
GetCurrentEpoch
()
=
0
;
virtual
const
char
*
PickOneFile
()
=
0
;
virtual
const
char
*
PickOneFile
()
=
0
;
virtual
void
UpdateEpochNum
()
=
0
;
virtual
void
UpdateEpochNum
()
=
0
;
virtual
void
Start
OneEpoch
()
=
0
;
virtual
void
Start
()
=
0
;
virtual
void
WaitNextEpoch
()
=
0
;
virtual
int
Next
()
=
0
;
std
::
vector
<
LoDTensor
*>&
GetFeedVec
()
{
std
::
vector
<
LoDTensor
*>&
GetFeedVec
()
{
return
feed_vec_
;
return
feed_vec_
;
...
@@ -106,13 +106,13 @@ class DataFeed {
...
@@ -106,13 +106,13 @@ class DataFeed {
int
thread_id_
;
int
thread_id_
;
};
};
class
TextClass
DataFeed
:
public
DataFeed
{
class
MultiSlot
DataFeed
:
public
DataFeed
{
public:
public:
TextClass
DataFeed
();
MultiSlot
DataFeed
();
TextClassDataFeed
(
const
TextClass
DataFeed
&
data_feed
);
MultiSlotDataFeed
(
const
MultiSlot
DataFeed
&
data_feed
);
public:
public:
virtual
~
TextClass
DataFeed
()
{}
virtual
~
MultiSlot
DataFeed
()
{}
virtual
void
Init
();
virtual
void
Init
();
virtual
bool
ReadBatch
();
virtual
bool
ReadBatch
();
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
);
virtual
void
AddFeedVar
(
Variable
*
feed
,
const
std
::
string
&
name
);
...
@@ -125,8 +125,8 @@ class TextClassDataFeed : public DataFeed {
...
@@ -125,8 +125,8 @@ class TextClassDataFeed : public DataFeed {
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
void
SetBatchSize
(
int
batch
)
{
batch_size_
=
batch
;}
unsigned
int
GetCurrentEpoch
()
{
return
s_current_epoch_
;}
unsigned
int
GetCurrentEpoch
()
{
return
s_current_epoch_
;}
void
UpdateEpochNum
();
void
UpdateEpochNum
();
void
Start
OneEpoch
();
void
Start
();
void
WaitNextEpoch
();
int
Next
();
public:
public:
void
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
);
void
SetFieldNames
(
const
std
::
vector
<
std
::
string
>&
field_names
);
...
...
paddle/fluid/framework/data_feed.proto
浏览文件 @
91fc8f35
...
@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
syntax
=
"proto2"
;
syntax
=
"proto2"
;
package
paddle
;
package
paddle
.
framework
;
message
DataFeedDesc
{
message
DataFeedDesc
{
optional
string
name
=
1
;
optional
string
name
=
1
;
optional
int32
batch
=
2
[
default
=
32
];
optional
int32
batch
=
2
[
default
=
32
];
repeated
string
field_names
=
3
;
}
}
paddle/fluid/framework/data_feed_factory.cc
浏览文件 @
91fc8f35
...
@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,18 +12,23 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/framework/data_feed_factory.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
typedef
shared_ptr
<
DataFeed
>
(
*
Createdata_feedFunction
)();
typedef
s
td
::
s
hared_ptr
<
DataFeed
>
(
*
Createdata_feedFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
Createdata_feedFunction
>
data_feedMap
;
typedef
std
::
unordered_map
<
std
::
string
,
Createdata_feedFunction
>
data_feedMap
;
data_feedMap
g_data_feed_map
;
data_feedMap
g_data_feed_map
;
#define REGISTER_DATAFEED_CLASS(data_feed_class) \
#define REGISTER_DATAFEED_CLASS(data_feed_class) \
namespace { \
namespace { \
shared_ptr<DataFeed> Creator_##data_feed_class() { \
s
td::s
hared_ptr<DataFeed> Creator_##data_feed_class() { \
return shared_ptr<DataFeed>(new data_feed_class); \
return s
td::s
hared_ptr<DataFeed>(new data_feed_class); \
} \
} \
class __Registerer_##data_feed_class { \
class __Registerer_##data_feed_class { \
public: \
public: \
...
@@ -35,8 +40,8 @@ data_feedMap g_data_feed_map;
...
@@ -35,8 +40,8 @@ data_feedMap g_data_feed_map;
} // namespace
} // namespace
string
DataFeedFactory
::
DataFeedTypeList
()
{
st
d
::
st
ring
DataFeedFactory
::
DataFeedTypeList
()
{
string
data_feed_types
;
st
d
::
st
ring
data_feed_types
;
for
(
auto
iter
=
g_data_feed_map
.
begin
();
for
(
auto
iter
=
g_data_feed_map
.
begin
();
iter
!=
g_data_feed_map
.
end
();
++
iter
)
{
iter
!=
g_data_feed_map
.
end
();
++
iter
)
{
if
(
iter
!=
g_data_feed_map
.
begin
())
{
if
(
iter
!=
g_data_feed_map
.
begin
())
{
...
@@ -47,9 +52,9 @@ string DataFeedFactory::DataFeedTypeList() {
...
@@ -47,9 +52,9 @@ string DataFeedFactory::DataFeedTypeList() {
return
data_feed_types
;
return
data_feed_types
;
}
}
shared_ptr
<
DataFeed
>
DataFeedFactory
::
CreateDataFeed
(
s
td
::
s
hared_ptr
<
DataFeed
>
DataFeedFactory
::
CreateDataFeed
(
const
char
*
data_feed_class
)
{
std
::
string
data_feed_class
)
{
if
(
g_data_feed_map
.
count
(
string
(
data_feed_class
)
)
<
1
)
{
if
(
g_data_feed_map
.
count
(
data_feed_class
)
<
1
)
{
exit
(
-
1
);
exit
(
-
1
);
}
}
return
g_data_feed_map
[
data_feed_class
]();
return
g_data_feed_map
[
data_feed_class
]();
...
...
paddle/fluid/framework/data_feed_factory.h
浏览文件 @
91fc8f35
...
@@ -16,14 +16,15 @@ limitations under the License. */
...
@@ -16,14 +16,15 @@ limitations under the License. */
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_
#define PADDLE_FLUID_FRAMEWORK_DATA_FEED_FACTORY_H_
#include <string>
#include <string>
#include "paddle/framework/data_feed.h"
#include <memory>
#include "paddle/fluid/framework/data_feed.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
DataFeedFactory
{
class
DataFeedFactory
{
public:
public:
static
std
::
string
DataFeedTypeList
();
static
std
::
string
DataFeedTypeList
();
static
s
hared_ptr
<
DataFeed
>
CreateDataFeed
(
const
char
*
data_feed_class
);
static
s
td
::
shared_ptr
<
DataFeed
>
CreateDataFeed
(
std
::
string
data_feed_class
);
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
...
...
paddle/fluid/framework/executor_thread_worker.cc
浏览文件 @
91fc8f35
...
@@ -93,6 +93,11 @@ void ExecutorThreadWorker::CreateThreadResource(
...
@@ -93,6 +93,11 @@ void ExecutorThreadWorker::CreateThreadResource(
void
ExecutorThreadWorker
::
CreateThreadScope
(
const
ProgramDesc
&
program
)
{
void
ExecutorThreadWorker
::
CreateThreadScope
(
const
ProgramDesc
&
program
)
{
auto
&
block
=
program
.
Block
(
0
);
auto
&
block
=
program
.
Block
(
0
);
PADDLE_ENFORCE_NOT_NULL
(
root_scope_
,
"root_scope should be set before creating thread scope"
);
thread_scope_
=
&
root_scope_
->
NewScope
();
thread_scope_
=
&
root_scope_
->
NewScope
();
for
(
auto
&
var
:
block
.
AllVars
())
{
for
(
auto
&
var
:
block
.
AllVars
())
{
if
(
var
->
Persistable
())
{
if
(
var
->
Persistable
())
{
...
@@ -107,17 +112,24 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
...
@@ -107,17 +112,24 @@ void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
void
ExecutorThreadWorker
::
SetDataFeed
(
void
ExecutorThreadWorker
::
SetDataFeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
)
{
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
)
{
local
_reader_
=
datafeed
;
thread
_reader_
=
datafeed
;
}
}
void
ExecutorThreadWorker
::
BindingDataFeedMemory
()
{
void
ExecutorThreadWorker
::
BindingDataFeedMemory
()
{
const
std
::
vector
<
std
::
string
>&
input_feed
=
const
std
::
vector
<
std
::
string
>&
input_feed
=
thread_reader_
->
GetUseSlotAlias
();
thread_reader_
->
GetUseSlotAlias
();
for
(
auto
name
:
input_feed
)
{
for
(
auto
name
:
input_feed
)
{
local
_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
thread
_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
}
}
}
}
void
ExecutorThreadWorker
::
SetFetchVarNames
(
const
std
::
vector
<
std
::
string
>&
fetch_var_names
)
{
fetch_var_names_
.
clear
();
fetch_var_names_
.
insert
(
fetch_var_names_
.
end
(),
fetch_var_names
.
begin
(),
fetch_var_names
.
end
());
}
void
ExecutorThreadWorker
::
SetDevice
()
{
void
ExecutorThreadWorker
::
SetDevice
()
{
// at most 48 threads binding currently
// at most 48 threads binding currently
static
unsigned
priority
[]
=
{
static
unsigned
priority
[]
=
{
...
@@ -156,12 +168,28 @@ void ExecutorThreadWorker::SetDevice() {
...
@@ -156,12 +168,28 @@ void ExecutorThreadWorker::SetDevice() {
void
ExecutorThreadWorker
::
TrainFiles
()
{
void
ExecutorThreadWorker
::
TrainFiles
()
{
// todo: configurable
// todo: configurable
SetDevice
();
SetDevice
();
int
fetch_var_num
=
fetch_var_names_
.
size
();
fetch_values_
.
clear
();
fetch_values_
.
resize
(
fetch_var_num
,
0
);
thread_reader_
->
Start
();
thread_reader_
->
Start
();
while
(
int
cur_batch
=
thread_reader_
->
Next
())
{
int
cur_batch
;
while
((
cur_batch
=
thread_reader_
->
Next
())
>
0
)
{
// executor run here
// executor run here
for
(
auto
&
op
:
ops_
)
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
*
thread_scope_
,
place_
);
op
->
Run
(
*
thread_scope_
,
place_
);
}
}
float
avg_inspect
=
0.0
;
for
(
int
i
=
0
;
i
<
fetch_var_num
;
++
i
)
{
avg_inspect
=
thread_scope_
->
FindVar
(
fetch_var_names_
[
i
])
->
GetMutable
<
LoDTensor
>
()
->
data
<
float
>
()[
0
];
fetch_values_
[
i
]
+=
avg_inspect
;
}
thread_scope_
->
DropKids
();
thread_scope_
->
DropKids
();
}
}
}
}
...
...
paddle/fluid/framework/executor_thread_worker.h
浏览文件 @
91fc8f35
...
@@ -43,6 +43,9 @@ class ExecutorThreadWorker {
...
@@ -43,6 +43,9 @@ class ExecutorThreadWorker {
void
SetDevice
();
void
SetDevice
();
void
BindingDataFeedMemory
();
void
BindingDataFeedMemory
();
void
SetDataFeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
);
void
SetDataFeed
(
const
std
::
shared_ptr
<
DataFeed
>&
datafeed
);
void
TrainFiles
();
void
SetFetchVarNames
(
const
std
::
vector
<
std
::
string
>&
fetch_var_names
);
std
::
vector
<
float
>&
GetFetchValues
()
{
return
fetch_values_
;}
private:
private:
void
CreateThreadScope
(
const
framework
::
ProgramDesc
&
program
);
void
CreateThreadScope
(
const
framework
::
ProgramDesc
&
program
);
...
@@ -66,9 +69,13 @@ class ExecutorThreadWorker {
...
@@ -66,9 +69,13 @@ class ExecutorThreadWorker {
Scope
*
root_scope_
;
Scope
*
root_scope_
;
// a thread scope, father scope is global score which is shared
// a thread scope, father scope is global score which is shared
Scope
*
thread_scope_
;
Scope
*
thread_scope_
;
private:
std
::
vector
<
std
::
string
>
fetch_var_names_
;
std
::
vector
<
float
>
fetch_values_
;
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
#endif // PADDLE_FLUID_FRAMEWORK_
ASYNC_EXECUTO
R_H_
#endif // PADDLE_FLUID_FRAMEWORK_
EXECUTOR_THREAD_WORKE
R_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
paddle/fluid/pybind/async_executor_py.cc
浏览文件 @
91fc8f35
...
@@ -30,36 +30,32 @@ limitations under the License. */
...
@@ -30,36 +30,32 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/
async_executor_param
.pb.h"
#include "paddle/fluid/framework/
data_feed
.pb.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/data_feed.h"
namespace
py
=
pybind11
;
namespace
py
=
pybind11
;
namespace
pd
=
paddle
::
framework
;
namespace
paddle
{
namespace
paddle
{
namespace
pybind
{
namespace
pybind
{
using
set_name_func
=
void
(
pd
::
DataFeedDesc
::*
)(
const
std
::
string
&
);
void
BindAsyncExecutor
(
py
::
module
*
m
)
{
void
BindAsyncExecutor
(
py
::
module
*
m
)
{
py
::
class_
<
framework
::
DataFeed
>
(
*
m
,
"DataFeed"
);
py
::
class_
<
pd
::
DataFeedDesc
>
(
*
m
,
"DataFeedDesc"
)
py
::
class_
<
framework
::
TextClassDataFeed
,
.
def
(
pybind11
::
init
<>
())
framework
::
DataFeed
>
(
*
m
,
"TextDataFeed"
)
.
def
(
"set_name"
,
(
set_name_func
)
&
pd
::
DataFeedDesc
::
set_name
)
.
def
(
py
::
init
())
.
def
(
"set_batch"
,
&
pd
::
DataFeedDesc
::
set_batch
)
.
def
(
"set_filelist"
,
.
def
(
"set_field_names"
,
[]
(
framework
::
TextClassDataFeed
&
self
,
const
char
*
data_list_file
)
{
[]
(
pd
::
DataFeedDesc
&
self
,
const
std
::
vector
<
std
::
string
>
&
fields
)
{
self
.
SetFileList
(
data_list_file
);
for
(
auto
field
:
fields
)
{
})
self
.
add_field_names
(
field
);
.
def
(
"set_batch_size"
,
&
framework
::
TextClassDataFeed
::
SetBatchSize
)
}
.
def
(
"set_field_names"
,
&
framework
::
TextClassDataFeed
::
SetFieldNames
)
});
.
def
(
"start_one_epoch"
,
&
framework
::
TextClassDataFeed
::
StartOneEpoch
);
py
::
class_
<
framework
::
AsyncExecutor
>
(
*
m
,
"AsyncExecutor"
)
py
::
class_
<
framework
::
AsyncExecutor
>
(
*
m
,
"AsyncExecutor"
)
.
def
(
py
::
init
<
framework
::
ProgramDesc
&
,
.
def
(
py
::
init
<
pd
::
Scope
&
,
const
platform
::
Place
&>
())
std
::
vector
<
std
::
string
>&
,
.
def
(
"run_from_files"
,
&
framework
::
AsyncExecutor
::
RunFromFile
)
framework
::
TextClassDataFeed
&
,
.
def
(
"check_file"
,
&
framework
::
AsyncExecutor
::
CheckFiles
);
unsigned
int
,
const
platform
::
Place
&>
())
.
def
(
"init_root_scope"
,
&
framework
::
AsyncExecutor
::
InitRootScope
)
.
def
(
"run_startup_program"
,
&
framework
::
AsyncExecutor
::
RunStartupProgram
)
.
def
(
"run"
,
&
framework
::
AsyncExecutor
::
Run
);
}
// end BindAsyncExecutor
}
// end BindAsyncExecutor
}
// end namespace pybind
}
// end namespace pybind
}
// end namespace paddle
}
// end namespace paddle
...
...
python/paddle/fluid/async_executor.py
浏览文件 @
91fc8f35
...
@@ -19,30 +19,26 @@ import contextlib
...
@@ -19,30 +19,26 @@ import contextlib
import
six
import
six
from
.framework
import
Program
,
default_main_program
,
Variable
from
.framework
import
Program
,
default_main_program
,
Variable
from
.
import
core
from
.
import
core
from
.
import
Executor
from
.
executor
import
global_scope
__all__
=
[
'
Tex
tDataFeed'
,
'AsyncExecutor'
]
__all__
=
[
'
MultiSlo
tDataFeed'
,
'AsyncExecutor'
]
g_scope
=
core
.
Scope
()
g_scope
=
core
.
Scope
()
class
TextDataFeed
(
):
class
DataFeedDesc
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
feed
=
core
.
TextDataFeed
()
self
.
desc
=
core
.
DataFeedDesc
()
def
set_filelist
(
self
,
filelist
):
self
.
feed
.
set_filelist
(
filelist
)
def
set_batch_size
(
self
,
batch_size
):
def
set_batch_size
(
self
,
batch_size
):
self
.
feed
.
set_batch_size
(
batch_size
)
self
.
desc
.
set_batch
(
batch_size
)
def
set_field_name
(
self
,
field_names
):
def
set_field_names
(
self
,
field_names
):
if
isinstance
(
field_names
,
str
):
if
isinstance
(
field_names
,
Variable
):
field_names
=
[
field_names
]
field_names
=
[
field_names
]
self
.
desc
.
set_field_names
(
field_names
)
self
.
feed
.
set_field_names
(
field_names
)
class
MultiSlotDataFeed
(
DataFeedDesc
):
def
__init__
(
self
):
def
start_an_epoch
(
self
):
super
(
MultiSlotDataFeed
,
self
).
__init__
()
self
.
feed
.
start_one_epoch
(
)
self
.
desc
.
set_name
(
"MultiSlotDataFeed"
)
class
AsyncExecutor
(
object
):
class
AsyncExecutor
(
object
):
"""
"""
...
@@ -55,45 +51,19 @@ class AsyncExecutor(object):
...
@@ -55,45 +51,19 @@ class AsyncExecutor(object):
They has the exactly same arguments, and expected the same results.
They has the exactly same arguments, and expected the same results.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
place
=
None
):
program
,
param_names
,
data_feed
,
thread_num
,
place
=
None
,
scope
=
None
):
if
program
is
None
:
program
=
default_main_program
()
program_desc
=
program
.
desc
if
not
isinstance
(
data_feed
,
TextDataFeed
):
raise
ValueError
(
"data_feed for AsyncExecutor.run() type error"
)
if
place
is
None
:
if
place
is
None
:
place
=
core
.
CPUPlace
()
place
=
core
.
CPUPlace
()
if
not
isinstance
(
place
,
core
.
CPUPlace
):
if
not
isinstance
(
place
,
core
.
CPUPlace
):
raise
ValueError
(
"AsyncExecutor only supports CPU device"
)
raise
ValueError
(
"AsyncExecutor only supports CPU device"
)
if
isinstance
(
param_names
,
Variable
):
param_names
=
[
param_names
]
p
=
core
.
Place
()
p
=
core
.
Place
()
p
.
set_place
(
place
)
p
.
set_place
(
place
)
self
.
executor
=
core
.
AsyncExecutor
(
program_desc
,
param_names
,
data_feed
.
feed
,
thread_num
,
p
)
def
run_startup_program
(
self
,
program
=
None
,
scope
=
None
):
if
program
is
None
:
program
=
default_startup_program
()
program_desc
=
program
.
_get_desc
()
if
scope
is
None
:
scope
=
global_scope
()
scope
=
g_scope
self
.
executor
=
core
.
AsyncExecutor
(
scope
,
p
)
self
.
executor
.
run_startup_program
(
program_desc
,
scope
)
def
run
(
self
,
program
,
data_feed
,
filelist
,
thread_num
,
fetch
):
def
run
(
self
,
inspect_vars
,
scope
=
None
):
"""
"""
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
Python executor takes a program, add feed operators and fetch operators to this program according
...
@@ -136,16 +106,27 @@ class AsyncExecutor(object):
...
@@ -136,16 +106,27 @@ class AsyncExecutor(object):
>>> feed={'X': x},
>>> feed={'X': x},
>>> fetch_list=[loss.name])
>>> fetch_list=[loss.name])
"""
"""
if
inspect_vars
is
not
None
:
if
program
is
None
:
if
isinstance
(
inspect_vars
,
Variable
):
program
=
default_main_program
()
inspect_vars
=
[
inspect_vars
]
program_desc
=
program
.
desc
inspect_var_names
=
[
var
.
name
for
var
in
inspect_vars
]
if
data_feed
is
None
:
raise
ValueError
(
'ValueError: data_feed should be provided'
)
if
filelist
is
None
:
raise
ValueError
(
'ValueError: filelist should be provided'
)
if
isinstance
(
filelist
,
str
):
filelist
=
[
filelist
]
if
scope
is
None
:
if
not
isinstance
(
thread_num
,
int
)
:
scope
=
g_scope
raise
TypeError
(
'TypeError: thread_num should be a positive number'
)
self
.
executor
.
init_root_scope
(
scope
)
if
fetch
is
not
None
:
if
isinstance
(
fetch
,
Variable
):
fetch
=
[
fetch
]
fetch_var_names
=
[
var
.
name
for
var
in
fetch
]
evaluation
=
self
.
executor
.
run
(
inspect
_var_names
)
evaluation
=
self
.
executor
.
run
_from_files
(
program_desc
,
data_feed
.
desc
,
filelist
,
thread_num
,
fetch
_var_names
)
return
evaluation
return
evaluation
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录