未验证 提交 28bab073 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix the null ptr bug in build_cinn_pass. (#36698)

* Fix the null ptr bug in build_cinn_pass.

* Add test for empty&ctrl var.
上级 81e0c1ba
......@@ -114,7 +114,8 @@ void AddOutputVar(const std::unordered_set<Node*>& output_vars,
// var node are from internal nodes
std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_internals,
const GraphNodeSet& cluster_inputs) {
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs) {
// Graph's constructor must has one parameter, and in our code,
// the ProgramDesc is useless, so here we pass a temporary object.
auto subgraph = std::make_unique<Graph>(framework::ProgramDesc());
......@@ -127,7 +128,12 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
std::unordered_map<Node*, Node*> old_var2new_var;
for (auto* var : cluster_internals) {
auto sub_node = subgraph->CreateVarNode(var->Var());
Node* sub_node;
if (var->Var() == nullptr) {
sub_node = subgraph->CreateEmptyNode(var->Name(), var->NodeType());
} else {
sub_node = subgraph->CreateVarNode(var->Var());
}
old_var2new_var[var] = sub_node;
}
......@@ -140,7 +146,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
for (auto* var : op->inputs) {
if (cluster_internals.count(var)) {
old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]);
} else if (cluster_inputs.count(var)) {
} else if (cluster_inputs.count(var) && var->Var() != nullptr) {
if (var->Var()->IsParameter()) {
// Parameters have been preserved in scope, compared to feed var,
// param just need add new var and don't need add feed op.
......@@ -157,7 +163,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
for (auto* var : op->outputs) {
if (cluster_internals.count(var)) {
old_op2new_op[op]->outputs.emplace_back(old_var2new_var[var]);
} else {
} else if (cluster_outputs.count(var) && var->Var() != nullptr) {
// Create new output var node to guarantee the independency of
// subgraph. In other words, the subgraph has no connection with
// other graph, even the input graph.
......@@ -239,14 +245,20 @@ Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs,
framework::OpDesc special_op_desc;
special_op_desc.SetType(kCinnLaunchOp);
std::vector<std::string> input_names;
std::transform(cluster_inputs.begin(), cluster_inputs.end(),
std::back_inserter(input_names),
[](Node* n) { return n->Name(); });
std::for_each(cluster_inputs.begin(), cluster_inputs.end(),
[&input_names](Node* n) {
if (n->Var() != nullptr) {
input_names.emplace_back(n->Name());
}
});
special_op_desc.SetInput("X", input_names);
std::vector<std::string> output_names;
std::transform(cluster_outputs.begin(), cluster_outputs.end(),
std::back_inserter(output_names),
[](Node* n) { return n->Name(); });
std::for_each(cluster_outputs.begin(), cluster_outputs.end(),
[&output_names](Node* n) {
if (n->Var() != nullptr) {
output_names.emplace_back(n->Name());
}
});
special_op_desc.SetOutput("Out", output_names);
special_op_desc.SetAttr(kCompilationKey, compilation_key);
special_op_desc.Flush();
......@@ -362,8 +374,8 @@ void SearchAllSubgraphs(Graph* graph) {
&cluster_internals);
// Create a new subgraph according to the found cluster and
// save it in CinnCompiler
std::string compilation_key = cinn_compiler->AddGraph(
CreateNewSubGraph(cluster_set, cluster_internals, cluster_inputs));
std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph(
cluster_set, cluster_internals, cluster_inputs, cluster_outputs));
// Replace the found cluster to a new special op node
ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs,
cluster_outputs, cluster_internals,
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm>
#include <memory>
#include <string>
#include "gtest/gtest.h"
......@@ -50,9 +51,10 @@ inline int CountNode(const std::unordered_set<Node*>& nodes,
inline Node* GetNode(const std::unordered_set<Node*>& nodes,
const std::string& op_name) {
return *std::find_if(
nodes.begin(), nodes.end(),
[&op_name](const Node* node) { return node->Name() == op_name; });
return *std::find_if(nodes.begin(), nodes.end(),
[&op_name](const Node* node) {
return node->Name().find(op_name) != std::string::npos;
});
}
inline bool CheckGraphIndependence(const std::unordered_set<Node*>& nodes) {
......@@ -185,22 +187,25 @@ std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() {
ir::Node* mul = g->CreateOpNode(&mul_op);
ir::Node* relu = g->CreateOpNode(&relu_op);
ir::Node* v0 = g->CreateEmptyNode("var0", Node::Type::kVariable);
ir::Node* v1 = g->CreateVarNode(&var1);
ir::Node* v2 = g->CreateVarNode(&var2);
ir::Node* v3 = g->CreateVarNode(&var3);
ir::Node* v4 = g->CreateVarNode(&var4);
ir::Node* v5 = g->CreateVarNode(&var5);
ir::Node* v6 = g->CreateVarNode(&var6);
ir::Node* v7 = g->CreateControlDepVar();
// fill op node
mul->inputs = {v1, v2};
mul->inputs = {v0, v1, v2};
mul->outputs = {v3};
add->inputs = {v3, v4};
add->outputs = {v5};
relu->inputs = {v5};
relu->outputs = {v6};
relu->outputs = {v6, v7};
// fill variable node
v0->outputs = {mul};
v1->outputs = {mul};
v2->outputs = {mul};
......@@ -213,6 +218,7 @@ std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() {
v5->outputs = {relu};
v6->inputs = {relu};
v7->inputs = {relu};
return g;
}
......@@ -225,25 +231,28 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
pass->Apply(g.get());
// After search, the graph should as following
// v1 --|
// v2 --| --> kCinnLaunchOp --> v6
// v0 --|
// v1 --| |--> v6
// v2 --| --> kCinnLaunchOp |--> v7
// v4 --|
const auto& nodes = g->Nodes();
ASSERT_EQ(nodes.size(), static_cast<size_t>(5));
ASSERT_EQ(nodes.size(), static_cast<size_t>(7));
ASSERT_TRUE(CheckGraphIndependence(nodes));
// A new op named kCinnLaunchOp should be added
ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
auto* cinn_op = GetNode(nodes, kCinnLaunchOp);
auto* v0 = GetNode(nodes, "var0");
auto* v1 = GetNode(nodes, "var1");
auto* v2 = GetNode(nodes, "var2");
auto* v4 = GetNode(nodes, "var4");
auto* v6 = GetNode(nodes, "var6");
auto* v7 = GetNode(nodes, Node::kControlDepVarName);
ASSERT_EQ(
std::unordered_set<Node*>(cinn_op->inputs.begin(), cinn_op->inputs.end()),
std::unordered_set<Node*>({v1, v2, v4}));
ASSERT_EQ(cinn_op->outputs, std::vector<Node*>({v6}));
std::unordered_set<Node*>({v0, v1, v2, v4}));
ASSERT_EQ(cinn_op->outputs, std::vector<Node*>({v6, v7}));
ASSERT_EQ(v1->outputs, std::vector<Node*>({cinn_op}));
ASSERT_EQ(v6->inputs, std::vector<Node*>({cinn_op}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册