未验证 提交 da9dda5c 编写于 作者: W whs 提交者: GitHub

Make CreateProgramDesc more robust (#31543)

上级 99dcd665
...@@ -69,6 +69,7 @@ UniqueBlockVarGenerator::UniqueBlockVarGenerator( ...@@ -69,6 +69,7 @@ UniqueBlockVarGenerator::UniqueBlockVarGenerator(
std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr<VarBase> &var, std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr<VarBase> &var,
const std::string &prefix) { const std::string &prefix) {
VLOG(3) << "Finding: " << var.lock()->Name();
auto all_vars_iter = all_vars_.find(var); auto all_vars_iter = all_vars_.find(var);
PADDLE_ENFORCE_EQ(all_vars_iter != all_vars_.end(), true, PADDLE_ENFORCE_EQ(all_vars_iter != all_vars_.end(), true,
platform::errors::NotFound( platform::errors::NotFound(
...@@ -111,6 +112,15 @@ void UniqueBlockVarGenerator::InsertNewVarInBlock( ...@@ -111,6 +112,15 @@ void UniqueBlockVarGenerator::InsertNewVarInBlock(
} }
} }
bool ProgramDescTracer::ContainVar(const std::weak_ptr<VarBase> &var) const {
auto vars_iter = vars_.find(var);
bool ret = (vars_iter != vars_.end());
if (!ret) {
VLOG(5) << "Can't found variable: " << var.lock()->Name();
}
return ret;
}
void ProgramDescTracer::InsertOp(const std::string &type, void ProgramDescTracer::InsertOp(const std::string &type,
const NameVarBaseMap &inputs, const NameVarBaseMap &inputs,
const NameVarBaseMap &outputs, const NameVarBaseMap &outputs,
...@@ -147,13 +157,17 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( ...@@ -147,13 +157,17 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
std::vector<std::string> feed_var_names; std::vector<std::string> feed_var_names;
for (auto &feed_var : feed_vars) { for (auto &feed_var : feed_vars) {
if (ContainVar(feed_var)) {
feed_var_names.emplace_back(generator.NameOf(feed_var, feed_prefix)); feed_var_names.emplace_back(generator.NameOf(feed_var, feed_prefix));
} }
}
std::vector<std::string> fetch_var_names; std::vector<std::string> fetch_var_names;
for (auto &fetch_var : fetch_vars) { for (auto &fetch_var : fetch_vars) {
if (ContainVar(fetch_var)) {
fetch_var_names.emplace_back(generator.NameOf(fetch_var, fetch_prefix)); fetch_var_names.emplace_back(generator.NameOf(fetch_var, fetch_prefix));
} }
}
for (auto &op : ops_) { for (auto &op : ops_) {
auto *op_desc = block->AppendOp(); auto *op_desc = block->AppendOp();
...@@ -164,8 +178,10 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( ...@@ -164,8 +178,10 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
std::vector<std::string> names; std::vector<std::string> names;
names.reserve(pair.second.size()); names.reserve(pair.second.size());
for (auto &var : pair.second) { for (auto &var : pair.second) {
if (ContainVar(var)) {
names.emplace_back(generator.NameOf(var, tmp_prefix)); names.emplace_back(generator.NameOf(var, tmp_prefix));
} }
}
op_desc->SetInput(pair.first, std::move(names)); op_desc->SetInput(pair.first, std::move(names));
} }
...@@ -174,8 +190,10 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc( ...@@ -174,8 +190,10 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
std::vector<std::string> names; std::vector<std::string> names;
names.reserve(pair.second.size()); names.reserve(pair.second.size());
for (auto &var : pair.second) { for (auto &var : pair.second) {
if (ContainVar(var)) {
names.emplace_back(generator.NameOf(var, tmp_prefix)); names.emplace_back(generator.NameOf(var, tmp_prefix));
} }
}
op_desc->SetOutput(pair.first, std::move(names)); op_desc->SetOutput(pair.first, std::move(names));
} }
......
...@@ -66,7 +66,7 @@ class ProgramDescTracer { ...@@ -66,7 +66,7 @@ class ProgramDescTracer {
const std::string &feed_prefix, const std::string &feed_prefix,
const std::vector<std::shared_ptr<VarBase>> &fetch_vars, const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
const std::string &fetch_prefix, const std::string &tmp_prefix) const; const std::string &fetch_prefix, const std::string &tmp_prefix) const;
bool ContainVar(const std::weak_ptr<VarBase> &var) const;
void Reset(); void Reset();
private: private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册