未验证 提交 2c552d4e 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #9630 from Xreki/core_inference_prepare

Split Executor.Run to Executor.Prepare and Executor.RunPreparedContext for inference
...@@ -78,7 +78,7 @@ if(NOT CMAKE_CROSSCOMPILING) ...@@ -78,7 +78,7 @@ if(NOT CMAKE_CROSSCOMPILING)
/usr/lib/reference/ /usr/lib/reference/
) )
else() else()
# Diable the finding of reference cblas under host's system path # Disable the finding of reference cblas under host's system path
set(REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/include) set(REFERENCE_CBLAS_INCLUDE_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/include)
set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib) set(REFERENCE_CBLAS_LIB_SEARCH_PATHS ${REFERENCE_CBLAS_ROOT}/lib)
endif() endif()
......
...@@ -83,8 +83,8 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -83,8 +83,8 @@ static void CheckTensorNANOrInf(const std::string& name,
if (tensor.memory_size() == 0) { if (tensor.memory_size() == 0) {
return; return;
} }
if (tensor.type().hash_code() != typeid(float).hash_code() && if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
tensor.type().hash_code() != typeid(double).hash_code()) { tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
return; return;
} }
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor), PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
...@@ -145,12 +145,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -145,12 +145,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
// Return true if the block has feed operators and holder of matching info. // Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators( static bool has_feed_operators(
const BlockDesc& block, const BlockDesc& block,
std::map<std::string, const LoDTensor*>& feed_targets, const std::map<std::string, const LoDTensor*>& feed_targets,
const std::string& feed_holder_name) { const std::string& feed_holder_name) {
size_t feed_count = 0; size_t feed_count = 0;
for (auto* op : block.AllOps()) { for (auto* op : block.AllOps()) {
if (op->Type() == kFeedOpType) { if (op->Type() == kFeedOpType) {
feed_count++; feed_count++;
// The input variable's name of feed_op should be feed_holder_name.
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name, PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
"Input to feed op should be '%s'", feed_holder_name); "Input to feed op should be '%s'", feed_holder_name);
std::string feed_target_name = op->Output("Out")[0]; std::string feed_target_name = op->Output("Out")[0];
...@@ -166,13 +167,15 @@ static bool has_feed_operators( ...@@ -166,13 +167,15 @@ static bool has_feed_operators(
feed_count, feed_targets.size(), feed_count, feed_targets.size(),
"The number of feed operators should match 'feed_targets'"); "The number of feed operators should match 'feed_targets'");
// When feed operator are present, so should be feed_holder if (!feed_holder_name.empty()) {
auto var = block.FindVar(feed_holder_name); // When feed operator are present, so should be feed_holder.
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", auto var = block.FindVar(feed_holder_name);
feed_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, feed_holder_name);
"'%s' variable should be 'FEED_MINIBATCH' type", PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
feed_holder_name); "'%s' variable should be 'FEED_MINIBATCH' type",
feed_holder_name);
}
} }
return feed_count > 0; return feed_count > 0;
...@@ -185,12 +188,14 @@ static bool has_feed_operators( ...@@ -185,12 +188,14 @@ static bool has_feed_operators(
// and fetch_holder_name. Raise exception when any mismatch is found. // and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info. // Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators( static bool has_fetch_operators(
const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets, const BlockDesc& block,
const std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& fetch_holder_name) { const std::string& fetch_holder_name) {
size_t fetch_count = 0; size_t fetch_count = 0;
for (auto* op : block.AllOps()) { for (auto* op : block.AllOps()) {
if (op->Type() == kFetchOpType) { if (op->Type() == kFetchOpType) {
fetch_count++; fetch_count++;
// The output variable's name of fetch_op should be fetch_holder_name.
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name, PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
"Output of fetch op should be '%s'", fetch_holder_name); "Output of fetch op should be '%s'", fetch_holder_name);
std::string fetch_target_name = op->Input("X")[0]; std::string fetch_target_name = op->Input("X")[0];
...@@ -206,13 +211,15 @@ static bool has_fetch_operators( ...@@ -206,13 +211,15 @@ static bool has_fetch_operators(
fetch_count, fetch_targets.size(), fetch_count, fetch_targets.size(),
"The number of fetch operators should match 'fetch_targets'"); "The number of fetch operators should match 'fetch_targets'");
// When fetch operator are present, so should be fetch_holder if (!fetch_holder_name.empty()) {
auto var = block.FindVar(fetch_holder_name); // When fetch operator are present, so should be fetch_holder.
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", auto var = block.FindVar(fetch_holder_name);
fetch_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, fetch_holder_name);
"'%s' variable should be 'FETCH_LIST' type", PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
fetch_holder_name); "'%s' variable should be 'FETCH_LIST' type",
fetch_holder_name);
}
} }
return fetch_count > 0; return fetch_count > 0;
...@@ -259,16 +266,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -259,16 +266,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
} }
// map the data of feed_targets to feed_holder
for (auto* op : global_block->AllOps()) {
if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0];
int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
idx);
}
}
if (!has_fetch_ops) { if (!has_fetch_ops) {
// create fetch_holder variable // create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name); auto* fetch_holder = global_block->Var(fetch_holder_name);
...@@ -292,17 +289,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -292,17 +289,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
} }
Run(*copy_program, scope, 0, create_vars, create_vars); auto ctx = Prepare(*copy_program, 0);
RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, create_vars,
// obtain the data of fetch_targets from fetch_holder feed_holder_name, fetch_holder_name);
for (auto* op : global_block->AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
} }
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare( std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
...@@ -370,5 +359,42 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -370,5 +359,42 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} }
} }
void Executor::RunPreparedContext(
ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets, bool create_vars,
const std::string& feed_holder_name, const std::string& fetch_holder_name) {
auto& global_block = ctx->prog_.Block(ctx->block_id_);
PADDLE_ENFORCE(
has_feed_operators(global_block, feed_targets, feed_holder_name),
"Program in ExecutorPrepareContext should has feed_ops.");
PADDLE_ENFORCE(
has_fetch_operators(global_block, fetch_targets, fetch_holder_name),
"Program in the prepared context should has fetch_ops.");
// map the data of feed_targets to feed_holder
for (auto* op : global_block.AllOps()) {
if (op->Type() == kFeedOpType) {
std::string feed_target_name = op->Output("Out")[0];
int idx = boost::get<int>(op->GetAttr("col"));
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
idx);
}
}
RunPreparedContext(ctx, scope, create_vars, create_vars);
// obtain the data of fetch_targets from fetch_holder
for (auto* op : global_block.AllOps()) {
if (op->Type() == kFetchOpType) {
std::string fetch_target_name = op->Input("X")[0];
int idx = boost::get<int>(op->GetAttr("col"));
*fetch_targets[fetch_target_name] =
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,9 @@ limitations under the License. */ ...@@ -14,6 +14,9 @@ limitations under the License. */
#pragma once #pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -70,6 +73,13 @@ class Executor { ...@@ -70,6 +73,13 @@ class Executor {
bool create_local_scope = true, bool create_local_scope = true,
bool create_vars = true); bool create_vars = true);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch");
private: private:
const platform::Place place_; const platform::Place place_;
}; };
......
...@@ -23,7 +23,7 @@ limitations under the License. */ ...@@ -23,7 +23,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace inference { namespace inference {
// Temporarilly add this function for exposing framework::InitDevices() when // Temporarily add this function for exposing framework::InitDevices() when
// linking the inference shared library. // linking the inference shared library.
void Init(bool init_p2p) { framework::InitDevices(init_p2p); } void Init(bool init_p2p) { framework::InitDevices(init_p2p); }
......
...@@ -46,8 +46,8 @@ TEST(inference, image_classification) { ...@@ -46,8 +46,8 @@ TEST(inference, image_classification) {
// Run inference on CPU // Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---"; LOG(INFO) << "--- CPU Runs: ---";
TestInference<paddle::platform::CPUPlace, false>(dirname, cpu_feeds, TestInference<paddle::platform::CPUPlace, false, true>(
cpu_fetchs1, FLAGS_repeat); dirname, cpu_feeds, cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims(); LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -57,8 +57,8 @@ TEST(inference, image_classification) { ...@@ -57,8 +57,8 @@ TEST(inference, image_classification) {
// Run inference on CUDA GPU // Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs: ---"; LOG(INFO) << "--- GPU Runs: ---";
TestInference<paddle::platform::CUDAPlace, false>(dirname, cpu_feeds, TestInference<paddle::platform::CUDAPlace, false, true>(
cpu_fetchs2, FLAGS_repeat); dirname, cpu_feeds, cpu_fetchs2, FLAGS_repeat);
LOG(INFO) << output2.dims(); LOG(INFO) << output2.dims();
CheckError<float>(output1, output2); CheckError<float>(output1, output2);
......
...@@ -89,7 +89,7 @@ void CheckError(const paddle::framework::LoDTensor& output1, ...@@ -89,7 +89,7 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
} }
template <typename Place, bool CreateVars = true> template <typename Place, bool CreateVars = true, bool PrepareContext = false>
void TestInference(const std::string& dirname, void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds, const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs, const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
...@@ -175,8 +175,15 @@ void TestInference(const std::string& dirname, ...@@ -175,8 +175,15 @@ void TestInference(const std::string& dirname,
} }
// Ignore the profiling results of the first run // Ignore the profiling results of the first run
executor.Run(*inference_program, scope, feed_targets, fetch_targets, std::unique_ptr<paddle::framework::ExecutorPrepareContext> ctx;
CreateVars); if (PrepareContext) {
ctx = executor.Prepare(*inference_program, 0);
executor.RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
CreateVars);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);
}
// Enable the profiler // Enable the profiler
paddle::platform::EnableProfiler(state); paddle::platform::EnableProfiler(state);
...@@ -187,8 +194,15 @@ void TestInference(const std::string& dirname, ...@@ -187,8 +194,15 @@ void TestInference(const std::string& dirname,
"run_inference", "run_inference",
paddle::platform::DeviceContextPool::Instance().Get(place)); paddle::platform::DeviceContextPool::Instance().Get(place));
executor.Run(*inference_program, scope, feed_targets, fetch_targets, if (PrepareContext) {
CreateVars); // Note: if you change the inference_program, you need to call
// executor.Prepare() again to get a new ExecutorPrepareContext.
executor.RunPreparedContext(ctx.get(), scope, feed_targets,
fetch_targets, CreateVars);
} else {
executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);
}
} }
// Disable the profiler and print the timing information // Disable the profiler and print the timing information
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册