提交 f5990b46 编写于 作者: L Liu Yiqun

Merge branch 'develop' into core_add_inference_unittest

......@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory>
#include <vector>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/proto_desc.h"
#include "paddle/platform/macros.h"
......
......@@ -21,11 +21,12 @@ limitations under the License. */
#include <vector>
#include <glog/logging.h>
#include "paddle/framework/feed_fetch_type.h"
namespace paddle {
namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
const std::string kDropOutOpType = "dropout";
const std::string kBatchNormOpType = "batch_norm";
......
......@@ -22,11 +22,11 @@ namespace paddle {
namespace inference {
bool IsParameter(const framework::VarDesc* var,
const framework::ProgramDesc* main_program) {
const framework::ProgramDesc& main_program) {
if (var->Persistable()) {
// There are many unreachable variables in the program
for (size_t i = 0; i < main_program->Size(); ++i) {
const framework::BlockDesc& block = main_program->Block(i);
for (size_t i = 0; i < main_program.Size(); ++i) {
const framework::BlockDesc& block = main_program.Block(i);
for (auto* op : block.AllOps()) {
if (op->Type() == framework::kFeedOpType) {
continue;
......@@ -45,12 +45,12 @@ bool IsParameter(const framework::VarDesc* var,
void LoadPersistables(framework::Executor& executor,
framework::Scope& scope,
const std::string& dirname,
framework::ProgramDesc* main_program) {
framework::BlockDesc* global_block = main_program->MutableBlock(0);
const framework::ProgramDesc& main_program) {
const framework::BlockDesc& global_block = main_program.Block(0);
framework::ProgramDesc* load_program = new framework::ProgramDesc();
framework::BlockDesc* load_block = load_program->MutableBlock(0);
for (auto* var : global_block->AllVars()) {
for (auto* var : global_block.AllVars()) {
if (IsParameter(var, main_program)) {
VLOG(3) << "parameter's name: " << var->Name();
......@@ -73,9 +73,9 @@ void LoadPersistables(framework::Executor& executor,
delete load_program;
}
framework::ProgramDesc* Load(framework::Executor& executor,
framework::Scope& scope,
const std::string& dirname) {
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor& executor,
framework::Scope& scope,
const std::string& dirname) {
std::string model_filename = dirname + "/__model__";
LOG(INFO) << "loading model from " << model_filename;
std::ifstream inputfs(model_filename, std::ios::in | std::ios::binary);
......@@ -87,10 +87,10 @@ framework::ProgramDesc* Load(framework::Executor& executor,
inputfs.read(&program_desc_str[0], program_desc_str.size());
inputfs.close();
framework::ProgramDesc* main_program =
new framework::ProgramDesc(program_desc_str);
std::unique_ptr<framework::ProgramDesc> main_program(
new framework::ProgramDesc(program_desc_str));
LoadPersistables(executor, scope, dirname, main_program);
LoadPersistables(executor, scope, dirname, *main_program);
return main_program;
}
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/framework/executor.h"
......@@ -26,11 +27,11 @@ namespace inference {
void LoadPersistables(framework::Executor& executor,
framework::Scope& scope,
const std::string& dirname,
framework::ProgramDesc* main_program);
const framework::ProgramDesc& main_program);
framework::ProgramDesc* Load(framework::Executor& executor,
framework::Scope& scope,
const std::string& dirname);
std::unique_ptr<framework::ProgramDesc> Load(framework::Executor& executor,
framework::Scope& scope,
const std::string& dirname);
} // namespace inference
} // namespace paddle
......@@ -31,7 +31,7 @@ void TestInference(const std::string& dirname,
auto* scope = new paddle::framework::Scope();
// 2. Initialize the inference_program and load all parameters from file
auto* inference_program = paddle::inference::Load(executor, *scope, dirname);
auto inference_program = paddle::inference::Load(executor, *scope, dirname);
// 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names =
......@@ -39,14 +39,14 @@ void TestInference(const std::string& dirname,
const std::vector<std::string>& fetch_target_names =
inference_program->GetFetchTargetNames();
// 4. Prepare inputs
// 4. Prepare inputs: set up maps for feed targets
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
for (size_t i = 0; i < feed_target_names.size(); ++i) {
// Please make sure that cpu_feeds[i] is right for feed_target_names[i]
feed_targets[feed_target_names[i]] = cpu_feeds[i];
}
// 5. Define Tensor to get the outputs
// 5. Define Tensor to get the outputs: set up maps for fetch targets
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
for (size_t i = 0; i < fetch_target_names.size(); ++i) {
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
......@@ -55,7 +55,6 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
delete inference_program;
delete scope;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册