未验证 提交 cdc44a54 编写于 作者: S sunli 提交者: GitHub

fix wz review (#46937)

* fix wz review

* update code
上级 acdaa4fb
......@@ -50,46 +50,16 @@ std::unordered_set<Node*> GetConsumerOps(Node* node) {
return consumers;
}
struct Hasher {
size_t operator()(const CinnSubGraphPtr& subgraph) const noexcept {
return std::hash<uint64_t>()(reinterpret_cast<uint64_t>(subgraph.get()));
}
};
struct Comparator {
bool operator()(const CinnSubGraphPtr& first,
const CinnSubGraphPtr& second) const noexcept {
return first.get() == second.get();
}
};
struct CinnSubGraph {
using CinnSubGraphPtr = std::shared_ptr<CinnSubGraph>;
// construct function
CinnSubGraph() {}
// construct function
CinnSubGraph(Node* op, bool subst) : substitute(subst) { Insert(op); }
void CinnSubGraph::Insert(Node* op) {
nodes.push_back(op);
node_set.insert(op);
void Insert(Node* op) {
nodes.push_back(op);
node_set.insert(op);
auto producers = GetProducerOps(op);
for (auto producer : producers) {
input_nodes.insert(producer);
}
input_nodes.erase(op);
auto producers = GetProducerOps(op);
for (auto producer : producers) {
input_nodes.insert(producer);
}
int depth{0};
int max_depth{0}, min_depth{INT_MAX};
bool substitute{true};
std::vector<Node*> nodes;
std::unordered_set<Node*> node_set;
std::unordered_set<Node*> input_nodes;
std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> producers;
std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> consumers;
};
input_nodes.erase(op);
}
void CinnSubgraphDetector::DoOpFusion() {
// sort node from input to output
......@@ -183,7 +153,7 @@ void CinnSubgraphDetector::DoSubGraphFusion() {
continue;
}
// do fusion
update |= FuseSubGraph(&subgraph);
update |= FuseSubGraph(subgraph);
}
if (!update) {
break;
......@@ -191,8 +161,8 @@ void CinnSubgraphDetector::DoSubGraphFusion() {
}
}
bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr* subgraph_ptr) {
auto producer = *subgraph_ptr;
bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr subgraph_ptr) {
auto producer = subgraph_ptr;
auto& consumers = producer->consumers;
std::vector<CinnSubGraphPtr> candidates;
for (auto& consumer : consumers) {
......@@ -276,11 +246,11 @@ bool CinnSubgraphDetector::FuseSubGraph(CinnSubGraphPtr* subgraph_ptr) {
bool CinnSubgraphDetector::IsDependency(
const CinnSubGraphPtr& producer_g,
const CinnSubGraphPtr& consumer,
const std::unordered_set<CinnSubGraphPtr, Hasher, Comparator>& consumers) {
const std::unordered_set<CinnSubGraphPtr>& consumers) {
std::queue<CinnSubGraphPtr> candidates;
candidates.push(consumer);
std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> visited_set;
std::unordered_set<CinnSubGraphPtr> visited_set;
while (!candidates.empty()) {
auto& candidate = candidates.front();
candidates.pop();
......@@ -303,12 +273,12 @@ bool CinnSubgraphDetector::IsDependency(
bool CinnSubgraphDetector::IsDependencySimplify(
const CinnSubGraphPtr& producer_g,
const CinnSubGraphPtr& consumer,
const std::unordered_set<CinnSubGraphPtr, Hasher, Comparator>& consumers) {
const std::unordered_set<CinnSubGraphPtr>& consumers) {
std::queue<CinnSubGraphPtr> candidates;
candidates.push(consumer);
// check upper bound.
int check_upper_depth = producer_g->max_depth;
std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> visited_set;
std::unordered_set<CinnSubGraphPtr> visited_set;
while (!candidates.empty()) {
auto& candidate = candidates.front();
candidates.pop();
......
......@@ -31,10 +31,32 @@ namespace paddle2cinn {
using Node = ir::Node;
using Graph = ir::Graph;
struct Hasher;
struct Comparator;
/*
*
*
*/
struct CinnSubGraph;
using CinnSubGraphPtr = std::shared_ptr<CinnSubGraph>;
struct CinnSubGraph {
// construct function
CinnSubGraph() {}
// construct function
CinnSubGraph(Node *op, bool subst) : substitute(subst) { Insert(op); }
void Insert(Node *op);
int depth{0};
int max_depth{0};
int min_depth{INT_MAX};
bool substitute{true};
std::vector<Node *> nodes;
std::unordered_set<Node *> node_set;
std::unordered_set<Node *> input_nodes;
std::unordered_set<CinnSubGraphPtr> producers;
std::unordered_set<CinnSubGraphPtr> consumers;
};
/*
* Detect the nodes in a subgraph that meet some conditions. This class doesn't
* modify the graph.
......@@ -55,16 +77,14 @@ class CinnSubgraphDetector {
void BuildSubGraph();
// SubGraph Fusion
void DoSubGraphFusion();
bool FuseSubGraph(CinnSubGraphPtr *);
bool FuseSubGraph(CinnSubGraphPtr);
// check exist depency.
bool IsDependency(
const CinnSubGraphPtr &,
const CinnSubGraphPtr &,
const std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> &);
bool IsDependencySimplify(
const CinnSubGraphPtr &,
const CinnSubGraphPtr &,
const std::unordered_set<CinnSubGraphPtr, Hasher, Comparator> &);
bool IsDependency(const CinnSubGraphPtr &,
const CinnSubGraphPtr &,
const std::unordered_set<CinnSubGraphPtr> &);
bool IsDependencySimplify(const CinnSubGraphPtr &,
const CinnSubGraphPtr &,
const std::unordered_set<CinnSubGraphPtr> &);
private:
Graph *graph_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册