未验证 提交 cdc44a54 编写于 作者: S sunli 提交者: GitHub

fix wz review (#46937)

* fix wz review

* update code
上级 acdaa4fb
...@@ -50,46 +50,16 @@ std::unordered_set<Node*> GetConsumerOps(Node* node) { ...@@ -50,46 +50,16 @@ std::unordered_set<Node*> GetConsumerOps(Node* node) {
return consumers; return consumers;
} }
struct Hasher { void CinnSubGraph::Insert(Node* op) {
size_t operator()(const CinnSubGraphPtr& subgraph) const noexcept { nodes.push_back(op);
return std::hash<uint64_t>()(reinterpret_cast<uint64_t>(subgraph.get())); node_set.insert(op);
}
};
struct Comparator {
bool operator()(const CinnSubGraphPtr& first,
const CinnSubGraphPtr& second) const noexcept {
return first.get() == second.get();
}
};
struct CinnSubGraph {
using CinnSubGraphPtr = std::shared_ptr<CinnSubGraph>;
// construct function
CinnSubGraph() {}
// construct function
CinnSubGraph(Node* op, bool subst) : substitute(subst) { Insert(op); }
void Insert(Node* op) { auto producers = GetProducerOps(op);
nodes.push_back(op); for (auto producer : producers) {
node_set.insert(op); input_nodes.insert(producer);
auto producers = GetProducerOps(op);
for (auto producer : producers) {
input_nodes.insert(producer);
}
input_nodes.erase(op);
} }
input_nodes.erase(op);
int depth{0}; }
int max_depth{0}, min_depth{INT_MAX};
bool substitute{true};
std::vector<Node*> nodes;
std::unordered_set<Node*> node_set;
std::unordered_set<Node*> input_nodes;
std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> producers;
std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> consumers;
};
void CinnSubgraphDetector::DoOpFusion() { void CinnSubgraphDetector::DoOpFusion() {
// sort node from input to output // sort node from input to output
...@@ -183,7 +153,7 @@ void CinnSubgraphDetector::DoSubGraphFusion() { ...@@ -183,7 +153,7 @@ void CinnSubgraphDetector::DoSubGraphFusion() {
continue; continue;
} }
// do fusion // do fusion
update |= FuseSubGraph(&subgraph); update |= FuseSubGraph(subgraph);
} }
if (!update) { if (!update) {
break; break;
...@@ -191,8 +161,8 @@ void CinnSubgraphDetector::DoSubGraphFusion() { ...@@ -191,8 +161,8 @@ void CinnSubgraphDetector::DoSubGraphFusion() {
} }
} }
bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr* subgraph_ptr) { bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr subgraph_ptr) {
auto producer = *subgraph_ptr; auto producer = subgraph_ptr;
auto& consumers = producer->consumers; auto& consumers = producer->consumers;
std::vector<CinnSubGraphPtr> candidates; std::vector<CinnSubGraphPtr> candidates;
for (auto& consumer : consumers) { for (auto& consumer : consumers) {
...@@ -276,11 +246,11 @@ bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr* subgraph_ptr) { ...@@ -276,11 +246,11 @@ bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr* subgraph_ptr) {
bool CinnSubgraphDetector::IsDependency( bool CinnSubgraphDetector::IsDependency(
const CinnSubGraphPtr& producer_g, const CinnSubGraphPtr& producer_g,
const CinnSubGraphPtr& consumer, const CinnSubGraphPtr& consumer,
const std::unordered_set<CinnSubGraphPtr, Hasher, Comparator>& consumers) { const std::unordered_set<CinnSubGraphPtr>& consumers) {
std::queue<CinnSubGraphPtr> candidates; std::queue<CinnSubGraphPtr> candidates;
candidates.push(consumer); candidates.push(consumer);
std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> visited_set; std::unordered_set<CinnSubGraphPtr> visited_set;
while (!candidates.empty()) { while (!candidates.empty()) {
auto& candidate = candidates.front(); auto& candidate = candidates.front();
candidates.pop(); candidates.pop();
...@@ -303,12 +273,12 @@ bool CinnSubgraphDetector::IsDependency( ...@@ -303,12 +273,12 @@ bool CinnSubgraphDetector::IsDependency(
bool CinnSubgraphDetector::IsDependencySimplify( bool CinnSubgraphDetector::IsDependencySimplify(
const CinnSubGraphPtr& producer_g, const CinnSubGraphPtr& producer_g,
const CinnSubGraphPtr& consumer, const CinnSubGraphPtr& consumer,
const std::unordered_set<CinnSubGraphPtr, Hasher, Comparator>& consumers) { const std::unordered_set<CinnSubGraphPtr>& consumers) {
std::queue<CinnSubGraphPtr> candidates; std::queue<CinnSubGraphPtr> candidates;
candidates.push(consumer); candidates.push(consumer);
// check upper bound. // check upper bound.
int check_upper_depth = producer_g->max_depth; int check_upper_depth = producer_g->max_depth;
std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> visited_set; std::unordered_set<CinnSubGraphPtr> visited_set;
while (!candidates.empty()) { while (!candidates.empty()) {
auto& candidate = candidates.front(); auto& candidate = candidates.front();
candidates.pop(); candidates.pop();
......
...@@ -31,10 +31,32 @@ namespace paddle2cinn { ...@@ -31,10 +31,32 @@ namespace paddle2cinn {
using Node = ir::Node; using Node = ir::Node;
using Graph = ir::Graph; using Graph = ir::Graph;
struct Hasher; /*
struct Comparator; *
*
*/
struct CinnSubGraph; struct CinnSubGraph;
using CinnSubGraphPtr = std::shared_ptr<CinnSubGraph>; using CinnSubGraphPtr = std::shared_ptr<CinnSubGraph>;
struct CinnSubGraph {
// construct function
CinnSubGraph() {}
// construct function
CinnSubGraph(Node *op, bool subst) : substitute(subst) { Insert(op); }
void Insert(Node *op);
int depth{0};
int max_depth{0};
int min_depth{INT_MAX};
bool substitute{true};
std::vector<Node *> nodes;
std::unordered_set<Node *> node_set;
std::unordered_set<Node *> input_nodes;
std::unordered_set<CinnSubGraphPtr> producers;
std::unordered_set<CinnSubGraphPtr> consumers;
};
/* /*
* Detect the nodes in a subgraph that meet some conditions. This class doesn't * Detect the nodes in a subgraph that meet some conditions. This class doesn't
* modify the graph. * modify the graph.
...@@ -55,16 +77,14 @@ class CinnSubgraphDetector { ...@@ -55,16 +77,14 @@ class CinnSubgraphDetector {
void BuildSubGraph(); void BuildSubGraph();
// SubGraph Fusion // SubGraph Fusion
void DoSubGraphFusion(); void DoSubGraphFusion();
bool FuseSubGraph(CinnSubGraphPtr *); bool FuseSubGraph(CinnSubGraphPtr);
// check exist depency. // check exist depency.
bool IsDependency( bool IsDependency(const CinnSubGraphPtr &,
const CinnSubGraphPtr &, const CinnSubGraphPtr &,
const CinnSubGraphPtr &, const std::unordered_set<CinnSubGraphPtr> &);
const std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> &); bool IsDependencySimplify(const CinnSubGraphPtr &,
bool IsDependencySimplify( const CinnSubGraphPtr &,
const CinnSubGraphPtr &, const std::unordered_set<CinnSubGraphPtr> &);
const CinnSubGraphPtr &,
const std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> &);
private: private:
Graph *graph_; Graph *graph_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册