From 12e351413988ca10010d727d64d256a9db2ee518 Mon Sep 17 00:00:00 2001 From: Siddharth Goyal Date: Sun, 7 Jan 2018 17:52:08 -0800 Subject: [PATCH] Modify inference.cc to run example without pickletools (#7262) --- paddle/inference/inference.cc | 23 ++++++++--------------- python/paddle/v2/fluid/io.py | 5 +++++ 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/paddle/inference/inference.cc b/paddle/inference/inference.cc index 48a51efcd2..49e39358e8 100644 --- a/paddle/inference/inference.cc +++ b/paddle/inference/inference.cc @@ -38,23 +38,16 @@ void InferenceEngine::LoadInferenceModel( LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); // PicklingTools cannot parse the vector of strings correctly. #else - // program_desc_str - // the inference.model is stored by following python codes: - // inference_program = fluid.io.get_inference_program(predict) - // model_filename = "recognize_digits_mlp.inference.model/inference.model" - // with open(model_filename, "w") as f: - // program_str = inference_program.desc.serialize_to_string() - // f.write(struct.pack('q', len(program_str))) - // f.write(program_str) - std::string model_filename = dirname + "/inference.model"; + std::string model_filename = dirname + "/__model__.dat"; LOG(INFO) << "loading model from " << model_filename; - std::ifstream fs(model_filename, std::ios_base::binary); - int64_t size = 0; - fs.read(reinterpret_cast(&size), sizeof(int64_t)); - LOG(INFO) << "program_desc_str's size: " << size; + std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); std::string program_desc_str; - program_desc_str.resize(size); - fs.read(&program_desc_str[0], size); + inputfs.seekg(0, std::ios::end); + program_desc_str.resize(inputfs.tellg()); + inputfs.seekg(0, std::ios::beg); + LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); + inputfs.read(&program_desc_str[0], program_desc_str.size()); + inputfs.close(); #endif program_ = new framework::ProgramDesc(program_desc_str); GenerateLoadProgram(dirname); diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 926327b70c..c63567601a 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -212,6 +212,11 @@ def save_inference_model(dirname, "fetch_var_names": fetch_var_names }, f, -1) + # Save only programDesc of inference_program in binary format + # in another file: __model__.dat + with open(model_file_name + ".dat", "wb") as fp: + fp.write(inference_program.desc.serialize_to_string()) + save_params(executor, dirname, main_program) -- GitLab