From 97d1db18742292e8f08a4ece873d829795b297d1 Mon Sep 17 00:00:00 2001 From: Adam <38704900+grygielski@users.noreply.github.com> Date: Wed, 21 Aug 2019 06:46:54 +0200 Subject: [PATCH] Add generalized Conv+Activation MKLDNN fuse pass creation Part2 (#19237) * Add generalized Conv+Activation MKLDNN fuse pass creation Part2 test=develop * Undefined behaviour of GetAttrIfExists<> FIX test=develop --- .../framework/ir/graph_pattern_detector.cc | 54 ------- .../framework/ir/graph_pattern_detector.h | 42 ------ .../conv_activation_mkldnn_fuse_pass.cc | 8 +- ...conv_activation_mkldnn_fuse_pass_tester.cc | 6 +- .../ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc | 71 --------- .../ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h | 39 ----- .../conv_brelu_mkldnn_fuse_pass_tester.cc | 135 ------------------ .../conv_elementwise_add_mkldnn_fuse_pass.cc | 13 +- .../conv_elementwise_add_mkldnn_fuse_pass.h | 5 + .../ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc | 76 ---------- .../ir/mkldnn/conv_relu_mkldnn_fuse_pass.h | 39 ----- .../conv_relu_mkldnn_fuse_pass_tester.cc | 127 ---------------- .../framework/ir/mkldnn/cpu_quantize_pass.cc | 9 +- paddle/fluid/framework/op_desc.h | 9 ++ .../fluid/inference/api/mkldnn_quantizer.cc | 5 +- .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 8 -- 16 files changed, 24 insertions(+), 622 deletions(-) delete mode 100644 paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc delete mode 100644 paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h delete mode 100644 paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc delete mode 100644 paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc delete mode 100644 paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h delete mode 100644 paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 302c53a1d03..78e346bbdf0 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -800,60 +800,6 @@ PDNode *patterns::ConvActivation::operator()( return activation_out_var; } -PDNode *patterns::ConvReLU::operator()( - paddle::framework::ir::PDNode *conv_input) { - // Create Operators - conv_input->assert_is_op_input("conv2d", "Input"); - auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); - auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu"); - // Create variables - // Filter - auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("conv2d", "Filter"); - // intermediate variable, will be removed in the IR after fuse. - auto *conv_out_var = pattern->NewNode(conv_out_repr()) - ->AsIntermediate() - ->assert_is_only_output_of_op("conv2d") - ->assert_is_op_input("relu"); - // output - auto *relu_out_var = pattern->NewNode(relu_out_repr()) - ->AsOutput() - ->assert_is_op_output("relu"); - - conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); - relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var}); - return relu_out_var; -} - -PDNode *patterns::ConvBReLU::operator()( - paddle::framework::ir::PDNode *conv_input) { - // Create Operators - conv_input->assert_is_op_input("conv2d", "Input"); - auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); - auto *brelu_op = pattern->NewNode(brelu_repr())->assert_is_op("relu6"); - // Create variables - // Filter - auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("conv2d", "Filter"); - // intermediate variable, will be removed in the IR after fuse. - auto *conv_out_var = pattern->NewNode(conv_out_repr()) - ->AsIntermediate() - ->assert_is_only_output_of_op("conv2d") - ->assert_is_op_input("relu6"); - // output - auto *brelu_out_var = pattern->NewNode(brelu_out_repr()) - ->AsOutput() - ->assert_is_op_output("relu6"); - - conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); - brelu_op->LinksFrom({conv_out_var}).LinksTo({brelu_out_var}); - return brelu_out_var; -} - PDNode *patterns::SeqConvEltAddRelu::operator()( paddle::framework::ir::PDNode *seqconv_input) { // Create Operators diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index de091ce87ed..dafe9a6cbf4 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -453,48 +453,6 @@ struct ConvActivation : public PatternBase { PATTERN_DECL_NODE(activation_out); }; -// CONV with ReLU -// op: conv + relu -// named nodes: -// conv_input, conv_weight, -// conv_out, conv, -// relu_out, relu -struct ConvReLU : public PatternBase { - ConvReLU(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "conv_relu") {} - - PDNode* operator()(PDNode* conv_input); - - // declare operator node's name - PATTERN_DECL_NODE(conv); - PATTERN_DECL_NODE(relu); - // declare variable node's name - PATTERN_DECL_NODE(conv_weight); - PATTERN_DECL_NODE(conv_out); - PATTERN_DECL_NODE(relu_out); -}; - -// CONV with ReLU6 -// op: conv + relu6 -// named nodes: -// conv_input, conv_weight, -// conv_out, conv, -// relu6_out, relu6 -struct ConvBReLU : public PatternBase { - ConvBReLU(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "conv_bounded_relu") {} - - PDNode* operator()(PDNode* conv_input); - - // declare operator node's name - PATTERN_DECL_NODE(conv); - PATTERN_DECL_NODE(brelu); - // declare variable node's name - PATTERN_DECL_NODE(conv_weight); - PATTERN_DECL_NODE(conv_out); - PATTERN_DECL_NODE(brelu_out); -}; - // SEQCONV with Elementwise_Add ReLU // op: seqconv + elementwise_add + relu // named nodes: diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 947beeccbdb..2226169e65b 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -62,14 +62,10 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { boost::get(activation->Op()->GetAttr("threshold"))); } else { desc->SetAttr("fuse_alpha", - activation->Op()->HasAttr("alpha") - ? boost::get(activation->Op()->GetAttr("alpha")) - : 0.0f); + activation->Op()->GetAttrIfExists("alpha")); } desc->SetAttr("fuse_beta", - activation->Op()->HasAttr("beta") - ? boost::get(activation->Op()->GetAttr("beta")) - : 0.0f); + activation->Op()->GetAttrIfExists("beta")); GraphSafeRemoveNodes(graph, {activation, conv_out}); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc index 12611fb84a0..ec38788bb4b 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc @@ -114,11 +114,7 @@ void MainTest(std::string activation) { ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(boost::get(op->GetAttr("use_mkldnn"))); auto op_name = boost::get(op->GetAttr("name")); - std::string fuse_activation = - op->HasAttr("fuse_activation") - ? boost::get(op->GetAttr("fuse_activation")) - : ""; - if (fuse_activation == activation) { + if (op->GetAttrIfExists("fuse_activation") == activation) { ++conv_activation_count; } // check if only "conv1" convolution is fused diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc deleted file mode 100644 index dd9d4486348..00000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h" -#include -#include -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace framework { -namespace ir { - -void ConvBReLUFusePass::ApplyImpl(ir::Graph* graph) const { - PADDLE_ENFORCE(graph); - FusePassBase::Init("conv_bounded_relu_mkldnn_fuse", graph); - - GraphPatternDetector gpd; - auto* conv_input = gpd.mutable_pattern() - ->NewNode("conv_bounded_relu_mkldnn_fuse/conv_input") - ->AsInput() - ->assert_is_op_input("conv2d", "Input"); - patterns::ConvBReLU conv_brelu_pattern(gpd.mutable_pattern(), - "conv_bounded_relu_mkldnn_fuse"); - conv_brelu_pattern(conv_input); - - int found_conv_brelu_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - VLOG(4) << "handle ConvBoundedReLUFusePass fuse"; - GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, - conv_brelu_pattern); // Filter - GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_brelu_pattern); // tmp - GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_brelu_pattern); // CONV op - GET_IR_NODE_FROM_SUBGRAPH(brelu_out, brelu_out, conv_brelu_pattern); // Out - GET_IR_NODE_FROM_SUBGRAPH(brelu, brelu, conv_brelu_pattern); // ReLU op - - // Transform Conv node into ConvBReLU node. - OpDesc* desc = conv->Op(); - desc->SetOutput("Output", std::vector({brelu_out->Name()})); - desc->SetAttr("fuse_brelu", true); - desc->SetAttr("fuse_brelu_threshold", brelu->Op()->GetAttr("threshold")); - - GraphSafeRemoveNodes(graph, {brelu, conv_out}); - - PADDLE_ENFORCE(subgraph.count(conv_input)); - IR_NODE_LINK_TO(conv, brelu_out); - found_conv_brelu_count++; - }; - - gpd(graph, handler); - - AddStatis(found_conv_brelu_count); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(conv_brelu_mkldnn_fuse_pass, - paddle::framework::ir::ConvBReLUFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h deleted file mode 100644 index c898be69caf..00000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/pass.h" - -namespace paddle { -namespace framework { -namespace ir { - -/* - * Fuse the CONV and ReLU6 to a ConvReLU6Op. - */ -class ConvBReLUFusePass : public FusePassBase { - public: - virtual ~ConvBReLUFusePass() {} - - protected: - void ApplyImpl(ir::Graph* graph) const override; -}; - -} // namespace ir -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc deleted file mode 100644 index 5a546bfaeda..00000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h" - -#include -#include "paddle/fluid/framework/op_proto_maker.h" - -namespace paddle { -namespace framework { -namespace ir { - -void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, - const std::vector& inputs, - const std::vector& outputs, bool use_mkldnn = false) { - auto* op = prog->MutableBlock(0)->AppendOp(); - op->SetType(type); - if (type == "conv2d") { - op->SetAttr("use_mkldnn", use_mkldnn); - op->SetAttr("name", name); - op->SetInput("Input", {inputs[0]}); - op->SetInput("Filter", {inputs[1]}); - op->SetInput("Bias", {inputs[2]}); - } else if (type == "relu6") { - op->SetAttr("use_mkldnn", use_mkldnn); - if (use_mkldnn) { - op->SetAttr("threshold", 6.0f); - } - op->SetInput("X", inputs); - } - op->SetOutput("Out", outputs); - op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), - static_cast(OpRole::kForward)); -} - -// a->OP0->b -// b->OP1->c -// (c, weights, bias)->conv->f -// (f)->brelu->g -ProgramDesc BuildProgramDesc() { - ProgramDesc prog; - for (auto& v : - std::vector({"a", "b", "c", "weights", "bias", "f", "g", - "h", "weights2", "bias2", "k", "l"})) { - auto* var = prog.MutableBlock(0)->Var(v); - var->SetType(proto::VarType::SELECTED_ROWS); - if (v == "weights" || v == "bias") { - var->SetPersistable(true); - } - } - - SetOp(&prog, "OP0", "op0", std::vector({"a"}), - std::vector({"b"})); - SetOp(&prog, "OP1", "op1", std::vector({"b"}), - std::vector({"c"})); - // conv+brelu, both with MKL-DNN - SetOp(&prog, "conv2d", "conv1", - std::vector({"c", "weights", "bias"}), - std::vector({"f"}), true); - SetOp(&prog, "relu6", "relu1", std::vector({"f"}), - std::vector({"g"}), true); - SetOp(&prog, "OP3", "op3", std::vector({"g"}), - std::vector({"h"})); - // conv+brelu, only one with MKL-DNN - SetOp(&prog, "conv2d", "conv2", - std::vector({"h", "weights2", "bias2"}), - std::vector({"k"}), true); - SetOp(&prog, "relu6", "relu2", std::vector({"k"}), - std::vector({"l"})); - - return prog; -} - -TEST(ConvBReLUFusePass, basic) { - auto prog = BuildProgramDesc(); - - std::unique_ptr graph(new ir::Graph(prog)); - - auto pass = PassRegistry::Instance().Get("conv_brelu_mkldnn_fuse_pass"); - - int original_nodes_num = graph->Nodes().size(); - - graph.reset(pass->Apply(graph.release())); - - int current_nodes_num = graph->Nodes().size(); - - // Remove 3 Nodes: CONV, BRELU, conv_out - // Add 1 Node: ConvBReLU - EXPECT_EQ(original_nodes_num - 2, current_nodes_num); - - // Assert conv_brelu op in newly generated graph - int conv_brelu_count = 0; - - for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "conv2d") { - auto* op = node->Op(); - ASSERT_TRUE(op->HasAttr("use_mkldnn")); - EXPECT_TRUE(boost::get(op->GetAttr("use_mkldnn"))); - // check if only "conv1" convolution is fused - auto op_name = boost::get(op->GetAttr("name")); - if (op_name == "conv1") { - ASSERT_TRUE(op->HasAttr("fuse_brelu")); - ASSERT_TRUE(op->HasAttr("fuse_brelu_threshold")); - - bool fuse_brelu = boost::get(op->GetAttr("fuse_brelu")); - if (fuse_brelu) { - ++conv_brelu_count; - float fuse_brelu_threshold = - boost::get(op->GetAttr("fuse_brelu_threshold")); - EXPECT_EQ(fuse_brelu_threshold, 6.0f); - } - } else if (op_name == "conv2") { - ASSERT_FALSE(op->HasAttr("fuse_brelu")); - } - } - } - EXPECT_EQ(conv_brelu_count, 1); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -USE_PASS(conv_brelu_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index ff1a3dc9c34..1263ddd147e 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -109,11 +109,7 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( if (!IsReachable(graph, elementwise_add_identity, conv_output)) return; - std::string fuse_activation = - conv_op->Op()->HasAttr("fuse_activation") - ? boost::get(conv_op->Op()->GetAttr("fuse_activation")) - : ""; - if (fuse_activation == "relu" || fuse_activation == "relu6") return; + if (HasFusedActivation(conv_op)) return; conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); @@ -182,12 +178,7 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( return; } - std::string fuse_activation = - residual_conv_op->Op()->HasAttr("fuse_activation") - ? boost::get( - residual_conv_op->Op()->GetAttr("fuse_activation")) - : ""; - if (fuse_activation == "relu" || fuse_activation == "relu6") return; + if (HasFusedActivation(residual_conv_op)) return; residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()}); residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h index 9bf1ae60793..b95aec34d30 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h @@ -126,6 +126,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { protected: void ApplyImpl(graph_ptr graph) const; + static bool HasFusedActivation(Node* conv_node) { + return !(conv_node->Op() + ->GetAttrIfExists("fuse_activation") + .empty()); + } const std::string name_scope_{"residual_connection_fuse_pass"}; }; diff --git a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc deleted file mode 100644 index dd0fb456040..00000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h" -#include -#include -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace framework { -namespace ir { - -void ConvReLUFusePass::ApplyImpl(ir::Graph* graph) const { - PADDLE_ENFORCE(graph); - FusePassBase::Init("conv_relu_mkldnn_fuse", graph); - - GraphPatternDetector gpd; - auto* conv_input = gpd.mutable_pattern() - ->NewNode("conv_relu_mkldnn_fuse/conv_input") - ->AsInput() - ->assert_is_op_input("conv2d", "Input"); - patterns::ConvReLU conv_relu_pattern(gpd.mutable_pattern(), - "conv_relu_mkldnn_fuse"); - conv_relu_pattern(conv_input); - - int found_conv_relu_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - VLOG(4) << "handle ConvReLU fuse"; - GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, - conv_relu_pattern); // Filter - GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp - GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op - GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out - GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op - - FuseOptions fuse_option = FindFuseOption(*conv, *relu); - if (fuse_option == DO_NOT_FUSE) { - VLOG(3) << "do not perform conv+relu fuse"; - return; - } - - // Transform Conv node into ConvReLU node. - OpDesc* desc = conv->Op(); - desc->SetOutput("Output", std::vector({relu_out->Name()})); - desc->SetAttr("fuse_relu", true); - GraphSafeRemoveNodes(graph, {relu, conv_out}); - - PADDLE_ENFORCE(subgraph.count(conv_input)); - IR_NODE_LINK_TO(conv, relu_out); - - found_conv_relu_count++; - }; - - gpd(graph, handler); - - AddStatis(found_conv_relu_count); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(conv_relu_mkldnn_fuse_pass, - paddle::framework::ir::ConvReLUFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h deleted file mode 100644 index 2174c22dbf5..00000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/pass.h" - -namespace paddle { -namespace framework { -namespace ir { - -/* - * Fuse the CONV and ReLU to a ConvReLUOp. - */ -class ConvReLUFusePass : public FusePassBase { - public: - virtual ~ConvReLUFusePass() {} - - protected: - void ApplyImpl(ir::Graph* graph) const override; -}; - -} // namespace ir -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc deleted file mode 100644 index 67a9957059a..00000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h" - -#include -#include "paddle/fluid/framework/op_proto_maker.h" - -namespace paddle { -namespace framework { -namespace ir { - -void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, - const std::vector& inputs, - const std::vector& outputs, bool use_mkldnn = false) { - auto* op = prog->MutableBlock(0)->AppendOp(); - op->SetType(type); - if (type == "conv2d") { - op->SetAttr("use_mkldnn", use_mkldnn); - op->SetAttr("name", name); - op->SetInput("Input", {inputs[0]}); - op->SetInput("Filter", {inputs[1]}); - op->SetInput("Bias", {inputs[2]}); - } else if (type == "relu") { - op->SetAttr("use_mkldnn", use_mkldnn); - op->SetInput("X", inputs); - } - op->SetOutput("Out", outputs); - op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), - static_cast(OpRole::kForward)); -} - -// a->OP0->b -// b->OP1->c -// (c, weights, bias)->conv->f -// (f)->relu->g -ProgramDesc BuildProgramDesc() { - ProgramDesc prog; - for (auto& v : - std::vector({"a", "b", "c", "weights", "bias", "f", "g", - "h", "weights2", "bias2", "k", "l"})) { - auto* var = prog.MutableBlock(0)->Var(v); - var->SetType(proto::VarType::SELECTED_ROWS); - if (v == "weights" || v == "bias") { - var->SetPersistable(true); - } - } - - SetOp(&prog, "OP0", "op0", std::vector({"a"}), - std::vector({"b"})); - SetOp(&prog, "OP1", "op1", std::vector({"b"}), - std::vector({"c"})); - // conv+relu, both with MKL-DNN - SetOp(&prog, "conv2d", "conv1", - std::vector({"c", "weights", "bias"}), - std::vector({"f"}), true); - SetOp(&prog, "relu", "relu1", std::vector({"f"}), - std::vector({"g"}), true); - SetOp(&prog, "OP3", "op3", std::vector({"g"}), - std::vector({"h"})); - // conv+relu, only one with MKL-DNN - SetOp(&prog, "conv2d", "conv2", - std::vector({"h", "weights2", "bias2"}), - std::vector({"k"}), true); - SetOp(&prog, "relu", "relu2", std::vector({"k"}), - std::vector({"l"})); - - return prog; -} - -TEST(ConvReLUFusePass, basic) { - auto prog = BuildProgramDesc(); - - std::unique_ptr graph(new ir::Graph(prog)); - - auto pass = PassRegistry::Instance().Get("conv_relu_mkldnn_fuse_pass"); - - int original_nodes_num = graph->Nodes().size(); - - graph.reset(pass->Apply(graph.release())); - - int current_nodes_num = graph->Nodes().size(); - - // Remove 3 Nodes: CONV, RELU, conv_out - // Add 1 Node: ConvReLU - EXPECT_EQ(original_nodes_num - 2, current_nodes_num); - - // Assert conv_relu op in newly generated graph - int conv_relu_count = 0; - - for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "conv2d") { - auto* op = node->Op(); - ASSERT_TRUE(op->HasAttr("use_mkldnn")); - EXPECT_TRUE(boost::get(op->GetAttr("use_mkldnn"))); - // check if only "conv1" convolution is fused - auto op_name = boost::get(op->GetAttr("name")); - if (op_name == "conv1") { - ASSERT_TRUE(op->HasAttr("fuse_relu")); - bool fuse_relu = boost::get(op->GetAttr("fuse_relu")); - if (fuse_relu) { - ++conv_relu_count; - } - } else if (op_name == "conv2") { - ASSERT_FALSE(op->HasAttr("fuse_relu")); - } - } - } - EXPECT_EQ(conv_relu_count, 1); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -USE_PASS(conv_relu_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index de8b346203f..9cf55ee3254 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -209,12 +209,11 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, is_output_unsigned, "Scale_out"); // change threshold in bounded ReLu - if (conv_op->Op()->HasAttr("fuse_brelu") && - boost::get(conv_op->Op()->GetAttr("fuse_brelu"))) { + if (conv_op->Op()->GetAttrIfExists("fuse_activation") == + "relu6") { float scale_out = boost::get(conv_op->Op()->GetAttr("Scale_out")); - float threshold = - boost::get(conv_op->Op()->GetAttr("fuse_brelu_threshold")); - conv_op->Op()->SetAttr("fuse_brelu_threshold", scale_out * threshold); + float threshold = boost::get(conv_op->Op()->GetAttr("fuse_alpha")); + conv_op->Op()->SetAttr("fuse_alpha", scale_out * threshold); } ++quantize_conv_count; diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index dedaf243647..2f6fb9e2984 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -80,6 +80,15 @@ class OpDesc { Attribute GetAttr(const std::string &name) const; + template + T GetAttrIfExists(const std::string &name) const { + T result{}; + if (HasAttr(name)) { + result = boost::get(GetAttr(name)); + } + return result; + } + const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const; Attribute GetNullableAttr(const std::string &name) const; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 8ebdbf1673f..94c556ce52d 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -69,10 +69,7 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() { if (op->Type() == "conv2d") { // output of conv2d with relu must be unsigned std::string fuse_activation = - op->HasAttr("fuse_activation") - ? boost::get( - op->GetAttr("fuse_activation")) - : ""; + op->GetAttrIfExists("fuse_activation"); is_unsigned = (fuse_activation == "relu" || fuse_activation == "relu6"); } else if (op->Type() == "relu") { diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 25d8e4ee953..dbb792616a3 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -482,14 +482,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { handler.reset( new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key)); // create a conv primitive descriptor and save it for usage in backward - - // TODO(grygielski) if INT8 brelu post-op will be available, just delete - // whole if statement - if (fuse_activation == "relu6") { - fuse_activation = "relu"; - fuse_alpha = 0.0f; - } - auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training; -- GitLab