diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a00d183a83af386709c4231498b0e3471b42d794..d205a78841189ffe34d4f866d3f50f4563a567e6 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -86,6 +86,7 @@ if(WITH_MKLDNN) pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn) pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn) + pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(cpu_quantize_placement_pass base mkldnn) pass_library(cpu_quantize_pass inference mkldnn) @@ -116,6 +117,7 @@ if (WITH_MKLDNN) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass) + cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 789eee8aa1e47ea164e3a6ba70ea85955eece37a..cb7ef41861e436f952bd8dee63a9ce7ac46187d7 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1184,6 +1184,46 @@ PDNode *patterns::ElementwiseAdd::operator()(PDNode *x_var, PDNode *y_var) { return out_var; } +PDNode *patterns::ConcatReLU::operator()() { + auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat"); + auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu"); + + auto concat_out = + pattern->NewNode(concat_out_repr())->assert_is_op_output("concat", "Out"); + + auto relu_out = pattern->NewNode(relu_out_repr()) + ->AsOutput() + ->assert_is_op_output("relu", "Out"); + + concat_op->LinksTo({concat_out}); + relu_op->LinksFrom({concat_out}).LinksTo({relu_out}); + + return relu_out; +} + +PDNode *patterns::ConvConcatReLU::operator()() { + auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); + auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat"); + auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu"); + + auto conv_out = pattern->NewNode(conv_out_repr()) + ->assert_is_op_output("conv2d", "Output"); + + auto concat_out = pattern->NewNode(concat_out_repr()) + ->assert_is_op_output("concat", "Out") + ->assert_is_op_input("relu", "X"); + + auto relu_out = pattern->NewNode(relu_out_repr()) + ->AsOutput() + ->assert_is_op_output("relu", "Out"); + + conv_op->LinksTo({conv_out}); + concat_op->LinksFrom({conv_out}).LinksTo({concat_out}); + relu_op->LinksFrom({concat_out}).LinksTo({relu_out}); + + return relu_out; +} + std::unordered_set conv_act_set({"identity", "relu"}); PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 1147f1e8ce00294cd0e7886e257c9d7a41ca289c..bb62716ec7fb696924d6dd8a50ef899eb0d80d08 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -728,6 +728,39 @@ struct ElementwiseAdd : public PatternBase { PATTERN_DECL_NODE(elementwise_add_out); }; +// Concat + ReLU +// named nodes: +// concat_op, concat_out, relu_op, relu_out +struct ConcatReLU : public PatternBase { + ConcatReLU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "concat_relu") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(concat_op); + PATTERN_DECL_NODE(concat_out); + PATTERN_DECL_NODE(relu_op); + PATTERN_DECL_NODE(relu_out); +}; + +// Conv + Concat + ReLU +// named nodes: +// conv_op, conv_out +// concat_op, concat_out, relu_op, relu_out +struct ConvConcatReLU : public PatternBase { + ConvConcatReLU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_concat_relu") {} + + PDNode* operator()(); + + PATTERN_DECL_NODE(conv_op); + PATTERN_DECL_NODE(conv_out); + PATTERN_DECL_NODE(concat_op); + PATTERN_DECL_NODE(concat_out); + PATTERN_DECL_NODE(relu_op); + PATTERN_DECL_NODE(relu_out); +}; + // Conv + ElementwiseAdd + an activation // This pattern can futher fuse the conv related ops after the conv+bn fusion. struct ConvElementwiseaddAct : public PatternBase { diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..a037a6bf90979ec1d6cd76ff7c07fa2858be8796 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc @@ -0,0 +1,119 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h" +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void ConvConcatReLUFusePass::FindConcatWithConvs( + ir::Graph* graph, + std::unordered_map* concat_with_convs_counter) const { + GraphPatternDetector gpd; + patterns::ConcatReLU concat_relu_pattern{gpd.mutable_pattern(), + "concat_relu"}; + concat_relu_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "Find Concats with Convs"; + GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_relu_pattern); + GET_IR_NODE_FROM_SUBGRAPH(relu_op, relu_op, concat_relu_pattern); + + auto concat_inputs = concat_op->inputs; + + for (auto node : concat_inputs) { + auto prev_op_node = node->inputs; + PADDLE_ENFORCE_EQ(prev_op_node.size(), 1); + auto* conv_op = prev_op_node[0]; + if (conv_op->Op()->Type() != "conv2d") return; + + FuseOptions fuse_option = FindFuseOption(*conv_op, *relu_op); + if (fuse_option == DO_NOT_FUSE) { + return; + } + } + + (*concat_with_convs_counter)[concat_op] = concat_inputs.size(); + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +void ConvConcatReLUFusePass::FuseConvConcatReLU( + ir::Graph* graph, + std::unordered_map* concat_with_convs_counter) const { + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + patterns::ConvConcatReLU conv_concat_relu(pattern, name_scope_); + conv_concat_relu(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvConcatReLU fuse"; + + GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_concat_relu); + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_concat_relu); + GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, conv_concat_relu); + GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, conv_concat_relu); + GET_IR_NODE_FROM_SUBGRAPH(relu_op, relu_op, conv_concat_relu); + GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_concat_relu); + + if (!concat_with_convs_counter->count(concat_op)) { + VLOG(4) << "this concat has input from non-conv2d operator"; + return; + } + + // Transform Conv node into ConvReLU node. + OpDesc* conv_desc = conv_op->Op(); + conv_desc->SetAttr("fuse_relu", true); + + // Remove ReLU when all Convs were transformed. + auto number_of_unfused_convs_left = + --(*concat_with_convs_counter)[concat_op]; + if (number_of_unfused_convs_left == 0) { + OpDesc* concat_desc = concat_op->Op(); + concat_desc->SetOutput("Out", + std::vector({relu_out->Name()})); + GraphSafeRemoveNodes(graph, {relu_op, concat_out}); + IR_NODE_LINK_TO(concat_op, relu_out); + } + + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init(name_scope_, graph); + + std::unordered_map concat_with_convs_counter; + FindConcatWithConvs(graph, &concat_with_convs_counter); + FuseConvConcatReLU(graph, &concat_with_convs_counter); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass, + paddle::framework::ir::ConvConcatReLUFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..91ff0760f0483c41cb5be5507426290c90142b13 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h @@ -0,0 +1,53 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the (multi conv) -> Concat -> ReLU -> next_op + * to a: + * (multi ConvReLU) -> Concat -> next_op. + */ +class ConvConcatReLUFusePass : public FusePassBase { + public: + virtual ~ConvConcatReLUFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; + + void FindConcatWithConvs( + Graph* graph, + std::unordered_map* concat_with_convs_counter) const; + + void FuseConvConcatReLU( + Graph* graph, + std::unordered_map* concat_with_convs_counter) const; + + const std::string name_scope_{"conv_concat_relu_mkldnn_fuse"}; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..0d7ddac8884d22af636c3b8e3964f6e8fe69880d --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc @@ -0,0 +1,157 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h" + +#include +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, + const std::vector& inputs, + const std::vector& outputs, bool use_mkldnn = true) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + if (type == "conv2d") { + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("fuse_relu", false); + op->SetInput("Input", {inputs[0]}); + op->SetInput("Filter", {inputs[1]}); + if (inputs.size() > 2) { + op->SetInput("Bias", {inputs[2]}); + } + op->SetOutput("Output", outputs); + } else if (type == "relu") { + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetInput("X", inputs); + op->SetOutput("Out", outputs); + } else if (type == "pool2d") { + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetInput("X", inputs); + op->SetOutput("Out", outputs); + } else if (type == "concat") { + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetInput("X", inputs); + op->SetOutput("Out", outputs); + } + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); +} + +// (a1,w1)->conv1->c1 +// (a2,w2,b2)->conv2->c2 +// if put_only_convs_before_concat=true +// (a3,w3)->conv3->c3 +// else +// a3->pool1->c3 +// +// (c1,c2,c3)->concat1->d +// d->relu1->e +ProgramDesc BuildProgramDesc(bool put_only_convs_before_concat, + bool all_convs_use_mkldnn) { + ProgramDesc prog; + for (auto& v : + std::initializer_list({"a1", "w1", "c1", "a2", "w2", "b2", + "c2", "a3", "w3", "c3", "d", "e"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::SELECTED_ROWS); + if (v.find("w") == 0 || v.find("b") == 0) { + var->SetPersistable(true); + } + } + + SetOp(&prog, "conv2d", {"a1", "w1", "b1"}, {"c1"}, all_convs_use_mkldnn); + SetOp(&prog, "conv2d", {"a2", "w2", "b2"}, {"c2"}); + if (put_only_convs_before_concat) { + SetOp(&prog, "conv2d", {"a3", "w3", "b3"}, {"c3"}); + } else { + SetOp(&prog, "pool2d", {"a3"}, {"c3"}); + } + SetOp(&prog, "concat", {"c1", "c2", "c3"}, {"d"}); + SetOp(&prog, "relu", {"d"}, {"e"}); + + return prog; +} + +void MainTest(const ProgramDesc& prog, bool fuse_relu) { + std::unique_ptr graph(new ir::Graph(prog)); + + int original_nodes_num = graph->Nodes().size(); + + auto pass = PassRegistry::Instance().Get("conv_concat_relu_mkldnn_fuse_pass"); + graph.reset(pass->Apply(graph.release())); + + int current_nodes_num = graph->Nodes().size(); + + if (fuse_relu) { + // Remove 2 nodes: concat_out, relu + EXPECT_EQ(original_nodes_num - 2, current_nodes_num); + } else { + EXPECT_EQ(original_nodes_num, current_nodes_num); + } + + int relu_count = 0; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op = node->Op(); + if (op->Type() == "conv2d") { + ASSERT_TRUE(op->HasAttr("fuse_relu")); + bool fuse_relu_attr = boost::get(op->GetAttr("fuse_relu")); + EXPECT_EQ(fuse_relu, fuse_relu_attr); + } else if (op->Type() == "relu") { + relu_count++; + } + } + } + EXPECT_EQ(relu_count, fuse_relu ? 0 : 1); +} + +TEST(ConvConcatReLUFusePass, only_convs_before_concat) { + bool all_convs_use_mkldnn = true; + bool put_only_convs_before_concat = true; + auto prog = + BuildProgramDesc(put_only_convs_before_concat, all_convs_use_mkldnn); + + bool expect_relu_fuse = true; + MainTest(prog, expect_relu_fuse); +} + +TEST(ConvConcatReLUFusePass, only_convs_before_concat_but_one_non_mkldnn) { + bool all_convs_use_mkldnn = false; + bool put_only_convs_before_concat = true; + auto prog = + BuildProgramDesc(put_only_convs_before_concat, all_convs_use_mkldnn); + + bool expect_relu_fuse = false; + MainTest(prog, expect_relu_fuse); +} + +TEST(ConvConcatReLUFusePass, convs_and_pool_before_concat) { + bool all_convs_use_mkldnn = true; + bool put_only_convs_before_concat = false; + auto prog = + BuildProgramDesc(put_only_convs_before_concat, all_convs_use_mkldnn); + + bool expect_relu_fuse = false; + MainTest(prog, expect_relu_fuse); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(conv_concat_relu_mkldnn_fuse_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index b8167b0ab324c82d545f8ee3bdbf700804d05ce4..b39f740ec025d3a87f85c87f852fcf26827fff59 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -153,6 +153,7 @@ void CpuPassStrategy::EnableMKLDNN() { "conv_bias_mkldnn_fuse_pass", // "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", + "conv_concat_relu_mkldnn_fuse_pass", "conv_relu_mkldnn_fuse_pass", // "conv_brelu_mkldnn_fuse_pass"})) { passes_.push_back(pass);