From d03ef0541959391e7414d6d8780f6248383fef18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Mon, 22 Aug 2022 17:45:31 +0200 Subject: [PATCH] Extend conv_concat_relu to support all activations (#45089) * merge conv_concat_relu to conv_act * fix typo * extend unit test * reuse existing gpd * codestyle * enforce mkldnn conv --- paddle/fluid/framework/ir/CMakeLists.txt | 3 +- .../framework/ir/graph_pattern_detector.cc | 40 ---- .../framework/ir/graph_pattern_detector.h | 33 --- .../conv_activation_mkldnn_fuse_pass.cc | 107 ++++++++- .../mkldnn/conv_activation_mkldnn_fuse_pass.h | 2 + .../conv_concat_relu_mkldnn_fuse_pass.cc | 197 --------------- .../conv_concat_relu_mkldnn_fuse_pass.h | 53 ----- ...onv_concat_relu_mkldnn_fuse_pass_tester.cc | 7 +- .../inference/api/paddle_pass_builder.cc | 2 - .../quantization/quant2_int8_mkldnn_pass.py | 1 - ...kldnn_conv_concat_relu_mkldnn_fuse_pass.py | 224 +++++++++--------- 11 files changed, 219 insertions(+), 450 deletions(-) delete mode 100644 paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc delete mode 100644 paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 0882004618..f219117594 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -199,7 +199,6 @@ if(WITH_MKLDNN) pass_library(conv_affine_channel_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn) - pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn) pass_library(int8_scale_calculation_mkldnn_pass inference DIR mkldnn) pass_library(params_quantization_mkldnn_pass inference DIR mkldnn) @@ -409,7 +408,7 @@ if(WITH_MKLDNN) cc_test( test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc - DEPS conv_concat_relu_mkldnn_fuse_pass) + DEPS conv_activation_mkldnn_fuse_pass) cc_test( test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 271b9b9d02..844738c46e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2081,46 +2081,6 @@ PDNode *patterns::Concat::operator()() { return output_var; } -PDNode *patterns::ConcatReLU::operator()() { - auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat"); - auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu"); - - auto concat_out = - pattern->NewNode(concat_out_repr())->assert_is_op_output("concat", "Out"); - - auto relu_out = pattern->NewNode(relu_out_repr()) - ->AsOutput() - ->assert_is_op_output("relu", "Out"); - - concat_op->LinksTo({concat_out}); - relu_op->LinksFrom({concat_out}).LinksTo({relu_out}); - - return relu_out; -} - -PDNode *patterns::ConvConcatReLU::operator()() { - auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); - auto concat_op = pattern->NewNode(concat_op_repr())->assert_is_op("concat"); - auto relu_op = pattern->NewNode(relu_op_repr())->assert_is_op("relu"); - - auto conv_out = pattern->NewNode(conv_out_repr()) - ->assert_is_op_output("conv2d", "Output"); - - auto concat_out = pattern->NewNode(concat_out_repr()) - ->assert_is_op_output("concat", "Out") - ->assert_is_op_input("relu", "X"); - - auto relu_out = pattern->NewNode(relu_out_repr()) - ->AsOutput() - ->assert_is_op_output("relu", "Out"); - - conv_op->LinksTo({conv_out}); - concat_op->LinksFrom({conv_out}).LinksTo({concat_out}); - relu_op->LinksFrom({concat_out}).LinksTo({relu_out}); - - return relu_out; -} - PDNode *patterns::OpRequant::operator()() { auto any_op = pattern->NewNode(any_op_repr()) ->assert_is_op() diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 507fb83af4..7bee093c3e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1228,39 +1228,6 @@ struct Concat : public PatternBase { PATTERN_DECL_NODE(concat_out); }; -// Concat + ReLU -// named nodes: -// concat_op, concat_out, relu_op, relu_out -struct ConcatReLU : public PatternBase { - ConcatReLU(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "concat_relu") {} - - PDNode* operator()(); - - PATTERN_DECL_NODE(concat_op); - PATTERN_DECL_NODE(concat_out); - PATTERN_DECL_NODE(relu_op); - PATTERN_DECL_NODE(relu_out); -}; - -// Conv + Concat + ReLU -// named nodes: -// conv_op, conv_out -// concat_op, concat_out, relu_op, relu_out -struct ConvConcatReLU : public PatternBase { - ConvConcatReLU(PDPattern* pattern, const std::string& name_scope) - : PatternBase(pattern, name_scope, "conv_concat_relu") {} - - PDNode* operator()(); - - PATTERN_DECL_NODE(conv_op); - PATTERN_DECL_NODE(conv_out); - PATTERN_DECL_NODE(concat_op); - PATTERN_DECL_NODE(concat_out); - PATTERN_DECL_NODE(relu_op); - PATTERN_DECL_NODE(relu_out); -}; - // Op + Requant // named nodes: // any_op, any_out diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 5fe6eb50aa..8723cab36c 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -28,10 +28,12 @@ void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const { auto act_types = paddle::platform::GetSupportedActivations(); std::vector conv_types = {"conv2d"}; - for (const auto& conv_type : conv_types) - for (auto& act_type : act_types) { + for (auto& act_type : act_types) { + FuseConvConcatAct(graph, act_type); + for (const auto& conv_type : conv_types) { FuseConvAct(graph, conv_type, act_type); } + } } void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, @@ -49,8 +51,6 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, int found_conv_activation_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "handle " + conv_type + "+" + act_type + " fuse"; - if (!IsCompat(subgraph, g)) { LOG(WARNING) << "conv_activation_mkldnn_fuse_pass op compat failed."; return; @@ -89,13 +89,95 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph, gpd(graph, handler); AddStatis(found_conv_activation_count); - if (!Has("disable_logs") || !Get("disable_logs")) { + if ((!Has("disable_logs") || !Get("disable_logs")) && + found_conv_activation_count > 0) { PrettyLogDetail("--- fused %d conv with %s activation", found_conv_activation_count, act_type); } } +void ConvActivationMkldnnFusePass::FuseConvConcatAct( + Graph* graph, std::string& act_type) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + FusePassBase::Init("conv2d_concat_" + act_type + "_mkldnn_fuse_pass", graph); + + GraphPatternDetector gpd; + auto pattern = gpd.mutable_pattern(); + patterns::OperatorActivation conv_concat_act( + pattern, "conv2d_concat_" + act_type + "_mkldnn_fuse_pass"); + conv_concat_act("concat", act_type); + + int found_conv_concat_activation_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) + << "conv_concat_activation_mkldnn_fuse_pass op compat failed."; + return; + } + + GET_IR_NODE_FROM_SUBGRAPH(concat_op, preceding_op, conv_concat_act); + GET_IR_NODE_FROM_SUBGRAPH(concat_out, preceding_op_out, conv_concat_act); + GET_IR_NODE_FROM_SUBGRAPH(activation_op, activation, conv_concat_act); + GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, conv_concat_act); + + auto concat_inputs = concat_op->inputs; + for (auto node : concat_inputs) { + auto prev_op_nodes = node->inputs; + if (prev_op_nodes.size() != 1) { + LOG(WARNING) + << "Operator connected to concat can have only one output."; + return; + } + + bool is_not_conv_mkldnn = + !(prev_op_nodes[0]->Op()->GetAttrIfExists("use_mkldnn")); + if (prev_op_nodes[0]->Op()->Type() != "conv2d" || is_not_conv_mkldnn) { + LOG(WARNING) + << "This fuse pass supports only conv2d (mkldnn) + activation."; + return; + } + } + + for (auto node : concat_inputs) { + OpDesc* conv_op = node->inputs[0]->Op(); + OpDesc* act_op = activation_op->Op(); + + auto attr_map = paddle::platform::GetAttributeMap(act_type); + for (const auto& attrs : attr_map) { + if (act_op->HasAttr(attrs.first)) { + conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first)); + } + } + + if (act_type == "gelu" && act_op->HasAttr("approximate")) { + act_type = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) + ? "gelu_tanh" + : "gelu_erf"; + conv_op->SetAttr("fuse_alpha", 0.0f); + conv_op->SetAttr("fuse_beta", 0.0f); + } + conv_op->SetAttr("fuse_activation", act_type); + } + + concat_op->Op()->SetOutput("Out", {activation_out->Name()}); + GraphSafeRemoveNodes(graph, {activation_op, concat_out}); + IR_NODE_LINK_TO(concat_op, activation_out); + + found_conv_concat_activation_count++; + }; + gpd(graph, handler); + AddStatis(found_conv_concat_activation_count); + if ((!Has("disable_logs") || !Get("disable_logs")) && + found_conv_concat_activation_count > 0) { + PrettyLogDetail("--- fused %d conv_concat with %s activation", + found_conv_concat_activation_count, + act_type); + } +} + ConvActivationMkldnnFusePass::ConvActivationMkldnnFusePass() { AddOpCompat(OpCompat("conv2d")) .AddInput("Input") @@ -136,6 +218,20 @@ ConvActivationMkldnnFusePass::ConvActivationMkldnnFusePass() { .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .End(); + AddOpCompat(OpCompat("concat")) + .AddInput("X") + .End() + .AddInput("AxisTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumGE(0) + .End(); + AddOpCompat(OpCompat("relu")) .AddInput("X") .IsTensor() @@ -276,6 +372,7 @@ REGISTER_PASS_CAPABILITY(conv_activation_mkldnn_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("conv2d", 1) + .EQ("concat", 0) .EQ("abs", 0) .LE("clip", 1) .EQ("gelu", 0) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h index 11925e1992..b50fa8997f 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h @@ -34,6 +34,8 @@ class ConvActivationMkldnnFusePass : public FusePassBase { void FuseConvAct(Graph *graph, const std::string &conv_type, std::string &act_type) const; + + void FuseConvConcatAct(Graph *graph, std::string &act_type) const; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc deleted file mode 100644 index 239c699114..0000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h" - -#include - -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/platform/enforce.h" - -namespace paddle { -namespace framework { -namespace ir { - -ConvConcatReLUFusePass::ConvConcatReLUFusePass() { - AddOpCompat(OpCompat("conv2d")) - .AddInput("Input") - .IsTensor() - .End() - .AddInput("Filter") - .IsTensor() - .End() - .AddInput("Bias") - .IsTensor() - .IsOptional() - .End() - .AddInput("ResidualData") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Output") - .IsTensor() - .End() - .AddAttr("strides") - .IsType>() - .End() - .AddAttr("paddings") - .IsType>() - .End() - .AddAttr("padding_algorithm") - .IsOptional() - .IsStringIn({"EXPLICIT", "SAME", "VALID"}) - .End() - .AddAttr("groups") - .IsNumGE(1) - .End() - .AddAttr("dilations") - .IsType>() - .End() - .AddAttr("data_format") - .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) - .End(); - - AddOpCompat(OpCompat("concat")) - .AddInput("X") // Input("X"): vector - .End() - .AddInput("AxisTensor") - .IsTensor() - .IsOptional() - .End() - .AddOutput("Out") - .IsTensor() - .End() - .AddAttr("axis") - .IsNumGE(0) - .End(); - - AddOpCompat(OpCompat("relu")) - .AddInput("X") - .IsTensor() - .End() - .AddOutput("Out") - .IsTensor() - .End(); -} - -void ConvConcatReLUFusePass::FindConcatWithConvs( - ir::Graph* graph, - std::unordered_map* concat_with_convs_counter) const { - GraphPatternDetector gpd; - patterns::ConcatReLU concat_relu_pattern{gpd.mutable_pattern(), - "concat_relu"}; - concat_relu_pattern(); - - int found_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - VLOG(4) << "Find Concats with Convs"; - GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, concat_relu_pattern); - GET_IR_NODE_FROM_SUBGRAPH(relu_op, relu_op, concat_relu_pattern); - - auto concat_inputs = concat_op->inputs; - - for (auto node : concat_inputs) { - auto prev_op_node = node->inputs; - PADDLE_ENFORCE_EQ(prev_op_node.size(), - 1, - platform::errors::InvalidArgument( - "Node(%s) input size(%d) must be 1.", - node->Name(), - prev_op_node.size())); - auto* conv_op = prev_op_node[0]; - if (conv_op->Op()->Type() != "conv2d") return; - - FuseOptions fuse_option = FindFuseOption(*conv_op, *relu_op); - if (fuse_option == DO_NOT_FUSE) { - return; - } - } - - (*concat_with_convs_counter)[concat_op] = concat_inputs.size(); - found_count++; - }; - gpd(graph, handler); - AddStatis(found_count); -} - -void ConvConcatReLUFusePass::FuseConvConcatReLU( - ir::Graph* graph, - std::unordered_map* concat_with_convs_counter) const { - GraphPatternDetector gpd; - auto pattern = gpd.mutable_pattern(); - patterns::ConvConcatReLU conv_concat_relu(pattern, name_scope_); - conv_concat_relu(); - - int found_count = 0; - auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, - Graph* g) { - VLOG(4) << "handle ConvConcatReLU fuse"; - - GET_IR_NODE_FROM_SUBGRAPH(conv_op, conv_op, conv_concat_relu); - GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_concat_relu); - GET_IR_NODE_FROM_SUBGRAPH(concat_op, concat_op, conv_concat_relu); - GET_IR_NODE_FROM_SUBGRAPH(concat_out, concat_out, conv_concat_relu); - GET_IR_NODE_FROM_SUBGRAPH(relu_op, relu_op, conv_concat_relu); - GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_concat_relu); - - if (!concat_with_convs_counter->count(concat_op)) { - VLOG(4) << "this concat has input from non-conv2d operator"; - return; - } - - // Transform Conv node into ConvReLU node. - OpDesc* conv_desc = conv_op->Op(); - conv_desc->SetAttr("fuse_activation", std::string("relu")); - - // Remove ReLU when all Convs were transformed. - auto number_of_unfused_convs_left = - --(*concat_with_convs_counter)[concat_op]; - if (number_of_unfused_convs_left == 0) { - OpDesc* concat_desc = concat_op->Op(); - concat_desc->SetOutput("Out", - std::vector({relu_out->Name()})); - GraphSafeRemoveNodes(graph, {relu_op, concat_out}); - IR_NODE_LINK_TO(concat_op, relu_out); - } - - found_count++; - }; - gpd(graph, handler); - AddStatis(found_count); -} - -void ConvConcatReLUFusePass::ApplyImpl(ir::Graph* graph) const { - PADDLE_ENFORCE_NOT_NULL( - graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - FusePassBase::Init(name_scope_, graph); - - std::unordered_map concat_with_convs_counter; - FindConcatWithConvs(graph, &concat_with_convs_counter); - FuseConvConcatReLU(graph, &concat_with_convs_counter); -} - -} // namespace ir -} // namespace framework -} // namespace paddle - -REGISTER_PASS(conv_concat_relu_mkldnn_fuse_pass, - paddle::framework::ir::ConvConcatReLUFusePass); - -REGISTER_PASS_CAPABILITY(conv_concat_relu_mkldnn_fuse_pass) - .AddCombination( - paddle::framework::compatible::OpVersionComparatorCombination() - .LE("conv2d", 1) - .EQ("concat", 0) - .EQ("relu", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h deleted file mode 100644 index af372dbf97..0000000000 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/ir/fuse_pass_base.h" - -namespace paddle { -namespace framework { -namespace ir { - -/* - * Fuse the (multi conv) -> Concat -> ReLU -> next_op - * to a: - * (multi ConvReLU) -> Concat -> next_op. - */ - -class ConvConcatReLUFusePass : public FusePassBase { - public: - ConvConcatReLUFusePass(); - virtual ~ConvConcatReLUFusePass() {} - - protected: - void ApplyImpl(ir::Graph* graph) const override; - - void FindConcatWithConvs( - Graph* graph, - std::unordered_map* concat_with_convs_counter) const; - - void FuseConvConcatReLU( - Graph* graph, - std::unordered_map* concat_with_convs_counter) const; - - const std::string name_scope_{"conv_concat_relu_mkldnn_fuse"}; -}; - -} // namespace ir -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc index 8210bfeba4..6ab8708c7a 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc @@ -14,7 +14,7 @@ #include -#include "paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h" +#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/op_proto_maker.h" namespace paddle { @@ -47,6 +47,7 @@ void SetOp(ProgramDesc* prog, op->SetOutput("Out", outputs); } else if (type == "concat") { op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("axis", 0); op->SetInput("X", inputs); op->SetOutput("Out", outputs); } @@ -103,7 +104,7 @@ void MainTest(const ProgramDesc& prog, bool fuse_relu) { int original_nodes_num = graph->Nodes().size(); - auto pass = PassRegistry::Instance().Get("conv_concat_relu_mkldnn_fuse_pass"); + auto pass = PassRegistry::Instance().Get("conv_activation_mkldnn_fuse_pass"); graph.reset(pass->Apply(graph.release())); int current_nodes_num = graph->Nodes().size(); @@ -167,4 +168,4 @@ TEST(ConvConcatReLUFusePass, convs_and_pool_before_concat) { } // namespace framework } // namespace paddle -USE_PASS(conv_concat_relu_mkldnn_fuse_pass); +USE_PASS(conv_activation_mkldnn_fuse_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index d7e0850857..82df50bcae 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -303,7 +303,6 @@ void CpuPassStrategy::EnableMKLDNN() { // TODO(baoachun): Need to support 5-dimensional input. // "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", - "conv_concat_relu_mkldnn_fuse_pass", "conv_activation_mkldnn_fuse_pass", // "scale_matmul_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", // @@ -396,7 +395,6 @@ void CpuPassStrategy::EnableMkldnnInt8() { passes_.push_back("conv_bias_mkldnn_fuse_pass"); passes_.push_back("conv_transpose_bias_mkldnn_fuse_pass"); passes_.push_back("conv_elementwise_add_mkldnn_fuse_pass"); - passes_.push_back("conv_concat_relu_mkldnn_fuse_pass"); passes_.push_back("conv_activation_mkldnn_fuse_pass"); passes_.push_back("fc_fuse_pass"); passes_.push_back("repeated_fc_relu_fuse_pass"); diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 9fb14e4e72..a9ace73d5b 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -439,7 +439,6 @@ class Quant2Int8MkldnnPass(object): graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_transpose_bias_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass') - graph = self._apply_pass(graph, 'conv_concat_relu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_activation_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'fc_fuse_pass', ['use_gpu', 'use_fc_padding'], [False, False]) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py index 2a313bbdaa..a5d2738869 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_conv_concat_relu_mkldnn_fuse_pass.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,140 +12,136 @@ # See the License for the specific language governing permissions and # limitations under the License. -from auto_scan_test import PassAutoScanTest, SkipReasons -from program_config import TensorConfig, ProgramConfig +from auto_scan_test import PassAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig import numpy as np -import paddle.inference as paddle_infer from functools import partial -from typing import Optional, List, Callable, Dict, Any, Set import unittest - -import hypothesis -from hypothesis import given, settings, seed, example, assume import hypothesis.strategies as st -class TestConvConcatReluMkldnnFusePass(PassAutoScanTest): - - def is_program_valid(self, program_config: ProgramConfig) -> bool: - return True +class TestConvConcatActivationMkldnnFusePass(PassAutoScanTest): def sample_program_config(self, draw): - data_format = draw(st.sampled_from(["NCHW", "NHWC"])) - dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) - padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"])) - groups = draw(st.sampled_from([1, 2, 4])) - paddings = draw(st.sampled_from([[0, 3], [1, 2, 3, 4]])) - strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]])) + data_format = draw(st.sampled_from(['NCHW', 'NHWC'])) + dilations = draw(st.sampled_from([[2, 2]])) + padding_algorithm = draw(st.sampled_from(['VALID'])) + groups = draw(st.sampled_from([4])) + paddings = draw(st.sampled_from([[0, 3]])) + strides = draw(st.sampled_from([[1, 2]])) axis = draw(st.sampled_from([0])) - batch_size = draw(st.integers(min_value=1, max_value=4)) - - def generate_input(attrs): - if attrs[0]['data_format'] == "NCHW": - return np.random.random([attrs[2]['batch_size'], 48, 64, - 64]).astype(np.float32) - else: - return np.random.random([attrs[2]['batch_size'], 64, 64, - 48]).astype(np.float32) - - def generate_weight(): - return np.random.random([16, int(48 / groups), 3, - 3]).astype(np.float32) - - attrs = [{ - "data_format": data_format, - "dilations": dilations, - "padding_algorithm": padding_algorithm, - "groups": groups, - "paddings": paddings, - "strides": strides - }, { - "axis": axis - }, { - 'batch_size': batch_size - }] - - ops_config = [{ - "op_type": "conv2d", - "op_inputs": { - "Input": ["input_data1"], - "Filter": ["input_weight"] - }, - "op_outputs": { - "Output": ["conv1_output"] - }, - "op_attrs": { - "data_format": attrs[0]['data_format'], - "dilations": attrs[0]['dilations'], - "padding_algorithm": attrs[0]['padding_algorithm'], - "groups": attrs[0]['groups'], - "paddings": attrs[0]['paddings'], - "strides": attrs[0]['strides'] - } - }, { - "op_type": "conv2d", - "op_inputs": { - "Input": ["input_data2"], - "Filter": ["input_weight"] - }, - "op_outputs": { - "Output": ["conv2_output"] - }, - "op_attrs": { - "data_format": attrs[0]['data_format'], - "dilations": attrs[0]['dilations'], - "padding_algorithm": attrs[0]['padding_algorithm'], - "groups": attrs[0]['groups'], - "paddings": attrs[0]['paddings'], - "strides": attrs[0]['strides'] - } - }, { - "op_type": "concat", - "op_inputs": { - "X": ["conv1_output", "conv2_output"] - }, - "op_outputs": { - "Out": ["concat_output"] - }, - "op_attrs": { - 'axis': attrs[1]['axis'] - } - }, { - "op_type": "relu", - "op_inputs": { - "X": ["concat_output"] - }, - "op_outputs": { - "Out": ["relu_output"] - }, - "op_attrs": {} - }] - - ops = self.generate_op_config(ops_config) + activation_type = draw( + st.sampled_from([ + 'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish', + 'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid', + 'leaky_relu' + ])) + + def generate_data(input_type): + if input_type == 'NCHW': + return np.random.random([16, 48, 64, 64]).astype(np.float32) + elif input_type == 'NHWC': + return np.random.random([16, 64, 64, 48]).astype(np.float32) + elif input_type == 'weights': + return np.random.random([16, int(48 / groups), 3, + 3]).astype(np.float32) + + conv2d_op1 = OpConfig(type='conv2d', + inputs={ + 'Input': ['conv_input_1'], + 'Filter': ['conv_weights_1'] + }, + outputs={'Output': ['conv_output_1']}, + attrs={ + 'data_format': data_format, + 'dilations': dilations, + 'padding_algorithm': padding_algorithm, + 'groups': groups, + 'paddings': paddings, + 'strides': strides + }) + + conv2d_op2 = OpConfig(type='conv2d', + inputs={ + 'Input': ['conv_input_2'], + 'Filter': ['conv_weights_2'] + }, + outputs={'Output': ['conv_output_2']}, + attrs={ + 'data_format': data_format, + 'dilations': dilations, + 'padding_algorithm': padding_algorithm, + 'groups': groups, + 'paddings': paddings, + 'strides': strides + }) + + concat_op = OpConfig(type='concat', + inputs={'X': ['conv_output_1', 'conv_output_2']}, + outputs={'Out': ['concat_output']}, + attrs={'axis': axis}) + + if activation_type == 'relu6': + activation_op = OpConfig(activation_type, + inputs={'X': ['concat_output']}, + outputs={'Out': ['activation_output']}, + threshold=draw( + st.floats(min_value=1.0, + max_value=10.0))) + elif activation_type == 'leaky_relu': + activation_op = OpConfig(activation_type, + inputs={'X': ['concat_output']}, + outputs={'Out': ['activation_output']}, + alpha=draw( + st.floats(min_value=0.1, + max_value=1.0))) + elif activation_type == 'swish': + activation_op = OpConfig(activation_type, + inputs={'X': ['concat_output']}, + outputs={'Out': ['activation_output']}, + beta=draw( + st.floats(min_value=0.1, + max_value=1.0))) + elif activation_type == 'clip': + activation_op = OpConfig( + activation_type, + inputs={'X': ['concat_output']}, + outputs={'Out': ['activation_output']}, + min=draw(st.floats(min_value=0.1, max_value=0.49)), + max=draw(st.floats(min_value=0.5, max_value=1.0))) + else: + activation_op = OpConfig(activation_type, + inputs={'X': ['concat_output']}, + outputs={'Out': ['activation_output']}) + + model_net = [conv2d_op1, conv2d_op2, concat_op, activation_op] program_config = ProgramConfig( - ops=ops, - weights={ - "input_weight": TensorConfig(data_gen=partial(generate_weight)) - }, + ops=model_net, inputs={ - "input_data1": - TensorConfig(data_gen=partial(generate_input, attrs)), - "input_data2": - TensorConfig(data_gen=partial(generate_input, attrs)) + 'conv_input_1': + TensorConfig(data_gen=partial(generate_data, data_format)), + 'conv_input_2': + TensorConfig(data_gen=partial(generate_data, data_format)) + }, + weights={ + 'conv_weights_1': + TensorConfig(data_gen=partial(generate_data, 'weights')), + 'conv_weights_2': + TensorConfig(data_gen=partial(generate_data, 'weights')) }, - outputs=["relu_output"]) + outputs=['activation_output']) return program_config def sample_predictor_configs(self, program_config): config = self.create_inference_config(use_mkldnn=True) - yield config, ["conv2d", "conv2d", "concat"], (1e-5, 1e-5) + yield config, ['conv2d', 'conv2d', 'concat'], (1e-5, 1e-5) def test(self): self.run_and_statis(quant=False, - passes=["conv_concat_relu_mkldnn_fuse_pass"]) + passes=['conv_activation_mkldnn_fuse_pass']) -if __name__ == "__main__": +if __name__ == '__main__': unittest.main() -- GitLab