From 64e261c6cdaae175dcd1c2ec78b67be3bb0ae36d Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Mon, 3 Dec 2018 13:57:58 +0800 Subject: [PATCH] Implement the fusion of convolution and bias for mkldnn (test=develop) --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../ir/conv3d_bias_mkldnn_fuse_pass.cc | 18 ++++++++++++ .../ir/conv3d_bias_mkldnn_fuse_pass.h | 29 +++++++++++++++++++ .../ir/conv_bias_mkldnn_fuse_pass.cc | 8 +++-- .../framework/ir/conv_bias_mkldnn_fuse_pass.h | 1 + .../framework/ir/graph_pattern_detector.cc | 25 +++++++++------- .../framework/ir/graph_pattern_detector.h | 2 +- .../fluid/inference/api/paddle_pass_builder.h | 7 +++-- .../tests/api/analyzer_dam_tester.cc | 12 ++++++-- 9 files changed, 83 insertions(+), 20 deletions(-) create mode 100644 paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 883575e41..0bbfe3c0e 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -46,6 +46,7 @@ if(WITH_MKLDNN) pass_library(mkldnn_placement_pass base) pass_library(depthwise_conv_mkldnn_pass base) pass_library(conv_bias_mkldnn_fuse_pass inference) + pass_library(conv3d_bias_mkldnn_fuse_pass inference) pass_library(conv_relu_mkldnn_fuse_pass inference) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference) endif() diff --git a/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.cc new file mode 100644 index 000000000..e2968ddf6 --- /dev/null +++ b/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.cc @@ -0,0 +1,18 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h" + +REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, + paddle::framework::ir::Conv3DBiasFusePass); diff --git a/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h new file mode 100644 index 000000000..5afe2cc61 --- /dev/null +++ b/paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h @@ -0,0 +1,29 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h" + +namespace paddle { +namespace framework { +namespace ir { +/* +* Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp. +*/ +class Conv3DBiasFusePass : public ConvBiasFusePass { + public: + bool is_conv3d() const override { return true; } +}; +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc index 449cc78be..c3ad3a0f1 100644 --- a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.cc @@ -46,14 +46,16 @@ std::unique_ptr ConvBiasFusePass::ApplyImpl( auto* scope = param_scope(); PADDLE_ENFORCE(scope); + std::string type = is_conv3d() ? "conv3d" : "conv2d"; + GraphPatternDetector gpd; auto* conv_input = gpd.mutable_pattern() ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) ->AsInput() - ->assert_is_op_input("conv2d", "Input"); + ->assert_is_op_input(type, "Input"); patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_); - conv_bias_pattern(conv_input); + conv_bias_pattern(conv_input, is_conv3d()); int found_conv_bias_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { @@ -109,7 +111,7 @@ std::unique_ptr ConvBiasFusePass::ApplyImpl( desc.SetInput("Filter", std::vector({conv_weight->Name()})); desc.SetInput("Bias", std::vector({eltwise_bias->Name()})); desc.SetOutput("Output", std::vector({eltwise_out->Name()})); - desc.SetType("conv2d"); + desc.SetType(type); for (auto& attr : conv->Op()->GetAttrMap()) { desc.SetAttr(attr.first, attr.second); diff --git a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h index 5775b83b8..c3b58bf58 100644 --- a/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h @@ -26,6 +26,7 @@ namespace ir { class ConvBiasFusePass : public FusePassBase { public: virtual ~ConvBiasFusePass() {} + virtual bool is_conv3d() const { return false; } protected: std::unique_ptr ApplyImpl(std::unique_ptr graph) const; diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 258182b25..ed99d9883 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1030,23 +1030,26 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( } PDNode *patterns::ConvBias::operator()( - paddle::framework::ir::PDNode *conv_input) { + paddle::framework::ir::PDNode *conv_input, bool is_conv3d) { // Create Operators - conv_input->assert_is_op_input("conv2d", "Input"); - auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); + conv_input->assert_is_op_input(is_conv3d ? "conv3d" : "conv2d", "Input"); + auto *conv_op = pattern->NewNode(conv_repr()) + ->assert_is_op(is_conv3d ? "conv3d" : "conv2d"); auto *eltiwse_op = pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); // Create variables // Filter - auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("conv2d", "Filter"); + auto *conv_weight_var = + pattern->NewNode(conv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input(is_conv3d ? "conv3d" : "conv2d", "Filter"); // intermediate variable, will be removed in the IR after fuse. - auto *conv_out_var = pattern->NewNode(conv_out_repr()) - ->AsIntermediate() - ->assert_is_only_output_of_op("conv2d") - ->assert_is_op_input("elementwise_add"); + auto *conv_out_var = + pattern->NewNode(conv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op(is_conv3d ? "conv3d" : "conv2d") + ->assert_is_op_input("elementwise_add"); // Bias stored in elementwise_add auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr()) ->AsInput() diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index c12b9503f..d044802f2 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -623,7 +623,7 @@ struct ElewiseAddActInplaceGrad : public PatternBase { struct ConvBias : public PatternBase { ConvBias(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "conv_bias") {} - PDNode* operator()(PDNode* conv_input); + PDNode* operator()(PDNode* conv_input, bool is_conv3d = false); // declare operator node's name PATTERN_DECL_NODE(conv); PATTERN_DECL_NODE(eltwise); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 825bee833..bc5139a7e 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -98,9 +98,10 @@ class CpuPassStrategy : public PassStrategy { passes_.insert(passes_.begin(), "mkldnn_placement_pass"); for (auto &pass : - std::vector({"depthwise_conv_mkldnn_pass", // - "conv_bias_mkldnn_fuse_pass", // - "conv_relu_mkldnn_fuse_pass", // + std::vector({"depthwise_conv_mkldnn_pass", // + "conv_bias_mkldnn_fuse_pass", // + "conv3d_bias_mkldnn_fuse_pass", // + "conv_relu_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass"})) { passes_.push_back(pass); } diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index a3a6130db..6186c6a42 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -222,9 +222,12 @@ TEST(Analyzer_dam, fuse_statis) { } // Compare result of NativeConfig and AnalysisConfig -TEST(Analyzer_dam, compare) { - contrib::AnalysisConfig cfg; +void compare(bool use_mkldnn = false) { + AnalysisConfig cfg; SetConfig(&cfg); + if (use_mkldnn) { + cfg.EnableMKLDNN(); + } std::vector> input_slots_all; SetInput(&input_slots_all); @@ -233,5 +236,10 @@ TEST(Analyzer_dam, compare) { reinterpret_cast(&cfg), input_slots_all); } +TEST(Analyzer_dam, compare) { compare(); } +#ifdef PADDLE_WITH_MKLDNN +TEST(Analyzer_dam, compare_mkldnn) { compare(true /* use_mkldnn */); } +#endif + } // namespace inference } // namespace paddle -- GitLab