提交 64e261c6 编写于 作者: Y Yihua Xu

Implement the fusion of convolution and bias for mkldnn

(test=develop)
上级 4f71a6ee
...@@ -46,6 +46,7 @@ if(WITH_MKLDNN) ...@@ -46,6 +46,7 @@ if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base) pass_library(mkldnn_placement_pass base)
pass_library(depthwise_conv_mkldnn_pass base) pass_library(depthwise_conv_mkldnn_pass base)
pass_library(conv_bias_mkldnn_fuse_pass inference) pass_library(conv_bias_mkldnn_fuse_pass inference)
pass_library(conv3d_bias_mkldnn_fuse_pass inference)
pass_library(conv_relu_mkldnn_fuse_pass inference) pass_library(conv_relu_mkldnn_fuse_pass inference)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference)
endif() endif()
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/conv3d_bias_mkldnn_fuse_pass.h"
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
paddle::framework::ir::Conv3DBiasFusePass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse the Conv3D and Elementwise_add to a Conv3DBiasOp.
*/
class Conv3DBiasFusePass : public ConvBiasFusePass {
public:
bool is_conv3d() const override { return true; }
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -46,14 +46,16 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( ...@@ -46,14 +46,16 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
std::string type = is_conv3d() ? "conv3d" : "conv2d";
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* conv_input = auto* conv_input =
gpd.mutable_pattern() gpd.mutable_pattern()
->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) ->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
->AsInput() ->AsInput()
->assert_is_op_input("conv2d", "Input"); ->assert_is_op_input(type, "Input");
patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_); patterns::ConvBias conv_bias_pattern(gpd.mutable_pattern(), name_scope_);
conv_bias_pattern(conv_input); conv_bias_pattern(conv_input, is_conv3d());
int found_conv_bias_count = 0; int found_conv_bias_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
...@@ -109,7 +111,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl( ...@@ -109,7 +111,7 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()})); desc.SetInput("Filter", std::vector<std::string>({conv_weight->Name()}));
desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()})); desc.SetInput("Bias", std::vector<std::string>({eltwise_bias->Name()}));
desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()})); desc.SetOutput("Output", std::vector<std::string>({eltwise_out->Name()}));
desc.SetType("conv2d"); desc.SetType(type);
for (auto& attr : conv->Op()->GetAttrMap()) { for (auto& attr : conv->Op()->GetAttrMap()) {
desc.SetAttr(attr.first, attr.second); desc.SetAttr(attr.first, attr.second);
......
...@@ -26,6 +26,7 @@ namespace ir { ...@@ -26,6 +26,7 @@ namespace ir {
class ConvBiasFusePass : public FusePassBase { class ConvBiasFusePass : public FusePassBase {
public: public:
virtual ~ConvBiasFusePass() {} virtual ~ConvBiasFusePass() {}
virtual bool is_conv3d() const { return false; }
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const; std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
......
...@@ -1030,23 +1030,26 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( ...@@ -1030,23 +1030,26 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()(
} }
PDNode *patterns::ConvBias::operator()( PDNode *patterns::ConvBias::operator()(
paddle::framework::ir::PDNode *conv_input) { paddle::framework::ir::PDNode *conv_input, bool is_conv3d) {
// Create Operators // Create Operators
conv_input->assert_is_op_input("conv2d", "Input"); conv_input->assert_is_op_input(is_conv3d ? "conv3d" : "conv2d", "Input");
auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); auto *conv_op = pattern->NewNode(conv_repr())
->assert_is_op(is_conv3d ? "conv3d" : "conv2d");
auto *eltiwse_op = auto *eltiwse_op =
pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add"); pattern->NewNode(eltwise_repr())->assert_is_op("elementwise_add");
// Create variables // Create variables
// Filter // Filter
auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) auto *conv_weight_var =
->AsInput() pattern->NewNode(conv_weight_repr())
->assert_is_persistable_var() ->AsInput()
->assert_is_op_input("conv2d", "Filter"); ->assert_is_persistable_var()
->assert_is_op_input(is_conv3d ? "conv3d" : "conv2d", "Filter");
// intermediate variable, will be removed in the IR after fuse. // intermediate variable, will be removed in the IR after fuse.
auto *conv_out_var = pattern->NewNode(conv_out_repr()) auto *conv_out_var =
->AsIntermediate() pattern->NewNode(conv_out_repr())
->assert_is_only_output_of_op("conv2d") ->AsIntermediate()
->assert_is_op_input("elementwise_add"); ->assert_is_only_output_of_op(is_conv3d ? "conv3d" : "conv2d")
->assert_is_op_input("elementwise_add");
// Bias stored in elementwise_add // Bias stored in elementwise_add
auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr()) auto *eltwise_bias_var = pattern->NewNode(eltwise_bias_repr())
->AsInput() ->AsInput()
......
...@@ -623,7 +623,7 @@ struct ElewiseAddActInplaceGrad : public PatternBase { ...@@ -623,7 +623,7 @@ struct ElewiseAddActInplaceGrad : public PatternBase {
struct ConvBias : public PatternBase { struct ConvBias : public PatternBase {
ConvBias(PDPattern* pattern, const std::string& name_scope) ConvBias(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "conv_bias") {} : PatternBase(pattern, name_scope, "conv_bias") {}
PDNode* operator()(PDNode* conv_input); PDNode* operator()(PDNode* conv_input, bool is_conv3d = false);
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(conv); PATTERN_DECL_NODE(conv);
PATTERN_DECL_NODE(eltwise); PATTERN_DECL_NODE(eltwise);
......
...@@ -98,9 +98,10 @@ class CpuPassStrategy : public PassStrategy { ...@@ -98,9 +98,10 @@ class CpuPassStrategy : public PassStrategy {
passes_.insert(passes_.begin(), "mkldnn_placement_pass"); passes_.insert(passes_.begin(), "mkldnn_placement_pass");
for (auto &pass : for (auto &pass :
std::vector<std::string>({"depthwise_conv_mkldnn_pass", // std::vector<std::string>({"depthwise_conv_mkldnn_pass", //
"conv_bias_mkldnn_fuse_pass", // "conv_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", // "conv3d_bias_mkldnn_fuse_pass", //
"conv_relu_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass"})) { "conv_elementwise_add_mkldnn_fuse_pass"})) {
passes_.push_back(pass); passes_.push_back(pass);
} }
......
...@@ -222,9 +222,12 @@ TEST(Analyzer_dam, fuse_statis) { ...@@ -222,9 +222,12 @@ TEST(Analyzer_dam, fuse_statis) {
} }
// Compare result of NativeConfig and AnalysisConfig // Compare result of NativeConfig and AnalysisConfig
TEST(Analyzer_dam, compare) { void compare(bool use_mkldnn = false) {
contrib::AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
if (use_mkldnn) {
cfg.EnableMKLDNN();
}
std::vector<std::vector<PaddleTensor>> input_slots_all; std::vector<std::vector<PaddleTensor>> input_slots_all;
SetInput(&input_slots_all); SetInput(&input_slots_all);
...@@ -233,5 +236,10 @@ TEST(Analyzer_dam, compare) { ...@@ -233,5 +236,10 @@ TEST(Analyzer_dam, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
TEST(Analyzer_dam, compare) { compare(); }
#ifdef PADDLE_WITH_MKLDNN
TEST(Analyzer_dam, compare_mkldnn) { compare(true /* use_mkldnn */); }
#endif
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册