未验证 提交 da9dda5c 编写于 作者: W whs 提交者: GitHub

Make CreateProgramDesc more robust (#31543)

上级 99dcd665
......@@ -69,6 +69,7 @@ UniqueBlockVarGenerator::UniqueBlockVarGenerator(
std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr<VarBase> &var,
const std::string &prefix) {
VLOG(3) << "Finding: " << var.lock()->Name();
auto all_vars_iter = all_vars_.find(var);
PADDLE_ENFORCE_EQ(all_vars_iter != all_vars_.end(), true,
platform::errors::NotFound(
......@@ -111,6 +112,15 @@ void UniqueBlockVarGenerator::InsertNewVarInBlock(
}
}
bool ProgramDescTracer::ContainVar(const std::weak_ptr<VarBase> &var) const {
auto vars_iter = vars_.find(var);
bool ret = (vars_iter != vars_.end());
if (!ret) {
VLOG(5) << "Can't found variable: " << var.lock()->Name();
}
return ret;
}
void ProgramDescTracer::InsertOp(const std::string &type,
const NameVarBaseMap &inputs,
const NameVarBaseMap &outputs,
......@@ -147,12 +157,16 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
std::vector<std::string> feed_var_names;
for (auto &feed_var : feed_vars) {
feed_var_names.emplace_back(generator.NameOf(feed_var, feed_prefix));
if (ContainVar(feed_var)) {
feed_var_names.emplace_back(generator.NameOf(feed_var, feed_prefix));
}
}
std::vector<std::string> fetch_var_names;
for (auto &fetch_var : fetch_vars) {
fetch_var_names.emplace_back(generator.NameOf(fetch_var, fetch_prefix));
if (ContainVar(fetch_var)) {
fetch_var_names.emplace_back(generator.NameOf(fetch_var, fetch_prefix));
}
}
for (auto &op : ops_) {
......@@ -164,7 +178,9 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
std::vector<std::string> names;
names.reserve(pair.second.size());
for (auto &var : pair.second) {
names.emplace_back(generator.NameOf(var, tmp_prefix));
if (ContainVar(var)) {
names.emplace_back(generator.NameOf(var, tmp_prefix));
}
}
op_desc->SetInput(pair.first, std::move(names));
......@@ -174,7 +190,9 @@ TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
std::vector<std::string> names;
names.reserve(pair.second.size());
for (auto &var : pair.second) {
names.emplace_back(generator.NameOf(var, tmp_prefix));
if (ContainVar(var)) {
names.emplace_back(generator.NameOf(var, tmp_prefix));
}
}
op_desc->SetOutput(pair.first, std::move(names));
......
......@@ -66,7 +66,7 @@ class ProgramDescTracer {
const std::string &feed_prefix,
const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
const std::string &fetch_prefix, const std::string &tmp_prefix) const;
bool ContainVar(const std::weak_ptr<VarBase> &var) const;
void Reset();
private:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册