提交 09687534 编写于 作者: L Liu Yiqun

Enable the test of not creating variables every time.

上级 ed2bc194
......@@ -93,6 +93,42 @@ static void CheckTensorNANOrInf(const std::string& name,
"Tensor %s contains NAN", name);
}
void Executor::CreateVariables(const ProgramDesc& pdesc, Scope* scope) {
auto& global_block = pdesc.Block(0);
const Scope* ancestor_scope = scope;
while (ancestor_scope->parent()) {
ancestor_scope = ancestor_scope->parent();
}
if (ancestor_scope != scope) {
for (auto& var : global_block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (var->Persistable()) {
auto* ptr = const_cast<Scope*>(ancestor_scope)->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : global_block.AllVars()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
platform::RecordBlock b(block_id);
......@@ -184,8 +220,8 @@ static bool has_fetch_operators(
void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name,
const std::string& fetch_holder_name, bool create_vars) {
bool create_vars, const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
platform::RecordBlock b(kProgramId);
bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
......@@ -281,39 +317,16 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars) {
auto& block = ctx->prog_.Block(ctx->block_id_);
Scope* local_scope = scope;
if (create_vars) {
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
}
} else {
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
} // if (create_local_scope)
} // if (create_vars)
CreateVariables(ctx->prog_, local_scope);
for (auto& op : ctx->ops_) {
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
......
......@@ -53,13 +53,15 @@ class Executor {
void Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, const LoDTensor*>& feed_targets,
std::map<std::string, LoDTensor*>& fetch_targets,
bool create_vars = true,
const std::string& feed_holder_name = "feed",
const std::string& fetch_holder_name = "fetch",
bool create_vars = true);
const std::string& fetch_holder_name = "fetch");
static std::unique_ptr<ExecutorPrepareContext> Prepare(
const ProgramDesc& program, int block_id);
void CreateVariables(const ProgramDesc& pdesc, Scope* scope);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true,
bool create_vars = true);
......
......@@ -57,7 +57,7 @@ class Scope {
/// nullptr if cannot find.
Variable* FindVar(const std::string& name) const;
const Scope& parent() const { return *parent_; }
const Scope* parent() const { return parent_; }
/// Find the scope or an ancestor scope that contains the given variable.
const Scope* FindScope(const Variable* var) const;
......
......@@ -169,8 +169,14 @@ void TestInference(const std::string& dirname,
// 6. Run the inference program
{
const bool create_vars = false;
if (!create_vars) {
executor.CreateVariables(*inference_program, scope);
}
// Ignore the profiling results of the first run
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
executor.Run(
*inference_program, scope, feed_targets, fetch_targets, create_vars);
// Enable the profiler
paddle::platform::EnableProfiler(state);
......@@ -181,7 +187,8 @@ void TestInference(const std::string& dirname,
"run_inference",
paddle::platform::DeviceContextPool::Instance().Get(place));
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
executor.Run(
*inference_program, scope, feed_targets, fetch_targets, create_vars);
}
// Disable the profiler and print the timing information
......
......@@ -56,11 +56,11 @@ class GoOp : public framework::OperatorBase {
// TODO(varunarora): Consider moving this root scope lookup to scope.h.
const framework::Scope *root_scope = &scope;
const framework::Scope *parent_scope = &(root_scope->parent());
const framework::Scope *parent_scope = root_scope->parent();
while (parent_scope != nullptr) {
root_scope = parent_scope;
parent_scope = &(parent_scope->parent());
parent_scope = parent_scope->parent();
}
framework::BlockDesc *block = Attr<framework::BlockDesc *>(kBlock);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册