From 6d8075ecef58f127336f3f6e9f152ce95a34539a Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Tue, 28 May 2019 11:50:35 +0200 Subject: [PATCH] [MKL-DNN] conv_transpose mkldnn bias pass (#17644) * - changes to graph detector - Changes to pass - Added ut for new pass - use_pass - Added pass to mkldnn passes - fix to registration - improved verbose messaging for conv bias passes - Lint fixes test=develop * - Lint fixes test=develop --- paddle/fluid/framework/ir/graph_pattern_detector.cc | 12 ++++++------ paddle/fluid/framework/ir/graph_pattern_detector.h | 2 +- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc | 12 ++++++------ .../framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h | 9 +++++++-- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc | 7 ++++++- paddle/fluid/inference/api/paddle_pass_builder.cc | 1 + 6 files changed, 27 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 16a0f0d03f..15b3429ef1 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1092,12 +1092,12 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( return ele_add_grad; } +// conv_type: conv2d, conv3d, conv2d_transpose PDNode *patterns::ConvBias::operator()( - paddle::framework::ir::PDNode *conv_input, bool is_conv3d) { - std::string type = is_conv3d ? "conv3d" : "conv2d"; + paddle::framework::ir::PDNode *conv_input, std::string conv_type) { // Create Operators - conv_input->assert_is_op_input(type, "Input"); - auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(type); + conv_input->assert_is_op_input(conv_type, "Input"); + auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type); auto *eltiwse_op = pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); // Create variables @@ -1105,11 +1105,11 @@ PDNode *patterns::ConvBias::operator()( auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) ->AsInput() ->assert_is_persistable_var() - ->assert_is_op_input(type, "Filter"); + ->assert_is_op_input(conv_type, "Filter"); // intermediate variable, will be removed in the IR after fuse. auto *conv_out_var = pattern->NewNode(conv_out_repr()) ->AsIntermediate() - ->assert_is_only_output_of_op(type) + ->assert_is_only_output_of_op(conv_type) ->assert_is_op_input("elementwise_add"); // Bias stored in elementwise_add auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr()) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 4a90f086fe..1c53b91052 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -669,7 +669,7 @@ struct ElewiseAddActInplaceGrad : public PatternBase { struct ConvBias : public PatternBase { ConvBias(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "conv_bias") {} - PDNode* operator()(PDNode* conv_input, bool is_conv3d = false); + PDNode* operator()(PDNode* conv_input, std::string conv_type = "conv2d"); // declare operator node's name PATTERN_DECL_NODE(conv); PATTERN_DECL_NODE(eltwise); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index 8ef3993b06..bbfc8c0055 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -45,16 +45,14 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { auto* scope = param_scope(); PADDLE_ENFORCE(scope); - std::string type = is_conv3d() ? "conv3d" : "conv2d"; - GraphPatternDetector gpd; auto* conv_input = gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) ->AsInput() - ->assert_is_op_input(type, "Input"); + ->assert_is_op_input(type(), "Input"); patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_); - conv_bias_pattern(conv_input, is_conv3d()); + conv_bias_pattern(conv_input, type()); int found_conv_bias_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -75,7 +73,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { // check if fuse can be done and if MKL-DNN should be used FuseOptions fuse_option = FindFuseOption(*conv, *eltwise); if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) { - VLOG(3) << "do not perform conv+bias fuse"; + VLOG(3) << "do not perform " + type() + "+bias fuse"; return; } @@ -110,7 +108,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { desc.SetInput("Filter", std::vector({conv_weight->Name()})); desc.SetInput("Bias", std::vector({eltwise_bias->Name()})); desc.SetOutput("Output", std::vector({eltwise_out->Name()})); - desc.SetType(type); + desc.SetType(type()); for (auto& attr : conv->Op()->GetAttrMap()) { desc.SetAttr(attr.first, attr.second); @@ -135,5 +133,7 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { } // namespace paddle REGISTER_PASS(conv_bias_mkldnn_fuse_pass, paddle::framework::ir::ConvBiasFusePass); +REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, + paddle::framework::ir::Conv2DTransposeBiasFusePass); REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, paddle::framework::ir::Conv3DBiasFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h index 84106d0655..833fbc748e 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h @@ -26,7 +26,7 @@ namespace ir { class ConvBiasFusePass : public FusePassBase { public: virtual ~ConvBiasFusePass() {} - virtual bool is_conv3d() const { return false; } + virtual std::string type() const { return "conv2d"; } protected: void ApplyImpl(ir::Graph* graph) const override; @@ -35,9 +35,14 @@ class ConvBiasFusePass : public FusePassBase { /* * Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp. */ +class Conv2DTransposeBiasFusePass : public ConvBiasFusePass { + public: + std::string type() const override { return "conv2d_transpose"; } +}; + class Conv3DBiasFusePass : public ConvBiasFusePass { public: - bool is_conv3d() const override { return true; } + std::string type() const override { return "conv3d"; } }; } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc index 9f61817674..a6546cb452 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc @@ -141,7 +141,12 @@ TEST(ConvBiasFusePass, conv_with_existing_bias) { MainTest(true); } TEST(ConvBiasFusePass, conv3d) { Conv3DBiasFusePass pass; - ASSERT_TRUE(pass.is_conv3d()); + ASSERT_EQ(pass.type(), std::string("conv3d")); +} + +TEST(ConvBiasFusePass, conv2d_transpose) { + Conv2DTransposeBiasFusePass pass; + ASSERT_EQ(pass.type(), std::string("conv2d_transpose")); } } // namespace ir diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 3dc9814d0d..f5bca9e556 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -169,6 +169,7 @@ void CpuPassStrategy::EnableMKLDNN() { "conv_bn_fuse_pass", // Execute BN passes again to "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order "conv_bias_mkldnn_fuse_pass", // + "conv_transpose_bias_mkldnn_fuse_pass", "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass", -- GitLab