diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc index 56d5831f3329b94d06940107f99150616b03eeb9..7334d9ad466c0fc0248b7146e50b13808540f104 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.cc @@ -94,6 +94,77 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, } } +ConvAffineChannelFusePass::ConvAffineChannelFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC"}) + .End(); + + AddOpCompat(OpCompat("affine_channel")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("data_layout") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); +} + void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -116,6 +187,11 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { int found_conv_ac_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "ConvAffineChannelFusePass in op compat failed."; + return; + } + VLOG(4) << "handle ConvAffineChannel fuse"; GET_CONV_BN_NODES(conv_ac_pattern); @@ -149,6 +225,12 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetType("elementwise_add"); desc.SetAttr("axis", 1); desc.SetAttr("use_mkldnn", conv->Op()->GetAttrIfExists("use_mkldnn")); + + if (!IsCompat(desc)) { + LOG(WARNING) << "ConvAffineChannelFusePass in out fc op compat failed."; + return; + } + auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel}); @@ -164,6 +246,75 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_conv_ac_count); } +ConvEltwiseAddAffineChannelFusePass::ConvEltwiseAddAffineChannelFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC"}) + .End(); + AddOpCompat(OpCompat("affine_channel")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Scale") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("data_layout") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); +} + void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -186,6 +337,12 @@ void ConvEltwiseAddAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { int found_conv_ac_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) + << "ConvEltwiseAddAffineChannelFusePass in op compat failed."; + return; + } + VLOG(4) << "handle ConvBN fuse"; GET_CONV_BN_NODES(conv_ac_pattern); diff --git a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h index 916384ec44704537f472c8b99bc5766489bd1ced..8cfaf5c6a89f06b453dbbc94b5a7fe8b83e5c111 100644 --- a/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h @@ -31,6 +31,7 @@ class Graph; class ConvAffineChannelFusePass : public FusePassBase { public: + ConvAffineChannelFusePass(); virtual ~ConvAffineChannelFusePass() {} protected: @@ -40,6 +41,7 @@ class ConvAffineChannelFusePass : public FusePassBase { class ConvEltwiseAddAffineChannelFusePass : public FusePassBase { public: + ConvEltwiseAddAffineChannelFusePass(); virtual ~ConvEltwiseAddAffineChannelFusePass() {} protected: