Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
68409533
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
68409533
编写于
6月 06, 2018
作者:
T
tensor-tang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine nlp multi-threads
上级
b74362f9
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
26 addition
and
31 deletion
+26
-31
paddle/fluid/inference/tests/book/test_inference_nlp.cc
paddle/fluid/inference/tests/book/test_inference_nlp.cc
+26
-31
未找到文件。
paddle/fluid/inference/tests/book/test_inference_nlp.cc
浏览文件 @
68409533
...
...
@@ -101,23 +101,22 @@ void SplitData(
}
void
ThreadRunInfer
(
const
int
tid
,
paddle
::
framework
::
Executor
*
executor
,
paddle
::
framework
::
Scope
*
scope
,
const
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>&
inference_program
,
const
int
tid
,
paddle
::
framework
::
Scope
*
scope
,
const
std
::
vector
<
std
::
vector
<
const
paddle
::
framework
::
LoDTensor
*>>&
jobs
)
{
auto
copy_program
=
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
(
new
paddle
::
framework
::
ProgramDesc
(
*
inference_program
));
// maybe framework:ProgramDesc is not thread-safe
auto
&
sub_scope
=
scope
->
NewScope
();
auto
place
=
paddle
::
platform
::
CPUPlace
();
auto
executor
=
paddle
::
framework
::
Executor
(
place
);
auto
inference_program
=
paddle
::
inference
::
Load
(
&
executor
,
scope
,
FLAGS_model_path
);
std
::
string
feed_holder_name
=
"feed_"
+
paddle
::
string
::
to_string
(
tid
);
std
::
string
fetch_holder_name
=
"fetch_"
+
paddle
::
string
::
to_string
(
tid
);
copy_program
->
SetFeedHolderName
(
feed_holder_name
);
copy_program
->
SetFetchHolderName
(
fetch_holder_name
);
auto
ctx
=
executor
.
Prepare
(
*
inference_program
,
/*block_id*/
0
);
executor
.
CreateVariables
(
*
inference_program
,
&
sub_scope
,
/*block_id*/
0
);
const
std
::
vector
<
std
::
string
>&
feed_target_names
=
copy
_program
->
GetFeedTargetNames
();
inference
_program
->
GetFeedTargetNames
();
const
std
::
vector
<
std
::
string
>&
fetch_target_names
=
copy
_program
->
GetFetchTargetNames
();
inference
_program
->
GetFetchTargetNames
();
PADDLE_ENFORCE_EQ
(
fetch_target_names
.
size
(),
1UL
);
std
::
map
<
std
::
string
,
paddle
::
framework
::
LoDTensor
*>
fetch_targets
;
...
...
@@ -131,9 +130,8 @@ void ThreadRunInfer(
auto
start_ms
=
GetCurrentMs
();
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
++
i
)
{
feed_targets
[
feed_target_names
[
0
]]
=
inputs
[
i
];
executor
->
Run
(
*
copy_program
,
&
sub_scope
,
&
feed_targets
,
&
fetch_targets
,
true
/*create_local_scope*/
,
true
/*create_vars*/
,
feed_holder_name
,
fetch_holder_name
);
executor
.
RunPreparedContext
(
ctx
.
get
(),
&
sub_scope
,
&
feed_targets
,
&
fetch_targets
,
false
/*create_local_scope*/
);
}
auto
stop_ms
=
GetCurrentMs
();
scope
->
DeleteScope
(
&
sub_scope
);
...
...
@@ -158,22 +156,10 @@ TEST(inference, nlp) {
LOG
(
INFO
)
<<
"Number of samples (seq_len<1024): "
<<
datasets
.
size
();
LOG
(
INFO
)
<<
"Total number of words: "
<<
num_total_words
;
const
bool
model_combined
=
false
;
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// 1. Define place, executor, scope
auto
place
=
paddle
::
platform
::
CPUPlace
();
auto
executor
=
paddle
::
framework
::
Executor
(
place
);
std
::
unique_ptr
<
paddle
::
framework
::
Scope
>
scope
(
new
paddle
::
framework
::
Scope
());
// 2. Initialize the inference_program and load parameters
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
inference_program
;
inference_program
=
InitProgram
(
&
executor
,
scope
.
get
(),
FLAGS_model_path
,
model_combined
);
if
(
FLAGS_use_mkldnn
)
{
EnableMKLDNN
(
inference_program
);
}
#ifdef PADDLE_WITH_MKLML
// only use 1 thread number per std::thread
omp_set_dynamic
(
0
);
...
...
@@ -189,21 +175,30 @@ TEST(inference, nlp) {
start_ms
=
GetCurrentMs
();
for
(
int
i
=
0
;
i
<
FLAGS_num_threads
;
++
i
)
{
threads
.
emplace_back
(
new
std
::
thread
(
ThreadRunInfer
,
i
,
&
executor
,
scope
.
get
(),
std
::
ref
(
inference_program
),
std
::
ref
(
jobs
)));
new
std
::
thread
(
ThreadRunInfer
,
i
,
scope
.
get
(),
std
::
ref
(
jobs
)));
}
for
(
int
i
=
0
;
i
<
FLAGS_num_threads
;
++
i
)
{
threads
[
i
]
->
join
();
}
stop_ms
=
GetCurrentMs
();
}
else
{
if
(
FLAGS_prepare_vars
)
{
executor
.
CreateVariables
(
*
inference_program
,
scope
.
get
(),
0
);
// 1. Define place, executor, scope
auto
place
=
paddle
::
platform
::
CPUPlace
();
auto
executor
=
paddle
::
framework
::
Executor
(
place
);
// 2. Initialize the inference_program and load parameters
std
::
unique_ptr
<
paddle
::
framework
::
ProgramDesc
>
inference_program
;
inference_program
=
InitProgram
(
&
executor
,
scope
.
get
(),
FLAGS_model_path
,
/*model combined*/
false
);
if
(
FLAGS_use_mkldnn
)
{
EnableMKLDNN
(
inference_program
);
}
// always prepare context
std
::
unique_ptr
<
paddle
::
framework
::
ExecutorPrepareContext
>
ctx
;
ctx
=
executor
.
Prepare
(
*
inference_program
,
0
);
if
(
FLAGS_prepare_vars
)
{
executor
.
CreateVariables
(
*
inference_program
,
scope
.
get
(),
0
);
}
// preapre fetch
const
std
::
vector
<
std
::
string
>&
fetch_target_names
=
inference_program
->
GetFetchTargetNames
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录