提交 12e35141 编写于 作者: S Siddharth Goyal 提交者: Yiqun Liu

Modify inference.cc to run example without pickletools (#7262)

上级 3b543756
...@@ -38,23 +38,16 @@ void InferenceEngine::LoadInferenceModel( ...@@ -38,23 +38,16 @@ void InferenceEngine::LoadInferenceModel(
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size(); LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
// PicklingTools cannot parse the vector of strings correctly. // PicklingTools cannot parse the vector of strings correctly.
#else #else
// program_desc_str std::string model_filename = dirname + "/__model__.dat";
// the inference.model is stored by following python codes:
// inference_program = fluid.io.get_inference_program(predict)
// model_filename = "recognize_digits_mlp.inference.model/inference.model"
// with open(model_filename, "w") as f:
// program_str = inference_program.desc.serialize_to_string()
// f.write(struct.pack('q', len(program_str)))
// f.write(program_str)
std::string model_filename = dirname + "/inference.model";
LOG(INFO) << "loading model from " << model_filename; LOG(INFO) << "loading model from " << model_filename;
std::ifstream fs(model_filename, std::ios_base::binary); std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
int64_t size = 0;
fs.read(reinterpret_cast<char*>(&size), sizeof(int64_t));
LOG(INFO) << "program_desc_str's size: " << size;
std::string program_desc_str; std::string program_desc_str;
program_desc_str.resize(size); inputfs.seekg(0, std::ios::end);
fs.read(&program_desc_str[0], size); program_desc_str.resize(inputfs.tellg());
inputfs.seekg(0, std::ios::beg);
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close();
#endif #endif
program_ = new framework::ProgramDesc(program_desc_str); program_ = new framework::ProgramDesc(program_desc_str);
GenerateLoadProgram(dirname); GenerateLoadProgram(dirname);
......
...@@ -212,6 +212,11 @@ def save_inference_model(dirname, ...@@ -212,6 +212,11 @@ def save_inference_model(dirname,
"fetch_var_names": fetch_var_names "fetch_var_names": fetch_var_names
}, f, -1) }, f, -1)
# Save only programDesc of inference_program in binary format
# in another file: __model__.dat
with open(model_file_name + ".dat", "wb") as fp:
fp.write(inference_program.desc.serialize_to_string())
save_params(executor, dirname, main_program) save_params(executor, dirname, main_program)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册