Fuse based on IR/PASS
Created by: Superjomn
Fusion based on IR/Analysis
PDPattern
and PDNode
Pattern related:
-
We treat operator fusion problem as a modification of IR graph
- extract some subgraphs that meets some pre-defined pattern/rules
- remove the subgraphs and insert the new fused operator(or a new sub-graph)
-
Get a sub-graph and modify the nodes is easy, so the core problem is how to extract the subgraphs that meets some pattern, and identify the nodes in those subgraphs. We provide some basic module and concept for such graph tasks:
-
PDPattern
: a pattern that defines how a target subgraph(or subgraphs) is like -
PDNode
: basic node in aPDPattern
, onePDNode
corresponding to oneNode
meets some conditions in the IR graph. - The prefix
PD
representsPattern-Detector
. - The
PDPattern
is a graph ofPDNode
, so there are nodes(PDNode
) and edges in the pattern. - Usage:
PDpattern pattern; PDNode* new_node = pattern.NewNode("name"); // One can add some assertions to help extract the more specific nodes for a pattern. PDNode* mul_op = pattern.NewNode("some_op")->assert_is_op("mul"); PDNode* mul_input = pattern.NewNode("some_var")->assert_is_var()->assert_is_op_input("mul"); // more assertions can be combined. // Because the operator's definition is clear, free to add more assertions to make the // PDNode or the PDPattern's rule more clear. PDNode* mul_input_x = pattern.NewNode("some_var_clear")->assert_is_op_input("mul", "X"); // Check the file graph_pattern_detector.h, there are more assertions provided.
After creating the
PDNode
, we can add edges to make them a graphmul_op->LinksFrom({mul_input, mul_input_x});
-
a.LinksFrom({b,c})
means there are two edges,b->a
,c->b
-
a.LinksTo({b, c})
meansa->b
,a->c
-
-
Some nodes in the IR after fusion be be removed, so PDNode in a pattern might not share with other patterns, such as the MUL's output variable in a FC pattern, we mark these
PDNode
as intermediate nodes with the APIPDNode.AsIntermediate()
. The PDNodes that marked as intermediate can only exists in one pattern. There are two other similar APISAsInput
andAsOutput
, not works currently.
Lets take FullyConnected
layer combined by mul
and elementwise_add
op as a example
PDPattern pattern;
// Create operator nodes
auto* mul = pattern->NewNode("mul")->assert_is_op("mul")->AsIntermediate(); // MUL will be removed after fusion, so marked as intermediate;
auto* elementwise_add = pattern->NewNode("elementwise_add")->assert_is_op("elementwise_add")->AsIntermediate();
// Create variables
auto* mul_out = pattern->NewNode("mul_out")->assert_is_op_output("mul")->AsIntermediate();
// Link each other
mul_out->LinksFrom({mul})->LinksTo({elementwise_add});
// PDPattern done
GraphPatternDetector
- the core module to detect some patterns in the IR graph
- finally extract some sub-graphs that meets the provided pattern
- Basic APIs
-
mutable_pattern()
get the mutable pattern object, one can customize the pattern -
operator(graph, handler)
to perform the pattern extraction
-
- Usages:
GraphPatternDetector pd;
auto* pattern = pd.mutable_pattern();
// Create a pattern
// ...
auto subgraph_handler = [](const GraphPatternDetector::subgraph_t& subgraph, ir::Graph* graph) {
// Create new ir::Nodes and insert to graph
// remove the subgraph
};
pd(graph, subgraph_handler);
Best practice
Currently, manully defining PDPattern
is quite trivial and fragile. We tried some ways to make it easier:
-
patterns
to shared thePDPattern
defined for different fusion passes- reference the
namespace patterns
ingraph_pattern_detector.h
, there should have theFC
andLSTM
patterns, one can use them as a function, combine them. -
PDPattern.DotString()
will return the DOT code for a pattern, one can easily visualize the pattern.
- reference the
-
graph_viz_pass
to visualization the modification of the IR graph.- After each fusion pass, add a
graph_viz_pass
and it will visualize the IR graph, and make it easier to compare the graphes generated by different passes.
- After each fusion pass, add a
Some real codes
https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/fluid/framework/ir/fc_fuse_pass.cc#L33