未验证 提交 310b7dba 编写于 作者: J jiangcheng 提交者: GitHub

fix build_cinn_pass internal var may be control var problem (#40812)

* fix build_cinn_pass internal var may be control var problem

* add annotation and vlog by review advice
上级 98244a9a
......@@ -220,8 +220,12 @@ std::unordered_set<std::string> ExtractNoNeedBufferFeeds(
// 1. Find op with NoNeedBufferVarsInferer defined and collect its input nodes
std::unordered_map<Node*, GraphNodeSet> op_node2no_need_buffer_nodes;
for (auto* op_node : cluster) {
auto& inferer =
OpInfoMap::Instance().Get(op_node->Name()).NoNeedBufferVarsInferer();
const auto* op = OpInfoMap::Instance().GetNullable(op_node->Name());
// If op not registered in Paddle, skip
if (!op) {
continue;
}
auto& inferer = op->NoNeedBufferVarsInferer();
if (!inferer) {
continue;
}
......@@ -300,10 +304,19 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
GraphNodeMap old_var2new_var;
for (auto* var : cluster_internals) {
PADDLE_ENFORCE_NOT_NULL(var->Var(),
platform::errors::PreconditionNotMet(
"The var desc of the node in cluster_internals "
"shouldn't be null."));
if (!var->Var()) {
// skip control var
// TODO(jiangcheng05): CINN not support control var now, so here we skip
// it, but it may incur result incorrect problem. In detail, for two
// unconnected ops, with control var, an op must run before another op.
// If we remove the control var, the program wouldn't guarantee the run
// ordering, in other words, the result may incorrect.
VLOG(4)
<< "The internal var [" << var->Name() << "]'s vardesc empty,"
<< " it may be a control var, but CINN not support control var now.";
continue;
}
auto* sub_node = subgraph->CreateVarNode(var->Var());
old_var2new_var[var] = sub_node;
}
......@@ -327,6 +340,10 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
// out-graph.
for (auto* op : cluster) {
for (auto* var : op->inputs) {
if (!var->Var()) {
// skip control var
continue;
}
// one output var maybe an input of the cluster
if (cluster_internals.count(var) ||
(cluster_outputs.count(var) && old_var2new_var.count(var))) {
......@@ -346,6 +363,10 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
}
}
for (auto* var : op->outputs) {
if (!var->Var()) {
// skip control var
continue;
}
if (cluster_internals.count(var)) {
IR_NODE_LINK_TO(old_op2new_op.at(op), old_var2new_var.at(var));
} else if (cluster_outputs.count(var) && var->Var() != nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册