未验证 提交 46d6f4ce 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #9301 from Xreki/core_inference_fix_run

Expose a interface to create all variables and enable the test of not creating variables every time.
...@@ -93,6 +93,43 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -93,6 +93,43 @@ static void CheckTensorNANOrInf(const std::string& name,
"Tensor %s contains NAN", name); "Tensor %s contains NAN", name);
} }
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope,
int block_id) {
auto& global_block = pdesc.Block(block_id);
const Scope* ancestor_scope = scope;
while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent();
}
if (ancestor_scope != scope) {
for (auto& var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (var->Persistable()) {
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : global_block.AllVars()) {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
...@@ -184,8 +221,8 @@ static bool has_fetch_operators( ...@@ -184,8 +221,8 @@ static bool has_fetch_operators(
void Executor::Run(const ProgramDesc& program, Scope* scope, void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets, std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets, std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name, bool create_vars, const std::string& feed_holder_name,
const std::string& fetch_holder_name, bool create_vars) { const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId); platform::RecordBlock b(kProgramId);
bool has_feed_ops = bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name); has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
...@@ -296,38 +333,13 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare( ...@@ -296,38 +333,13 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars) {
auto& block = ctx->prog_.Block(ctx->block_id_);
Scope* local_scope = scope; Scope* local_scope = scope;
if (create_vars) { if (create_vars) {
if (create_local_scope) { if (create_local_scope) {
local_scope = &scope->NewScope(); local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) { }
if (var->Name() == framework::kEmptyVarName) { CreateVariables(ctx->prog_, local_scope, ctx->block_id_);
continue; }
}
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
InitializeVariable(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
} // if (create_local_scope)
} // if (create_vars)
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
......
...@@ -54,9 +54,9 @@ class Executor { ...@@ -54,9 +54,9 @@ class Executor {
void Run(const ProgramDesc& program, Scope* scope, void Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets, std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets, std::map<std::string, LoDTensor*>& fetch_targets,
bool create_vars = true,
const std::string& feed_holder_name = "feed", const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch", const std::string& fetch_holder_name = "fetch");
bool create_vars = true);
static std::unique_ptr<ExecutorPrepareContext> Prepare( static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id); const ProgramDesc& program, int block_id);
...@@ -64,6 +64,8 @@ class Executor { ...@@ -64,6 +64,8 @@ class Executor {
static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare( static std::vector<std::shared_ptr<ExecutorPrepareContext>> Prepare(
const ProgramDesc& program, const std::vector<int>& block_ids); const ProgramDesc& program, const std::vector<int>& block_ids);
void CreateVariables(const ProgramDesc& pdesc, Scope* scope, int block_id);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true, bool create_local_scope = true,
bool create_vars = true); bool create_vars = true);
......
...@@ -58,7 +58,7 @@ class Scope { ...@@ -58,7 +58,7 @@ class Scope {
/// nullptr if cannot find. /// nullptr if cannot find.
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
const Scope& parent() const { return *parent_; } const Scope* parent() const { return parent_; }
/// Find the scope or an ancestor scope that contains the given variable. /// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const; const Scope* FindScope(const Variable* var) const;
......
...@@ -46,8 +46,8 @@ TEST(inference, image_classification) { ...@@ -46,8 +46,8 @@ TEST(inference, image_classification) {
// Run inference on CPU // Run inference on CPU
LOG(INFO) << "--- CPU Runs: ---"; LOG(INFO) << "--- CPU Runs: ---";
TestInference<paddle::platform::CPUPlace>(dirname, cpu_feeds, cpu_fetchs1, TestInference<paddle::platform::CPUPlace, false>(dirname, cpu_feeds,
FLAGS_repeat); cpu_fetchs1, FLAGS_repeat);
LOG(INFO) << output1.dims(); LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -57,8 +57,8 @@ TEST(inference, image_classification) { ...@@ -57,8 +57,8 @@ TEST(inference, image_classification) {
// Run inference on CUDA GPU // Run inference on CUDA GPU
LOG(INFO) << "--- GPU Runs: ---"; LOG(INFO) << "--- GPU Runs: ---";
TestInference<paddle::platform::CUDAPlace>(dirname, cpu_feeds, cpu_fetchs2, TestInference<paddle::platform::CUDAPlace, false>(dirname, cpu_feeds,
FLAGS_repeat); cpu_fetchs2, FLAGS_repeat);
LOG(INFO) << output2.dims(); LOG(INFO) << output2.dims();
CheckError<float>(output1, output2); CheckError<float>(output1, output2);
......
...@@ -88,7 +88,7 @@ void CheckError(const paddle::framework::LoDTensor& output1, ...@@ -88,7 +88,7 @@ void CheckError(const paddle::framework::LoDTensor& output1,
EXPECT_EQ(count, 0U) << "There are " << count << " different elements."; EXPECT_EQ(count, 0U) << "There are " << count << " different elements.";
} }
template <typename Place> template <typename Place, bool CreateVars = true>
void TestInference(const std::string& dirname, void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds, const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs, const std::vector<paddle::framework::LoDTensor*>& cpu_fetchs,
...@@ -166,8 +166,16 @@ void TestInference(const std::string& dirname, ...@@ -166,8 +166,16 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program // 6. Run the inference program
{ {
if (!CreateVars) {
// If users don't want to create and destroy variables every time they
// run, they need to set `create_vars` to false and manually call
// `CreateVariables` before running.
executor.CreateVariables(*inference_program, scope, 0);
}
// Ignore the profiling results of the first run // Ignore the profiling results of the first run
executor.Run(*inference_program, scope, feed_targets, fetch_targets); executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);
// Enable the profiler // Enable the profiler
paddle::platform::EnableProfiler(state); paddle::platform::EnableProfiler(state);
...@@ -178,7 +186,8 @@ void TestInference(const std::string& dirname, ...@@ -178,7 +186,8 @@ void TestInference(const std::string& dirname,
"run_inference", "run_inference",
paddle::platform::DeviceContextPool::Instance().Get(place)); paddle::platform::DeviceContextPool::Instance().Get(place));
executor.Run(*inference_program, scope, feed_targets, fetch_targets); executor.Run(*inference_program, scope, feed_targets, fetch_targets,
CreateVars);
} }
// Disable the profiler and print the timing information // Disable the profiler and print the timing information
......
...@@ -56,11 +56,11 @@ class GoOp : public framework::OperatorBase { ...@@ -56,11 +56,11 @@ class GoOp : public framework::OperatorBase {
// TODO(varunarora): Consider moving this root scope lookup to scope.h. // TODO(varunarora): Consider moving this root scope lookup to scope.h.
const framework::Scope *root_scope = &scope; const framework::Scope *root_scope = &scope;
const framework::Scope *parent_scope = &(root_scope->parent()); const framework::Scope *parent_scope = root_scope->parent();
while (parent_scope != nullptr) { while (parent_scope != nullptr) {
root_scope = parent_scope; root_scope = parent_scope;
parent_scope = &(parent_scope->parent()); parent_scope = parent_scope->parent();
} }
framework::BlockDesc *block = Attr<framework::BlockDesc *>(kBlock); framework::BlockDesc *block = Attr<framework::BlockDesc *>(kBlock);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册