diff --git a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt index 9049da9179258342de971784ece0ff9de3f969a2..6aced31feb2f003ad30095adf9d0a34facfdf687 100644 --- a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt +++ b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt @@ -2,6 +2,7 @@ pass_library( build_cinn_pass base DEPS + cinn_subgraph_detector subgraph_detector cinn_compiler errors @@ -11,6 +12,10 @@ cc_library( cinn_cache_key SRCS cinn_cache_key.cc DEPS graph graph_helper lod_tensor proto_desc) +cc_library( + cinn_subgraph_detector + SRCS cinn_subgraph_detector.cc + DEPS graph graph_helper subgraph_detector lod_tensor proto_desc) cc_library( transform_desc SRCS transform_desc.cc diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 9aca93dd73551a2ea0d90e8256143a1a8f4c771b..77b7bd7e338f128b36298a8466f41378f99be59a 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -31,10 +31,10 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/node.h" -#include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h" #include "paddle/fluid/operators/cinn/cinn_launch_op.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" @@ -645,10 +645,9 @@ void SearchAllSubgraphs(Graph* graph) { }; VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; - std::vector clusters = - framework::ir::SubgraphDetector(graph, teller)(); - LOG(INFO) << "--- [build_cinn_pass] detected " << clusters.size() - << " cinn supported subgraphs"; + std::vector clusters = CinnSubgraphDetector(graph, teller)(); + VLOG(3) << "--- [build_cinn_pass] detected " << clusters.size() + << " cinn supported subgraphs"; auto cluster_debug_info = [](const GraphNodeSet& cluster) { std::string res = "("; diff --git a/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.cc b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.cc new file mode 100644 index 0000000000000000000000000000000000000000..26416269c9e1f1a8d888e9f4c4736458bdc6cd36 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.cc @@ -0,0 +1,352 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h" +#include "glog/logging.h" +#include "paddle/fluid/framework/ir/subgraph_detector.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +std::unordered_set GetProducerOps(Node* node) { + CHECK(node->IsOp()); + std::unordered_set producers; + + for (auto input_var : node->inputs) { + CHECK(input_var->IsVar()); + for (auto input_op : input_var->inputs) { + CHECK(input_op->IsOp()); + producers.insert(input_op); + } + } + + return producers; +} + +std::unordered_set GetConsumerOps(Node* node) { + CHECK(node->IsOp()); + std::unordered_set consumers; + + for (auto output_var : node->outputs) { + CHECK(output_var->IsVar()); + for (auto output_op : output_var->outputs) { + CHECK(output_op->IsOp()); + consumers.insert(output_op); + } + } + + return consumers; +} + +struct Hasher { + size_t operator()(const CinnSubGraphPtr& subgraph) const noexcept { + return std::hash()(reinterpret_cast(subgraph.get())); + } +}; +struct Comparator { + bool operator()(const CinnSubGraphPtr& first, + const CinnSubGraphPtr& second) const noexcept { + return first.get() == second.get(); + } +}; + +struct CinnSubGraph { + using CinnSubGraphPtr = std::shared_ptr; + // construct function + CinnSubGraph() {} + // construct function + CinnSubGraph(Node* op, bool subst) : substitute(subst) { Insert(op); } + + void Insert(Node* op) { + nodes.push_back(op); + node_set.insert(op); + + auto producers = GetProducerOps(op); + for (auto producer : producers) { + input_nodes.insert(producer); + } + input_nodes.erase(op); + } + + int depth{0}; + int max_depth{0}, min_depth{INT_MAX}; + bool substitute{true}; + std::vector nodes; + std::unordered_set node_set; + std::unordered_set input_nodes; + + std::unordered_set producers; + std::unordered_set consumers; +}; + +void CinnSubgraphDetector::DoOpFusion() { + // sort node from input to output + for (auto& node : TopologicalSort(*graph_)) { + if (node.IsVar()) { + continue; + } + nodes_.push_back(&node); + } + // reverse from output to input + std::reverse(nodes_.begin(), nodes_.end()); + + // do fusion + for (auto* node : nodes_) { + auto subgraph = + subgraph_map_.count(node) + ? subgraph_map_[node] + : std::make_shared(node, node_classifier_(node)); + if (!subgraph_map_.count(node)) { + subgraph_map_[node] = subgraph; + } + auto producers = GetProducerOps(node); + + for (auto producer : producers) { + if (node_classifier_(producer) != subgraph->substitute) { + continue; + } + + bool can_fused = true; + auto consumers = GetConsumerOps(producer); + for (auto consumer : consumers) { + if (!subgraph->node_set.count(consumer)) { + can_fused = false; + break; + } + } + + if (!can_fused) { + continue; + } + + // fuse producer to sub-graph + subgraph->Insert(producer); + subgraph_map_[producer] = subgraph; + } + } +} + +void CinnSubgraphDetector::BuildSubGraph() { + std::unordered_set subgraph_set; + for (auto node : nodes_) { + CHECK(subgraph_map_.count(node)); + auto& subgraph = subgraph_map_[node]; + if (subgraph_set.count(subgraph.get())) { + continue; + } + + subgraph_set.insert(subgraph.get()); + subgraph_list_.push_back(subgraph); + } + + for (auto& subgraph : subgraph_list_) { + for (auto& input_node : subgraph->input_nodes) { + CHECK(subgraph_map_.count(input_node)); + auto& producer = subgraph_map_[input_node]; + subgraph->producers.insert(producer); + producer->consumers.insert(subgraph); + } + } + + // init group depth. + for (auto& subgraph : subgraph_list_) { + for (auto& consumer : subgraph->consumers) { + // update depth. + subgraph->depth = std::max(subgraph->depth, consumer->depth + 1); + } + subgraph->max_depth = subgraph->depth; + subgraph->min_depth = subgraph->depth; + } + + // reverse to keep fusion group in order. + std::reverse(subgraph_list_.begin(), subgraph_list_.end()); +} + +void CinnSubgraphDetector::DoSubGraphFusion() { + while (true) { + bool update = false; + for (auto& subgraph : subgraph_list_) { + // sub graph is not substitute + if (!subgraph->substitute) { + continue; + } + // do fusion + update |= FuseSubGraph(&subgraph); + } + if (!update) { + break; + } + } +} + +bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr* subgraph_ptr) { + auto producer = *subgraph_ptr; + auto& consumers = producer->consumers; + std::vector candidates; + for (auto& consumer : consumers) { + if (!consumer->substitute) { + continue; + } + // fast depency check. + if (IsDependencySimplify(producer, consumer, consumers)) { + continue; + } + // global depency check. + if (IsDependency(producer, consumer, consumers)) { + continue; + } + + candidates.push_back(consumer); + } + + if (!candidates.size()) { + return false; + } + + // fuse candidate to producer + for (auto& candidate : candidates) { + candidate->substitute = false; + + // merge nodes + producer->nodes.insert(producer->nodes.end(), + candidate->nodes.begin(), + candidate->nodes.end()); + producer->node_set.insert(candidate->node_set.begin(), + candidate->node_set.end()); + + // update bound for check depency + producer->max_depth = std::max(producer->max_depth, candidate->max_depth); + producer->min_depth = std::min(producer->min_depth, candidate->min_depth); + + // merge producer/consumer + producer->producers.insert(candidate->producers.begin(), + candidate->producers.end()); + producer->consumers.insert(candidate->consumers.begin(), + candidate->consumers.end()); + // update producers's consumer + for (auto& tmp : candidate->producers) { + if (tmp.get() == producer.get()) { + continue; + } + tmp->consumers.insert(producer); + tmp->consumers.erase(candidate); + } + // update consumers's producer + for (auto& tmp : candidate->consumers) { + tmp->producers.insert(producer); + tmp->producers.erase(candidate); + } + + // remove candicate in producer/consumer + producer->producers.erase(candidate); + producer->consumers.erase(candidate); + + // merge input nodes + producer->input_nodes.insert(candidate->input_nodes.begin(), + candidate->input_nodes.end()); + } + + // remove input nodes that is in node set + auto input_nodes = producer->input_nodes; + for (auto input_node : input_nodes) { + if (producer->node_set.count(input_node)) { + producer->input_nodes.erase(input_node); + } + } + + // remove producer from set. + producer->producers.erase(producer); + producer->consumers.erase(producer); + + return true; +} + +bool CinnSubgraphDetector::IsDependency( + const CinnSubGraphPtr& producer_g, + const CinnSubGraphPtr& consumer, + const std::unordered_set& consumers) { + std::queue candidates; + candidates.push(consumer); + + std::unordered_set visited_set; + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + for (auto& producer : candidate->producers) { + if (producer.get() == producer_g.get()) { + continue; + } + if (consumers.count(producer)) { + return true; + } + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); + } + } + } + return false; +} + +bool CinnSubgraphDetector::IsDependencySimplify( + const CinnSubGraphPtr& producer_g, + const CinnSubGraphPtr& consumer, + const std::unordered_set& consumers) { + std::queue candidates; + candidates.push(consumer); + // check upper bound. + int check_upper_depth = producer_g->max_depth; + std::unordered_set visited_set; + while (!candidates.empty()) { + auto& candidate = candidates.front(); + candidates.pop(); + for (auto& producer : candidate->producers) { + if (producer.get() == producer_g.get()) { + continue; + } + if (producer->min_depth > check_upper_depth) { + continue; + } + if (consumers.count(producer)) { + return true; + } + if (!visited_set.count(producer)) { + visited_set.insert(producer); + candidates.push(producer); + } + } + } + return false; +} + +std::vector> CinnSubgraphDetector::operator()() { + DoOpFusion(); + BuildSubGraph(); + DoSubGraphFusion(); + + std::vector> clusters; + for (auto& subgraph : subgraph_list_) { + if (!subgraph->substitute) { + continue; + } + clusters.push_back(subgraph->nodes); + } + + return clusters; +} + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h new file mode 100644 index 0000000000000000000000000000000000000000..1eb3ebbe62fca50519ff836b2de985b4dbcaf2d2 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_traits.h" +#include "paddle/fluid/framework/ir/node.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +using Node = ir::Node; +using Graph = ir::Graph; + +struct Hasher; +struct Comparator; +struct CinnSubGraph; +using CinnSubGraphPtr = std::shared_ptr; +/* + * Detect the nodes in a subgraph that meet some conditions. This class doesn't + * modify the graph. + */ +class CinnSubgraphDetector { + public: + // Tell whether a node is inside a sub-graph. + using NodeClassifier = std::function; + + CinnSubgraphDetector(Graph *graph, const NodeClassifier &classifier) + : graph_(graph), node_classifier_(classifier) {} + + std::vector> operator()(); + + protected: + // Do Op Fusion + void DoOpFusion(); + void BuildSubGraph(); + // SubGraph Fusion + void DoSubGraphFusion(); + bool FuseSubGraph(CinnSubGraphPtr *); + // check exist depency. + bool IsDependency( + const CinnSubGraphPtr &, + const CinnSubGraphPtr &, + const std::unordered_set &); + bool IsDependencySimplify( + const CinnSubGraphPtr &, + const CinnSubGraphPtr &, + const std::unordered_set &); + + private: + Graph *graph_; + NodeClassifier node_classifier_; + + std::vector nodes_; + std::vector subgraph_list_; + std::unordered_map subgraph_map_; +}; + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle