未验证 提交 28bab073 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix the null ptr bug in build_cinn_pass. (#36698)

* Fix the null ptr bug in build_cinn_pass.

* Add test for empty&ctrl var.
上级 81e0c1ba
...@@ -114,7 +114,8 @@ void AddOutputVar(const std::unordered_set<Node*>& output_vars, ...@@ -114,7 +114,8 @@ void AddOutputVar(const std::unordered_set<Node*>& output_vars,
// var node are from internal nodes // var node are from internal nodes
std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_internals, const GraphNodeSet& cluster_internals,
const GraphNodeSet& cluster_inputs) { const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs) {
// Graph's constructor must has one parameter, and in our code, // Graph's constructor must has one parameter, and in our code,
// the ProgramDesc is useless, so here we pass a temporary object. // the ProgramDesc is useless, so here we pass a temporary object.
auto subgraph = std::make_unique<Graph>(framework::ProgramDesc()); auto subgraph = std::make_unique<Graph>(framework::ProgramDesc());
...@@ -127,7 +128,12 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, ...@@ -127,7 +128,12 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
std::unordered_map<Node*, Node*> old_var2new_var; std::unordered_map<Node*, Node*> old_var2new_var;
for (auto* var : cluster_internals) { for (auto* var : cluster_internals) {
auto sub_node = subgraph->CreateVarNode(var->Var()); Node* sub_node;
if (var->Var() == nullptr) {
sub_node = subgraph->CreateEmptyNode(var->Name(), var->NodeType());
} else {
sub_node = subgraph->CreateVarNode(var->Var());
}
old_var2new_var[var] = sub_node; old_var2new_var[var] = sub_node;
} }
...@@ -140,7 +146,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, ...@@ -140,7 +146,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
for (auto* var : op->inputs) { for (auto* var : op->inputs) {
if (cluster_internals.count(var)) { if (cluster_internals.count(var)) {
old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]); old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]);
} else if (cluster_inputs.count(var)) { } else if (cluster_inputs.count(var) && var->Var() != nullptr) {
if (var->Var()->IsParameter()) { if (var->Var()->IsParameter()) {
// Parameters have been preserved in scope, compared to feed var, // Parameters have been preserved in scope, compared to feed var,
// param just need add new var and don't need add feed op. // param just need add new var and don't need add feed op.
...@@ -157,7 +163,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, ...@@ -157,7 +163,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
for (auto* var : op->outputs) { for (auto* var : op->outputs) {
if (cluster_internals.count(var)) { if (cluster_internals.count(var)) {
old_op2new_op[op]->outputs.emplace_back(old_var2new_var[var]); old_op2new_op[op]->outputs.emplace_back(old_var2new_var[var]);
} else { } else if (cluster_outputs.count(var) && var->Var() != nullptr) {
// Create new output var node to guarantee the independency of // Create new output var node to guarantee the independency of
// subgraph. In other words, the subgraph has no connection with // subgraph. In other words, the subgraph has no connection with
// other graph, even the input graph. // other graph, even the input graph.
...@@ -239,14 +245,20 @@ Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs, ...@@ -239,14 +245,20 @@ Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs,
framework::OpDesc special_op_desc; framework::OpDesc special_op_desc;
special_op_desc.SetType(kCinnLaunchOp); special_op_desc.SetType(kCinnLaunchOp);
std::vector<std::string> input_names; std::vector<std::string> input_names;
std::transform(cluster_inputs.begin(), cluster_inputs.end(), std::for_each(cluster_inputs.begin(), cluster_inputs.end(),
std::back_inserter(input_names), [&input_names](Node* n) {
[](Node* n) { return n->Name(); }); if (n->Var() != nullptr) {
input_names.emplace_back(n->Name());
}
});
special_op_desc.SetInput("X", input_names); special_op_desc.SetInput("X", input_names);
std::vector<std::string> output_names; std::vector<std::string> output_names;
std::transform(cluster_outputs.begin(), cluster_outputs.end(), std::for_each(cluster_outputs.begin(), cluster_outputs.end(),
std::back_inserter(output_names), [&output_names](Node* n) {
[](Node* n) { return n->Name(); }); if (n->Var() != nullptr) {
output_names.emplace_back(n->Name());
}
});
special_op_desc.SetOutput("Out", output_names); special_op_desc.SetOutput("Out", output_names);
special_op_desc.SetAttr(kCompilationKey, compilation_key); special_op_desc.SetAttr(kCompilationKey, compilation_key);
special_op_desc.Flush(); special_op_desc.Flush();
...@@ -362,8 +374,8 @@ void SearchAllSubgraphs(Graph* graph) { ...@@ -362,8 +374,8 @@ void SearchAllSubgraphs(Graph* graph) {
&cluster_internals); &cluster_internals);
// Create a new subgraph according to the found cluster and // Create a new subgraph according to the found cluster and
// save it in CinnCompiler // save it in CinnCompiler
std::string compilation_key = cinn_compiler->AddGraph( std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph(
CreateNewSubGraph(cluster_set, cluster_internals, cluster_inputs)); cluster_set, cluster_internals, cluster_inputs, cluster_outputs));
// Replace the found cluster to a new special op node // Replace the found cluster to a new special op node
ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs, ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs,
cluster_outputs, cluster_internals, cluster_outputs, cluster_internals,
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -50,9 +51,10 @@ inline int CountNode(const std::unordered_set<Node*>& nodes, ...@@ -50,9 +51,10 @@ inline int CountNode(const std::unordered_set<Node*>& nodes,
inline Node* GetNode(const std::unordered_set<Node*>& nodes, inline Node* GetNode(const std::unordered_set<Node*>& nodes,
const std::string& op_name) { const std::string& op_name) {
return *std::find_if( return *std::find_if(nodes.begin(), nodes.end(),
nodes.begin(), nodes.end(), [&op_name](const Node* node) {
[&op_name](const Node* node) { return node->Name() == op_name; }); return node->Name().find(op_name) != std::string::npos;
});
} }
inline bool CheckGraphIndependence(const std::unordered_set<Node*>& nodes) { inline bool CheckGraphIndependence(const std::unordered_set<Node*>& nodes) {
...@@ -185,22 +187,25 @@ std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() { ...@@ -185,22 +187,25 @@ std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() {
ir::Node* mul = g->CreateOpNode(&mul_op); ir::Node* mul = g->CreateOpNode(&mul_op);
ir::Node* relu = g->CreateOpNode(&relu_op); ir::Node* relu = g->CreateOpNode(&relu_op);
ir::Node* v0 = g->CreateEmptyNode("var0", Node::Type::kVariable);
ir::Node* v1 = g->CreateVarNode(&var1); ir::Node* v1 = g->CreateVarNode(&var1);
ir::Node* v2 = g->CreateVarNode(&var2); ir::Node* v2 = g->CreateVarNode(&var2);
ir::Node* v3 = g->CreateVarNode(&var3); ir::Node* v3 = g->CreateVarNode(&var3);
ir::Node* v4 = g->CreateVarNode(&var4); ir::Node* v4 = g->CreateVarNode(&var4);
ir::Node* v5 = g->CreateVarNode(&var5); ir::Node* v5 = g->CreateVarNode(&var5);
ir::Node* v6 = g->CreateVarNode(&var6); ir::Node* v6 = g->CreateVarNode(&var6);
ir::Node* v7 = g->CreateControlDepVar();
// fill op node // fill op node
mul->inputs = {v1, v2}; mul->inputs = {v0, v1, v2};
mul->outputs = {v3}; mul->outputs = {v3};
add->inputs = {v3, v4}; add->inputs = {v3, v4};
add->outputs = {v5}; add->outputs = {v5};
relu->inputs = {v5}; relu->inputs = {v5};
relu->outputs = {v6}; relu->outputs = {v6, v7};
// fill variable node // fill variable node
v0->outputs = {mul};
v1->outputs = {mul}; v1->outputs = {mul};
v2->outputs = {mul}; v2->outputs = {mul};
...@@ -213,6 +218,7 @@ std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() { ...@@ -213,6 +218,7 @@ std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() {
v5->outputs = {relu}; v5->outputs = {relu};
v6->inputs = {relu}; v6->inputs = {relu};
v7->inputs = {relu};
return g; return g;
} }
...@@ -225,25 +231,28 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ...@@ -225,25 +231,28 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
pass->Apply(g.get()); pass->Apply(g.get());
// After search, the graph should as following // After search, the graph should as following
// v1 --| // v0 --|
// v2 --| --> kCinnLaunchOp --> v6 // v1 --| |--> v6
// v2 --| --> kCinnLaunchOp |--> v7
// v4 --| // v4 --|
const auto& nodes = g->Nodes(); const auto& nodes = g->Nodes();
ASSERT_EQ(nodes.size(), static_cast<size_t>(5)); ASSERT_EQ(nodes.size(), static_cast<size_t>(7));
ASSERT_TRUE(CheckGraphIndependence(nodes)); ASSERT_TRUE(CheckGraphIndependence(nodes));
// A new op named kCinnLaunchOp should be added // A new op named kCinnLaunchOp should be added
ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
auto* cinn_op = GetNode(nodes, kCinnLaunchOp); auto* cinn_op = GetNode(nodes, kCinnLaunchOp);
auto* v0 = GetNode(nodes, "var0");
auto* v1 = GetNode(nodes, "var1"); auto* v1 = GetNode(nodes, "var1");
auto* v2 = GetNode(nodes, "var2"); auto* v2 = GetNode(nodes, "var2");
auto* v4 = GetNode(nodes, "var4"); auto* v4 = GetNode(nodes, "var4");
auto* v6 = GetNode(nodes, "var6"); auto* v6 = GetNode(nodes, "var6");
auto* v7 = GetNode(nodes, Node::kControlDepVarName);
ASSERT_EQ( ASSERT_EQ(
std::unordered_set<Node*>(cinn_op->inputs.begin(), cinn_op->inputs.end()), std::unordered_set<Node*>(cinn_op->inputs.begin(), cinn_op->inputs.end()),
std::unordered_set<Node*>({v1, v2, v4})); std::unordered_set<Node*>({v0, v1, v2, v4}));
ASSERT_EQ(cinn_op->outputs, std::vector<Node*>({v6})); ASSERT_EQ(cinn_op->outputs, std::vector<Node*>({v6, v7}));
ASSERT_EQ(v1->outputs, std::vector<Node*>({cinn_op})); ASSERT_EQ(v1->outputs, std::vector<Node*>({cinn_op}));
ASSERT_EQ(v6->inputs, std::vector<Node*>({cinn_op})); ASSERT_EQ(v6->inputs, std::vector<Node*>({cinn_op}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册