diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 302c53a1d0377def39e580551944dfe54c0d1a7f..78e346bbdf0ae6e50ec926b9627ad6c9966b53c5 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -800,60 +800,6 @@ PDNode *patterns::ConvActivation::operator()( return activation_out_var; } -PDNode *patterns::ConvReLU::operator()( - paddle::framework::ir::PDNode *conv_input) { - // Create Operators - conv_input->assert_is_op_input("conv2d", "Input"); - auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); - auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu"); - // Create variables - // Filter - auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("conv2d", "Filter"); - // intermediate variable, will be removed in the IR after fuse. - auto *conv_out_var = pattern->NewNode(conv_out_repr()) - ->AsIntermediate() - ->assert_is_only_output_of_op("conv2d") - ->assert_is_op_input("relu"); - // output - auto *relu_out_var = pattern->NewNode(relu_out_repr()) - ->AsOutput() - ->assert_is_op_output("relu"); - - conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); - relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var}); - return relu_out_var; -} - -PDNode *patterns::ConvBReLU::operator()( - paddle::framework::ir::PDNode *conv_input) { - // Create Operators - conv_input->assert_is_op_input("conv2d", "Input"); - auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); - auto *brelu_op = pattern->NewNode(brelu_repr())->assert_is_op("relu6"); - // Create variables - // Filter - auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) - ->AsInput() - ->assert_is_persistable_var() - ->assert_is_op_input("conv2d", "Filter"); - // intermediate variable, will be removed in the IR after fuse. - auto *conv_out_var = pattern->NewNode(conv_out_repr()) - ->AsIntermediate() - ->assert_is_only_output_of_op("conv2d") - ->assert_is_op_input("relu6"); - // output - auto *brelu_out_var = pattern->NewNode(brelu_out_repr()) - ->AsOutput() - ->assert_is_op_output("relu6"); - - conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); - brelu_op->LinksFrom({conv_out_var}).LinksTo({brelu_out_var}); - return brelu_out_var; -} - PDNode *patterns::SeqConvEltAddRelu::operator()( paddle::framework::ir::PDNode *seqconv_input) { // Create Operators diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index de091ce87ed8bc3be28b6e4e9e01d84b56dc3d76..dafe9a6cbf4bad4bba31886ef4da937dc37ecbd9 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -453,48 +453,6 @@ struct ConvActivation : public PatternBase { PATTERN_DECL_NODE(activation_out); }; -// CONV with ReLU -// op: conv + relu -// named nodes: -// conv_input, conv_weight, -// conv_out, conv, -// relu_out, relu -struct ConvReLU : public PatternBase { - ConvReLU(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "conv_relu") {} - - PDNode* operator()(PDNode* conv_input); - - // declare operator node's name - PATTERN_DECL_NODE(conv); - PATTERN_DECL_NODE(relu); - // declare variable node's name - PATTERN_DECL_NODE(conv_weight); - PATTERN_DECL_NODE(conv_out); - PATTERN_DECL_NODE(relu_out); -}; - -// CONV with ReLU6 -// op: conv + relu6 -// named nodes: -// conv_input, conv_weight, -// conv_out, conv, -// relu6_out, relu6 -struct ConvBReLU : public PatternBase { - ConvBReLU(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "conv_bounded_relu") {} - - PDNode* operator()(PDNode* conv_input); - - // declare operator node's name - PATTERN_DECL_NODE(conv); - PATTERN_DECL_NODE(brelu); - // declare variable node's name - PATTERN_DECL_NODE(conv_weight); - PATTERN_DECL_NODE(conv_out); - PATTERN_DECL_NODE(brelu_out); -}; - // SEQCONV with Elementwise_Add ReLU // op: seqconv + elementwise_add + relu // named nodes: diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 947beeccbdbd6e971be3935ed7aecec358eacc00..2226169e65b03ce3a0d37c026f38f8031828c0ac 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -62,14 +62,10 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { boost::get(activation->Op()->GetAttr("threshold"))); } else { desc->SetAttr("fuse_alpha", - activation->Op()->HasAttr("alpha") - ? boost::get(activation->Op()->GetAttr("alpha")) - : 0.0f); + activation->Op()->GetAttrIfExists("alpha")); } desc->SetAttr("fuse_beta", - activation->Op()->HasAttr("beta") - ? boost::get(activation->Op()->GetAttr("beta")) - : 0.0f); + activation->Op()->GetAttrIfExists("beta")); GraphSafeRemoveNodes(graph, {activation, conv_out}); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc index 12611fb84a016cc08dbb2307e7bd4eef25add455..ec38788bb4bf59f97c1a7bbbf63d8e389457d7eb 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc @@ -114,11 +114,7 @@ void MainTest(std::string activation) { ASSERT_TRUE(op->HasAttr("use_mkldnn")); EXPECT_TRUE(boost::get(op->GetAttr("use_mkldnn"))); auto op_name = boost::get(op->GetAttr("name")); - std::string fuse_activation = - op->HasAttr("fuse_activation") - ? boost::get(op->GetAttr("fuse_activation")) - : ""; - if (fuse_activation == activation) { + if (op->GetAttrIfExists("fuse_activation") == activation) { ++conv_activation_count; } // check if only "conv1" convolution is fused diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc deleted file mode 100644 index dd9d448634806377b5f62b045f2ff59f65529780..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h" -#include -#include -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace framework { -namespace ir { - -void ConvBReLUFusePass::ApplyImpl(ir::Graph* graph) const { - PADDLE_ENFORCE(graph); - FusePassBase::Init("conv_bounded_relu_mkldnn_fuse", graph); - - GraphPatternDetector gpd; - auto* conv_input = gpd.mutable_pattern() - ->NewNode("conv_bounded_relu_mkldnn_fuse/conv_input") - ->AsInput() - ->assert_is_op_input("conv2d", "Input"); - patterns::ConvBReLU conv_brelu_pattern(gpd.mutable_pattern(), - "conv_bounded_relu_mkldnn_fuse"); - conv_brelu_pattern(conv_input); - - int found_conv_brelu_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - VLOG(4) << "handle ConvBoundedReLUFusePass fuse"; - GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, - conv_brelu_pattern); // Filter - GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_brelu_pattern); // tmp - GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_brelu_pattern); // CONV op - GET_IR_NODE_FROM_SUBGRAPH(brelu_out, brelu_out, conv_brelu_pattern); // Out - GET_IR_NODE_FROM_SUBGRAPH(brelu, brelu, conv_brelu_pattern); // ReLU op - - // Transform Conv node into ConvBReLU node. - OpDesc* desc = conv->Op(); - desc->SetOutput("Output", std::vector({brelu_out->Name()})); - desc->SetAttr("fuse_brelu", true); - desc->SetAttr("fuse_brelu_threshold", brelu->Op()->GetAttr("threshold")); - - GraphSafeRemoveNodes(graph, {brelu, conv_out}); - - PADDLE_ENFORCE(subgraph.count(conv_input)); - IR_NODE_LINK_TO(conv, brelu_out); - found_conv_brelu_count++; - }; - - gpd(graph, handler); - - AddStatis(found_conv_brelu_count); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(conv_brelu_mkldnn_fuse_pass, - paddle::framework::ir::ConvBReLUFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h deleted file mode 100644 index c898be69caf049d2de14f13714036a8f45508f98..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/pass.h" - -namespace paddle { -namespace framework { -namespace ir { - -/* - * Fuse the CONV and ReLU6 to a ConvReLU6Op. - */ -class ConvBReLUFusePass : public FusePassBase { - public: - virtual ~ConvBReLUFusePass() {} - - protected: - void ApplyImpl(ir::Graph* graph) const override; -}; - -} // namespace ir -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc deleted file mode 100644 index 5a546bfaedadf4d7038a0636098936c2ffd7ed72..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h" - -#include -#include "paddle/fluid/framework/op_proto_maker.h" - -namespace paddle { -namespace framework { -namespace ir { - -void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, - const std::vector& inputs, - const std::vector& outputs, bool use_mkldnn = false) { - auto* op = prog->MutableBlock(0)->AppendOp(); - op->SetType(type); - if (type == "conv2d") { - op->SetAttr("use_mkldnn", use_mkldnn); - op->SetAttr("name", name); - op->SetInput("Input", {inputs[0]}); - op->SetInput("Filter", {inputs[1]}); - op->SetInput("Bias", {inputs[2]}); - } else if (type == "relu6") { - op->SetAttr("use_mkldnn", use_mkldnn); - if (use_mkldnn) { - op->SetAttr("threshold", 6.0f); - } - op->SetInput("X", inputs); - } - op->SetOutput("Out", outputs); - op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), - static_cast(OpRole::kForward)); -} - -// a->OP0->b -// b->OP1->c -// (c, weights, bias)->conv->f -// (f)->brelu->g -ProgramDesc BuildProgramDesc() { - ProgramDesc prog; - for (auto& v : - std::vector({"a", "b", "c", "weights", "bias", "f", "g", - "h", "weights2", "bias2", "k", "l"})) { - auto* var = prog.MutableBlock(0)->Var(v); - var->SetType(proto::VarType::SELECTED_ROWS); - if (v == "weights" || v == "bias") { - var->SetPersistable(true); - } - } - - SetOp(&prog, "OP0", "op0", std::vector({"a"}), - std::vector({"b"})); - SetOp(&prog, "OP1", "op1", std::vector({"b"}), - std::vector({"c"})); - // conv+brelu, both with MKL-DNN - SetOp(&prog, "conv2d", "conv1", - std::vector({"c", "weights", "bias"}), - std::vector({"f"}), true); - SetOp(&prog, "relu6", "relu1", std::vector({"f"}), - std::vector({"g"}), true); - SetOp(&prog, "OP3", "op3", std::vector({"g"}), - std::vector({"h"})); - // conv+brelu, only one with MKL-DNN - SetOp(&prog, "conv2d", "conv2", - std::vector({"h", "weights2", "bias2"}), - std::vector({"k"}), true); - SetOp(&prog, "relu6", "relu2", std::vector({"k"}), - std::vector({"l"})); - - return prog; -} - -TEST(ConvBReLUFusePass, basic) { - auto prog = BuildProgramDesc(); - - std::unique_ptr graph(new ir::Graph(prog)); - - auto pass = PassRegistry::Instance().Get("conv_brelu_mkldnn_fuse_pass"); - - int original_nodes_num = graph->Nodes().size(); - - graph.reset(pass->Apply(graph.release())); - - int current_nodes_num = graph->Nodes().size(); - - // Remove 3 Nodes: CONV, BRELU, conv_out - // Add 1 Node: ConvBReLU - EXPECT_EQ(original_nodes_num - 2, current_nodes_num); - - // Assert conv_brelu op in newly generated graph - int conv_brelu_count = 0; - - for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "conv2d") { - auto* op = node->Op(); - ASSERT_TRUE(op->HasAttr("use_mkldnn")); - EXPECT_TRUE(boost::get(op->GetAttr("use_mkldnn"))); - // check if only "conv1" convolution is fused - auto op_name = boost::get(op->GetAttr("name")); - if (op_name == "conv1") { - ASSERT_TRUE(op->HasAttr("fuse_brelu")); - ASSERT_TRUE(op->HasAttr("fuse_brelu_threshold")); - - bool fuse_brelu = boost::get(op->GetAttr("fuse_brelu")); - if (fuse_brelu) { - ++conv_brelu_count; - float fuse_brelu_threshold = - boost::get(op->GetAttr("fuse_brelu_threshold")); - EXPECT_EQ(fuse_brelu_threshold, 6.0f); - } - } else if (op_name == "conv2") { - ASSERT_FALSE(op->HasAttr("fuse_brelu")); - } - } - } - EXPECT_EQ(conv_brelu_count, 1); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -USE_PASS(conv_brelu_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index ff1a3dc9c340ee5eb04b50046fa70d76c5f3b184..1263ddd147e86a47b8e5952f6a8cdfd40d1ee305 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -109,11 +109,7 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( if (!IsReachable(graph, elementwise_add_identity, conv_output)) return; - std::string fuse_activation = - conv_op->Op()->HasAttr("fuse_activation") - ? boost::get(conv_op->Op()->GetAttr("fuse_activation")) - : ""; - if (fuse_activation == "relu" || fuse_activation == "relu6") return; + if (HasFusedActivation(conv_op)) return; conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); @@ -182,12 +178,7 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( return; } - std::string fuse_activation = - residual_conv_op->Op()->HasAttr("fuse_activation") - ? boost::get( - residual_conv_op->Op()->GetAttr("fuse_activation")) - : ""; - if (fuse_activation == "relu" || fuse_activation == "relu6") return; + if (HasFusedActivation(residual_conv_op)) return; residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()}); residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h index 9bf1ae607937f0cae2fd312b0f6c7f7e14bd8fbf..b95aec34d30745d99f6066e36f19c883927e2b53 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h @@ -126,6 +126,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { protected: void ApplyImpl(graph_ptr graph) const; + static bool HasFusedActivation(Node* conv_node) { + return !(conv_node->Op() + ->GetAttrIfExists("fuse_activation") + .empty()); + } const std::string name_scope_{"residual_connection_fuse_pass"}; }; diff --git a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc deleted file mode 100644 index dd0fb456040fcf4e135333f938f8e3bdb18b7bcf..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.cc +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h" -#include -#include -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace framework { -namespace ir { - -void ConvReLUFusePass::ApplyImpl(ir::Graph* graph) const { - PADDLE_ENFORCE(graph); - FusePassBase::Init("conv_relu_mkldnn_fuse", graph); - - GraphPatternDetector gpd; - auto* conv_input = gpd.mutable_pattern() - ->NewNode("conv_relu_mkldnn_fuse/conv_input") - ->AsInput() - ->assert_is_op_input("conv2d", "Input"); - patterns::ConvReLU conv_relu_pattern(gpd.mutable_pattern(), - "conv_relu_mkldnn_fuse"); - conv_relu_pattern(conv_input); - - int found_conv_relu_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - VLOG(4) << "handle ConvReLU fuse"; - GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, - conv_relu_pattern); // Filter - GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp - GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op - GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out - GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op - - FuseOptions fuse_option = FindFuseOption(*conv, *relu); - if (fuse_option == DO_NOT_FUSE) { - VLOG(3) << "do not perform conv+relu fuse"; - return; - } - - // Transform Conv node into ConvReLU node. - OpDesc* desc = conv->Op(); - desc->SetOutput("Output", std::vector({relu_out->Name()})); - desc->SetAttr("fuse_relu", true); - GraphSafeRemoveNodes(graph, {relu, conv_out}); - - PADDLE_ENFORCE(subgraph.count(conv_input)); - IR_NODE_LINK_TO(conv, relu_out); - - found_conv_relu_count++; - }; - - gpd(graph, handler); - - AddStatis(found_conv_relu_count); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(conv_relu_mkldnn_fuse_pass, - paddle::framework::ir::ConvReLUFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h deleted file mode 100644 index 2174c22dbf53790015be4c651b6e0c40b8e159fb..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h +++ /dev/null @@ -1,39 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/pass.h" - -namespace paddle { -namespace framework { -namespace ir { - -/* - * Fuse the CONV and ReLU to a ConvReLUOp. - */ -class ConvReLUFusePass : public FusePassBase { - public: - virtual ~ConvReLUFusePass() {} - - protected: - void ApplyImpl(ir::Graph* graph) const override; -}; - -} // namespace ir -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc deleted file mode 100644 index 67a9957059a501f39f20c1de2ae17cafbe51a53a..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc +++ /dev/null @@ -1,127 +0,0 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/mkldnn/conv_relu_mkldnn_fuse_pass.h" - -#include -#include "paddle/fluid/framework/op_proto_maker.h" - -namespace paddle { -namespace framework { -namespace ir { - -void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, - const std::vector& inputs, - const std::vector& outputs, bool use_mkldnn = false) { - auto* op = prog->MutableBlock(0)->AppendOp(); - op->SetType(type); - if (type == "conv2d") { - op->SetAttr("use_mkldnn", use_mkldnn); - op->SetAttr("name", name); - op->SetInput("Input", {inputs[0]}); - op->SetInput("Filter", {inputs[1]}); - op->SetInput("Bias", {inputs[2]}); - } else if (type == "relu") { - op->SetAttr("use_mkldnn", use_mkldnn); - op->SetInput("X", inputs); - } - op->SetOutput("Out", outputs); - op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), - static_cast(OpRole::kForward)); -} - -// a->OP0->b -// b->OP1->c -// (c, weights, bias)->conv->f -// (f)->relu->g -ProgramDesc BuildProgramDesc() { - ProgramDesc prog; - for (auto& v : - std::vector({"a", "b", "c", "weights", "bias", "f", "g", - "h", "weights2", "bias2", "k", "l"})) { - auto* var = prog.MutableBlock(0)->Var(v); - var->SetType(proto::VarType::SELECTED_ROWS); - if (v == "weights" || v == "bias") { - var->SetPersistable(true); - } - } - - SetOp(&prog, "OP0", "op0", std::vector({"a"}), - std::vector({"b"})); - SetOp(&prog, "OP1", "op1", std::vector({"b"}), - std::vector({"c"})); - // conv+relu, both with MKL-DNN - SetOp(&prog, "conv2d", "conv1", - std::vector({"c", "weights", "bias"}), - std::vector({"f"}), true); - SetOp(&prog, "relu", "relu1", std::vector({"f"}), - std::vector({"g"}), true); - SetOp(&prog, "OP3", "op3", std::vector({"g"}), - std::vector({"h"})); - // conv+relu, only one with MKL-DNN - SetOp(&prog, "conv2d", "conv2", - std::vector({"h", "weights2", "bias2"}), - std::vector({"k"}), true); - SetOp(&prog, "relu", "relu2", std::vector({"k"}), - std::vector({"l"})); - - return prog; -} - -TEST(ConvReLUFusePass, basic) { - auto prog = BuildProgramDesc(); - - std::unique_ptr graph(new ir::Graph(prog)); - - auto pass = PassRegistry::Instance().Get("conv_relu_mkldnn_fuse_pass"); - - int original_nodes_num = graph->Nodes().size(); - - graph.reset(pass->Apply(graph.release())); - - int current_nodes_num = graph->Nodes().size(); - - // Remove 3 Nodes: CONV, RELU, conv_out - // Add 1 Node: ConvReLU - EXPECT_EQ(original_nodes_num - 2, current_nodes_num); - - // Assert conv_relu op in newly generated graph - int conv_relu_count = 0; - - for (auto* node : graph->Nodes()) { - if (node->IsOp() && node->Op()->Type() == "conv2d") { - auto* op = node->Op(); - ASSERT_TRUE(op->HasAttr("use_mkldnn")); - EXPECT_TRUE(boost::get(op->GetAttr("use_mkldnn"))); - // check if only "conv1" convolution is fused - auto op_name = boost::get(op->GetAttr("name")); - if (op_name == "conv1") { - ASSERT_TRUE(op->HasAttr("fuse_relu")); - bool fuse_relu = boost::get(op->GetAttr("fuse_relu")); - if (fuse_relu) { - ++conv_relu_count; - } - } else if (op_name == "conv2") { - ASSERT_FALSE(op->HasAttr("fuse_relu")); - } - } - } - EXPECT_EQ(conv_relu_count, 1); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -USE_PASS(conv_relu_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index de8b346203f2f4f665acc0c1da615984fbd59f81..9cf55ee3254f4f1eacd717dd0c8d4497b7c559de 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -209,12 +209,11 @@ void CPUQuantizePass::QuantizeConv(Graph* graph, is_output_unsigned, "Scale_out"); // change threshold in bounded ReLu - if (conv_op->Op()->HasAttr("fuse_brelu") && - boost::get(conv_op->Op()->GetAttr("fuse_brelu"))) { + if (conv_op->Op()->GetAttrIfExists("fuse_activation") == + "relu6") { float scale_out = boost::get(conv_op->Op()->GetAttr("Scale_out")); - float threshold = - boost::get(conv_op->Op()->GetAttr("fuse_brelu_threshold")); - conv_op->Op()->SetAttr("fuse_brelu_threshold", scale_out * threshold); + float threshold = boost::get(conv_op->Op()->GetAttr("fuse_alpha")); + conv_op->Op()->SetAttr("fuse_alpha", scale_out * threshold); } ++quantize_conv_count; diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index dedaf24364703877a4cacb23a27550b54dad53f8..2f6fb9e298440e0aaac79d0dc5ad1e7d1aed6990 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -80,6 +80,15 @@ class OpDesc { Attribute GetAttr(const std::string &name) const; + template + T GetAttrIfExists(const std::string &name) const { + T result{}; + if (HasAttr(name)) { + result = boost::get(GetAttr(name)); + } + return result; + } + const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const; Attribute GetNullableAttr(const std::string &name) const; diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index 8ebdbf1673fb2648b84d0451a6bc64426dfb6ce7..94c556ce52d61258475e4e9cc497b23b073938fc 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -69,10 +69,7 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() { if (op->Type() == "conv2d") { // output of conv2d with relu must be unsigned std::string fuse_activation = - op->HasAttr("fuse_activation") - ? boost::get( - op->GetAttr("fuse_activation")) - : ""; + op->GetAttrIfExists("fuse_activation"); is_unsigned = (fuse_activation == "relu" || fuse_activation == "relu6"); } else if (op->Type() == "relu") { diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 25d8e4ee953969a54a0ddc1534bc4bfb868ddcee..dbb792616a3f2e8f3b04e1a15549c24742ed06b1 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -482,14 +482,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { handler.reset( new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key)); // create a conv primitive descriptor and save it for usage in backward - - // TODO(grygielski) if INT8 brelu post-op will be available, just delete - // whole if statement - if (fuse_activation == "relu6") { - fuse_activation = "relu"; - fuse_alpha = 0.0f; - } - auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training;