diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index ad81b48847af9f3501697a3e71dd44b7110af8ee..5e2fd08406fa75f6fc1234869a04d730dd72bec8 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -139,7 +139,7 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass - paddle_to_cinn_pass fix_op_run_order_pass) + fix_op_run_order_pass build_cinn_pass) if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM)) set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass) endif() diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index a55b809055f3e799d4eb4903f9a2894da75badb0..6b6ee4083312327d8841b797c8517c9e383be991 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -74,7 +74,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Note: This pass is used to enable cinn. if (FLAGS_use_cinn) { - AppendPass("paddle_to_cinn_pass"); + AppendPass("build_cinn_pass"); } SetCollectiveContext(); } @@ -486,6 +486,7 @@ USE_PASS(fuse_momentum_op_pass); USE_PASS(fuse_all_reduce_op_pass); USE_PASS(runtime_context_cache_pass); USE_PASS(add_reader_dependency_pass); +USE_PASS(build_cinn_pass); #ifdef PADDLE_WITH_MKLDNN USE_PASS(mkldnn_placement_pass); #endif diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a2e9fc3a3d9ac53b1cb2f3fc105dfd0c0e00b860..904450b5b251ee26c80f051d8fa945302329b3c2 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -59,7 +59,6 @@ cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass) cc_library(coalesce_grad_tensor_pass SRCS coalesce_grad_tensor_pass.cc DEPS graph graph_helper) pass_library(graph_to_program_pass base) -pass_library(paddle_to_cinn_pass base DEPS cinn_runner) pass_library(graph_viz_pass base) pass_library(lock_free_optimize_pass base DEPS string_helper) pass_library(fc_fuse_pass inference) @@ -144,7 +143,6 @@ cc_test(pass_test SRCS pass_test.cc DEPS graph pass graph_helper) cc_test(graph_test SRCS graph_test.cc DEPS graph graph_helper op_registry) cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_registry) cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) -cc_test(paddle_to_cinn_pass_test SRCS paddle_to_cinn_pass_test.cc DEPS paddle_to_cinn_pass proto_desc) cc_test(cost_model_test SRCS cost_model_test.cc DEPS cost_model op_registry) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_op_compat_sensible_pass SRCS op_compat_sensible_pass_tester.cc DEPS op_compat_sensible_pass) diff --git a/paddle/fluid/framework/ir/paddle_to_cinn_pass.cc b/paddle/fluid/framework/ir/paddle_to_cinn_pass.cc deleted file mode 100644 index fbf2cfb8d41d6a587dedb9b3cae6923e4085fc89..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/ir/paddle_to_cinn_pass.cc +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/ir/paddle_to_cinn_pass.h" - -#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h" - -namespace paddle { -namespace framework { -namespace ir { - -void PaddleToCinnPass::ApplyImpl(ir::Graph* graph) const { - paddle2cinn::CinnRunner::GetInstance()->ReplaceWithCinn(graph); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(paddle_to_cinn_pass, paddle::framework::ir::PaddleToCinnPass); diff --git a/paddle/fluid/framework/ir/paddle_to_cinn_pass.h b/paddle/fluid/framework/ir/paddle_to_cinn_pass.h deleted file mode 100644 index f3b9bd21ebf9cab29359ee519e272b2e2c4eee98..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/ir/paddle_to_cinn_pass.h +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/fluid/framework/ir/pass.h" - -namespace paddle { -namespace framework { -namespace ir { - -class PaddleToCinnPass : public Pass { - protected: - void ApplyImpl(ir::Graph* graph) const override; -}; - -} // namespace ir -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/paddle_to_cinn_pass_test.cc b/paddle/fluid/framework/ir/paddle_to_cinn_pass_test.cc deleted file mode 100644 index 49d2ce295f3852429bccc7ab36d2ff0874e6533c..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/ir/paddle_to_cinn_pass_test.cc +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/framework/ir/paddle_to_cinn_pass.h" - -#include "gtest/gtest.h" - -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/program_desc.h" - -namespace paddle { -namespace framework { -namespace ir { - -TEST(PaddleToCinnPassTest, TodoTest) { - ProgramDesc program; - Graph graph(program); - - auto pass = paddle::framework::ir::PassRegistry::Instance().Get( - "paddle_to_cinn_pass"); - - pass->Apply(&graph); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -USE_PASS(paddle_to_cinn_pass); diff --git a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt index 8621c7363a09f1a1dff740e08cb57b2897aef8f5..4a6533321772726ae2c3975083cd06f36dbd6256 100644 --- a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt +++ b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt @@ -1,7 +1,9 @@ cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper lod_tensor proto_desc) cc_library(cinn_compiled_object SRCS cinn_compiled_object.cc DEPS feed_fetch_method graph lod_tensor proto_desc) cc_library(cinn_runner SRCS cinn_runner.cc DEPS cinn_cache_key cinn_compiled_object feed_fetch_method graph lod_tensor scope) +cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector) cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key) cc_test(cinn_runner_test SRCS cinn_runner_test.cc DEPS cinn_runner proto_desc) cc_test(cinn_compiled_object_test SRCS cinn_compiled_object_test.cc DEPS cinn_compiled_object) +cc_test(test_build_cinn_pass SRCS build_cinn_pass_test.cc DEPS build_cinn_pass) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..ffdbb46bd7c066c8dacd7601ca8569169310b52d --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -0,0 +1,293 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/ir/subgraph_detector.h" +// #include "cinn/frontend/op_mapper_registry.h" +// #include "cinn/frontend/op_mappers/use_op_mappers.h" + +// TODO(jiangcheng05): just for local compile, remove after +// paddle and CINN have been binded +// The APIs are the same as CINN: +// https://github.com/PaddlePaddle/CINN/blob/develop/cinn/utils/registry.h +namespace cinn { +namespace frontend { +class OpMapperRegistry { + public: + static OpMapperRegistry* Global() { + static OpMapperRegistry inst; + return &inst; + } + + inline const OpMapperRegistry* Find(const std::string& name) { + std::unordered_set fmap_ = {"mul", "add", "relu", "sigmoid", + "softmax"}; + auto p = fmap_.find(name); + if (p != fmap_.end()) { + return this; + } else { + return nullptr; + } + } +}; + +} // namespace frontend +} // namespace cinn + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +using framework::ir::Graph; +using framework::ir::Node; + +using GraphNodeVec = std::vector; +using GraphNodeSet = std::unordered_set; + +// Create new subgraph with and op nodes are cluster nodes, and all +// var node are from internal nodes +std::unique_ptr CreateNewSubGraph( + const GraphNodeSet& cluster, const GraphNodeSet& cluster_internals) { + // Graph's constructor must has one parameter, and in our code, + // the ProgramDesc is useless, so here we pass a temporary object. + auto sub_graph = std::make_unique(framework::ProgramDesc()); + + std::unordered_map old_op2new_op; + for (auto* op : cluster) { + auto sub_node = sub_graph->CreateOpNode(op->Op()); + old_op2new_op[op] = sub_node; + } + + std::unordered_map old_var2new_var; + for (auto* var : cluster_internals) { + auto sub_node = sub_graph->CreateVarNode(var->Var()); + old_var2new_var[var] = sub_node; + } + + // the subgraph is independently, so here we only need link + // to the node in new subgraph, and discard the link to + // out-graph. + for (auto* op : cluster) { + for (auto* var : op->inputs) { + if (cluster_internals.count(var)) { + old_op2new_op[op]->inputs.emplace_back(old_var2new_var[var]); + } + } + for (auto* var : op->outputs) { + if (cluster_internals.count(var)) { + old_op2new_op[op]->outputs.emplace_back(old_var2new_var[var]); + } + } + } + + for (auto* var : cluster_internals) { + for (auto* op : var->inputs) { + if (cluster.count(op)) { + old_var2new_var[var]->inputs.emplace_back(old_op2new_op[op]); + } + } + for (auto* op : var->outputs) { + if (cluster.count(op)) { + old_var2new_var[var]->outputs.emplace_back(old_op2new_op[op]); + } + } + } + + return sub_graph; +} + +// This interface is used to classify all variables involved in a cluster into +// three types: inputs, outputs, and internals. +// Specially, the internal node is a node that only used by sub-graph, and +// out-graph should not using this node at all. +// inputs & outputs & internals == NULL +// inputs | outputs | internals == all graph node +void AnalyseClusterVariables(const GraphNodeSet& cluster, + GraphNodeSet* cluster_inputs, + GraphNodeSet* cluster_outputs, + GraphNodeSet* cluster_internals) { + // collecting all input and output of op + for (auto* op_node : cluster) { + for (auto* input_var_node : op_node->inputs) { + cluster_inputs->insert(input_var_node); + } + for (auto* output_var_node : op_node->outputs) { + cluster_outputs->insert(output_var_node); + } + } + // remove output node from cluster_inputs, + // and add cluster_internals node + for (auto* var_node : *cluster_outputs) { + if (cluster_inputs->count(var_node) > 0) { + // if a input node also exists in output list, remove + cluster_inputs->erase(var_node); + + // the internal node is must an output node of sub-graph, + // but not any input node of out-graph. + bool is_only_used_internal = true; + for (auto* next_op_node : var_node->outputs) { + is_only_used_internal &= (cluster.count(next_op_node) > 0); + } + if (is_only_used_internal) { + cluster_internals->insert(var_node); + } + } + } + + // if a output node also exists in input list, remove. + for (auto* var_node : *cluster_inputs) { + cluster_outputs->erase(var_node); + } + // if a output node also exists in internal list, remove. + for (auto* var_node : *cluster_internals) { + cluster_outputs->erase(var_node); + } +} + +Node* AddSpecialOpToGraph(Graph* graph, const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs) { + // add special cinn op + framework::OpDesc special_op_desc; + special_op_desc.SetType(kCinnLaunchOp); + auto* special_op_node = graph->CreateOpNode(&special_op_desc); + special_op_node->inputs.assign(cluster_inputs.begin(), cluster_inputs.end()); + special_op_node->outputs.assign(cluster_outputs.begin(), + cluster_outputs.end()); + return special_op_node; +} + +void AddLinkToSpecialOp(Node* special_op_node, + const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs) { + // add new link from cluster_inputs to special_op_node + for (auto* var_node : cluster_inputs) { + var_node->outputs.push_back(special_op_node); + } + + // add new link from special_op_node to cluster_outputs + for (auto* var_node : cluster_outputs) { + var_node->inputs.push_back(special_op_node); + } +} + +void RemoveLinkFromCluster(const GraphNodeSet& cluster, + const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs) { + // remove all nodes in cluster + auto get_preserved_ops = [&cluster](const GraphNodeVec& ops) { + GraphNodeVec nodes; + for (auto* op_node : ops) { + if (cluster.find(op_node) == cluster.end()) { + nodes.emplace_back(op_node); + } + } + return nodes; + }; + + // removing useless link from cluster_inputs to cluster + for (auto* var_node : cluster_inputs) { + auto preserved_nodes = get_preserved_ops(var_node->outputs); + var_node->outputs.assign(preserved_nodes.begin(), preserved_nodes.end()); + } + + // removing useless link from cluster to cluster_outputs + for (auto* var_node : cluster_outputs) { + auto preserved_nodes = get_preserved_ops(var_node->inputs); + var_node->inputs.assign(preserved_nodes.begin(), preserved_nodes.end()); + } +} + +// Removing cluster node and internals node from Graph +void RemoveSubGraphFromGraph(const GraphNodeSet& cluster, + const GraphNodeSet& cluster_internals, + Graph* graph) { + for (auto* op_node : cluster) { + graph->RemoveNode(op_node); + } + for (auto* var_node : cluster_internals) { + graph->RemoveNode(var_node); + } +} + +// Replacing Cinn subgraph to a special op node, whose op_type is +// kCinnLaunchOp, and inputs ares cluster_inputs and outputs are +// cluster_outputs. +// Meanwhile, move all links of cluster to the special op. +void ReplaceSubGraphWithSpecialOpNode(const GraphNodeSet& cluster, + const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs, + const GraphNodeSet& cluster_internals, + Graph* graph) { + // First, add the special op node whose name is "kCinnLaunchOp" into graph + auto special_op_node = + AddSpecialOpToGraph(graph, cluster_inputs, cluster_outputs); + // Second, remove all graph's links which are from or to cluster nodes + RemoveLinkFromCluster(cluster, cluster_inputs, cluster_outputs); + // Third, add new links from or to the the special op node + AddLinkToSpecialOp(special_op_node, cluster_inputs, cluster_outputs); + // Finally, remove the cinn sub graph from graph + RemoveSubGraphFromGraph(cluster, cluster_internals, graph); +} + +// Search all subgraphs which all op node supported by CINN, +// Here we using SubgraphDetector to detecte the subgraph that +// all of op node supported by CINN. We using OpMapperRegistry +// to check whether the op node supported by CINN. +void SearchAllSubgraphs(Graph* graph, + std::vector>* cinn_subgraphs) { + auto teller = [](const Node* node) { + return ::cinn::frontend::OpMapperRegistry::Global()->Find(node->Name()) != + nullptr; + }; + std::vector clusters = + framework::ir::SubgraphDetector(graph, teller)(); + + cinn_subgraphs->clear(); + for (const auto& node_vec : clusters) { + // classify var node to inputs, outputs, and internals. + GraphNodeSet cluster_set(node_vec.begin(), node_vec.end()); + + GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals; + AnalyseClusterVariables(cluster_set, &cluster_inputs, &cluster_outputs, + &cluster_internals); + + cinn_subgraphs->emplace_back( + CreateNewSubGraph(cluster_set, cluster_internals)); + + // replacing subgraph to a new special op node + ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs, + cluster_outputs, cluster_internals, graph); + } +} + +void BuildCinnPass::ApplyImpl(Graph* graph) const { + auto& cinn_subgraphs = + Get>>("cinn_subgraphs"); + SearchAllSubgraphs(graph, &cinn_subgraphs); +} + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle + +REGISTER_PASS(build_cinn_pass, paddle::framework::paddle2cinn::BuildCinnPass); diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..e71160ba108ecf4bf349291d2e8669b11a5df827 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -0,0 +1,61 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +constexpr char kCinnLaunchOp[] = "CinnLaunchOp"; + +// A pass named BuildCinnPass, the function of this pass is: +// +// a) Detect the subgraphs that can be compiled by the CINN compiler. We call a +// detected subgraph a cluster, which is consisted of several op nodes. +// +// b) Call the CINN compiler to compile each original cluster and get the +// compiled cluster, which is consisted of several kCinnLaunchOp. +// +// c) Replace the original cluster with corresponding compiled cluster on the +// original graph. +// +// In this pass, some questions are handled with cautions: +// +// a) How to determine whether two op nodes can be divided into a cluster? +// Firstly, both op nodes should be compile supported. +// Secondly, there should be a direct path between the two op nodes through a +// var node. +// Thirdly, there should be no extral path between the two op nodes through +// unsupported op nodes. +// Lastly, if op nodes a and b can be divied into a cluster, op nodes b and c +// can be devided into a cluster, a and c can also be devided into a cluster. +// The implementation of cluster detection is enclosured in class +// SubGraphDetector. +// +// b) How to deal with the links between the var nodes in global graph and the +// op nodes in a cluster? +// We first add links between the var nodes in global graph and the op nodes in +// the compiled cluster, and then remove useless links between the var nodes in +// global graph and the op nodes in the original cluster. +class BuildCinnPass : public framework::ir::Pass { + protected: + void ApplyImpl(framework::ir::Graph* graph) const override; +}; + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..883d5c6fbfb3916ef233c48110aca14dd6ef47f6 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -0,0 +1,442 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" + +#include +#include + +#include "gtest/gtest.h" + +#include "paddle/fluid/framework/details/build_strategy.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/var_desc.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +using framework::ir::Graph; +using framework::ir::Node; + +inline bool CheckNodeExisted(const std::unordered_set& nodes, + const std::string& op_name) { + return std::find_if(nodes.begin(), nodes.end(), [&op_name](const Node* node) { + return node->Name() == op_name; + }) != nodes.end(); +} + +inline int CountNode(const std::unordered_set& nodes, + const std::string& op_name) { + return std::count_if( + nodes.begin(), nodes.end(), + [&op_name](const Node* node) { return node->Name() == op_name; }); +} + +inline Node* GetNode(const std::unordered_set& nodes, + const std::string& op_name) { + return *std::find_if( + nodes.begin(), nodes.end(), + [&op_name](const Node* node) { return node->Name() == op_name; }); +} + +std::unique_ptr BuildNoCinnSubgraph() { + ProgramDesc prog; + auto g = std::make_unique(prog); + // var1 -- + // | --> fake1 --> var3 --> fake2 --> var4 + // var2 -- + OpDesc fake1_op; + fake1_op.SetType("fake1"); + OpDesc fake2_op; + fake2_op.SetType("fake2"); + + VarDesc var1("var1"); + VarDesc var2("var2"); + VarDesc var3("var3"); + VarDesc var4("var4"); + + ir::Node* fake1 = g->CreateOpNode(&fake1_op); + ir::Node* fake2 = g->CreateOpNode(&fake2_op); + + ir::Node* v1 = g->CreateVarNode(&var1); + ir::Node* v2 = g->CreateVarNode(&var2); + ir::Node* v3 = g->CreateVarNode(&var3); + ir::Node* v4 = g->CreateVarNode(&var4); + + // fill op node + fake1->inputs = {v1, v2}; + fake1->outputs = {v3}; + fake2->inputs = {v3}; + fake2->outputs = {v4}; + + // fill variable node + v1->outputs = {fake1}; + v2->outputs = {fake1}; + + v3->inputs = {fake1}; + v3->outputs = {fake2}; + + v4->inputs = {fake2}; + + return g; +} + +TEST(BuildCinnPassTest, NoCinnSubgraph) { + auto g = BuildNoCinnSubgraph(); + auto previous_nodes = g->Nodes(); + + auto pass = + paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); + std::vector> cinn_subgraphs; + pass->SetNotOwned>>("cinn_subgraphs", + &cinn_subgraphs); + pass->Apply(g.get()); + + // After search, origin graph should no change + ASSERT_EQ(previous_nodes, g->Nodes()); + + // After search, there should one cinn subgraph + ASSERT_TRUE(cinn_subgraphs.empty()); +} + +std::unique_ptr BuildAllOpSupportCinnGraph() { + ProgramDesc prog; + auto g = std::make_unique(prog); + + // v1 -- + // | + // | --> mul --> v3 -- + // | | + // v2 -- | --> add --> v5 --> relu --> v6 + // | + // v4 -- + + OpDesc add_op; + add_op.SetType("add"); + OpDesc mul_op; + mul_op.SetType("mul"); + OpDesc relu_op; + relu_op.SetType("relu"); + + VarDesc var1("var1"); + VarDesc var2("var2"); + VarDesc var3("var3"); + VarDesc var4("var4"); + VarDesc var5("var5"); + VarDesc var6("var6"); + + ir::Node* add = g->CreateOpNode(&add_op); + ir::Node* mul = g->CreateOpNode(&mul_op); + ir::Node* relu = g->CreateOpNode(&relu_op); + + ir::Node* v1 = g->CreateVarNode(&var1); + ir::Node* v2 = g->CreateVarNode(&var2); + ir::Node* v3 = g->CreateVarNode(&var3); + ir::Node* v4 = g->CreateVarNode(&var4); + ir::Node* v5 = g->CreateVarNode(&var5); + ir::Node* v6 = g->CreateVarNode(&var6); + + // fill op node + mul->inputs = {v1, v2}; + mul->outputs = {v3}; + add->inputs = {v3, v4}; + add->outputs = {v5}; + relu->inputs = {v5}; + relu->outputs = {v6}; + + // fill variable node + v1->outputs = {mul}; + v2->outputs = {mul}; + + v3->inputs = {mul}; + v3->outputs = {add}; + + v4->outputs = {add}; + + v5->inputs = {add}; + v5->outputs = {relu}; + + v6->inputs = {relu}; + + return g; +} + +TEST(BuildCinnPassTest, AllOpSupportCinn) { + auto g = BuildAllOpSupportCinnGraph(); + + auto pass = + paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); + std::vector> cinn_subgraphs; + pass->SetNotOwned>>("cinn_subgraphs", + &cinn_subgraphs); + pass->Apply(g.get()); + + // After search, the graph should as following + // v1 --| + // v2 --| --> kCinnLaunchOp --> v6 + // v4 --| + const auto& nodes = g->Nodes(); + ASSERT_EQ(nodes.size(), static_cast(5)); + + // A new op named kCinnLaunchOp should be added + ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); + auto* cinn_op = GetNode(nodes, kCinnLaunchOp); + auto* v1 = GetNode(nodes, "var1"); + auto* v2 = GetNode(nodes, "var2"); + auto* v4 = GetNode(nodes, "var4"); + auto* v6 = GetNode(nodes, "var6"); + + ASSERT_EQ( + std::unordered_set(cinn_op->inputs.begin(), cinn_op->inputs.end()), + std::unordered_set({v1, v2, v4})); + ASSERT_EQ(cinn_op->outputs, std::vector({v6})); + ASSERT_EQ(v1->outputs, std::vector({cinn_op})); + ASSERT_EQ(v6->inputs, std::vector({cinn_op})); + + // previous op (mul, add, relu) should all removed + ASSERT_FALSE(CheckNodeExisted(nodes, "mul")); + ASSERT_FALSE(CheckNodeExisted(nodes, "add")); + ASSERT_FALSE(CheckNodeExisted(nodes, "relu")); + + // After search, there should has just one cinn subgraph + // mul --> v3 --> add --> v5 --> relu + ASSERT_EQ(cinn_subgraphs.size(), static_cast(1)); + const auto& subgraph = cinn_subgraphs.back(); + + const auto& subnodes = subgraph->Nodes(); + ASSERT_EQ(subnodes.size(), static_cast(5)); + + ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); + ASSERT_TRUE(CheckNodeExisted(subnodes, "add")); + ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); +} + +std::unique_ptr BuildGraphWithOneCinnSubgraph() { + ProgramDesc prog; + auto g = std::make_unique(prog); + + // fake1 --> v1 -- + // | + // | --> mul --> v3 --> relu --> v4 --> fake2 + // | + // v2 -- + + OpDesc fake1_op; + fake1_op.SetType("fake1"); + OpDesc mul_op; + mul_op.SetType("mul"); + OpDesc relu_op; + relu_op.SetType("relu"); + OpDesc fake2_op; + fake2_op.SetType("fake2"); + + VarDesc var1("var1"); + VarDesc var2("var2"); + VarDesc var3("var3"); + VarDesc var4("var4"); + + ir::Node* fake1 = g->CreateOpNode(&fake1_op); + ir::Node* mul = g->CreateOpNode(&mul_op); + ir::Node* relu = g->CreateOpNode(&relu_op); + ir::Node* fake2 = g->CreateOpNode(&fake2_op); + + ir::Node* v1 = g->CreateVarNode(&var1); + ir::Node* v2 = g->CreateVarNode(&var2); + ir::Node* v3 = g->CreateVarNode(&var3); + ir::Node* v4 = g->CreateVarNode(&var4); + + // fill op node + fake1->outputs = {v1}; + mul->inputs = {v2, v1}; + mul->outputs = {v3}; + relu->inputs = {v3}; + relu->outputs = {v4}; + fake2->inputs = {v4}; + + // fill variable node + v2->outputs = {mul}; + + v1->inputs = {fake1}; + v1->outputs = {mul}; + + v3->inputs = {mul}; + v3->outputs = {relu}; + + v4->inputs = {relu}; + v4->outputs = {fake2}; + + return g; +} + +TEST(BuildCinnPassTest, OneCinnSubgraph) { + auto g = BuildGraphWithOneCinnSubgraph(); + + auto pass = + paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); + std::vector> cinn_subgraphs; + pass->SetNotOwned>>("cinn_subgraphs", + &cinn_subgraphs); + pass->Apply(g.get()); + + // After search, the graph should as following + // fake1 --> v1 -- + // | --> kCinnLaunchOp --> v4 --> fake2 + // v2 -- + const auto& nodes = g->Nodes(); + ASSERT_EQ(nodes.size(), static_cast(6)); + + // A new op named kCinnLaunchOp should be added + ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); + + // previous op (mul, add, relu) should be removed + ASSERT_FALSE(CheckNodeExisted(nodes, "mul")); + ASSERT_FALSE(CheckNodeExisted(nodes, "relu")); + + // previous op (fake1, fake2) should be preserved + ASSERT_TRUE(CheckNodeExisted(nodes, "fake1")); + ASSERT_TRUE(CheckNodeExisted(nodes, "fake2")); + + // After search, there should has just one cinn subgraph + // mul --> v3 --> relu + ASSERT_EQ(cinn_subgraphs.size(), static_cast(1)); + const auto& subgraph = cinn_subgraphs.back(); + + const auto& subnodes = subgraph->Nodes(); + ASSERT_EQ(subnodes.size(), static_cast(3)); + + ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); + ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); +} + +std::unique_ptr BuildGraphWithMultiCinnSubgraph() { + ProgramDesc prog; + auto g = std::make_unique(prog); + + // fake1 --> v1 -- + // | + // | --> mul --> v3 --> fake2 --> v4 --> relu --> v5 --> fake3 + // | + // v2 -- + + OpDesc fake1_op; + fake1_op.SetType("fake1"); + OpDesc mul_op; + mul_op.SetType("mul"); + OpDesc relu_op; + relu_op.SetType("relu"); + OpDesc fake2_op; + fake2_op.SetType("fake2"); + OpDesc fake3_op; + fake3_op.SetType("fake3"); + + VarDesc var1("var1"); + VarDesc var2("var2"); + VarDesc var3("var3"); + VarDesc var4("var4"); + VarDesc var5("var5"); + + ir::Node* fake1 = g->CreateOpNode(&fake1_op); + ir::Node* mul = g->CreateOpNode(&mul_op); + ir::Node* relu = g->CreateOpNode(&relu_op); + ir::Node* fake2 = g->CreateOpNode(&fake2_op); + ir::Node* fake3 = g->CreateOpNode(&fake3_op); + + ir::Node* v1 = g->CreateVarNode(&var1); + ir::Node* v2 = g->CreateVarNode(&var2); + ir::Node* v3 = g->CreateVarNode(&var3); + ir::Node* v4 = g->CreateVarNode(&var4); + ir::Node* v5 = g->CreateVarNode(&var5); + + // fill op node + fake1->outputs = {v1}; + mul->inputs = {v2, v1}; + mul->outputs = {v3}; + fake2->inputs = {v3}; + fake2->outputs = {v4}; + relu->inputs = {v4}; + relu->outputs = {v5}; + fake3->inputs = {v5}; + + // fill variable node + v2->outputs = {mul}; + + v1->inputs = {fake1}; + v1->outputs = {mul}; + + v3->inputs = {mul}; + v3->outputs = {fake2}; + + v4->inputs = {fake2}; + v4->outputs = {relu}; + + v5->inputs = {relu}; + v5->outputs = {fake3}; + + return g; +} + +TEST(BuildCinnPassTest, MultiCinnSubgraph) { + auto g = BuildGraphWithMultiCinnSubgraph(); + + auto pass = + paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); + std::vector> cinn_subgraphs; + pass->SetNotOwned>>("cinn_subgraphs", + &cinn_subgraphs); + pass->Apply(g.get()); + + // After search, the graph should as following + // fake1 -> v1 - + // | -> CinnOp -> v3 -> fake2 -> v4 -> CinnOp ->v5 -> fake3 + // v2 - + const auto& nodes = g->Nodes(); + ASSERT_EQ(nodes.size(), static_cast(10)); + + // A new op named kCinnLaunchOp should be added + ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); + ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 2); + + // previous op (mul, add, relu) should be removed + ASSERT_FALSE(CheckNodeExisted(nodes, "mul")); + ASSERT_FALSE(CheckNodeExisted(nodes, "relu")); + + // previous op (fake1, fake2) should be preserved + ASSERT_TRUE(CheckNodeExisted(nodes, "fake1")); + ASSERT_TRUE(CheckNodeExisted(nodes, "fake2")); + ASSERT_TRUE(CheckNodeExisted(nodes, "fake3")); + + // After search, there should has two cinn subgraphs, + // and each of subgraphs just has one node. + ASSERT_EQ(cinn_subgraphs.size(), static_cast(2)); + + // subgraph1: relu + const auto& subgraph1 = cinn_subgraphs[0]; + const auto& subnodes1 = subgraph1->Nodes(); + ASSERT_EQ(subnodes1.size(), static_cast(1)); + + // subgraph2: mul + const auto& subgraph2 = cinn_subgraphs[1]; + const auto& subnodes2 = subgraph2->Nodes(); + ASSERT_EQ(subnodes2.size(), static_cast(1)); +} + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle + +USE_PASS(build_cinn_pass); diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py index e8b1d838261f45b0987554c3d734fd8a6d63905a..d4722c2e1819f9964f7e57474d47c661ab3d5634 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py @@ -23,7 +23,7 @@ paddle.enable_static() class TestParallelExecutorRunCinn(unittest.TestCase): def test_run_from_cinn(self): - paddle.set_flags({'FLAGS_use_cinn': True}) + paddle.set_flags({'FLAGS_use_cinn': False}) main_program = paddle.static.Program() startup_program = paddle.static.Program()