未验证 提交 1db36584 编写于 作者: W Wangzheee 提交者: GitHub

[pass_enhance] mul_gru_fuse_pass; fc_gru_fuse_pass (#33793)

上级 97f86d84
......@@ -30,8 +30,137 @@ namespace ir {
class Node;
static int BuildFusion(Graph* graph, const std::string& name_scope,
Scope* scope, bool with_fc_bias) {
MulGRUFusePass::MulGRUFusePass() {
AddOpCompat(OpCompat("gru"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("H0")
.IsTensor()
.IsOptional()
.End()
.AddInput("Weight")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("BatchGate")
.IsTensor()
.End()
.AddOutput("BatchResetHiddenPrev")
.IsTensor()
.End()
.AddOutput("BatchHidden")
.IsTensor()
.End()
.AddOutput("Hidden")
.IsTensor()
.End()
.AddAttr("activation")
.IsStringIn({"sigmoid", "tanh", "relu", "identity"})
.End()
.AddAttr("gate_activation")
.IsStringIn({"sigmoid", "tanh", "relu", "identity"})
.End()
.AddAttr("is_reverse")
.IsType<bool>()
.End()
.AddAttr("origin_mode")
.IsType<bool>()
.IsOptional()
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
}
FCGRUFusePass::FCGRUFusePass() {
AddOpCompat(OpCompat("gru"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("H0")
.IsTensor()
.IsOptional()
.End()
.AddInput("Weight")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("BatchGate")
.IsTensor()
.End()
.AddOutput("BatchResetHiddenPrev")
.IsTensor()
.End()
.AddOutput("BatchHidden")
.IsTensor()
.End()
.AddOutput("Hidden")
.IsTensor()
.End()
.AddAttr("activation")
.IsStringIn({"sigmoid", "tanh", "relu", "identity"})
.End()
.AddAttr("gate_activation")
.IsStringIn({"sigmoid", "tanh", "relu", "identity"})
.End()
.AddAttr("is_reverse")
.IsType<bool>()
.End()
.AddAttr("origin_mode")
.IsType<bool>()
.IsOptional()
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumEQ(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumGE(-1)
.End();
}
int FCGRUFusePass::BuildFusion(Graph* graph, const std::string& name_scope,
Scope* scope, bool with_fc_bias) const {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
......@@ -133,6 +262,10 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
auto* x_n = subgraph.at(x);
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
......@@ -189,8 +322,8 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
void MulGRUFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), false /*with_fc_bias*/);
int fusion_count = MulGRUFusePass::BuildFusion(
graph, name_scope_, param_scope(), false /*with_fc_bias*/);
AddStatis(fusion_count);
}
......@@ -198,8 +331,8 @@ void MulGRUFusePass::ApplyImpl(ir::Graph* graph) const {
void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
int fusion_count =
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
int fusion_count = FCGRUFusePass::BuildFusion(
graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count);
}
......
......@@ -18,7 +18,6 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
......@@ -26,21 +25,22 @@ namespace ir {
// The MulGRUFusePass and MulGRUFusePass will fuse to the same FusionGRU op.
class Graph;
class FCGRUFusePass : public FusePassBase {
public:
FCGRUFusePass();
virtual ~FCGRUFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"fc_gru_fuse"};
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
bool with_fc_bias) const;
};
// Just FC without bias
class MulGRUFusePass : public FusePassBase {
class MulGRUFusePass : public FCGRUFusePass {
public:
MulGRUFusePass();
virtual ~MulGRUFusePass() {}
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册