From 41de582bb092dfa67bd2a1fa5d3b469db1ae81e2 Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Wed, 12 Sep 2018 10:22:11 +0200 Subject: [PATCH] create conv relu pass for MKLDNN (#13258) --- paddle/fluid/framework/ir/CMakeLists.txt | 6 + .../ir/conv_relu_mkldnn_fuse_pass.cc | 90 +++++++++++++++ .../framework/ir/conv_relu_mkldnn_fuse_pass.h | 39 +++++++ .../ir/conv_relu_mkldnn_fuse_pass_tester.cc | 108 ++++++++++++++++++ .../framework/ir/graph_pattern_detector.cc | 33 ++++++ .../framework/ir/graph_pattern_detector.h | 22 ++++ 6 files changed, 298 insertions(+) create mode 100644 paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ce3ebed00b..7004f484a9 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -28,6 +28,9 @@ cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph grap pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) pass_library(fc_fuse_pass inference) +if(WITH_MKLDNN) + pass_library(conv_relu_mkldnn_fuse_pass inference) +endif() pass_library(attention_lstm_fuse_pass inference) pass_library(infer_clean_graph_pass inference) pass_library(fc_lstm_fuse_pass inference) @@ -42,3 +45,6 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass) cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector) cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto) +if(WITH_MKLDNN) + cc_test(test_conv_relu_mkldnn_fuse_pass SRCS conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) +endif() diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc new file mode 100644 index 0000000000..4408cb45ac --- /dev/null +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc @@ -0,0 +1,90 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" +#include +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::unique_ptr ConvReLUFusePass::ApplyImpl( + std::unique_ptr graph) const { + PADDLE_ENFORCE(graph.get()); + FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get()); + + std::unordered_set nodes2delete; + + GraphPatternDetector gpd; + auto* conv_input = gpd.mutable_pattern() + ->NewNode("conv_relu_mkldnn_fuse/conv_input") + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvReLU conv_relu_pattern(gpd.mutable_pattern(), + "conv_relu_mkldnn_fuse"); + conv_relu_pattern(conv_input); + + int found_conv_relu_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvReLU fuse"; + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, + conv_relu_pattern); // Filter + GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op + GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out + GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op + + // Create an ConvReLU Node. + OpDesc desc; + std::string conv_relu_i_in = subgraph.at(conv_input)->Name(); + std::string conv_relu_w_in = conv_weight->Name(); + std::string conv_relu_b_in = conv_bias->Name(); + std::string conv_relu_out = relu_out->Name(); + desc.SetInput("Input", std::vector({conv_relu_i_in})); + desc.SetInput("Filter", std::vector({conv_relu_w_in})); + desc.SetInput("Bias", std::vector({conv_relu_b_in})); + desc.SetOutput("Out", std::vector({conv_relu_out})); + desc.SetType("conv2d"); + for (auto& attr : conv->Op()->GetAttrMap()) { + desc.SetAttr(attr.first, attr.second); + } + desc.SetAttr("fuse_relu", true); + auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied. + GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out}); + + PADDLE_ENFORCE(subgraph.count(conv_input)); + IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node); + IR_NODE_LINK_TO(conv_weight, conv_relu_node); + IR_NODE_LINK_TO(conv_bias, conv_relu_node); + IR_NODE_LINK_TO(conv_relu_node, relu_out); + + found_conv_relu_count++; + }; + + gpd(graph.get(), handler); + + AddStatis(found_conv_relu_count); + return graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv_relu_mkldnn_fuse_pass, + paddle::framework::ir::ConvReLUFusePass); diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h new file mode 100644 index 0000000000..b5de0d5487 --- /dev/null +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h @@ -0,0 +1,39 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the CONV and ReLU to a ConvReLUOp. + */ +class ConvReLUFusePass : public FusePassBase { + public: + virtual ~ConvReLUFusePass() {} + + protected: + std::unique_ptr ApplyImpl(std::unique_ptr graph) const; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc new file mode 100644 index 0000000000..82b5fa1886 --- /dev/null +++ b/paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.h" + +#include + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, + const std::vector& inputs, + const std::vector& outputs) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + if (type == "conv2d") { + op->SetAttr("use_mkldnn", true); + op->SetInput("Input", {inputs[0]}); + op->SetInput("Filter", {inputs[1]}); + op->SetInput("Bias", {inputs[2]}); + } else if (type == "relu") { + op->SetInput("X", inputs); + } + op->SetOutput("Out", outputs); +} + +// a->OP0->b +// b->OP1->c +// (c, weights, bias)->conv->f +// (f)->relu->g +ProgramDesc BuildProgramDesc() { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "c", "weights", "bias", "f", "g"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::SELECTED_ROWS); + if (v == "weights" || v == "bias") { + var->SetPersistable(true); + } + } + + SetOp(&prog, "OP0", std::vector({"a"}), + std::vector({"b"})); + SetOp(&prog, "OP1", std::vector({"b"}), + std::vector({"c"})); + SetOp(&prog, "conv2d", std::vector({"c", "weights", "bias"}), + std::vector({"f"})); + SetOp(&prog, "relu", std::vector({"f"}), + std::vector({"g"})); + + return prog; +} + +TEST(ConvReLUFusePass, basic) { + auto prog = BuildProgramDesc(); + + std::unique_ptr graph(new ir::Graph(prog)); + + auto pass = PassRegistry::Instance().Get("conv_relu_mkldnn_fuse_pass"); + + int original_nodes_num = graph->Nodes().size(); + + graph = pass->Apply(std::move(graph)); + + int current_nodes_num = graph->Nodes().size(); + + // Remove 3 Nodes: CONV, RELU, conv_out + // Add 1 Node: ConvReLU + EXPECT_EQ(original_nodes_num - 2, current_nodes_num); + + // Assert conv_relu op in newly generated graph + int conv_relu_count = 0; + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "conv2d") { + if (node->Op()->HasAttr("use_mkldnn")) { + bool use_mkldnn = boost::get(node->Op()->GetAttr("use_mkldnn")); + if (use_mkldnn) { + if (node->Op()->HasAttr("fuse_relu")) { + bool fuse_relu = boost::get(node->Op()->GetAttr("fuse_relu")); + if (fuse_relu) { + ++conv_relu_count; + } + } + } + } + } + } + EXPECT_EQ(conv_relu_count, 1); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(conv_relu_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 5825a129b7..11d5998aaf 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -522,6 +522,39 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) { return false; } +PDNode* patterns::ConvReLU::operator()( + paddle::framework::ir::PDNode* conv_input) { + // Create Operators + conv_input->assert_is_op_input("conv2d", "Input"); + auto* conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); + auto* relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu"); + // Create variables + // Filter + auto* conv_weight_var = pattern->NewNode(conv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("conv2d", "Filter"); + // Bias + auto* conv_bias_var = pattern->NewNode(conv_bias_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("conv2d", "Bias"); + // intermediate variable, will be removed in the IR after fuse. + auto* conv_out_var = pattern->NewNode(conv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("conv2d") + ->assert_is_op_input("relu"); + // output + auto* relu_out_var = pattern->NewNode(relu_out_repr()) + ->AsOutput() + ->assert_is_op_output("relu"); + + conv_op->LinksFrom({conv_input, conv_weight_var, conv_bias_var}) + .LinksTo({conv_out_var}); + relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var}); + return relu_out_var; +} + PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x, bool with_bias) { // Create shared nodes. diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 57482a07b6..371384dc56 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -360,6 +360,28 @@ struct PatternBase { size_t id_; }; +// CONV with ReLU +// op: conv + relu +// named nodes: +// conv_input, conv_weight, +// conv_bias, conv_out, conv, +// relu_out, relu +struct ConvReLU : public PatternBase { + ConvReLU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_relu") {} + + PDNode* operator()(PDNode* conv_input); + + // declare operator node's name + PATTERN_DECL_NODE(conv); + PATTERN_DECL_NODE(relu); + // declare variable node's name + PATTERN_DECL_NODE(conv_weight); + PATTERN_DECL_NODE(conv_bias); + PATTERN_DECL_NODE(conv_out); + PATTERN_DECL_NODE(relu_out); +}; + // FC with bias // op: mul + elementwise_add // named nodes: -- GitLab