未验证 提交 7e964579 编写于 作者: W Wangzheee 提交者: GitHub

[pass_enhance] fc_lstm_fuse_pass; mul_lstm_fuse_pass (#33811)

上级 1db36584
...@@ -29,8 +29,149 @@ namespace ir { ...@@ -29,8 +29,149 @@ namespace ir {
class Node; class Node;
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, MulLstmFusePass::MulLstmFusePass() {
bool with_fc_bias) { AddOpCompat(OpCompat("lstm"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("H0")
.IsTensor()
.IsOptional()
.End()
.AddInput("C0")
.IsTensor()
.IsOptional()
.End()
.AddInput("Weight")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Hidden")
.IsTensor()
.End()
.AddOutput("Cell")
.IsTensor()
.End()
.AddOutput("BatchGate")
.IsTensor()
.End()
.AddOutput("BatchCellPreAct")
.IsTensor()
.End()
.AddAttr("use_peepholes")
.IsType<bool>()
.End()
.AddAttr("is_reverse")
.IsType<bool>()
.End()
.AddAttr("gate_activation")
.IsStringIn({"sigmoid", "tanh", "relu", "identity"})
.End()
.AddAttr("cell_activation")
.IsStringIn({"sigmoid", "tanh", "relu", "identity"})
.End()
.AddAttr("candidate_activation")
.IsStringIn({"sigmoid", "tanh", "relu", "identity"})
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
FCLstmFusePass::FCLstmFusePass() {
AddOpCompat(OpCompat("lstm"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("H0")
.IsTensor()
.IsOptional()
.End()
.AddInput("C0")
.IsTensor()
.IsOptional()
.End()
.AddInput("Weight")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Hidden")
.IsTensor()
.End()
.AddOutput("Cell")
.IsTensor()
.End()
.AddOutput("BatchGate")
.IsTensor()
.End()
.AddOutput("BatchCellPreAct")
.IsTensor()
.End()
.AddAttr("use_peepholes")
.IsType<bool>()
.End()
.AddAttr("is_reverse")
.IsType<bool>()
.End()
.AddAttr("gate_activation")
.IsStringIn({"sigmoid", "tanh", "relu", "identity"})
.End()
.AddAttr("cell_activation")
.IsStringIn({"sigmoid", "tanh", "relu", "identity"})
.End()
.AddAttr("candidate_activation")
.IsStringIn({"sigmoid", "tanh", "relu", "identity"})
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(-1)
.End();
}
int FCLstmFusePass::BuildFusion(Graph* graph, const std::string& name_scope,
Scope* scope, bool with_fc_bias) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
...@@ -140,6 +281,10 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -140,6 +281,10 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
......
...@@ -31,16 +31,19 @@ class Graph; ...@@ -31,16 +31,19 @@ class Graph;
class FCLstmFusePass : public FusePassBase { class FCLstmFusePass : public FusePassBase {
public: public:
FCLstmFusePass();
virtual ~FCLstmFusePass() {} virtual ~FCLstmFusePass() {}
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
bool with_fc_bias) const;
const std::string name_scope_{"fc_lstm_fuse"}; const std::string name_scope_{"fc_lstm_fuse"};
}; };
class MulLstmFusePass : public FusePassBase { class MulLstmFusePass : public FCLstmFusePass {
public: public:
MulLstmFusePass();
virtual ~MulLstmFusePass() {} virtual ~MulLstmFusePass() {}
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册