提交 dc168ed0 编写于 作者: K Kexin Zhao

modify programDesc based on feed and fetch names

上级 c5067a6a
...@@ -18,33 +18,21 @@ limitations under the License. */ ...@@ -18,33 +18,21 @@ limitations under the License. */
#include "paddle/inference/inference.h" #include "paddle/inference/inference.h"
DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(dirname, "", "Directory of the inference model.");
DEFINE_string(feed_var_names, "", "Names of feeding variables");
DEFINE_string(fetch_var_names, "", "Names of fetching variables");
int main(int argc, char** argv) { int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true); google::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_dirname.empty() || FLAGS_feed_var_names.empty() || if (FLAGS_dirname.empty()) {
FLAGS_fetch_var_names.empty()) {
// Example: // Example:
// ./example --dirname=recognize_digits_mlp.inference.model // ./example --dirname=recognize_digits_mlp.inference.model
// --feed_var_names="x" std::cout << "Usage: ./example --dirname=path/to/your/model" << std::endl;
// --fetch_var_names="fc_2.tmp_2"
std::cout << "Usage: ./example --dirname=path/to/your/model "
"--feed_var_names=x --fetch_var_names=y"
<< std::endl;
exit(1); exit(1);
} }
std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::cout << "FLAGS_feed_var_names: " << FLAGS_feed_var_names << std::endl;
std::cout << "FLAGS_fetch_var_names: " << FLAGS_fetch_var_names << std::endl;
std::string dirname = FLAGS_dirname; std::string dirname = FLAGS_dirname;
std::vector<std::string> feed_var_names = {FLAGS_feed_var_names};
std::vector<std::string> fetch_var_names = {FLAGS_fetch_var_names};
paddle::InferenceEngine* engine = new paddle::InferenceEngine(); paddle::InferenceEngine* engine = new paddle::InferenceEngine();
engine->LoadInferenceModel(dirname, feed_var_names, fetch_var_names); engine->LoadInferenceModel(dirname);
paddle::framework::LoDTensor input; paddle::framework::LoDTensor input;
srand(time(0)); srand(time(0));
......
...@@ -25,6 +25,33 @@ limitations under the License. */ ...@@ -25,6 +25,33 @@ limitations under the License. */
namespace paddle { namespace paddle {
void InferenceEngine::LoadInferenceModel(const std::string& dirname) {
std::string model_filename = dirname + "/__model__.dat";
LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
std::string program_desc_str;
inputfs.seekg(0, std::ios::end);
program_desc_str.resize(inputfs.tellg());
inputfs.seekg(0, std::ios::beg);
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close();
program_ = new framework::ProgramDesc(program_desc_str);
GenerateLoadProgram(dirname);
framework::BlockDesc* global_block = program_->MutableBlock(0);
feed_var_names_.clear();
fetch_var_names_.clear();
for (auto* op : global_block->AllOps()) {
if (op->Type() == "feed") {
feed_var_names_.insert(feed_var_names_.begin(), op->Output("Out")[0]);
} else if (op->Type() == "fetch") {
fetch_var_names_.push_back(op->Input("X")[0]);
}
}
}
void InferenceEngine::LoadInferenceModel( void InferenceEngine::LoadInferenceModel(
const std::string& dirname, const std::string& dirname,
const std::vector<std::string>& feed_var_names, const std::vector<std::string>& feed_var_names,
......
...@@ -28,6 +28,7 @@ public: ...@@ -28,6 +28,7 @@ public:
delete load_program_; delete load_program_;
} }
void LoadInferenceModel(const std::string& dirname);
void LoadInferenceModel(const std::string& dirname, void LoadInferenceModel(const std::string& dirname,
const std::vector<std::string>& feed_var_names, const std::vector<std::string>& feed_var_names,
const std::vector<std::string>& fetch_var_names); const std::vector<std::string>& fetch_var_names);
......
...@@ -243,6 +243,28 @@ def save_inference_model(dirname, ...@@ -243,6 +243,28 @@ def save_inference_model(dirname,
# Save only programDesc of inference_program in binary format # Save only programDesc of inference_program in binary format
# in another file: __model__.dat # in another file: __model__.dat
global_block = inference_program.global_block()
feed_var = global_blok.create_var(
name='feed', type=core.VarDesc.VarType.FEED_MINIBATCH, persistable=True)
for i, name in enumerated(feeded_var_names):
out = global_block.var(name)
global_block.prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
fetch_var = global_block.create_var(
name='fetch', type=core.VarDesc.VarType.FETCH_LIST, persistable=True)
for i, name in enumerated(fetch_var_names):
global_block.append_op(
type='fetch',
inputs={'X': [var]},
outputs={'Out': [fetch_var]},
attrs={'col': i})
with open(model_file_name + ".dat", "wb") as fp: with open(model_file_name + ".dat", "wb") as fp:
fp.write(inference_program.desc.serialize_to_string()) fp.write(inference_program.desc.serialize_to_string())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册