提交 311334e0 编写于 作者: S Siddharth Goyal 提交者: Yiqun Liu

Implement basic `Load()` and modify example based on updated inference design (#7690)

* Initial commit

* Remove resolution bug

* Modify IsParam

* Remove small bugs

* First commit unifying Run and Load

* Fix bugs

* Fix Cmake

* Modify Cmake and dir structure

* Add io.* files to inference dir

* Fix include in example

* Address review comments: part 1

* Address review comments: round 2

* Address review comments: round 3

* Address review comments: round 4
上级 308f6022
...@@ -18,6 +18,9 @@ limitations under the License. */ ...@@ -18,6 +18,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) { BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) {
auto *b = desc_.add_blocks(); auto *b = desc_.add_blocks();
b->set_parent_idx(parent.ID()); b->set_parent_idx(parent.ID());
...@@ -64,5 +67,27 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { ...@@ -64,5 +67,27 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
} }
} }
const std::vector<std::string> ProgramDesc::GetFeedVarNames() {
BlockDesc *global_block = blocks_[0].get();
std::vector<std::string> feed_var_names;
for (auto *op : global_block->AllOps()) {
if (op->Type() == "feed") {
feed_var_names.insert(feed_var_names.begin(), op->Output("Out")[0]);
}
}
return feed_var_names;
}
const std::vector<std::string> ProgramDesc::GetFetchVarNames() {
BlockDesc *global_block = blocks_[0].get();
std::vector<std::string> fetch_var_names;
for (auto *op : global_block->AllOps()) {
if (op->Type() == "fetch") {
fetch_var_names.push_back(op->Input("X")[0]);
}
}
return fetch_var_names;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -45,6 +45,10 @@ class ProgramDesc { ...@@ -45,6 +45,10 @@ class ProgramDesc {
proto::ProgramDesc *Proto(); proto::ProgramDesc *Proto();
const std::vector<std::string> GetFeedVarNames();
const std::vector<std::string> GetFetchVarNames();
private: private:
proto::ProgramDesc desc_; proto::ProgramDesc desc_;
......
set(FLUID_CORE_MODULES proto_desc paddle_memory executor prune init) set(FLUID_CORE_MODULES proto_desc paddle_memory executor prune init)
cc_library(paddle_fluid_api cc_library(paddle_fluid_api
SRCS inference.cc SRCS io.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
# Merge all modules into a single static library # Merge all modules into a single static library
cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) cc_library(paddle_fluid DEPS paddle_fluid_api ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
# Create shared library # Create shared library
add_library(paddle_fluid_shared SHARED inference.cc) add_library(paddle_fluid_shared SHARED io.cc)
target_circle_link_libraries(paddle_fluid_shared target_circle_link_libraries(paddle_fluid_shared
ARCHIVE_START ARCHIVE_START
...@@ -20,7 +20,7 @@ SET_TARGET_PROPERTIES(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid) ...@@ -20,7 +20,7 @@ SET_TARGET_PROPERTIES(paddle_fluid_shared PROPERTIES OUTPUT_NAME paddle_fluid)
# install library & headers # install library & headers
if(NOT WITH_C_API AND WITH_FLUID) if(NOT WITH_C_API AND WITH_FLUID)
install(FILES inference.h DESTINATION include/paddle/inference) install(FILES io.h DESTINATION include/paddle/inference)
install(TARGETS paddle_fluid_shared DESTINATION lib) install(TARGETS paddle_fluid_shared DESTINATION lib)
endif() endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include <time.h> #include <time.h>
#include <iostream> #include <iostream>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/inference/inference.h" #include "paddle/framework/init.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/inference/io.h"
DEFINE_string(dirname, "", "Directory of the inference model."); DEFINE_string(dirname, "", "Directory of the inference model.");
...@@ -28,12 +30,27 @@ int main(int argc, char** argv) { ...@@ -28,12 +30,27 @@ int main(int argc, char** argv) {
exit(1); exit(1);
} }
// 1. Define place, executor, scope
auto place = paddle::platform::CPUPlace();
paddle::framework::InitDevices();
auto* executor = new paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl; std::cout << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::string dirname = FLAGS_dirname; std::string dirname = FLAGS_dirname;
paddle::InferenceEngine* engine = new paddle::InferenceEngine(); // 2. Initialize the inference program
engine->LoadInferenceModel(dirname); auto* inference_program = paddle::inference::Load(*executor, *scope, dirname);
// 3. Optional: perform optimization on the inference_program
// 4. Get the feed_var_names and fetch_var_names
const std::vector<std::string>& feed_var_names =
inference_program->GetFeedVarNames();
const std::vector<std::string>& fetch_var_names =
inference_program->GetFetchVarNames();
// 5. Generate input
paddle::framework::LoDTensor input; paddle::framework::LoDTensor input;
srand(time(0)); srand(time(0));
float* input_ptr = float* input_ptr =
...@@ -45,8 +62,26 @@ int main(int argc, char** argv) { ...@@ -45,8 +62,26 @@ int main(int argc, char** argv) {
std::vector<paddle::framework::LoDTensor> feeds; std::vector<paddle::framework::LoDTensor> feeds;
feeds.push_back(input); feeds.push_back(input);
std::vector<paddle::framework::LoDTensor> fetchs; std::vector<paddle::framework::LoDTensor> fetchs;
engine->Execute(feeds, fetchs);
// Set up maps for feed and fetch targets
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
// set_feed_variable
for (size_t i = 0; i < feed_var_names.size(); ++i) {
feed_targets[feed_var_names[i]] = &feeds[i];
}
// get_fetch_variable
fetchs.resize(fetch_var_names.size());
for (size_t i = 0; i < fetch_var_names.size(); ++i) {
fetch_targets[fetch_var_names[i]] = &fetchs[i];
}
// Run the inference program
executor->Run(*inference_program, scope, feed_targets, fetch_targets);
// Get outputs
for (size_t i = 0; i < fetchs.size(); ++i) { for (size_t i = 0; i < fetchs.size(); ++i) {
auto dims_i = fetchs[i].dims(); auto dims_i = fetchs[i].dims();
std::cout << "dims_i:"; std::cout << "dims_i:";
...@@ -62,6 +97,9 @@ int main(int argc, char** argv) { ...@@ -62,6 +97,9 @@ int main(int argc, char** argv) {
std::cout << std::endl; std::cout << std::endl;
} }
delete engine; delete inference_program;
delete scope;
delete executor;
return 0; return 0;
} }
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,48 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,48 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "inference.h" #include "paddle/inference/io.h"
#include <fstream> #include <fstream>
#include "paddle/framework/executor.h"
#include "paddle/framework/init.h"
#include "paddle/framework/scope.h"
namespace paddle { namespace paddle {
namespace inference {
void InferenceEngine::LoadInferenceModel(const std::string& dirname) { const std::string kFeedOpType = "feed";
std::string model_filename = dirname + "/__model__";
LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
std::string program_desc_str;
inputfs.seekg(0, std::ios::end);
program_desc_str.resize(inputfs.tellg());
inputfs.seekg(0, std::ios::beg);
LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close();
program_ = new framework::ProgramDesc(program_desc_str); bool IsParameter(const framework::VarDesc* var,
GenerateLoadProgram(dirname); const framework::ProgramDesc* main_program) {
framework::BlockDesc* global_block = program_->MutableBlock(0);
feed_var_names_.clear();
fetch_var_names_.clear();
for (auto* op : global_block->AllOps()) {
if (op->Type() == "feed") {
feed_var_names_.insert(feed_var_names_.begin(), op->Output("Out")[0]);
} else if (op->Type() == "fetch") {
fetch_var_names_.push_back(op->Input("X")[0]);
}
}
}
bool InferenceEngine::IsParameter(const framework::VarDesc* var) {
if (var->Persistable()) { if (var->Persistable()) {
// There are many unreachable variables in the program // There are many unreachable variables in the program
for (size_t i = 0; i < program_->Size(); ++i) { for (size_t i = 0; i < main_program->Size(); ++i) {
const framework::BlockDesc& block = program_->Block(i); const framework::BlockDesc& block = main_program->Block(i);
for (auto* op : block.AllOps()) { for (auto* op : block.AllOps()) {
if (op->Type() == "feed") { if (op->Type() == kFeedOpType) {
continue; continue;
} }
for (auto input_argument_name : op->InputArgumentNames()) { for (auto input_argument_name : op->InputArgumentNames()) {
...@@ -67,13 +41,16 @@ bool InferenceEngine::IsParameter(const framework::VarDesc* var) { ...@@ -67,13 +41,16 @@ bool InferenceEngine::IsParameter(const framework::VarDesc* var) {
return false; return false;
} }
void InferenceEngine::GenerateLoadProgram(const std::string& dirname) { void LoadPersistables(framework::Executor& executor,
framework::BlockDesc* global_block = program_->MutableBlock(0); framework::Scope& scope,
const std::string& dirname,
framework::ProgramDesc* main_program) {
framework::BlockDesc* global_block = main_program->MutableBlock(0);
load_program_ = new framework::ProgramDesc(); framework::ProgramDesc* load_program = new framework::ProgramDesc();
framework::BlockDesc* load_block = load_program_->MutableBlock(0); framework::BlockDesc* load_block = load_program->MutableBlock(0);
for (auto* var : global_block->AllVars()) { for (auto* var : global_block->AllVars()) {
if (IsParameter(var)) { if (IsParameter(var, main_program)) {
LOG(INFO) << "parameter's name: " << var->Name(); LOG(INFO) << "parameter's name: " << var->Name();
framework::VarDesc* new_var = load_block->Var(var->Name()); framework::VarDesc* new_var = load_block->Var(var->Name());
...@@ -91,97 +68,30 @@ void InferenceEngine::GenerateLoadProgram(const std::string& dirname) { ...@@ -91,97 +68,30 @@ void InferenceEngine::GenerateLoadProgram(const std::string& dirname) {
op->CheckAttrs(); op->CheckAttrs();
} }
} }
executor.Run(*load_program, &scope, 0, true, true);
delete load_program;
} }
void InferenceEngine::PrependFeedOp() { framework::ProgramDesc* Load(framework::Executor& executor,
if (!program_) { framework::Scope& scope,
LOG(FATAL) << "Please initialize the program_ first."; const std::string& dirname) {
} std::string model_filename = dirname + "/__model__";
LOG(INFO) << "loading model from " << model_filename;
framework::BlockDesc* global_block = program_->MutableBlock(0); std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
std::string program_desc_str;
// create_var inputfs.seekg(0, std::ios::end);
framework::VarDesc* feed_var = global_block->Var("feed"); program_desc_str.resize(inputfs.tellg());
feed_var->SetType(framework::proto::VarDesc::FEED_MINIBATCH); inputfs.seekg(0, std::ios::beg);
feed_var->SetPersistable(true); LOG(INFO) << "program_desc_str's size: " << program_desc_str.size();
inputfs.read(&program_desc_str[0], program_desc_str.size());
// prepend feed_op inputfs.close();
for (size_t i = 0; i < feed_var_names_.size(); ++i) {
std::string var_name = feed_var_names_[i];
LOG(INFO) << "feed var's name: " << var_name;
// prepend_op
framework::OpDesc* op = global_block->PrependOp();
op->SetType("feed");
op->SetInput("X", {"feed"});
op->SetOutput("Out", {var_name});
op->SetAttr("col", {static_cast<int>(i)});
op->CheckAttrs();
}
}
void InferenceEngine::AppendFetchOp() {
if (!program_) {
LOG(FATAL) << "Please initialize the program_ first.";
}
framework::BlockDesc* global_block = program_->MutableBlock(0);
// create_var
framework::VarDesc* fetch_var = global_block->Var("fetch");
fetch_var->SetType(framework::proto::VarDesc::FETCH_LIST);
fetch_var->SetPersistable(true);
// append fetch_op framework::ProgramDesc* main_program =
for (size_t i = 0; i < fetch_var_names_.size(); ++i) { new framework::ProgramDesc(program_desc_str);
std::string var_name = fetch_var_names_[i];
LOG(INFO) << "fetch var's name: " << var_name;
// append_op LoadPersistables(executor, scope, dirname, main_program);
framework::OpDesc* op = global_block->AppendOp(); return main_program;
op->SetType("fetch");
op->SetInput("X", {var_name});
op->SetOutput("Out", {"fetch"});
op->SetAttr("col", {static_cast<int>(i)});
op->CheckAttrs();
}
} }
void InferenceEngine::Execute(const std::vector<framework::LoDTensor>& feeds, } // namespace inference
std::vector<framework::LoDTensor>& fetchs) {
if (!program_ || !load_program_) {
LOG(FATAL) << "Please initialize the program_ and load_program_ first.";
}
if (feeds.size() != feed_var_names_.size()) {
LOG(FATAL) << "Please feed " << feed_var_names_.size() << " input Tensors.";
}
auto* place = new platform::CPUPlace();
framework::InitDevices();
framework::Executor* executor = new framework::Executor(*place);
framework::Scope* scope = new framework::Scope();
executor->Run(*load_program_, scope, 0, true, true);
std::map<std::string, const framework::LoDTensor*> feed_targets;
std::map<std::string, framework::LoDTensor*> fetch_targets;
// set_feed_variable
for (size_t i = 0; i < feed_var_names_.size(); ++i) {
feed_targets[feed_var_names_[i]] = &feeds[i];
}
// get_fetch_variable
fetchs.resize(fetch_var_names_.size());
for (size_t i = 0; i < fetch_var_names_.size(); ++i) {
fetch_targets[fetch_var_names_[i]] = &fetchs[i];
}
executor->Run(*program_, scope, feed_targets, fetch_targets);
delete place;
delete scope;
delete executor;
}
} // namespace paddle } // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -14,35 +14,28 @@ limitations under the License. */ ...@@ -14,35 +14,28 @@ limitations under the License. */
#pragma once #pragma once
#include <string>
#include <vector>
#include "paddle/framework/block_desc.h" #include "paddle/framework/block_desc.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/executor.h"
#include "paddle/framework/program_desc.h" #include "paddle/framework/program_desc.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/var_desc.h"
namespace paddle { namespace paddle {
namespace inference {
class InferenceEngine { bool IsParameter(const framework::VarDesc* var,
public: const framework::ProgramDesc* main_program);
InferenceEngine() : program_(nullptr), load_program_(nullptr) {}
~InferenceEngine() {
delete program_;
delete load_program_;
}
void LoadInferenceModel(const std::string& dirname);
void Execute(const std::vector<framework::LoDTensor>& feeds,
std::vector<framework::LoDTensor>& fetchs);
private:
bool IsParameter(const framework::VarDesc* var);
void GenerateLoadProgram(const std::string& dirname);
void PrependFeedOp();
void AppendFetchOp();
private:
framework::ProgramDesc* program_;
framework::ProgramDesc* load_program_;
std::vector<std::string> feed_var_names_;
std::vector<std::string> fetch_var_names_;
};
void LoadPersistables(framework::Executor& executor,
framework::Scope& scope,
const std::string& dirname,
framework::ProgramDesc* main_program);
framework::ProgramDesc* Load(framework::Executor& executor,
framework::Scope& scope,
const std::string& dirname);
} // namespace inference
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册