You need to sign in or sign up before continuing.
未验证 提交 d77e6a67 编写于 作者: K kexinzhao 提交者: GitHub

Merge pull request #7636 from kexinzhao/save_inference_model

Add feed and fetch op to ProgramDesc before saving for inference
...@@ -8,27 +8,6 @@ cc_library(paddle_fluid_api ...@@ -8,27 +8,6 @@ cc_library(paddle_fluid_api
# Merge all modules into a simgle static library # Merge all modules into a simgle static library
cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES}) cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES})
# ptools
# just for testing, we may need to change the storing format for inference_model
# and move the dependent of pickle.
# download from http://www.picklingtools.com/
# build in the C++ sub-directory, using command
# make -f Makefile.Linux libptools.so
set(PTOOLS_LIB)
set(PTOOLS_ROOT $ENV{PTOOLS_ROOT} CACHE PATH "Folder contains PicklingTools")
find_path(PTOOLS_INC_DIR chooseser.h PATHS ${PTOOLS_ROOT}/C++)
find_library(PTOOLS_SHARED_LIB NAMES ptools PATHS ${PTOOLS_ROOT}/C++)
if(PTOOLS_INC_DIR AND PTOOLS_SHARED_LIB)
add_definitions(-DPADDLE_USE_PTOOLS)
set(PTOOLS_LIB ptools)
message(STATUS "Found PicklingTools: ${PTOOLS_SHARED_LIB}")
add_library(${PTOOLS_LIB} SHARED IMPORTED GLOBAL)
set_property(TARGET ${PTOOLS_LIB} PROPERTY IMPORTED_LOCATION ${PTOOLS_SHARED_LIB})
include_directories(${PTOOLS_ROOT}/C++)
include_directories(${PTOOLS_ROOT}/C++/opencontainers_1_8_5/include)
add_definitions(-DOC_NEW_STYLE_INCLUDES) # used in ptools
endif()
add_executable(example example.cc) add_executable(example example.cc)
if(APPLE) if(APPLE)
set(OPTIONAL_LINK_FLAGS) set(OPTIONAL_LINK_FLAGS)
......
...@@ -18,33 +18,21 @@ limitations under the License. */ ...@@ -18,33 +18,21 @@ limitations under the License. */
#include "paddle/inference/inference.h" #include "paddle/inference/inference.h"
DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_string(feed_var_names, "", "Names of feeding variables");
DEFINE_string(fetch_var_names, "", "Names of fetching variables");
int main(int argc, char** argv) { int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true); google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() || if (FLAGS_dirname.empty()) {
FLAGS_fetch_var_names.empty()) {
// Example: // Example:
// ./example --dirname=recognize_digits_mlp.inference.model // ./example --dirname=recognize_digits_mlp.inference.model
// --feed_var_names="x" std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl;
// --fetch_var_names="fc_2.tmp_2"
std::cout << "Usage: ./example --dirname=path/to/your/model "
"--feed_var_names=x --fetch_var_names=y"
<< std::endl;
exit(1); exit(1);
} }
std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << std::endl;
std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << std::endl;
std::string dirname = FLAGS_dirname; std::string dirname = FLAGS_dirname;
std::vector<std::string> feed_var_names = {FLAGS_feed_var_names};
std::vector<std::string> fetch_var_names = {FLAGS_fetch_var_names};
paddle::InferenceEngine* engine = new paddle::InferenceEngine(); paddle::InferenceEngine* engine = new paddle::InferenceEngine();
engine->LoadInferenceModel(dirname, feed_var_names, fetch_var_names); engine->LoadInferenceModel(dirname);
paddle::framework::LoDTensor input; paddle::framework::LoDTensor input;
srand(time(0)); srand(time(0));
......
...@@ -25,19 +25,37 @@ limitations under the License. */ ...@@ -25,19 +25,37 @@ limitations under the License. */
namespace paddle { namespace paddle {
void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
std::string model_filename = dirname + "/__model__.dat";
LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
std::string program_desc_str;
inputfs.seekg(0, std::ios::end);
program_desc_str.resize(inputfs.tellg());
inputfs.seekg(0, std::ios::beg);
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close();
program_ = new framework::ProgramDesc(program_desc_str);
GenerateLoadProgram(dirname);
framework::BlockDesc* global_block = program_->MutableBlock(0);
feed_var_names_.clear();
fetch_var_names_.clear();
for (auto* op : global_block->AllOps()) {
if (op->Type() == "feed") {
feed_var_names_.insert(feed_var_names_.begin(), op->Output("Out")[0]);
} else if (op->Type() == "fetch") {
fetch_var_names_.push_back(op->Input("X")[0]);
}
}
}
void InferenceEngine::LoadInferenceModel( void InferenceEngine::LoadInferenceModel(
const std::string& dirname, const std::string& dirname,
const std::vector<std::string>& feed_var_names, const std::vector<std::string>& feed_var_names,
const std::vector<std::string>& fetch_var_names) { const std::vector<std::string>& fetch_var_names) {
#ifdef PADDLE_USE_PTOOLS
std::string model_filename = dirname + "/__model__";
LOG(INFO) << "Using PicklingTools, loading model from " << model_filename;
Val v;
LoadValFromFile(model_filename.c_str(), v, SERIALIZE_P0);
std::string program_desc_str = v["program_desc_str"];
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
// PicklingTools cannot parse the vector of strings correctly.
#else
std::string model_filename = dirname + "/__model__.dat"; std::string model_filename = dirname + "/__model__.dat";
LOG(INFO) << "loading model from " << model_filename; LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
...@@ -48,7 +66,7 @@ void InferenceEngine::LoadInferenceModel( ...@@ -48,7 +66,7 @@ void InferenceEngine::LoadInferenceModel(
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size()); inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close(); inputfs.close();
#endif
program_ = new framework::ProgramDesc(program_desc_str); program_ = new framework::ProgramDesc(program_desc_str);
GenerateLoadProgram(dirname); GenerateLoadProgram(dirname);
...@@ -62,7 +80,7 @@ void InferenceEngine::LoadInferenceModel( ...@@ -62,7 +80,7 @@ void InferenceEngine::LoadInferenceModel(
} }
bool InferenceEngine::IsParameter(const framework::VarDesc* var) { bool InferenceEngine::IsParameter(const framework::VarDesc* var) {
if (var->Persistable()) { if (var->Persistable() && var->Name() != "feed" && var->Name() != "fetch") {
// There are many unreachable variables in the program // There are many unreachable variables in the program
for (size_t i = 0; i < program_->Size(); ++i) { for (size_t i = 0; i < program_->Size(); ++i) {
const framework::BlockDesc& block = program_->Block(i); const framework::BlockDesc& block = program_->Block(i);
......
...@@ -28,6 +28,7 @@ public: ...@@ -28,6 +28,7 @@ public:
delete load_program_; delete load_program_;
} }
void LoadInferenceModel(const std::string& dirname);
void LoadInferenceModel(const std::string& dirname, void LoadInferenceModel(const std::string& dirname,
const std::vector<std::string>& feed_var_names, const std::vector<std::string>& feed_var_names,
const std::vector<std::string>& fetch_var_names); const std::vector<std::string>& fetch_var_names);
......
...@@ -15,6 +15,7 @@ import os ...@@ -15,6 +15,7 @@ import os
import cPickle as pickle import cPickle as pickle
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core
__all__ = [ __all__ = [
'save_vars', 'save_vars',
...@@ -191,6 +192,33 @@ def get_inference_program(target_vars, main_program=None): ...@@ -191,6 +192,33 @@ def get_inference_program(target_vars, main_program=None):
return inference_program return inference_program
def prepend_feed_ops(inference_program, feeded_var_names):
global_block = inference_program.global_block()
feed_var = global_block.create_var(
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)
for i, name in enumerate(feeded_var_names):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
def append_fetch_ops(inference_program, fetch_var_names):
global_block = inference_program.global_block()
fetch_var = global_block.create_var(
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)
for i, name in enumerate(fetch_var_names):
global_block.append_op(
type='fetch',
inputs={'X': [name]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
def save_inference_model(dirname, def save_inference_model(dirname,
feeded_var_names, feeded_var_names,
target_vars, target_vars,
...@@ -241,6 +269,9 @@ def save_inference_model(dirname, ...@@ -241,6 +269,9 @@ def save_inference_model(dirname,
"fetch_var_names": fetch_var_names "fetch_var_names": fetch_var_names
}, f, -1) }, f, -1)
prepend_feed_ops(inference_program, feeded_var_names)
append_fetch_ops(inference_program, fetch_var_names)
# Save only programDesc of inference_program in binary format # Save only programDesc of inference_program in binary format
# in another file: __model__.dat # in another file: __model__.dat
with open(model_file_name + ".dat", "wb") as fp: with open(model_file_name + ".dat", "wb") as fp:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册