Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
dc168ed0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dc168ed0
编写于
1月 17, 2018
作者:
K
Kexin Zhao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify programDesc based on feed and fetch names
上级
c5067a6a
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
53 addition
and
15 deletion
+53
-15
paddle/inference/example.cc
paddle/inference/example.cc
+3
-15
paddle/inference/inference.cc
paddle/inference/inference.cc
+27
-0
paddle/inference/inference.h
paddle/inference/inference.h
+1
-0
python/paddle/v2/fluid/io.py
python/paddle/v2/fluid/io.py
+22
-0
未找到文件。
paddle/inference/example.cc
浏览文件 @
dc168ed0
...
@@ -18,33 +18,21 @@ limitations under the License. */
...
@@ -18,33 +18,21 @@ limitations under the License. */
#include "paddle/inference/inference.h"
#include "paddle/inference/inference.h"
DEFINE_string
(
dirname
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
dirname
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
feed_var_names
,
""
,
"Names of feeding variables"
);
DEFINE_string
(
fetch_var_names
,
""
,
"Names of fetching variables"
);
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
if
(
FLAGS_dirname
.
empty
()
||
FLAGS_feed_var_names
.
empty
()
||
if
(
FLAGS_dirname
.
empty
())
{
FLAGS_fetch_var_names
.
empty
())
{
// Example:
// Example:
// ./example --dirname=recognize_digits_mlp.inference.model
// ./example --dirname=recognize_digits_mlp.inference.model
// --feed_var_names="x"
std
::
cout
<<
"Usage: ./example --dirname=path/to/your/model"
<<
std
::
endl
;
// --fetch_var_names="fc_2.tmp_2"
std
::
cout
<<
"Usage: ./example --dirname=path/to/your/model "
"--feed_var_names=x --fetch_var_names=y"
<<
std
::
endl
;
exit
(
1
);
exit
(
1
);
}
}
std
::
cout
<<
"FLAGS_dirname: "
<<
FLAGS_dirname
<<
std
::
endl
;
std
::
cout
<<
"FLAGS_dirname: "
<<
FLAGS_dirname
<<
std
::
endl
;
std
::
cout
<<
"FLAGS_feed_var_names: "
<<
FLAGS_feed_var_names
<<
std
::
endl
;
std
::
cout
<<
"FLAGS_fetch_var_names: "
<<
FLAGS_fetch_var_names
<<
std
::
endl
;
std
::
string
dirname
=
FLAGS_dirname
;
std
::
string
dirname
=
FLAGS_dirname
;
std
::
vector
<
std
::
string
>
feed_var_names
=
{
FLAGS_feed_var_names
};
std
::
vector
<
std
::
string
>
fetch_var_names
=
{
FLAGS_fetch_var_names
};
paddle
::
InferenceEngine
*
engine
=
new
paddle
::
InferenceEngine
();
paddle
::
InferenceEngine
*
engine
=
new
paddle
::
InferenceEngine
();
engine
->
LoadInferenceModel
(
dirname
,
feed_var_names
,
fetch_var_names
);
engine
->
LoadInferenceModel
(
dirname
);
paddle
::
framework
::
LoDTensor
input
;
paddle
::
framework
::
LoDTensor
input
;
srand
(
time
(
0
));
srand
(
time
(
0
));
...
...
paddle/inference/inference.cc
浏览文件 @
dc168ed0
...
@@ -25,6 +25,33 @@ limitations under the License. */
...
@@ -25,6 +25,33 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
void
InferenceEngine
::
LoadInferenceModel
(
const
std
::
string
&
dirname
)
{
std
::
string
model_filename
=
dirname
+
"/__model__.dat"
;
LOG
(
INFO
)
<<
"loading model from "
<<
model_filename
;
std
::
ifstream
inputfs
(
model_filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
string
program_desc_str
;
inputfs
.
seekg
(
0
,
std
::
ios
::
end
);
program_desc_str
.
resize
(
inputfs
.
tellg
());
inputfs
.
seekg
(
0
,
std
::
ios
::
beg
);
LOG
(
INFO
)
<<
"program_desc_str's size: "
<<
program_desc_str
.
size
();
inputfs
.
read
(
&
program_desc_str
[
0
],
program_desc_str
.
size
());
inputfs
.
close
();
program_
=
new
framework
::
ProgramDesc
(
program_desc_str
);
GenerateLoadProgram
(
dirname
);
framework
::
BlockDesc
*
global_block
=
program_
->
MutableBlock
(
0
);
feed_var_names_
.
clear
();
fetch_var_names_
.
clear
();
for
(
auto
*
op
:
global_block
->
AllOps
())
{
if
(
op
->
Type
()
==
"feed"
)
{
feed_var_names_
.
insert
(
feed_var_names_
.
begin
(),
op
->
Output
(
"Out"
)[
0
]);
}
else
if
(
op
->
Type
()
==
"fetch"
)
{
fetch_var_names_
.
push_back
(
op
->
Input
(
"X"
)[
0
]);
}
}
}
void
InferenceEngine
::
LoadInferenceModel
(
void
InferenceEngine
::
LoadInferenceModel
(
const
std
::
string
&
dirname
,
const
std
::
string
&
dirname
,
const
std
::
vector
<
std
::
string
>&
feed_var_names
,
const
std
::
vector
<
std
::
string
>&
feed_var_names
,
...
...
paddle/inference/inference.h
浏览文件 @
dc168ed0
...
@@ -28,6 +28,7 @@ public:
...
@@ -28,6 +28,7 @@ public:
delete
load_program_
;
delete
load_program_
;
}
}
void
LoadInferenceModel
(
const
std
::
string
&
dirname
);
void
LoadInferenceModel
(
const
std
::
string
&
dirname
,
void
LoadInferenceModel
(
const
std
::
string
&
dirname
,
const
std
::
vector
<
std
::
string
>&
feed_var_names
,
const
std
::
vector
<
std
::
string
>&
feed_var_names
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
);
const
std
::
vector
<
std
::
string
>&
fetch_var_names
);
...
...
python/paddle/v2/fluid/io.py
浏览文件 @
dc168ed0
...
@@ -243,6 +243,28 @@ def save_inference_model(dirname,
...
@@ -243,6 +243,28 @@ def save_inference_model(dirname,
# Save only programDesc of inference_program in binary format
# Save only programDesc of inference_program in binary format
# in another file: __model__.dat
# in another file: __model__.dat
global_block
=
inference_program
.
global_block
()
feed_var
=
global_blok
.
create_var
(
name
=
'feed'
,
type
=
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
,
persistable
=
True
)
for
i
,
name
in
enumerated
(
feeded_var_names
):
out
=
global_block
.
var
(
name
)
global_block
.
prepend_op
(
type
=
'feed'
,
inputs
=
{
'X'
:
[
feed_var
]},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
{
'col'
:
i
})
fetch_var
=
global_block
.
create_var
(
name
=
'fetch'
,
type
=
core
.
VarDesc
.
VarType
.
FETCH_LIST
,
persistable
=
True
)
for
i
,
name
in
enumerated
(
fetch_var_names
):
global_block
.
append_op
(
type
=
'fetch'
,
inputs
=
{
'X'
:
[
var
]},
outputs
=
{
'Out'
:
[
fetch_var
]},
attrs
=
{
'col'
:
i
})
with
open
(
model_file_name
+
".dat"
,
"wb"
)
as
fp
:
with
open
(
model_file_name
+
".dat"
,
"wb"
)
as
fp
:
fp
.
write
(
inference_program
.
desc
.
serialize_to_string
())
fp
.
write
(
inference_program
.
desc
.
serialize_to_string
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录