未验证 提交 20eafd79 编写于 作者: F feng_shuai 提交者: GitHub

Add squared_mat_sub_fuse_pass (#33597)

上级 cf3ddd3b
...@@ -298,7 +298,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern, ...@@ -298,7 +298,8 @@ PDNode* BuildSquaredMatSubPattern(PDPattern* pattern,
return last_out_var; return last_out_var;
} }
static int BuildFusion(Graph* graph, const std::string& name_scope) { static int BuildFusion(Graph* graph, const std::string& name_scope,
const SquaredMatSubFusePass* pass) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
...@@ -320,6 +321,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -320,6 +321,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
LOG(INFO) << "handle sqaure mat sub fuse"; LOG(INFO) << "handle sqaure mat sub fuse";
if (!pass->IsAcceptable(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
auto& fused_pattern = gpd.pattern(); auto& fused_pattern = gpd.pattern();
auto* matx = retrieve_node(name_scope + "/x", subgraph, fused_pattern); auto* matx = retrieve_node(name_scope + "/x", subgraph, fused_pattern);
...@@ -368,14 +374,109 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) { ...@@ -368,14 +374,109 @@ static int BuildFusion(Graph* graph, const std::string& name_scope) {
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count; ++fusion_count;
}; };
gpd(graph, handler); gpd(graph, handler);
return fusion_count; return fusion_count;
} }
SquaredMatSubFusePass::SquaredMatSubFusePass() {
AddOpCompat(OpCompat("square"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("elementwise_sub"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(-1)
.End();
AddOpCompat(OpCompat("elementwise_mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(-1)
.End();
AddOpCompat(OpCompat("fill_constant"))
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("dtype")
.IsNumGE(0)
.IsNumLE(25)
.End()
.AddAttr("shape")
.End()
// type:float,there is no restriction
.AddAttr("value")
.End();
}
// to use IsCompat
bool SquaredMatSubFusePass::IsAcceptable(
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) const {
return IsCompat(subgraph, g);
}
void SquaredMatSubFusePass::ApplyImpl(ir::Graph* graph) const { void SquaredMatSubFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph); FusePassBase::Init(name_scope_, graph);
int fusion_count = BuildFusion(graph, name_scope_); int fusion_count = BuildFusion(graph, name_scope_, this);
AddStatis(fusion_count); AddStatis(fusion_count);
} }
......
...@@ -31,11 +31,13 @@ class Graph; ...@@ -31,11 +31,13 @@ class Graph;
class SquaredMatSubFusePass : public FusePassBase { class SquaredMatSubFusePass : public FusePassBase {
public: public:
SquaredMatSubFusePass();
bool IsAcceptable(const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) const;
virtual ~SquaredMatSubFusePass() {} virtual ~SquaredMatSubFusePass() {}
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"squared_mat_sub_fuse"}; const std::string name_scope_{"squared_mat_sub_fuse"};
}; };
......
type: "elementwise_mul"
def {
inputs {
name: "X"
}
inputs {
name: "Y"
}
outputs {
name: "Out"
}
attrs {
name: "axis"
type: INT
}
}
extra {
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "x_data_format"
type: STRING
}
attrs {
name: "y_data_format"
type: STRING
}
attrs {
name: "use_quantizer"
type: BOOLEAN
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "Scale_x"
type: FLOAT
}
attrs {
name: "Scale_y"
type: FLOAT
}
attrs {
name: "Scale_out"
type: FLOAT
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
...@@ -24,12 +24,13 @@ def { ...@@ -24,12 +24,13 @@ def {
name: "value" name: "value"
type: FLOAT type: FLOAT
} }
attrs {
}
extra {
attrs {
name: "str_value" name: "str_value"
type: STRING type: STRING
} }
}
extra {
attrs { attrs {
name: "force_cpu" name: "force_cpu"
type: BOOLEAN type: BOOLEAN
......
type: "square"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
}
extra {
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "use_cudnn"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册