未验证 提交 899792c1 编写于 作者: W Wangzheee 提交者: GitHub

pass enhance (#33710)

上级 9bf00cd5
......@@ -48,6 +48,60 @@ framework::proto::OpDesc PrepareOpDesc(
return *desc.Proto();
}
ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("ResidualData")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_act_fuse";
FusePassBase::Init(pattern_name, graph);
......@@ -63,6 +117,10 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
......
......@@ -24,6 +24,7 @@ class Graph;
class ConvElementwiseAddActFusePass : public FusePassBase {
public:
ConvElementwiseAddActFusePass();
virtual ~ConvElementwiseAddActFusePass() {}
protected:
......
......@@ -29,6 +29,52 @@ namespace ir {
GET_IR_NODE(elementwise_add_in_y); \
GET_IR_NODE(elementwise_add_out);
ConvElementwiseAddFusePass::ConvElementwiseAddFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("ResidualData")
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.End()
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(1)
.End();
}
void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "conv_elementwise_add_fuse";
FusePassBase::Init(pattern_name, graph);
......@@ -44,6 +90,10 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_NODES;
auto base_op_desc = *conv_op->Op()->Proto();
......
......@@ -24,6 +24,7 @@ class Graph;
class ConvElementwiseAddFusePass : public FusePassBase {
public:
ConvElementwiseAddFusePass();
virtual ~ConvElementwiseAddFusePass() {}
protected:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册