未验证 提交 2ef34c64 编写于 作者: Y Yan Chunwei 提交者: GitHub

refine fc with pattern reusing (#13187)

上级 796c87d5
...@@ -99,17 +99,13 @@ void FindWhileOp(Graph* graph) { ...@@ -99,17 +99,13 @@ void FindWhileOp(Graph* graph) {
auto* cell_init = graph->RetriveNode(6); auto* cell_init = graph->RetriveNode(6);
auto* hidden_init = graph->RetriveNode(8); auto* hidden_init = graph->RetriveNode(8);
#define LINK_TO(node0, node1) \
node0->outputs.push_back(node1); \
node1->inputs.push_back(node0);
auto* lstm_op = graph->CreateOpNode(&op_desc); auto* lstm_op = graph->CreateOpNode(&op_desc);
PrepareParameters(graph, param); PrepareParameters(graph, param);
LINK_TO(X, lstm_op); IR_NODE_LINK_TO(X, lstm_op);
LINK_TO(cell_init, lstm_op); IR_NODE_LINK_TO(cell_init, lstm_op);
LINK_TO(hidden_init, lstm_op); IR_NODE_LINK_TO(hidden_init, lstm_op);
LINK_TO(lstm_op, LSTMOUT); IR_NODE_LINK_TO(lstm_op, LSTMOUT);
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} }
......
...@@ -21,59 +21,6 @@ namespace paddle { ...@@ -21,59 +21,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
bool VarOutLinksToOp(Node* node, const std::string& op_type) {
for (auto* out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
void BuildFCPattern(PDPattern* pattern) {
// Create Operators
auto* mul_op = pattern->NewNode("mul")->assert_is_op("mul");
auto* elementwise_add_op =
pattern->NewNode("elementwise_add")->assert_is_op("elementwise_add");
// Create variables
// w
auto* mul_weight_var = pattern->NewNode("mul_weight")
->AsInput()
->assert_is_op_nth_input("mul", "Y", 0);
// x
auto* mul_tmp_var = pattern->NewNode("mul_tmp_var")
->AsInput()
->assert_is_op_nth_input("mul", "X", 0);
// intermediate variable, will be removed in the IR after fuse.
auto* mul_out_var = pattern->NewNode("mul_out")
->AsIntermediate()
->assert_is_only_output_of_op("mul")
->assert_is_op_input("elementwise_add");
// bias
auto* elementwise_add_tmp_var = pattern->NewNode("elementwise_add_tmpvar")
->assert_is_op_input("elementwise_add")
->AsInput();
// output
auto* elementwise_add_out_var = pattern->NewNode("elementwise_add_out")
->AsOutput()
->assert_is_op_output("elementwise_add");
mul_op->LinksFrom({mul_weight_var, mul_tmp_var}).LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
.LinksTo({elementwise_add_out_var});
}
// Replace the node `from` in the links to `to`
bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
for (auto*& n : *links) {
if (n == from) {
n = to;
return true;
}
}
return false;
}
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const { std::unique_ptr<ir::Graph> graph) const {
PADDLE_ENFORCE(graph.get()); PADDLE_ENFORCE(graph.get());
...@@ -82,13 +29,18 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -82,13 +29,18 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unordered_set<Node*> nodes2delete; std::unordered_set<Node*> nodes2delete;
GraphPatternDetector gpd; GraphPatternDetector gpd;
BuildFCPattern(gpd.mutable_pattern()); // BuildFCPattern(gpd.mutable_pattern());
auto* x = gpd.mutable_pattern()
#define GET_NODE(id) \ ->NewNode("fc_fuse/x")
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode(#id)), \ ->AsInput()
"pattern has no Node called %s", #id); \ ->assert_is_op_input("mul", "X");
auto* id = subgraph.at(gpd.pattern().RetrieveNode(#id)); \ patterns::FC(gpd.mutable_pattern(), "fc_fuse", x, true /*with bias*/);
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
#define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode("fc_fuse/" #id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetrieveNode("fc_fuse/" #id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", "fc_fuse/" #id);
int found_fc_count = 0; int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
...@@ -98,43 +50,33 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -98,43 +50,33 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
// scenerio. // scenerio.
// FC's fusion is simple, just op fuse, no need to process the // FC's fusion is simple, just op fuse, no need to process the
// parameters. // parameters.
GET_NODE(mul_tmp_var); // x GET_NODE(x); // x
GET_NODE(mul_weight); // Y GET_NODE(w); // Y
GET_NODE(elementwise_add_tmpvar); // bias GET_NODE(fc_bias); // bias
GET_NODE(elementwise_add_out); // Out GET_NODE(fc_out); // Out
GET_NODE(mul); // MUL op GET_NODE(mul); // MUL op
GET_NODE(elementwise_add); // ELEMENT_ADD op GET_NODE(elementwise_add); // ELEMENT_ADD op
GET_NODE(mul_out); // tmp GET_NODE(mul_out); // tmp
#undef GET_NODE #undef GET_NODE
// Create an FC Node. // Create an FC Node.
OpDesc desc; OpDesc desc;
std::string fc_x_in = mul_tmp_var->Name(); std::string fc_x_in = x->Name();
std::string fc_Y_in = mul_weight->Name(); std::string fc_Y_in = w->Name();
std::string fc_bias_in = elementwise_add_tmpvar->Name(); std::string fc_bias_in = fc_bias->Name();
std::string fc_out = elementwise_add_out->Name(); std::string fc_out_out = fc_out->Name();
desc.SetInput("Input", std::vector<std::string>({fc_x_in})); desc.SetInput("Input", std::vector<std::string>({fc_x_in}));
desc.SetInput("W", std::vector<std::string>({fc_Y_in})); desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in})); desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
desc.SetOutput("Out", std::vector<std::string>({fc_out})); desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
desc.SetType("fc"); desc.SetType("fc");
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
fc_node->inputs = GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});
std::vector<Node*>({mul_tmp_var, mul_weight, elementwise_add_tmpvar});
fc_node->outputs.push_back(elementwise_add_out);
// Update link relatons
PADDLE_ENFORCE(LinksReplace(&mul_tmp_var->outputs, mul, fc_node));
PADDLE_ENFORCE(LinksReplace(&mul_weight->outputs, mul, fc_node));
PADDLE_ENFORCE(LinksReplace(&elementwise_add_tmpvar->outputs,
elementwise_add, fc_node));
PADDLE_ENFORCE(
LinksReplace(&elementwise_add_out->inputs, elementwise_add, fc_node));
// Drop old nodes IR_NODE_LINK_TO(x, fc_node);
graph->RemoveNode(mul); IR_NODE_LINK_TO(w, fc_node);
graph->RemoveNode(elementwise_add); IR_NODE_LINK_TO(fc_bias, fc_node);
graph->RemoveNode(mul_out); // tmp variable IR_NODE_LINK_TO(fc_node, fc_out);
found_fc_count++; found_fc_count++;
}; };
......
...@@ -121,15 +121,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -121,15 +121,11 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
#undef TMP_NEW #undef TMP_NEW
#undef TMP_NAME #undef TMP_NAME
#define LINK_TO(a, b) \ IR_NODE_LINK_TO(input_n, op);
a->outputs.push_back(b); \ IR_NODE_LINK_TO(weight_x_n, op);
b->inputs.push_back(a); IR_NODE_LINK_TO(weight_h_n, op);
LINK_TO(input_n, op); IR_NODE_LINK_TO(bias_n, op);
LINK_TO(weight_x_n, op); IR_NODE_LINK_TO(op, hidden_n);
LINK_TO(weight_h_n, op);
LINK_TO(bias_n, op);
LINK_TO(op, hidden_n);
#undef LINK_TO
return op; return op;
}; };
......
...@@ -297,6 +297,10 @@ PDNode* LSTM(PDPattern* pattern, const std::string& name_scope, PDNode* x); ...@@ -297,6 +297,10 @@ PDNode* LSTM(PDPattern* pattern, const std::string& name_scope, PDNode* x);
} // namespace patterns } // namespace patterns
#define IR_NODE_LINK_TO(a, b) \
a->outputs.push_back(b); \
b->inputs.push_back(a);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -219,16 +219,13 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( ...@@ -219,16 +219,13 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
op_desc.SetAttr("fc_activation", act->Op()->Type()); op_desc.SetAttr("fc_activation", act->Op()->Type());
auto* op_node = graph->CreateOpNode(&op_desc); auto* op_node = graph->CreateOpNode(&op_desc);
// Add links // Add links
#define NODE_LINKS(a, b) \ IR_NODE_LINK_TO(fc_w, op_node);
a->outputs.push_back(b); \ IR_NODE_LINK_TO(fc_bias, op_node);
b->inputs.push_back(a); IR_NODE_LINK_TO(concat_in0, op_node);
NODE_LINKS(fc_w, op_node); IR_NODE_LINK_TO(sequence_expand0_in, op_node);
NODE_LINKS(fc_bias, op_node); IR_NODE_LINK_TO(sequence_expand1_in, op_node);
NODE_LINKS(concat_in0, op_node); IR_NODE_LINK_TO(op_node, fc_out);
NODE_LINKS(sequence_expand0_in, op_node);
NODE_LINKS(sequence_expand1_in, op_node);
NODE_LINKS(op_node, fc_out);
// Clean nodes. // Clean nodes.
std::unordered_set<const Node*> marked_nodes; std::unordered_set<const Node*> marked_nodes;
...@@ -241,7 +238,6 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( ...@@ -241,7 +238,6 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
marked_nodes.erase(sequence_expand0_in); marked_nodes.erase(sequence_expand0_in);
marked_nodes.erase(sequence_expand1_in); marked_nodes.erase(sequence_expand1_in);
marked_nodes.erase(fc_out); marked_nodes.erase(fc_out);
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
}); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册