Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d77e6a67
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d77e6a67
编写于
1月 18, 2018
作者:
K
kexinzhao
提交者:
GitHub
1月 18, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7636 from kexinzhao/save_inference_model
Add feed and fetch op to ProgramDesc before saving for inference
上级
7905e367
856f650a
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
64 addition
and
47 deletion
+64
-47
paddle/inference/CMakeLists.txt
paddle/inference/CMakeLists.txt
+0
-21
paddle/inference/example.cc
paddle/inference/example.cc
+3
-15
paddle/inference/inference.cc
paddle/inference/inference.cc
+29
-11
paddle/inference/inference.h
paddle/inference/inference.h
+1
-0
python/paddle/v2/fluid/io.py
python/paddle/v2/fluid/io.py
+31
-0
未找到文件。
paddle/inference/CMakeLists.txt
浏览文件 @
d77e6a67
...
@@ -8,27 +8,6 @@ cc_library(paddle_fluid_api
...
@@ -8,27 +8,6 @@ cc_library(paddle_fluid_api
# Merge all modules into a simgle static library
# Merge all modules into a simgle static library
cc_library
(
paddle_fluid DEPS paddle_fluid_api
${
FLUID_CORE_MODULES
}
)
cc_library
(
paddle_fluid DEPS paddle_fluid_api
${
FLUID_CORE_MODULES
}
)
# ptools
# just for testing, we may need to change the storing format for inference_model
# and move the dependent of pickle.
# download from http://www.picklingtools.com/
# build in the C++ sub-directory, using command
# make -f Makefile.Linux libptools.so
set
(
PTOOLS_LIB
)
set
(
PTOOLS_ROOT $ENV{PTOOLS_ROOT} CACHE PATH
"Folder contains PicklingTools"
)
find_path
(
PTOOLS_INC_DIR chooseser.h PATHS
${
PTOOLS_ROOT
}
/C++
)
find_library
(
PTOOLS_SHARED_LIB NAMES ptools PATHS
${
PTOOLS_ROOT
}
/C++
)
if
(
PTOOLS_INC_DIR AND PTOOLS_SHARED_LIB
)
add_definitions
(
-DPADDLE_USE_PTOOLS
)
set
(
PTOOLS_LIB ptools
)
message
(
STATUS
"Found PicklingTools:
${
PTOOLS_SHARED_LIB
}
"
)
add_library
(
${
PTOOLS_LIB
}
SHARED IMPORTED GLOBAL
)
set_property
(
TARGET
${
PTOOLS_LIB
}
PROPERTY IMPORTED_LOCATION
${
PTOOLS_SHARED_LIB
}
)
include_directories
(
${
PTOOLS_ROOT
}
/C++
)
include_directories
(
${
PTOOLS_ROOT
}
/C++/opencontainers_1_8_5/include
)
add_definitions
(
-DOC_NEW_STYLE_INCLUDES
)
# used in ptools
endif
()
add_executable
(
example example.cc
)
add_executable
(
example example.cc
)
if
(
APPLE
)
if
(
APPLE
)
set
(
OPTIONAL_LINK_FLAGS
)
set
(
OPTIONAL_LINK_FLAGS
)
...
...
paddle/inference/example.cc
浏览文件 @
d77e6a67
...
@@ -18,33 +18,21 @@ limitations under the License. */
...
@@ -18,33 +18,21 @@ limitations under the License. */
#include "paddle/inference/inference.h"
#include "paddle/inference/inference.h"
DEFINE_string
(
dirname
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
dirname
,
""
,
"Directory of the inference model."
);
DEFINE_string
(
feed_var_names
,
""
,
"Names of feeding variables"
);
DEFINE_string
(
fetch_var_names
,
""
,
"Names of fetching variables"
);
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
google
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
true
);
if
(
FLAGS_dirname
.
empty
()
||
FLAGS_feed_var_names
.
empty
()
||
if
(
FLAGS_dirname
.
empty
())
{
FLAGS_fetch_var_names
.
empty
())
{
// Example:
// Example:
// ./example --dirname=recognize_digits_mlp.inference.model
// ./example --dirname=recognize_digits_mlp.inference.model
// --feed_var_names="x"
std
::
cout
<<
"Usage: ./example --dirname=path/to/your/model"
<<
std
::
endl
;
// --fetch_var_names="fc_2.tmp_2"
std
::
cout
<<
"Usage: ./example --dirname=path/to/your/model "
"--feed_var_names=x --fetch_var_names=y"
<<
std
::
endl
;
exit
(
1
);
exit
(
1
);
}
}
std
::
cout
<<
"FLAGS_dirname: "
<<
FLAGS_dirname
<<
std
::
endl
;
std
::
cout
<<
"FLAGS_dirname: "
<<
FLAGS_dirname
<<
std
::
endl
;
std
::
cout
<<
"FLAGS_feed_var_names: "
<<
FLAGS_feed_var_names
<<
std
::
endl
;
std
::
cout
<<
"FLAGS_fetch_var_names: "
<<
FLAGS_fetch_var_names
<<
std
::
endl
;
std
::
string
dirname
=
FLAGS_dirname
;
std
::
string
dirname
=
FLAGS_dirname
;
std
::
vector
<
std
::
string
>
feed_var_names
=
{
FLAGS_feed_var_names
};
std
::
vector
<
std
::
string
>
fetch_var_names
=
{
FLAGS_fetch_var_names
};
paddle
::
InferenceEngine
*
engine
=
new
paddle
::
InferenceEngine
();
paddle
::
InferenceEngine
*
engine
=
new
paddle
::
InferenceEngine
();
engine
->
LoadInferenceModel
(
dirname
,
feed_var_names
,
fetch_var_names
);
engine
->
LoadInferenceModel
(
dirname
);
paddle
::
framework
::
LoDTensor
input
;
paddle
::
framework
::
LoDTensor
input
;
srand
(
time
(
0
));
srand
(
time
(
0
));
...
...
paddle/inference/inference.cc
浏览文件 @
d77e6a67
...
@@ -25,19 +25,37 @@ limitations under the License. */
...
@@ -25,19 +25,37 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
void
InferenceEngine
::
LoadInferenceModel
(
const
std
::
string
&
dirname
)
{
std
::
string
model_filename
=
dirname
+
"/__model__.dat"
;
LOG
(
INFO
)
<<
"loading model from "
<<
model_filename
;
std
::
ifstream
inputfs
(
model_filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
string
program_desc_str
;
inputfs
.
seekg
(
0
,
std
::
ios
::
end
);
program_desc_str
.
resize
(
inputfs
.
tellg
());
inputfs
.
seekg
(
0
,
std
::
ios
::
beg
);
LOG
(
INFO
)
<<
"program_desc_str's size: "
<<
program_desc_str
.
size
();
inputfs
.
read
(
&
program_desc_str
[
0
],
program_desc_str
.
size
());
inputfs
.
close
();
program_
=
new
framework
::
ProgramDesc
(
program_desc_str
);
GenerateLoadProgram
(
dirname
);
framework
::
BlockDesc
*
global_block
=
program_
->
MutableBlock
(
0
);
feed_var_names_
.
clear
();
fetch_var_names_
.
clear
();
for
(
auto
*
op
:
global_block
->
AllOps
())
{
if
(
op
->
Type
()
==
"feed"
)
{
feed_var_names_
.
insert
(
feed_var_names_
.
begin
(),
op
->
Output
(
"Out"
)[
0
]);
}
else
if
(
op
->
Type
()
==
"fetch"
)
{
fetch_var_names_
.
push_back
(
op
->
Input
(
"X"
)[
0
]);
}
}
}
void
InferenceEngine
::
LoadInferenceModel
(
void
InferenceEngine
::
LoadInferenceModel
(
const
std
::
string
&
dirname
,
const
std
::
string
&
dirname
,
const
std
::
vector
<
std
::
string
>&
feed_var_names
,
const
std
::
vector
<
std
::
string
>&
feed_var_names
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
)
{
const
std
::
vector
<
std
::
string
>&
fetch_var_names
)
{
#ifdef PADDLE_USE_PTOOLS
std
::
string
model_filename
=
dirname
+
"/__model__"
;
LOG
(
INFO
)
<<
"Using PicklingTools, loading model from "
<<
model_filename
;
Val
v
;
LoadValFromFile
(
model_filename
.
c_str
(),
v
,
SERIALIZE_P0
);
std
::
string
program_desc_str
=
v
[
"program_desc_str"
];
LOG
(
INFO
)
<<
"program_desc_str's size: "
<<
program_desc_str
.
size
();
// PicklingTools cannot parse the vector of strings correctly.
#else
std
::
string
model_filename
=
dirname
+
"/__model__.dat"
;
std
::
string
model_filename
=
dirname
+
"/__model__.dat"
;
LOG
(
INFO
)
<<
"loading model from "
<<
model_filename
;
LOG
(
INFO
)
<<
"loading model from "
<<
model_filename
;
std
::
ifstream
inputfs
(
model_filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
std
::
ifstream
inputfs
(
model_filename
,
std
::
ios
::
in
|
std
::
ios
::
binary
);
...
@@ -48,7 +66,7 @@ void InferenceEngine::LoadInferenceModel(
...
@@ -48,7 +66,7 @@ void InferenceEngine::LoadInferenceModel(
LOG
(
INFO
)
<<
"program_desc_str's size: "
<<
program_desc_str
.
size
();
LOG
(
INFO
)
<<
"program_desc_str's size: "
<<
program_desc_str
.
size
();
inputfs
.
read
(
&
program_desc_str
[
0
],
program_desc_str
.
size
());
inputfs
.
read
(
&
program_desc_str
[
0
],
program_desc_str
.
size
());
inputfs
.
close
();
inputfs
.
close
();
#endif
program_
=
new
framework
::
ProgramDesc
(
program_desc_str
);
program_
=
new
framework
::
ProgramDesc
(
program_desc_str
);
GenerateLoadProgram
(
dirname
);
GenerateLoadProgram
(
dirname
);
...
@@ -62,7 +80,7 @@ void InferenceEngine::LoadInferenceModel(
...
@@ -62,7 +80,7 @@ void InferenceEngine::LoadInferenceModel(
}
}
bool
InferenceEngine
::
IsParameter
(
const
framework
::
VarDesc
*
var
)
{
bool
InferenceEngine
::
IsParameter
(
const
framework
::
VarDesc
*
var
)
{
if
(
var
->
Persistable
())
{
if
(
var
->
Persistable
()
&&
var
->
Name
()
!=
"feed"
&&
var
->
Name
()
!=
"fetch"
)
{
// There are many unreachable variables in the program
// There are many unreachable variables in the program
for
(
size_t
i
=
0
;
i
<
program_
->
Size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
program_
->
Size
();
++
i
)
{
const
framework
::
BlockDesc
&
block
=
program_
->
Block
(
i
);
const
framework
::
BlockDesc
&
block
=
program_
->
Block
(
i
);
...
...
paddle/inference/inference.h
浏览文件 @
d77e6a67
...
@@ -28,6 +28,7 @@ public:
...
@@ -28,6 +28,7 @@ public:
delete
load_program_
;
delete
load_program_
;
}
}
void
LoadInferenceModel
(
const
std
::
string
&
dirname
);
void
LoadInferenceModel
(
const
std
::
string
&
dirname
,
void
LoadInferenceModel
(
const
std
::
string
&
dirname
,
const
std
::
vector
<
std
::
string
>&
feed_var_names
,
const
std
::
vector
<
std
::
string
>&
feed_var_names
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
);
const
std
::
vector
<
std
::
string
>&
fetch_var_names
);
...
...
python/paddle/v2/fluid/io.py
浏览文件 @
d77e6a67
...
@@ -15,6 +15,7 @@ import os
...
@@ -15,6 +15,7 @@ import os
import
cPickle
as
pickle
import
cPickle
as
pickle
from
paddle.v2.fluid.framework
import
Program
,
Parameter
,
default_main_program
,
Variable
from
paddle.v2.fluid.framework
import
Program
,
Parameter
,
default_main_program
,
Variable
from
.
import
core
__all__
=
[
__all__
=
[
'save_vars'
,
'save_vars'
,
...
@@ -191,6 +192,33 @@ def get_inference_program(target_vars, main_program=None):
...
@@ -191,6 +192,33 @@ def get_inference_program(target_vars, main_program=None):
return
inference_program
return
inference_program
def
prepend_feed_ops
(
inference_program
,
feeded_var_names
):
global_block
=
inference_program
.
global_block
()
feed_var
=
global_block
.
create_var
(
name
=
'feed'
,
type
=
core
.
VarDesc
.
VarType
.
FEED_MINIBATCH
,
persistable
=
True
)
for
i
,
name
in
enumerate
(
feeded_var_names
):
out
=
global_block
.
var
(
name
)
global_block
.
prepend_op
(
type
=
'feed'
,
inputs
=
{
'X'
:
[
feed_var
]},
outputs
=
{
'Out'
:
[
out
]},
attrs
=
{
'col'
:
i
})
def
append_fetch_ops
(
inference_program
,
fetch_var_names
):
global_block
=
inference_program
.
global_block
()
fetch_var
=
global_block
.
create_var
(
name
=
'fetch'
,
type
=
core
.
VarDesc
.
VarType
.
FETCH_LIST
,
persistable
=
True
)
for
i
,
name
in
enumerate
(
fetch_var_names
):
global_block
.
append_op
(
type
=
'fetch'
,
inputs
=
{
'X'
:
[
name
]},
outputs
=
{
'Out'
:
[
fetch_var
]},
attrs
=
{
'col'
:
i
})
def
save_inference_model
(
dirname
,
def
save_inference_model
(
dirname
,
feeded_var_names
,
feeded_var_names
,
target_vars
,
target_vars
,
...
@@ -241,6 +269,9 @@ def save_inference_model(dirname,
...
@@ -241,6 +269,9 @@ def save_inference_model(dirname,
"fetch_var_names"
:
fetch_var_names
"fetch_var_names"
:
fetch_var_names
},
f
,
-
1
)
},
f
,
-
1
)
prepend_feed_ops
(
inference_program
,
feeded_var_names
)
append_fetch_ops
(
inference_program
,
fetch_var_names
)
# Save only programDesc of inference_program in binary format
# Save only programDesc of inference_program in binary format
# in another file: __model__.dat
# in another file: __model__.dat
with
open
(
model_file_name
+
".dat"
,
"wb"
)
as
fp
:
with
open
(
model_file_name
+
".dat"
,
"wb"
)
as
fp
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录