diff --git a/paddle/inference/example.cc b/paddle/inference/example.cc index 9711b20e6fb4099a2cc497029468ebd1fd0b3456..0c18b45624dedcb5839d4b771e044b4a7b32af52 100644 --- a/paddle/inference/example.cc +++ b/paddle/inference/example.cc @@ -18,33 +18,21 @@ limitations under the License. */ #include "paddle/inference/inference.h" DEFINE_string(dirname, "", "Directory of the inference model."); -DEFINE_string(feed_var_names, "", "Names of feeding variables"); -DEFINE_string(fetch_var_names, "", "Names of fetching variables"); int main(int argc, char** argv) { google::ParseCommandLineFlags(&argc, &argv, true); - if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() || - FLAGS_fetch_var_names.empty()) { + if (FLAGS_dirname.empty()) { // Example: // ./example --dirname=recognize_digits_mlp.inference.model - // --feed_var_names="x" - // --fetch_var_names="fc_2.tmp_2" - std::cout << "Usage: ./example --dirname=path/to/your/model " - "--feed_var_names=x --fetch_var_names=y" - << std::endl; + std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl; exit(1); } std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; - std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << std::endl; - std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << std::endl; - std::string dirname = FLAGS_dirname; - std::vector feed_var_names = {FLAGS_feed_var_names}; - std::vector fetch_var_names = {FLAGS_fetch_var_names}; paddle::InferenceEngine* engine = new paddle::InferenceEngine(); - engine->LoadInferenceModel(dirname, feed_var_names, fetch_var_names); + engine->LoadInferenceModel(dirname); paddle::framework::LoDTensor input; srand(time(0)); diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc index 37b8b20ddfcf2566b8410f950308309e5b2b2a7c..1f6e510b8b04125d00f732fe2ed1227d375fda81 100644 --- a/paddle/inference/inference.cc +++ b/paddle/inference/inference.cc @@ -25,6 +25,33 @@ limitations under the License. */ namespace paddle { +void InferenceEngine::LoadInferenceModel(const std::string& dirname) { + std::string model_filename = dirname + "/__model__.dat"; + LOG(INFO) << "loading model from " << model_filename; + std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); + std::string program_desc_str; + inputfs.seekg(0, std::ios::end); + program_desc_str.resize(inputfs.tellg()); + inputfs.seekg(0, std::ios::beg); + LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); + inputfs.read(&program_desc_str[0], program_desc_str.size()); + inputfs.close(); + + program_ = new framework::ProgramDesc(program_desc_str); + GenerateLoadProgram(dirname); + + framework::BlockDesc* global_block = program_->MutableBlock(0); + feed_var_names_.clear(); + fetch_var_names_.clear(); + for (auto* op : global_block->AllOps()) { + if (op->Type() == "feed") { + feed_var_names_.insert(feed_var_names_.begin(), op->Output("Out")[0]); + } else if (op->Type() == "fetch") { + fetch_var_names_.push_back(op->Input("X")[0]); + } + } +} + void InferenceEngine::LoadInferenceModel( const std::string& dirname, const std::vector& feed_var_names, diff --git a/paddle/inference/inference.h b/paddle/inference/inference.h index a3f3ef4b440036a0b27353cc092eed1bbf96eeb3..7fc09cb9e539a65a8cd3cceb1543bc7d111c22b3 100644 --- a/paddle/inference/inference.h +++ b/paddle/inference/inference.h @@ -28,6 +28,7 @@ public: delete load_program_; } + void LoadInferenceModel(const std::string& dirname); void LoadInferenceModel(const std::string& dirname, const std::vector& feed_var_names, const std::vector& fetch_var_names); diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 499df05e592855f63f41ec8ceb939edf0e4d435c..516b8471dfdf91a04264c161cc1403db53757bc3 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -243,6 +243,28 @@ def save_inference_model(dirname, # Save only programDesc of inference_program in binary format # in another file: __model__.dat + global_block = inference_program.global_block() + feed_var = global_blok.create_var( + name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True) + + for i, name in enumerated(feeded_var_names): + out = global_block.var(name) + global_block.prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}) + + fetch_var = global_block.create_var( + name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True) + + for i, name in enumerated(fetch_var_names): + global_block.append_op( + type='fetch', + inputs={'X': [var]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}) + with open(model_file_name + ".dat", "wb") as fp: fp.write(inference_program.desc.serialize_to_string())