diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 16a0f0d03fc40738ec4bdaee7d5eb2b0fbd5551f..15b3429ef170a7e750b2a4d004ba21100a8071ef 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1092,12 +1092,12 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( return ele_add_grad; } +// conv_type: conv2d, conv3d, conv2d_transpose PDNode *patterns::ConvBias::operator()( - paddle::framework::ir::PDNode *conv_input, bool is_conv3d) { - std::string type = is_conv3d ? "conv3d" : "conv2d"; + paddle::framework::ir::PDNode *conv_input, std::string conv_type) { // Create Operators - conv_input->assert_is_op_input(type, "Input"); - auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(type); + conv_input->assert_is_op_input(conv_type, "Input"); + auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type); auto *eltiwse_op = pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); // Create variables @@ -1105,11 +1105,11 @@ PDNode *patterns::ConvBias::operator()( auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) ->AsInput() ->assert_is_persistable_var() - ->assert_is_op_input(type, "Filter"); + ->assert_is_op_input(conv_type, "Filter"); // intermediate variable, will be removed in the IR after fuse. auto *conv_out_var = pattern->NewNode(conv_out_repr()) ->AsIntermediate() - ->assert_is_only_output_of_op(type) + ->assert_is_only_output_of_op(conv_type) ->assert_is_op_input("elementwise_add"); // Bias stored in elementwise_add auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr()) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 4a90f086fe4c35fc4e12d2f623c561c2de5d335b..1c53b9105225e6840bacb2edbe6ffe373ac16110 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -669,7 +669,7 @@ struct ElewiseAddActInplaceGrad : public PatternBase { struct ConvBias : public PatternBase { ConvBias(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "conv_bias") {} - PDNode* operator()(PDNode* conv_input, bool is_conv3d = false); + PDNode* operator()(PDNode* conv_input, std::string conv_type = "conv2d"); // declare operator node's name PATTERN_DECL_NODE(conv); PATTERN_DECL_NODE(eltwise); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 8ef3993b065bcd37dcd571ba5a284cd35cfe052d..bbfc8c005580bb949b498e4474c4059cd09f56b3 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -45,16 +45,14 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { auto* scope = param_scope(); PADDLE_ENFORCE(scope); - std::string type = is_conv3d() ? "conv3d" : "conv2d"; - GraphPatternDetector gpd; auto* conv_input = gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) ->AsInput() - ->assert_is_op_input(type, "Input"); + ->assert_is_op_input(type(), "Input"); patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_); - conv_bias_pattern(conv_input, is_conv3d()); + conv_bias_pattern(conv_input, type()); int found_conv_bias_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -75,7 +73,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { // check if fuse can be done and if MKL-DNN should be used FuseOptions fuse_option = FindFuseOption(*conv, *eltwise); if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) { - VLOG(3) << "do not perform conv+bias fuse"; + VLOG(3) << "do not perform " + type() + "+bias fuse"; return; } @@ -110,7 +108,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetInput("Filter", std::vector({conv_weight->Name()})); desc.SetInput("Bias", std::vector({eltwise_bias->Name()})); desc.SetOutput("Output", std::vector({eltwise_out->Name()})); - desc.SetType(type); + desc.SetType(type()); for (auto& attr : conv->Op()->GetAttrMap()) { desc.SetAttr(attr.first, attr.second); @@ -135,5 +133,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { } // namespace paddle REGISTER_PASS(conv_bias_mkldnn_fuse_pass, paddle::framework::ir::ConvBiasFusePass); +REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, + paddle::framework::ir::Conv2DTransposeBiasFusePass); REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, paddle::framework::ir::Conv3DBiasFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h index 84106d0655d5578338da3b5993f3d2ec191542fd..833fbc748ebd03377ebaa6a5fa72d334ff8b7d37 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h @@ -26,7 +26,7 @@ namespace ir { class ConvBiasFusePass : public FusePassBase { public: virtual ~ConvBiasFusePass() {} - virtual bool is_conv3d() const { return false; } + virtual std::string type() const { return "conv2d"; } protected: void ApplyImpl(ir::Graph* graph) const override; @@ -35,9 +35,14 @@ class ConvBiasFusePass : public FusePassBase { /* * Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp. */ +class Conv2DTransposeBiasFusePass : public ConvBiasFusePass { + public: + std::string type() const override { return "conv2d_transpose"; } +}; + class Conv3DBiasFusePass : public ConvBiasFusePass { public: - bool is_conv3d() const override { return true; } + std::string type() const override { return "conv3d"; } }; } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc index 9f61817674767cbe72b7ef3c6be64a0dd28946ab..a6546cb452a1ae0939ca7a189b8a9ca45c876fd5 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc @@ -141,7 +141,12 @@ TEST(ConvBiasFusePass, conv_with_existing_bias) { MainTest(true); } TEST(ConvBiasFusePass, conv3d) { Conv3DBiasFusePass pass; - ASSERT_TRUE(pass.is_conv3d()); + ASSERT_EQ(pass.type(), std::string("conv3d")); +} + +TEST(ConvBiasFusePass, conv2d_transpose) { + Conv2DTransposeBiasFusePass pass; + ASSERT_EQ(pass.type(), std::string("conv2d_transpose")); } } // namespace ir diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 3dc9814d0d192dff62c7206f5038e55fddf672ee..f5bca9e5562e8c623f7f91e26462d2e084642afb 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -169,6 +169,7 @@ void CpuPassStrategy::EnableMKLDNN() { "conv_bn_fuse_pass", // Execute BN passes again to "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order "conv_bias_mkldnn_fuse_pass", // + "conv_transpose_bias_mkldnn_fuse_pass", "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass",