Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fbd3604c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fbd3604c
编写于
4月 03, 2018
作者:
L
Liu Yiqun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Split Executor.Run to Executor.Prepare and Executor.RunPreparedContext for inference.
上级
172c887d
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
85 addition
and
40 deletion
+85
-40
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+59
-35
paddle/fluid/framework/executor.h
paddle/fluid/framework/executor.h
+7
-0
paddle/fluid/inference/tests/book/test_inference_image_classification.cc
...ference/tests/book/test_inference_image_classification.cc
+2
-2
paddle/fluid/inference/tests/test_helper.h
paddle/fluid/inference/tests/test_helper.h
+17
-3
未找到文件。
paddle/fluid/framework/executor.cc
浏览文件 @
fbd3604c
...
...
@@ -129,13 +129,15 @@ static bool has_feed_operators(
feed_count
,
feed_targets
.
size
(),
"The number of feed operators should match 'feed_targets'"
);
// When feed operator are present, so should be feed_holder
auto
var
=
block
.
FindVar
(
feed_holder_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Block should already have a '%s' variable"
,
feed_holder_name
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
proto
::
VarType
::
FEED_MINIBATCH
,
"'%s' variable should be 'FEED_MINIBATCH' type"
,
feed_holder_name
);
if
(
!
feed_holder_name
.
empty
())
{
// When feed operator are present, so should be feed_holder
auto
var
=
block
.
FindVar
(
feed_holder_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Block should already have a '%s' variable"
,
feed_holder_name
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
proto
::
VarType
::
FEED_MINIBATCH
,
"'%s' variable should be 'FEED_MINIBATCH' type"
,
feed_holder_name
);
}
}
return
feed_count
>
0
;
...
...
@@ -169,13 +171,15 @@ static bool has_fetch_operators(
fetch_count
,
fetch_targets
.
size
(),
"The number of fetch operators should match 'fetch_targets'"
);
// When fetch operator are present, so should be fetch_holder
auto
var
=
block
.
FindVar
(
fetch_holder_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Block should already have a '%s' variable"
,
fetch_holder_name
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
proto
::
VarType
::
FETCH_LIST
,
"'%s' variable should be 'FETCH_LIST' type"
,
fetch_holder_name
);
if
(
!
fetch_holder_name
.
empty
())
{
// When fetch operator are present, so should be fetch_holder
auto
var
=
block
.
FindVar
(
fetch_holder_name
);
PADDLE_ENFORCE_NOT_NULL
(
var
,
"Block should already have a '%s' variable"
,
fetch_holder_name
);
PADDLE_ENFORCE_EQ
(
var
->
GetType
(),
proto
::
VarType
::
FETCH_LIST
,
"'%s' variable should be 'FETCH_LIST' type"
,
fetch_holder_name
);
}
}
return
fetch_count
>
0
;
...
...
@@ -222,16 +226,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
// map the data of feed_targets to feed_holder
for
(
auto
*
op
:
global_block
->
AllOps
())
{
if
(
op
->
Type
()
==
kFeedOpType
)
{
std
::
string
feed_target_name
=
op
->
Output
(
"Out"
)[
0
];
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
SetFeedVariable
(
scope
,
*
feed_targets
[
feed_target_name
],
feed_holder_name
,
idx
);
}
}
if
(
!
has_fetch_ops
)
{
// create fetch_holder variable
auto
*
fetch_holder
=
global_block
->
Var
(
fetch_holder_name
);
...
...
@@ -255,17 +249,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
Run
(
*
copy_program
,
scope
,
0
,
create_vars
,
create_vars
);
// obtain the data of fetch_targets from fetch_holder
for
(
auto
*
op
:
global_block
->
AllOps
())
{
if
(
op
->
Type
()
==
kFetchOpType
)
{
std
::
string
fetch_target_name
=
op
->
Input
(
"X"
)[
0
];
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
*
fetch_targets
[
fetch_target_name
]
=
GetFetchVariable
(
*
scope
,
fetch_holder_name
,
idx
);
}
}
auto
ctx
=
Prepare
(
*
copy_program
,
0
);
RunPreparedContext
(
ctx
.
get
(),
scope
,
feed_targets
,
fetch_targets
,
feed_holder_name
,
fetch_holder_name
,
create_vars
);
}
std
::
unique_ptr
<
ExecutorPrepareContext
>
Executor
::
Prepare
(
...
...
@@ -343,5 +329,43 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
}
}
void
Executor
::
RunPreparedContext
(
ExecutorPrepareContext
*
ctx
,
Scope
*
scope
,
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
feed_targets
,
std
::
map
<
std
::
string
,
LoDTensor
*>&
fetch_targets
,
const
std
::
string
&
feed_holder_name
,
const
std
::
string
&
fetch_holder_name
,
bool
create_vars
)
{
auto
&
global_block
=
ctx
->
prog_
.
Block
(
ctx
->
block_id_
);
// map the data of feed_targets to feed_holder
for
(
auto
*
op
:
global_block
.
AllOps
())
{
if
(
op
->
Type
()
==
kFeedOpType
)
{
std
::
string
feed_target_name
=
op
->
Output
(
"Out"
)[
0
];
PADDLE_ENFORCE
(
feed_targets
.
find
(
feed_target_name
)
!=
feed_targets
.
end
(),
"Variable %s is not feeded."
);
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
SetFeedVariable
(
scope
,
*
feed_targets
[
feed_target_name
],
feed_holder_name
,
idx
);
}
}
RunPreparedContext
(
ctx
,
scope
,
create_vars
,
create_vars
);
// obtain the data of fetch_targets from fetch_holder
for
(
auto
*
op
:
global_block
.
AllOps
())
{
if
(
op
->
Type
()
==
kFetchOpType
)
{
std
::
string
fetch_target_name
=
op
->
Input
(
"X"
)[
0
];
PADDLE_ENFORCE
(
fetch_targets
.
find
(
fetch_target_name
)
!=
fetch_targets
.
end
(),
"Variable %s is not fetched."
);
int
idx
=
boost
::
get
<
int
>
(
op
->
GetAttr
(
"col"
));
*
fetch_targets
[
fetch_target_name
]
=
GetFetchVariable
(
*
scope
,
fetch_holder_name
,
idx
);
}
}
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/executor.h
浏览文件 @
fbd3604c
...
...
@@ -65,6 +65,13 @@ class Executor {
bool
create_local_scope
=
true
,
bool
create_vars
=
true
);
void
RunPreparedContext
(
ExecutorPrepareContext
*
ctx
,
Scope
*
scope
,
std
::
map
<
std
::
string
,
const
LoDTensor
*>&
feed_targets
,
std
::
map
<
std
::
string
,
LoDTensor
*>&
fetch_targets
,
const
std
::
string
&
feed_holder_name
=
"feed"
,
const
std
::
string
&
fetch_holder_name
=
"fetch"
,
bool
create_vars
=
true
);
private:
const
platform
::
Place
place_
;
};
...
...
paddle/fluid/inference/tests/book/test_inference_image_classification.cc
浏览文件 @
fbd3604c
...
...
@@ -48,7 +48,7 @@ TEST(inference, image_classification) {
// Run inference on CPU
LOG
(
INFO
)
<<
"--- CPU Runs: ---"
;
TestInference
<
paddle
::
platform
::
CPUPlace
>
(
TestInference
<
paddle
::
platform
::
CPUPlace
,
true
>
(
dirname
,
cpu_feeds
,
cpu_fetchs1
,
FLAGS_repeat
);
LOG
(
INFO
)
<<
output1
.
dims
();
...
...
@@ -59,7 +59,7 @@ TEST(inference, image_classification) {
// Run inference on CUDA GPU
LOG
(
INFO
)
<<
"--- GPU Runs: ---"
;
TestInference
<
paddle
::
platform
::
CUDAPlace
>
(
TestInference
<
paddle
::
platform
::
CUDAPlace
,
true
>
(
dirname
,
cpu_feeds
,
cpu_fetchs2
,
FLAGS_repeat
);
LOG
(
INFO
)
<<
output2
.
dims
();
...
...
paddle/fluid/inference/tests/test_helper.h
浏览文件 @
fbd3604c
...
...
@@ -88,7 +88,7 @@ void CheckError(paddle::framework::LoDTensor& output1,
EXPECT_EQ
(
count
,
0U
)
<<
"There are "
<<
count
<<
" different elements."
;
}
template
<
typename
Place
>
template
<
typename
Place
,
bool
PrepareContext
=
false
>
void
TestInference
(
const
std
::
string
&
dirname
,
const
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>&
cpu_feeds
,
std
::
vector
<
paddle
::
framework
::
LoDTensor
*>&
cpu_fetchs
,
...
...
@@ -170,7 +170,14 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program
{
// Ignore the profiling results of the first run
executor
.
Run
(
*
inference_program
,
scope
,
feed_targets
,
fetch_targets
);
std
::
unique_ptr
<
paddle
::
framework
::
ExecutorPrepareContext
>
ctx
;
if
(
PrepareContext
)
{
ctx
=
executor
.
Prepare
(
*
inference_program
,
0
);
executor
.
RunPreparedContext
(
ctx
.
get
(),
scope
,
feed_targets
,
fetch_targets
);
}
else
{
executor
.
Run
(
*
inference_program
,
scope
,
feed_targets
,
fetch_targets
);
}
// Enable the profiler
paddle
::
platform
::
EnableProfiler
(
state
);
...
...
@@ -181,7 +188,14 @@ void TestInference(const std::string& dirname,
"run_inference"
,
paddle
::
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
executor
.
Run
(
*
inference_program
,
scope
,
feed_targets
,
fetch_targets
);
if
(
PrepareContext
)
{
// Note: if you changed the inference_program, you need to call
// executor.Prepare() again to get a new ExecutorPrepareContext.
executor
.
RunPreparedContext
(
ctx
.
get
(),
scope
,
feed_targets
,
fetch_targets
);
}
else
{
executor
.
Run
(
*
inference_program
,
scope
,
feed_targets
,
fetch_targets
);
}
}
// Disable the profiler and print the timing information
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录