From da9dda5c9b6b2d43e5e81d53baef9d9abaa7f1ce Mon Sep 17 00:00:00 2001 From: whs Date: Fri, 12 Mar 2021 14:54:49 +0800 Subject: [PATCH] Make CreateProgramDesc more robust (#31543) --- .../imperative/jit/program_desc_tracer.cc | 26 ++++++++++++++++--- .../imperative/jit/program_desc_tracer.h | 2 +- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/imperative/jit/program_desc_tracer.cc b/paddle/fluid/imperative/jit/program_desc_tracer.cc index 53750f7bf02..1a44f50275e 100644 --- a/paddle/fluid/imperative/jit/program_desc_tracer.cc +++ b/paddle/fluid/imperative/jit/program_desc_tracer.cc @@ -69,6 +69,7 @@ UniqueBlockVarGenerator::UniqueBlockVarGenerator( std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr &var, const std::string &prefix) { + VLOG(3) << "Finding: " << var.lock()->Name(); auto all_vars_iter = all_vars_.find(var); PADDLE_ENFORCE_EQ(all_vars_iter != all_vars_.end(), true, platform::errors::NotFound( @@ -111,6 +112,15 @@ void UniqueBlockVarGenerator::InsertNewVarInBlock( } } +bool ProgramDescTracer::ContainVar(const std::weak_ptr &var) const { + auto vars_iter = vars_.find(var); + bool ret = (vars_iter != vars_.end()); + if (!ret) { + VLOG(5) << "Can't found variable: " << var.lock()->Name(); + } + return ret; +} + void ProgramDescTracer::InsertOp(const std::string &type, const NameVarBaseMap &inputs, const NameVarBaseMap &outputs, @@ -147,12 +157,16 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( std::vector feed_var_names; for (auto &feed_var : feed_vars) { - feed_var_names.emplace_back(generator.NameOf(feed_var, feed_prefix)); + if (ContainVar(feed_var)) { + feed_var_names.emplace_back(generator.NameOf(feed_var, feed_prefix)); + } } std::vector fetch_var_names; for (auto &fetch_var : fetch_vars) { - fetch_var_names.emplace_back(generator.NameOf(fetch_var, fetch_prefix)); + if (ContainVar(fetch_var)) { + fetch_var_names.emplace_back(generator.NameOf(fetch_var, fetch_prefix)); + } } for (auto &op : ops_) { @@ -164,7 +178,9 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( std::vector names; names.reserve(pair.second.size()); for (auto &var : pair.second) { - names.emplace_back(generator.NameOf(var, tmp_prefix)); + if (ContainVar(var)) { + names.emplace_back(generator.NameOf(var, tmp_prefix)); + } } op_desc->SetInput(pair.first, std::move(names)); @@ -174,7 +190,9 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( std::vector names; names.reserve(pair.second.size()); for (auto &var : pair.second) { - names.emplace_back(generator.NameOf(var, tmp_prefix)); + if (ContainVar(var)) { + names.emplace_back(generator.NameOf(var, tmp_prefix)); + } } op_desc->SetOutput(pair.first, std::move(names)); diff --git a/paddle/fluid/imperative/jit/program_desc_tracer.h b/paddle/fluid/imperative/jit/program_desc_tracer.h index 8e2e59a49ed..b231efb0e53 100644 --- a/paddle/fluid/imperative/jit/program_desc_tracer.h +++ b/paddle/fluid/imperative/jit/program_desc_tracer.h @@ -66,7 +66,7 @@ class ProgramDescTracer { const std::string &feed_prefix, const std::vector> &fetch_vars, const std::string &fetch_prefix, const std::string &tmp_prefix) const; - + bool ContainVar(const std::weak_ptr &var) const; void Reset(); private: -- GitLab