提交 f5990b46 编写于 作者: L Liu Yiqun

Merge branch 'develop' into core_add_inference_unittest

...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/proto_desc.h" #include "paddle/framework/proto_desc.h"
#include "paddle/platform/macros.h" #include "paddle/platform/macros.h"
......
...@@ -21,11 +21,12 @@ limitations under the License. */ ...@@ -21,11 +21,12 @@ limitations under the License. */
#include <vector> #include <vector>
#include <glog/logging.h> #include <glog/logging.h>
#include "paddle/framework/feed_fetch_type.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
const std::string kDropOutOpType = "dropout"; const std::string kDropOutOpType = "dropout";
const std::string kBatchNormOpType = "batch_norm"; const std::string kBatchNormOpType = "batch_norm";
......
...@@ -22,11 +22,11 @@ namespace paddle { ...@@ -22,11 +22,11 @@ namespace paddle {
namespace inference { namespace inference {
bool IsParameter(const framework::VarDesc* var, bool IsParameter(const framework::VarDesc* var,
const framework::ProgramDesc* main_program) { const framework::ProgramDesc& main_program) {
if (var->Persistable()) { if (var->Persistable()) {
// There are many unreachable variables in the program // There are many unreachable variables in the program
for (size_t i = 0; i < main_program->Size(); ++i) { for (size_t i = 0; i < main_program.Size(); ++i) {
const framework::BlockDesc& block = main_program->Block(i); const framework::BlockDesc& block = main_program.Block(i);
for (auto* op : block.AllOps()) { for (auto* op : block.AllOps()) {
if (op->Type() == framework::kFeedOpType) { if (op->Type() == framework::kFeedOpType) {
continue; continue;
...@@ -45,12 +45,12 @@ bool IsParameter(const framework::VarDesc* var, ...@@ -45,12 +45,12 @@ bool IsParameter(const framework::VarDesc* var,
void LoadPersistables(framework::Executor& executor, void LoadPersistables(framework::Executor& executor,
framework::Scope& scope, framework::Scope& scope,
const std::string& dirname, const std::string& dirname,
framework::ProgramDesc* main_program) { const framework::ProgramDesc& main_program) {
framework::BlockDesc* global_block = main_program->MutableBlock(0); const framework::BlockDesc& global_block = main_program.Block(0);
framework::ProgramDesc* load_program = new framework::ProgramDesc(); framework::ProgramDesc* load_program = new framework::ProgramDesc();
framework::BlockDesc* load_block = load_program->MutableBlock(0); framework::BlockDesc* load_block = load_program->MutableBlock(0);
for (auto* var : global_block->AllVars()) { for (auto* var : global_block.AllVars()) {
if (IsParameter(var, main_program)) { if (IsParameter(var, main_program)) {
VLOG(3) << "parameter's name: " << var->Name(); VLOG(3) << "parameter's name: " << var->Name();
...@@ -73,9 +73,9 @@ void LoadPersistables(framework::Executor& executor, ...@@ -73,9 +73,9 @@ void LoadPersistables(framework::Executor& executor,
delete load_program; delete load_program;
} }
framework::ProgramDesc* Load(framework::Executor& executor, std::unique_ptr<framework::ProgramDesc> Load(framework::Executor& executor,
framework::Scope& scope, framework::Scope& scope,
const std::string& dirname) { const std::string& dirname) {
std::string model_filename = dirname + "/__model__"; std::string model_filename = dirname + "/__model__";
LOG(INFO) << "loading model from " << model_filename; LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary); std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
...@@ -87,10 +87,10 @@ framework::ProgramDesc* Load(framework::Executor& executor, ...@@ -87,10 +87,10 @@ framework::ProgramDesc* Load(framework::Executor& executor,
inputfs.read(&program_desc_str[0], program_desc_str.size()); inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close(); inputfs.close();
framework::ProgramDesc* main_program = std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(program_desc_str); new framework::ProgramDesc(program_desc_str));
LoadPersistables(executor, scope, dirname, main_program); LoadPersistables(executor, scope, dirname, *main_program);
return main_program; return main_program;
} }
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
...@@ -26,11 +27,11 @@ namespace inference { ...@@ -26,11 +27,11 @@ namespace inference {
void LoadPersistables(framework::Executor& executor, void LoadPersistables(framework::Executor& executor,
framework::Scope& scope, framework::Scope& scope,
const std::string& dirname, const std::string& dirname,
framework::ProgramDesc* main_program); const framework::ProgramDesc& main_program);
framework::ProgramDesc* Load(framework::Executor& executor, std::unique_ptr<framework::ProgramDesc> Load(framework::Executor& executor,
framework::Scope& scope, framework::Scope& scope,
const std::string& dirname); const std::string& dirname);
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -31,7 +31,7 @@ void TestInference(const std::string& dirname, ...@@ -31,7 +31,7 @@ void TestInference(const std::string& dirname,
auto* scope = new paddle::framework::Scope(); auto* scope = new paddle::framework::Scope();
// 2. Initialize the inference_program and load all parameters from file // 2. Initialize the inference_program and load all parameters from file
auto* inference_program = paddle::inference::Load(executor, *scope, dirname); auto inference_program = paddle::inference::Load(executor, *scope, dirname);
// 3. Get the feed_target_names and fetch_target_names // 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names = const std::vector<std::string>& feed_target_names =
...@@ -39,14 +39,14 @@ void TestInference(const std::string& dirname, ...@@ -39,14 +39,14 @@ void TestInference(const std::string& dirname,
const std::vector<std::string>& fetch_target_names = const std::vector<std::string>& fetch_target_names =
inference_program->GetFetchTargetNames(); inference_program->GetFetchTargetNames();
// 4. Prepare inputs // 4. Prepare inputs: set up maps for feed targets
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets; std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
for (size_t i = 0; i < feed_target_names.size(); ++i) { for (size_t i = 0; i < feed_target_names.size(); ++i) {
// Please make sure that cpu_feeds[i] is right for feed_target_names[i] // Please make sure that cpu_feeds[i] is right for feed_target_names[i]
feed_targets[feed_target_names[i]] = cpu_feeds[i]; feed_targets[feed_target_names[i]] = cpu_feeds[i];
} }
// 5. Define Tensor to get the outputs // 5. Define Tensor to get the outputs: set up maps for fetch targets
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets; std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
for (size_t i = 0; i < fetch_target_names.size(); ++i) { for (size_t i = 0; i < fetch_target_names.size(); ++i) {
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i]; fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
...@@ -55,7 +55,6 @@ void TestInference(const std::string& dirname, ...@@ -55,7 +55,6 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program // 6. Run the inference program
executor.Run(*inference_program, scope, feed_targets, fetch_targets); executor.Run(*inference_program, scope, feed_targets, fetch_targets);
delete inference_program;
delete scope; delete scope;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册