未验证 提交 dfbfbd01 编写于 作者: F feng_shuai 提交者: GitHub

enhance Conv elementwise add2 act fuse pass (#33564)

* tmp

* pass con_element_add2_act

* recover unittests CMakeLists

* init pass enhance

* fix the attr according to review

* repair the attr conv2d

* repair axis of elementwise_add

* CI-coverage test=allcase

* repari some attr

* recover batch_norm_act

* conv_elementwise_add2_act_fuse
上级 68106509
...@@ -52,6 +52,56 @@ framework::proto::OpDesc PrepareOpDesc( ...@@ -52,6 +52,56 @@ framework::proto::OpDesc PrepareOpDesc(
desc.Flush(); desc.Flush();
return *desc.Proto(); return *desc.Proto();
} }
ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NHWC", "NCHW"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
// the first elementwise_add-axis needs to be 1, the second has to be -1
.IsIntIn({1, -1})
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add2_act_fuse"; const std::string pattern_name = "conv_elementwise_add2_act_fuse";
...@@ -66,6 +116,10 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -66,6 +116,10 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass op compat failed.";
return;
}
GET_NODES; GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto(); auto base_op_desc = *conv_op->Op()->Proto();
......
...@@ -24,6 +24,7 @@ class Graph; ...@@ -24,6 +24,7 @@ class Graph;
class ConvElementwiseAdd2ActFusePass : public FusePassBase { class ConvElementwiseAdd2ActFusePass : public FusePassBase {
public: public:
ConvElementwiseAdd2ActFusePass();
virtual ~ConvElementwiseAdd2ActFusePass() {} virtual ~ConvElementwiseAdd2ActFusePass() {}
protected: protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册