Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
438aad24
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
438aad24
编写于
1月 26, 2018
作者:
L
Liu Yiqun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Update the inference unittest using the new Executor.Run().
上级
2cf56367
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
59 addition
and
118 deletion
+59
-118
paddle/inference/inference.cc
paddle/inference/inference.cc
+7
-96
paddle/inference/inference.h
paddle/inference/inference.h
+13
-5
paddle/inference/tests/book/test_inference_recognize_digits.cc
...e/inference/tests/book/test_inference_recognize_digits.cc
+39
-17
未找到文件。
paddle/inference/inference.cc
浏览文件 @
438aad24
...
...
@@ -14,13 +14,13 @@ limitations under the License. */
#include "inference.h"
#include <fstream>
#include "paddle/framework/executor.h"
#include "paddle/framework/init.h"
#include "paddle/framework/scope.h"
namespace
paddle
{
void
InferenceEngine
::
LoadInferenceModel
(
const
std
::
string
&
dirname
)
{
framework
::
ProgramDesc
*
InferenceEngine
::
LoadInferenceModel
(
framework
::
Executor
&
exe
,
framework
::
Scope
*
scope
,
const
std
::
string
&
dirname
)
{
std
::
string
model_filename
=
dirname
+
"/__model__"
;
LOG
(
INFO
)
<<
"loading model from "
<<
model_filename
;
std
::
ifstream
inputfs
(
model_filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
...
...
@@ -34,6 +34,7 @@ void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
program_
=
new
framework
::
ProgramDesc
(
program_desc_str
);
GenerateLoadProgram
(
dirname
);
exe
.
Run
(
*
load_program_
,
scope
,
0
,
true
,
true
);
framework
::
BlockDesc
*
global_block
=
program_
->
MutableBlock
(
0
);
feed_var_names_
.
clear
();
...
...
@@ -45,6 +46,8 @@ void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
fetch_var_names_
.
push_back
(
op
->
Input
(
"X"
)[
0
]);
}
}
return
program_
;
}
bool
InferenceEngine
::
IsParameter
(
const
framework
::
VarDesc
*
var
)
{
...
...
@@ -92,96 +95,4 @@ void InferenceEngine::GenerateLoadProgram(const std::string& dirname) {
}
}
}
void
InferenceEngine
::
PrependFeedOp
()
{
if
(
!
program_
)
{
LOG
(
FATAL
)
<<
"Please initialize the program_ first."
;
}
framework
::
BlockDesc
*
global_block
=
program_
->
MutableBlock
(
0
);
// create_var
framework
::
VarDesc
*
feed_var
=
global_block
->
Var
(
"feed"
);
feed_var
->
SetType
(
framework
::
proto
::
VarDesc
::
FEED_MINIBATCH
);
feed_var
->
SetPersistable
(
true
);
// prepend feed_op
for
(
size_t
i
=
0
;
i
<
feed_var_names_
.
size
();
++
i
)
{
std
::
string
var_name
=
feed_var_names_
[
i
];
LOG
(
INFO
)
<<
"feed var's name: "
<<
var_name
;
// prepend_op
framework
::
OpDesc
*
op
=
global_block
->
PrependOp
();
op
->
SetType
(
"feed"
);
op
->
SetInput
(
"X"
,
{
"feed"
});
op
->
SetOutput
(
"Out"
,
{
var_name
});
op
->
SetAttr
(
"col"
,
{
static_cast
<
int
>
(
i
)});
op
->
CheckAttrs
();
}
}
void
InferenceEngine
::
AppendFetchOp
()
{
if
(
!
program_
)
{
LOG
(
FATAL
)
<<
"Please initialize the program_ first."
;
}
framework
::
BlockDesc
*
global_block
=
program_
->
MutableBlock
(
0
);
// create_var
framework
::
VarDesc
*
fetch_var
=
global_block
->
Var
(
"fetch"
);
fetch_var
->
SetType
(
framework
::
proto
::
VarDesc
::
FETCH_LIST
);
fetch_var
->
SetPersistable
(
true
);
// append fetch_op
for
(
size_t
i
=
0
;
i
<
fetch_var_names_
.
size
();
++
i
)
{
std
::
string
var_name
=
fetch_var_names_
[
i
];
LOG
(
INFO
)
<<
"fetch var's name: "
<<
var_name
;
// append_op
framework
::
OpDesc
*
op
=
global_block
->
AppendOp
();
op
->
SetType
(
"fetch"
);
op
->
SetInput
(
"X"
,
{
var_name
});
op
->
SetOutput
(
"Out"
,
{
"fetch"
});
op
->
SetAttr
(
"col"
,
{
static_cast
<
int
>
(
i
)});
op
->
CheckAttrs
();
}
}
void
InferenceEngine
::
Execute
(
const
std
::
vector
<
framework
::
LoDTensor
>&
feeds
,
std
::
vector
<
framework
::
LoDTensor
>&
fetchs
)
{
if
(
!
program_
||
!
load_program_
)
{
LOG
(
FATAL
)
<<
"Please initialize the program_ and load_program_ first."
;
}
if
(
feeds
.
size
()
!=
feed_var_names_
.
size
())
{
LOG
(
FATAL
)
<<
"Please feed "
<<
feed_var_names_
.
size
()
<<
" input Tensors."
;
}
auto
*
place
=
new
platform
::
CPUPlace
();
framework
::
InitDevices
();
framework
::
Executor
*
executor
=
new
framework
::
Executor
(
*
place
);
framework
::
Scope
*
scope
=
new
framework
::
Scope
();
executor
->
Run
(
*
load_program_
,
scope
,
0
,
true
,
true
);
std
::
map
<
std
::
string
,
const
framework
::
LoDTensor
*>
feed_targets
;
std
::
map
<
std
::
string
,
framework
::
LoDTensor
*>
fetch_targets
;
// set_feed_variable
for
(
size_t
i
=
0
;
i
<
feed_var_names_
.
size
();
++
i
)
{
feed_targets
[
feed_var_names_
[
i
]]
=
&
feeds
[
i
];
}
// get_fetch_variable
fetchs
.
resize
(
fetch_var_names_
.
size
());
for
(
size_t
i
=
0
;
i
<
fetch_var_names_
.
size
();
++
i
)
{
fetch_targets
[
fetch_var_names_
[
i
]]
=
&
fetchs
[
i
];
}
executor
->
Run
(
*
program_
,
scope
,
feed_targets
,
fetch_targets
);
delete
place
;
delete
scope
;
delete
executor
;
}
}
// namespace paddle
paddle/inference/inference.h
浏览文件 @
438aad24
...
...
@@ -15,8 +15,10 @@ limitations under the License. */
#pragma once
#include "paddle/framework/block_desc.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/program_desc.h"
#include "paddle/framework/scope.h"
namespace
paddle
{
...
...
@@ -28,15 +30,21 @@ public:
delete
load_program_
;
}
void
LoadInferenceModel
(
const
std
::
string
&
dirname
);
void
Execute
(
const
std
::
vector
<
framework
::
LoDTensor
>&
feeds
,
std
::
vector
<
framework
::
LoDTensor
>&
fetchs
);
framework
::
ProgramDesc
*
LoadInferenceModel
(
framework
::
Executor
&
exe
,
framework
::
Scope
*
scope
,
const
std
::
string
&
dirname
);
const
std
::
vector
<
std
::
string
>&
GetFeedVarNames
()
const
{
return
feed_var_names_
;
}
const
std
::
vector
<
std
::
string
>&
GetFetchVarNames
()
const
{
return
fetch_var_names_
;
}
private:
bool
IsParameter
(
const
framework
::
VarDesc
*
var
);
void
GenerateLoadProgram
(
const
std
::
string
&
dirname
);
void
PrependFeedOp
();
void
AppendFetchOp
();
private:
framework
::
ProgramDesc
*
program_
;
...
...
paddle/inference/tests/book/test_inference_recognize_digits.cc
浏览文件 @
438aad24
...
...
@@ -16,11 +16,12 @@ limitations under the License. */
#include <time.h>
#include <sstream>
#include "gflags/gflags.h"
#include "paddle/framework/init.h"
#include "paddle/inference/inference.h"
DEFINE_string
(
dirname
,
""
,
"Directory of the inference model."
);
TEST
(
inference
,
recognize_digits
)
{
TEST
(
recognize_digits
,
CPU
)
{
if
(
FLAGS_dirname
.
empty
())
{
LOG
(
FATAL
)
<<
"Usage: ./example --dirname=path/to/your/model"
;
}
...
...
@@ -28,33 +29,54 @@ TEST(inference, recognize_digits) {
std
::
cout
<<
"FLAGS_dirname: "
<<
FLAGS_dirname
<<
std
::
endl
;
std
::
string
dirname
=
FLAGS_dirname
;
// 0. Initialize all the devices
paddle
::
framework
::
InitDevices
();
// 1. Define place, executor and scope
auto
place
=
paddle
::
platform
::
CPUPlace
();
auto
executor
=
paddle
::
framework
::
Executor
(
place
);
auto
*
scope
=
new
paddle
::
framework
::
Scope
();
// 2. Initialize the inference_program and load all parameters from file
paddle
::
InferenceEngine
*
engine
=
new
paddle
::
InferenceEngine
();
engine
->
LoadInferenceModel
(
dirname
);
paddle
::
framework
::
ProgramDesc
*
inference_program
=
engine
->
LoadInferenceModel
(
executor
,
scope
,
dirname
);
// 3. Get the feed_var_names and fetch_var_names
const
std
::
vector
<
std
::
string
>&
feed_target_names
=
engine
->
GetFeedVarNames
();
const
std
::
vector
<
std
::
string
>&
fetch_target_names
=
engine
->
GetFetchVarNames
();
// 4. Prepare inputs
std
::
map
<
std
::
string
,
const
paddle
::
framework
::
LoDTensor
*>
feed_targets
;
paddle
::
framework
::
LoDTensor
input
;
srand
(
time
(
0
));
float
*
input_ptr
=
input
.
mutable_data
<
float
>
({
1
,
784
},
paddle
::
platform
::
CPUPlace
());
input
.
mutable_data
<
float
>
({
1
,
28
,
28
},
paddle
::
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
784
;
++
i
)
{
input_ptr
[
i
]
=
rand
()
/
(
static_cast
<
float
>
(
RAND_MAX
));
}
feed_targets
[
feed_target_names
[
0
]]
=
&
input
;
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
feeds
;
feeds
.
push_back
(
input
)
;
std
::
vector
<
paddle
::
framework
::
LoDTensor
>
fetchs
;
engine
->
Execute
(
feeds
,
fetchs
)
;
// 5. Define Tensor to get the outputs
std
::
map
<
std
::
string
,
paddle
::
framework
::
LoDTensor
*>
fetch_targets
;
paddle
::
framework
::
LoDTensor
output
;
fetch_targets
[
fetch_target_names
[
0
]]
=
&
output
;
for
(
size_t
i
=
0
;
i
<
fetchs
.
size
();
++
i
)
{
LOG
(
INFO
)
<<
fetchs
[
i
].
dims
();
// 6. Run the inference program
executor
.
Run
(
*
inference_program
,
scope
,
feed_targets
,
fetch_targets
);
// 7. Use the output as your expect.
LOG
(
INFO
)
<<
output
.
dims
();
std
::
stringstream
ss
;
ss
<<
"result:"
;
float
*
output_ptr
=
fetchs
[
i
]
.
data
<
float
>
();
for
(
int
j
=
0
;
j
<
fetchs
[
i
]
.
numel
();
++
j
)
{
float
*
output_ptr
=
output
.
data
<
float
>
();
for
(
int
j
=
0
;
j
<
output
.
numel
();
++
j
)
{
ss
<<
" "
<<
output_ptr
[
j
];
}
LOG
(
INFO
)
<<
ss
.
str
();
}
delete
scope
;
delete
engine
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录