From cdc44a54b6b2a1d9d62e07cb231ad7dcd0f29cbd Mon Sep 17 00:00:00 2001 From: sunli Date: Wed, 12 Oct 2022 19:48:10 +0800 Subject: [PATCH] fix wz review (#46937) * fix wz review * update code --- .../paddle2cinn/cinn_subgraph_detector.cc | 60 +++++-------------- .../paddle2cinn/cinn_subgraph_detector.h | 42 +++++++++---- 2 files changed, 46 insertions(+), 56 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.cc b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.cc index 26416269c9e..dc36f40d9c6 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.cc @@ -50,46 +50,16 @@ std::unordered_set GetConsumerOps(Node* node) { return consumers; } -struct Hasher { - size_t operator()(const CinnSubGraphPtr& subgraph) const noexcept { - return std::hash()(reinterpret_cast(subgraph.get())); - } -}; -struct Comparator { - bool operator()(const CinnSubGraphPtr& first, - const CinnSubGraphPtr& second) const noexcept { - return first.get() == second.get(); - } -}; - -struct CinnSubGraph { - using CinnSubGraphPtr = std::shared_ptr; - // construct function - CinnSubGraph() {} - // construct function - CinnSubGraph(Node* op, bool subst) : substitute(subst) { Insert(op); } +void CinnSubGraph::Insert(Node* op) { + nodes.push_back(op); + node_set.insert(op); - void Insert(Node* op) { - nodes.push_back(op); - node_set.insert(op); - - auto producers = GetProducerOps(op); - for (auto producer : producers) { - input_nodes.insert(producer); - } - input_nodes.erase(op); + auto producers = GetProducerOps(op); + for (auto producer : producers) { + input_nodes.insert(producer); } - - int depth{0}; - int max_depth{0}, min_depth{INT_MAX}; - bool substitute{true}; - std::vector nodes; - std::unordered_set node_set; - std::unordered_set input_nodes; - - std::unordered_set producers; - std::unordered_set consumers; -}; + input_nodes.erase(op); +} void CinnSubgraphDetector::DoOpFusion() { // sort node from input to output @@ -183,7 +153,7 @@ void CinnSubgraphDetector::DoSubGraphFusion() { continue; } // do fusion - update |= FuseSubGraph(&subgraph); + update |= FuseSubGraph(subgraph); } if (!update) { break; @@ -191,8 +161,8 @@ void CinnSubgraphDetector::DoSubGraphFusion() { } } -bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr* subgraph_ptr) { - auto producer = *subgraph_ptr; +bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr subgraph_ptr) { + auto producer = subgraph_ptr; auto& consumers = producer->consumers; std::vector candidates; for (auto& consumer : consumers) { @@ -276,11 +246,11 @@ bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr* subgraph_ptr) { bool CinnSubgraphDetector::IsDependency( const CinnSubGraphPtr& producer_g, const CinnSubGraphPtr& consumer, - const std::unordered_set& consumers) { + const std::unordered_set& consumers) { std::queue candidates; candidates.push(consumer); - std::unordered_set visited_set; + std::unordered_set visited_set; while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); @@ -303,12 +273,12 @@ bool CinnSubgraphDetector::IsDependency( bool CinnSubgraphDetector::IsDependencySimplify( const CinnSubGraphPtr& producer_g, const CinnSubGraphPtr& consumer, - const std::unordered_set& consumers) { + const std::unordered_set& consumers) { std::queue candidates; candidates.push(consumer); // check upper bound. int check_upper_depth = producer_g->max_depth; - std::unordered_set visited_set; + std::unordered_set visited_set; while (!candidates.empty()) { auto& candidate = candidates.front(); candidates.pop(); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h index 1eb3ebbe62f..e8ff3915c85 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_subgraph_detector.h @@ -31,10 +31,32 @@ namespace paddle2cinn { using Node = ir::Node; using Graph = ir::Graph; -struct Hasher; -struct Comparator; +/* + * + * + */ struct CinnSubGraph; using CinnSubGraphPtr = std::shared_ptr; + +struct CinnSubGraph { + // construct function + CinnSubGraph() {} + // construct function + CinnSubGraph(Node *op, bool subst) : substitute(subst) { Insert(op); } + void Insert(Node *op); + + int depth{0}; + int max_depth{0}; + int min_depth{INT_MAX}; + bool substitute{true}; + std::vector nodes; + std::unordered_set node_set; + std::unordered_set input_nodes; + + std::unordered_set producers; + std::unordered_set consumers; +}; + /* * Detect the nodes in a subgraph that meet some conditions. This class doesn't * modify the graph. @@ -55,16 +77,14 @@ class CinnSubgraphDetector { void BuildSubGraph(); // SubGraph Fusion void DoSubGraphFusion(); - bool FuseSubGraph(CinnSubGraphPtr *); + bool FuseSubGraph(CinnSubGraphPtr); // check exist depency. - bool IsDependency( - const CinnSubGraphPtr &, - const CinnSubGraphPtr &, - const std::unordered_set &); - bool IsDependencySimplify( - const CinnSubGraphPtr &, - const CinnSubGraphPtr &, - const std::unordered_set &); + bool IsDependency(const CinnSubGraphPtr &, + const CinnSubGraphPtr &, + const std::unordered_set &); + bool IsDependencySimplify(const CinnSubGraphPtr &, + const CinnSubGraphPtr &, + const std::unordered_set &); private: Graph *graph_; -- GitLab