From b837689e97d1527ed46488f8bbd73d80b9d87e60 Mon Sep 17 00:00:00 2001 From: Adam <38704900+grygielski@users.noreply.github.com> Date: Thu, 15 Aug 2019 10:52:19 +0200 Subject: [PATCH] Add generalized Conv+Activation MKLDNN fuse pass creation (#19072) test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 6 +- .../framework/ir/graph_pattern_detector.cc | 29 ++++ .../framework/ir/graph_pattern_detector.h | 22 +++ .../conv_activation_mkldnn_fuse_pass.cc | 101 ++++++++++++ .../mkldnn/conv_activation_mkldnn_fuse_pass.h | 55 +++++++ ...conv_activation_mkldnn_fuse_pass_tester.cc | 145 ++++++++++++++++++ .../conv_concat_relu_mkldnn_fuse_pass.cc | 2 +- ...onv_concat_relu_mkldnn_fuse_pass_tester.cc | 7 +- .../conv_elementwise_add_mkldnn_fuse_pass.cc | 15 +- .../fluid/inference/api/mkldnn_quantizer.cc | 11 +- .../inference/api/paddle_pass_builder.cc | 5 +- paddle/fluid/operators/conv_op.cc | 16 ++ paddle/fluid/operators/conv_transpose_op.cc | 8 + .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 67 ++++---- .../mkldnn/conv_transpose_mkldnn_op.cc | 9 +- paddle/fluid/platform/mkldnn_reuse.h | 38 ++--- .../mkldnn/test_conv2d_int8_mkldnn_op.py | 32 ++-- .../unittests/mkldnn/test_conv2d_mkldnn_op.py | 23 +-- .../mkldnn/test_conv2d_transpose_mkldnn_op.py | 12 +- 19 files changed, 495 insertions(+), 108 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ad4ede2b110..44dde061851 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -86,8 +86,7 @@ if(WITH_MKLDNN) pass_library(mkldnn_placement_pass base mkldnn) pass_library(depthwise_conv_mkldnn_pass base mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn) - pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn) - pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn) + pass_library(conv_activation_mkldnn_fuse_pass inference mkldnn) pass_library(conv_concat_relu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(fc_mkldnn_pass inference mkldnn) @@ -127,8 +126,7 @@ endif() if (WITH_MKLDNN) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) - cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) - cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass) + cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass) cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 2670c129116..302c53a1d03 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -771,6 +771,35 @@ PDNode *patterns::ConvBN::operator()(paddle::framework::ir::PDNode *conv_input, return bn_out_var; } +PDNode *patterns::ConvActivation::operator()( + paddle::framework::ir::PDNode *conv_input, std::string conv_type, + std::string activation_type) { + // Create Operators + conv_input->assert_is_op_input(conv_type, "Input"); + auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op(conv_type); + auto *activation_op = + pattern->NewNode(activation_repr())->assert_is_op(activation_type); + // Create variables + // Filter + auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input(conv_type, "Filter"); + // intermediate variable, will be removed in the IR after fuse. + auto *conv_out_var = pattern->NewNode(conv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op(conv_type) + ->assert_is_op_input(activation_type); + // output + auto *activation_out_var = pattern->NewNode(activation_out_repr()) + ->AsOutput() + ->assert_is_op_output(activation_type); + + conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); + activation_op->LinksFrom({conv_out_var}).LinksTo({activation_out_var}); + return activation_out_var; +} + PDNode *patterns::ConvReLU::operator()( paddle::framework::ir::PDNode *conv_input) { // Create Operators diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index d2ad12fca07..de091ce87ed 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -431,6 +431,28 @@ struct ConvBN : public PatternBase { PATTERN_DECL_NODE(bn_saved_variance); }; +// Conv with Activation +// op: conv + activation +// named nodes: +// conv_input, conv_weight, +// conv_out, conv, +// activation_out, activation +struct ConvActivation : public PatternBase { + ConvActivation(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_activation") {} + + PDNode* operator()(PDNode* conv_input, std::string conv_type = "conv2d", + std::string activation_type = "relu"); + + // declare operator node's name + PATTERN_DECL_NODE(conv); + PATTERN_DECL_NODE(activation); + // declare variable node's name + PATTERN_DECL_NODE(conv_weight); + PATTERN_DECL_NODE(conv_out); + PATTERN_DECL_NODE(activation_out); +}; + // CONV with ReLU // op: conv + relu // named nodes: diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc new file mode 100644 index 00000000000..947beeccbdb --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" +#include +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, "graph cannot be nullptr."); + FusePassBase::Init("conv_activation_mkldnn_fuse", graph); + + GraphPatternDetector gpd; + auto* conv_input = gpd.mutable_pattern() + ->NewNode("conv_activation_mkldnn_fuse/conv_input") + ->AsInput() + ->assert_is_op_input(conv_type(), "Input"); + patterns::ConvActivation conv_activation_pattern( + gpd.mutable_pattern(), "conv_activation_mkldnn_fuse"); + conv_activation_pattern(conv_input, conv_type(), activation_type()); + + int found_conv_activation_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse"; + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, + conv_activation_pattern); // Filter + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, + conv_activation_pattern); // tmp + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_activation_pattern); // CONV op + GET_IR_NODE_FROM_SUBGRAPH(activation_out, activation_out, + conv_activation_pattern); // Out + GET_IR_NODE_FROM_SUBGRAPH(activation, activation, + conv_activation_pattern); // Activation op + + // Transform Conv node into ConvActivation node. + OpDesc* desc = conv->Op(); + desc->SetOutput("Output", + std::vector({activation_out->Name()})); + + desc->SetAttr("fuse_activation", activation_type()); + + // MKLDNN ops use alpha and beta as activation parameters but paddle ops are + // not generalized + if (activation_type() == "relu6") { + desc->SetAttr("fuse_alpha", + boost::get(activation->Op()->GetAttr("threshold"))); + } else { + desc->SetAttr("fuse_alpha", + activation->Op()->HasAttr("alpha") + ? boost::get(activation->Op()->GetAttr("alpha")) + : 0.0f); + } + desc->SetAttr("fuse_beta", + activation->Op()->HasAttr("beta") + ? boost::get(activation->Op()->GetAttr("beta")) + : 0.0f); + + GraphSafeRemoveNodes(graph, {activation, conv_out}); + + PADDLE_ENFORCE_GT(subgraph.count(conv_input), 0UL, + "subgraph has to contain conv_input node."); + IR_NODE_LINK_TO(conv, activation_out); + found_conv_activation_count++; + }; + + gpd(graph, handler); + + AddStatis(found_conv_activation_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv_activation_mkldnn_fuse_pass, + paddle::framework::ir::ConvActivationFusePass); + +REGISTER_PASS(conv_relu_mkldnn_fuse_pass, + paddle::framework::ir::ConvActivationFusePass); + +REGISTER_PASS(conv_leaky_relu_mkldnn_fuse_pass, + paddle::framework::ir::Conv2DLeakyReLUFusePass); + +REGISTER_PASS(conv_relu6_mkldnn_fuse_pass, + paddle::framework::ir::Conv2DReLU6FusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h new file mode 100644 index 00000000000..7c6dc238a55 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h @@ -0,0 +1,55 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { +/* + * Fuse Conv and Activation base class. + */ +class ConvActivationFusePass : public FusePassBase { + public: + virtual ~ConvActivationFusePass() {} + virtual std::string conv_type() const { return "conv2d"; } + virtual std::string activation_type() const { return "relu"; } + + protected: + void ApplyImpl(ir::Graph* graph) const override; + const std::string name_scope_{"conv_activation_mkldnn_fuse"}; +}; +/* + * Fuse Conv and LeakyReLU class + */ +class Conv2DLeakyReLUFusePass : public ConvActivationFusePass { + public: + std::string activation_type() const { return "leaky_relu"; } +}; +/* + * Fuse Conv and BoundedReLU class + */ +class Conv2DReLU6FusePass : public ConvActivationFusePass { + public: + std::string activation_type() const { return "relu6"; } +}; +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc new file mode 100644 index 00000000000..12611fb84a0 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" + +#include +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, + const std::vector& inputs, + const std::vector& outputs, bool is_activation = false, + bool use_mkldnn = false) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + op->SetAttr("name", name); + if (type == "conv2d") { + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetInput("Input", {inputs[0]}); + op->SetInput("Filter", {inputs[1]}); + op->SetInput("Bias", {inputs[2]}); + } else if (is_activation) { + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetInput("X", inputs); + if (type == "leaky_relu") { + op->SetAttr("alpha", 0.02f); + } else if (type == "relu6") { + op->SetAttr("threshold", 6.0f); + } + } + op->SetOutput("Out", outputs); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); +} + +// a->OP0->b +// b->OP1->c +// (c, weights, bias)->conv->f +// (f)->activation->g +ProgramDesc BuildProgramDesc(std::string activation) { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "c", "weights", "bias", "f", "g", + "h", "weights2", "bias2", "k", "l", "m"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::SELECTED_ROWS); + if (v == "weights" || v == "bias" || v == "weights2" || v == "bias2") { + var->SetPersistable(true); + } + } + + SetOp(&prog, "OP0", "op0", std::vector({"a"}), + std::vector({"b"})); + SetOp(&prog, "OP1", "op1", std::vector({"b"}), + std::vector({"c"})); + // conv+activation, both with MKL-DNN + SetOp(&prog, "conv2d", "conv1", + std::vector({"c", "weights", "bias"}), + std::vector({"f"}), false, true); + SetOp(&prog, activation, "activation1", std::vector({"f"}), + std::vector({"g"}), true, true); + SetOp(&prog, "OP3", "op3", std::vector({"g"}), + std::vector({"h"})); + // conv+activation, only one with MKL-DNN + SetOp(&prog, "conv2d", "conv2", + std::vector({"h", "weights2", "bias2"}), + std::vector({"k"}), false, true); + SetOp(&prog, "activation", "activation2", std::vector({"k"}), + std::vector({"l"}), true, false); + SetOp(&prog, "OP4", "op4", std::vector({"l"}), + std::vector({"m"})); + + return prog; +} + +void MainTest(std::string activation) { + auto prog = BuildProgramDesc(activation); + + std::unique_ptr graph(new ir::Graph(prog)); + + auto pass = + PassRegistry::Instance().Get("conv_" + activation + "_mkldnn_fuse_pass"); + + int original_nodes_num = graph->Nodes().size(); + + graph.reset(pass->Apply(graph.release())); + + int current_nodes_num = graph->Nodes().size(); + + // Remove 3 Nodes: CONV, activation, conv_out + // Add 1 Node: ConvActivation + EXPECT_EQ(original_nodes_num - 2, current_nodes_num); + + // Assert conv_activation op in newly generated graph + int conv_activation_count = 0; + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "conv2d") { + auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(boost::get(op->GetAttr("use_mkldnn"))); + auto op_name = boost::get(op->GetAttr("name")); + std::string fuse_activation = + op->HasAttr("fuse_activation") + ? boost::get(op->GetAttr("fuse_activation")) + : ""; + if (fuse_activation == activation) { + ++conv_activation_count; + } + // check if only "conv1" convolution is fused + if (op_name == "conv1") { + ASSERT_TRUE(op->HasAttr("fuse_activation")); + } else if (op_name == "conv2") { + ASSERT_FALSE(op->HasAttr("fuse_activation")); + } + } + } + EXPECT_EQ(conv_activation_count, 1); +} + +TEST(ConvActivationFusePass, conv_relu_fuse_pass) { MainTest("relu"); } +TEST(ConvActivationFusePass, conv_leaky_relu_fuse_pass) { + MainTest("leaky_relu"); +} +TEST(ConvActivationFusePass, conv_relu6_fuse_pass) { MainTest("relu6"); } + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(conv_activation_mkldnn_fuse_pass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc index a037a6bf909..9e8f0f0c46c 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc @@ -83,7 +83,7 @@ void ConvConcatReLUFusePass::FuseConvConcatReLU( // Transform Conv node into ConvReLU node. OpDesc* conv_desc = conv_op->Op(); - conv_desc->SetAttr("fuse_relu", true); + conv_desc->SetAttr("fuse_activation", std::string("relu")); // Remove ReLU when all Convs were transformed. auto number_of_unfused_convs_left = diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc index 0d7ddac8884..ee00a39596a 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc @@ -28,7 +28,7 @@ void SetOp(ProgramDesc* prog, const std::string& type, op->SetType(type); if (type == "conv2d") { op->SetAttr("use_mkldnn", use_mkldnn); - op->SetAttr("fuse_relu", false); + op->SetAttr("fuse_activation", std::string("")); op->SetInput("Input", {inputs[0]}); op->SetInput("Filter", {inputs[1]}); if (inputs.size() > 2) { @@ -109,8 +109,9 @@ void MainTest(const ProgramDesc& prog, bool fuse_relu) { if (node->IsOp()) { auto* op = node->Op(); if (op->Type() == "conv2d") { - ASSERT_TRUE(op->HasAttr("fuse_relu")); - bool fuse_relu_attr = boost::get(op->GetAttr("fuse_relu")); + ASSERT_TRUE(op->HasAttr("fuse_activation")); + bool fuse_relu_attr = + (boost::get(op->GetAttr("fuse_activation")) == "relu"); EXPECT_EQ(fuse_relu, fuse_relu_attr); } else if (op->Type() == "relu") { relu_count++; diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index ef7874c1c0b..ff1a3dc9c34 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -109,8 +109,11 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( if (!IsReachable(graph, elementwise_add_identity, conv_output)) return; - auto fuse_relu = HasAttribute(*conv_op, "fuse_relu"); - if (fuse_relu && *fuse_relu) return; + std::string fuse_activation = + conv_op->Op()->HasAttr("fuse_activation") + ? boost::get(conv_op->Op()->GetAttr("fuse_activation")) + : ""; + if (fuse_activation == "relu" || fuse_activation == "relu6") return; conv_op->Op()->SetInput("ResidualData", {elementwise_add_identity->Name()}); conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); @@ -179,8 +182,12 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( return; } - auto fuse_relu = HasAttribute(*residual_conv_op, "fuse_relu"); - if (fuse_relu && *fuse_relu) return; + std::string fuse_activation = + residual_conv_op->Op()->HasAttr("fuse_activation") + ? boost::get( + residual_conv_op->Op()->GetAttr("fuse_activation")) + : ""; + if (fuse_activation == "relu" || fuse_activation == "relu6") return; residual_conv_op->Op()->SetInput("ResidualData", {projection_node->Name()}); residual_conv_op->Op()->SetOutput("Output", {elementwise_add_out->Name()}); diff --git a/paddle/fluid/inference/api/mkldnn_quantizer.cc b/paddle/fluid/inference/api/mkldnn_quantizer.cc index fea56f01cb5..179e002f7dd 100644 --- a/paddle/fluid/inference/api/mkldnn_quantizer.cc +++ b/paddle/fluid/inference/api/mkldnn_quantizer.cc @@ -68,10 +68,13 @@ bool AnalysisPredictor::MkldnnQuantizer::CalculateScales() { if (is_output) { if (op->Type() == "conv2d") { // output of conv2d with relu must be unsigned - is_unsigned = (op->HasAttr("fuse_relu") && - boost::get(op->GetAttr("fuse_relu"))) || - (op->HasAttr("fuse_brelu") && - boost::get(op->GetAttr("fuse_brelu"))); + std::string fuse_activation = + op->HasAttr("fuse_activation") + ? boost::get( + op->GetAttr("fuse_activation")) + : ""; + is_unsigned = + (fuse_activation == "relu" || fuse_activation == "relu6"); } else if (op->Type() == "relu") { is_unsigned = true; } else if (op->Type() == "transpose2" || diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index e6ad93ae1bf..239161bc9ef 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -180,8 +180,9 @@ void CpuPassStrategy::EnableMKLDNN() { "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass", - "conv_relu_mkldnn_fuse_pass", // - "conv_brelu_mkldnn_fuse_pass", // + "conv_relu_mkldnn_fuse_pass", // + "conv_leaky_relu_mkldnn_fuse_pass", // + "conv_relu6_mkldnn_fuse_pass", // // Disabled due to topology-dependent speed-up // "fc_mkldnn_pass" })) { diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index d2036c611ed..cdecd816524 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -215,6 +215,14 @@ void Conv2DOpMaker::Make() { AddAttr("fuse_brelu_threshold", "(float, default false 6.0) Only used in mkldnn kernel") .SetDefault(6.0f); + AddAttr("fuse_activation", + "(string, default \"\") Only used in mkldnn kernel") + .SetDefault(""); + AddAttr("fuse_alpha", + "(float, default 0.0) Only used in mkldnn kernel") + .SetDefault(0.0f); + AddAttr("fuse_beta", "(float, default 0.0) Only used in mkldnn kernel") + .SetDefault(0.0f); AddAttr("fuse_residual_connection", "(bool, default false) Only used in mkldnn kernel. Used " "whenever convolution output is as an input to residual " @@ -352,6 +360,14 @@ void Conv3DOpMaker::Make() { .SetDefault(false); AddAttr("fuse_relu", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("fuse_activation", + "(string, default \"\") Only used in mkldnn kernel") + .SetDefault(""); + AddAttr("fuse_alpha", + "(float, default 0.0) Only used in mkldnn kernel") + .SetDefault(0.0f); + AddAttr("fuse_beta", "(float, default 0.0) Only used in mkldnn kernel") + .SetDefault(0.0f); AddAttr("fuse_residual_connection", "(bool, default false) Only used in mkldnn kernel. Used " "whenever convolution output is as an input to residual " diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 01afdd28078..e76c57abc63 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -170,6 +170,14 @@ void Conv2DTransposeOpMaker::Make() { .SetDefault(false); AddAttr("fuse_relu", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("fuse_activation", + "(string, default \"\") Only used in mkldnn kernel") + .SetDefault(""); + AddAttr("fuse_alpha", + "(float, default 0.0) Only used in mkldnn kernel") + .SetDefault(0.0f); + AddAttr("fuse_beta", "(float, default 0.0) Only used in mkldnn kernel") + .SetDefault(0.0f); AddAttr( "data_format", "(string, default NCHW) Only used in " diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index cdb827b39ba..25d8e4ee953 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -71,13 +71,14 @@ inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format, static mkldnn::memory::data_type GetDstType(bool is_int8, bool force_fp32_output, - bool fuse_relu, bool fuse_brelu, + std::string fuse_activation, bool fuse_residual_conn, const Tensor* residual_param) { auto dst_dt = mkldnn::memory::data_type::f32; // uint8_t, int8_t, float if (is_int8) { - dst_dt = (fuse_relu || fuse_brelu) ? mkldnn::memory::data_type::u8 - : mkldnn::memory::data_type::s8; + dst_dt = (fuse_activation == "relu" || fuse_activation == "relu6") + ? mkldnn::memory::data_type::u8 + : mkldnn::memory::data_type::s8; if (force_fp32_output) { dst_dt = mkldnn::memory::data_type::f32; } @@ -100,12 +101,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { if (!is_INT8) { ComputeFP32(ctx); } else { - bool fuse_relu = ctx.Attr("fuse_relu"); + std::string fuse_activation = ctx.Attr("fuse_activation"); bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); - bool fuse_brelu = ctx.Attr("fuse_brelu"); bool force_fp32_output = ctx.Attr("force_fp32_output"); auto residual_param = ctx.Input("ResidualData"); - auto dst_dt = GetDstType(true, force_fp32_output, fuse_relu, fuse_brelu, + auto dst_dt = GetDstType(true, force_fp32_output, fuse_activation, fuse_residual_conn, residual_param); if (dst_dt == mkldnn::memory::data_type::f32) { ComputeINT8(ctx); @@ -150,16 +150,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); - bool fuse_relu = ctx.Attr("fuse_relu"); + std::string fuse_activation = ctx.Attr("fuse_activation"); + float fuse_alpha = ctx.Attr("fuse_alpha"); + float fuse_beta = ctx.Attr("fuse_beta"); bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); - bool fuse_brelu = false; - float fuse_brelu_threshold = 6.0; int groups = ctx.Attr("groups"); bool is_conv3d = strides.size() == 3U; - if (!is_conv3d) { - fuse_brelu = ctx.Attr("fuse_brelu"); - fuse_brelu_threshold = ctx.Attr("fuse_brelu_threshold"); - } + // TODO(tpatejko): add support for dilation PADDLE_ENFORCE( is_conv3d @@ -180,7 +177,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // Get unique name for storing MKLDNN primitives const std::string key = platform::ConvMKLDNNHandler::GetHash( - src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations, + src_tz, weights_tz, fuse_activation, strides, paddings, dilations, groups, ctx.op().Input("Input") + ctx.op().Input("Filter")); std::vector pipeline; @@ -232,13 +229,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz, platform::MKLDNNGetDataType(), memory::format::x); conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold, + fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn, fwd_prop_kind); } else { conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, boost::none, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu, - fuse_brelu_threshold, fwd_prop_kind); + mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, + fuse_residual_conn, fwd_prop_kind); } // create mkldnn memory from input tensors (data/weights) @@ -355,12 +352,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector paddings = ctx.Attr>("paddings"); std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); - bool fuse_relu = ctx.Attr("fuse_relu"); + std::string fuse_activation = ctx.Attr("fuse_activation"); + float fuse_alpha = ctx.Attr("fuse_alpha"); + float fuse_beta = ctx.Attr("fuse_beta"); bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); - bool fuse_brelu = ctx.Attr("fuse_brelu"); - float fuse_brelu_threshold = ctx.Attr("fuse_brelu_threshold"); bool force_fp32_output = ctx.Attr("force_fp32_output"); - bool unsigned_output = fuse_relu || fuse_brelu; + bool unsigned_output = + (fuse_activation == "relu" || fuse_activation == "relu6"); PADDLE_ENFORCE(!fuse_residual_conn || !force_fp32_output, "residual fusion does not support force output with fp32"); @@ -394,7 +392,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { key.reserve(MaxKeyLength); platform::ConvMKLDNNHandler::AppendKey( &key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt, - input->format(), fuse_relu, fuse_residual_conn, fuse_brelu, + input->format(), fuse_activation, fuse_residual_conn, ctx.op().Input("Input") + ctx.op().Input("Filter")); const std::string key_conv_pd = key + "@conv_pd"; @@ -484,6 +482,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { handler.reset( new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key)); // create a conv primitive descriptor and save it for usage in backward + + // TODO(grygielski) if INT8 brelu post-op will be available, just delete + // whole if statement + if (fuse_activation == "relu6") { + fuse_activation = "relu"; + fuse_alpha = 0.0f; + } + auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training; @@ -493,13 +499,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { mkldnn::memory::format::x); conv_pd = handler->AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, bias_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu, - fuse_brelu_threshold, propagation, output_shift_scale, sum_scale); + mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, + fuse_residual_conn, propagation, output_shift_scale, sum_scale); } else { conv_pd = handler->AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, boost::none, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu, - fuse_brelu_threshold, propagation, output_shift_scale, sum_scale); + mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, + fuse_residual_conn, propagation, output_shift_scale, sum_scale); } // create mkldnn memory from input tensors (data/weights) @@ -681,11 +687,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { GetWeightsTz(weights_tz, g, is_conv3d); std::vector dst_tz = paddle::framework::vectorize2int(output_grad->dims()); - bool fuse_relu = ctx.Attr("fuse_relu"); - bool fuse_brelu = false; - if (!is_conv3d) { - fuse_brelu = ctx.Attr("fuse_brelu"); - } auto src_format = input->format(); mkldnn::memory::format weights_format = GetWeightsFormat(filter->format(), g, is_conv3d); @@ -694,8 +695,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { // as well as attributes of primitive to be created // This name will be used as key when saving info into device context const std::string key = platform::ConvMKLDNNHandler::GetHash( - src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations, - groups, ctx.op().Input("Input") + ctx.op().Input("Filter")); + src_tz, weights_tz, "", strides, paddings, dilations, groups, + ctx.op().Input("Input") + ctx.op().Input("Filter")); const std::string key_conv_pd = key + "@conv_pd"; std::vector pipeline; diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index 6d5982ab3f8..86be8f5aced 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -142,7 +142,9 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel { std::string data_format = ctx.Attr("data_format"); auto chosen_memory_format = platform::data_format_to_memory_format(data_format); - bool fuse_relu = ctx.Attr("fuse_relu"); + std::string fuse_activation = ctx.Attr("fuse_activation"); + float fuse_alpha = ctx.Attr("fuse_alpha"); + float fuse_beta = ctx.Attr("fuse_beta"); auto src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); @@ -166,11 +168,12 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz, platform::MKLDNNGetDataType(), mkldnn::memory::format::x); conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu, false, false, 0.0, fwd_prop_kind); + fuse_activation, fuse_alpha, fuse_beta, false, fwd_prop_kind); } else { conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, boost::none, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, false, false, 0.0, fwd_prop_kind); + mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, false, + fwd_prop_kind); } // create mkldnn memory from input tensors (data/weights) diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index 935c4f734f4..b3d7ff31909 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -217,8 +217,8 @@ class MKLDNNHandler { const mkldnn::memory::dims& weights_dims, const std::vector& strides, const std::vector& paddings, const std::vector& dilations, const int& groups, const mkldnn::memory::data_type& srcdt, - const mkldnn::memory::format& format, const bool& relu, - const bool& residual, const bool& brelu, const std::string& suffix) { + const mkldnn::memory::format& format, const std::string& fuse_activation, + const bool& residual, const std::string& suffix) { AppendKeyDims(key, input_dims); AppendKeyDims(key, weights_dims); @@ -232,9 +232,8 @@ class MKLDNNHandler { AppendKey(key, std::to_string(groups)); AppendKey(key, std::to_string(srcdt)); AppendKey(key, std::to_string(format)); - AppendKey(key, std::to_string(relu)); + AppendKey(key, fuse_activation); AppendKey(key, std::to_string(residual)); - AppendKey(key, std::to_string(brelu)); AppendKey(key, suffix); } @@ -1179,9 +1178,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { } mkldnn::primitive_attr CreatePostOps( - bool fuse_relu, bool fuse_residual_conn, bool fuse_brelu, - float fuse_brelu_threshold, - const std::vector output_shift_scale = {}, + std::string fuse_activation, float fuse_alpha, float fuse_beta, + bool fuse_residual_conn, const std::vector output_shift_scale = {}, float sum_scale = 1.0f) const { mkldnn::primitive_attr conv_attr; mkldnn::post_ops post_operations; @@ -1199,20 +1197,17 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { } // Fusion with ReLU layer is executed through the PostOps feature. Create a // PostOps object and configure it to execute an eltwise relu operation. - if (fuse_relu) { + if (fuse_activation == "relu" || fuse_activation == "leaky_relu") { constexpr float scale = 1.0f; - constexpr float negative_slope = 0.0f; - constexpr float placeholder = 0.0f; post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, - negative_slope, placeholder); + fuse_alpha, fuse_beta); } - if (fuse_brelu) { + if (fuse_activation == "relu6") { constexpr float scale = 1.0f; - constexpr float placeholder = 0.0f; post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_bounded_relu, - fuse_brelu_threshold, placeholder); + fuse_alpha, fuse_beta); } conv_attr.set_post_ops(post_operations); return conv_attr; @@ -1224,9 +1219,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { boost::optional bias, const mkldnn::memory::desc& dst, const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, - const bool fuse_relu, const bool fuse_residual_conn, - const bool fuse_brelu, const float fuse_brelu_threshold, - mkldnn::prop_kind fwd_prop_kind, + const std::string& fuse_activation, float fuse_alpha, float fuse_beta, + const bool fuse_residual_conn, mkldnn::prop_kind fwd_prop_kind, const std::vector output_shift_scale = {}, const float sum_scale = 1.0f) { // Conv PD has to be passed to Grad op that @@ -1259,8 +1253,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { padding_dims, mkldnn::padding_kind::zero); mkldnn::primitive_attr conv_attr = - CreatePostOps(fuse_relu, fuse_residual_conn, fuse_brelu, - fuse_brelu_threshold, output_shift_scale, sum_scale); + CreatePostOps(fuse_activation, fuse_alpha, fuse_beta, + fuse_residual_conn, output_shift_scale, sum_scale); conv_pd_.reset(new typename forward_t::primitive_desc( conv_desc, conv_attr, engine)); @@ -1343,14 +1337,12 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { // TODO(jczaja): Make hashing function more optimial static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT mkldnn::memory::dims& weights_dims, // NOLINT - const bool& fuse_relu, // NOLINT - const bool& fuse_brelu, // NOLINT + const std::string& fuse_activation, // NOLINT std::vector& strides, // NOLINT std::vector& paddings, // NOLINT std::vector& dilations, // NOLINT int groups, const std::string& suffix) { - return dims2str(input_dims) + dims2str(weights_dims) + - std::to_string(fuse_relu) + std::to_string(fuse_brelu) + + return dims2str(input_dims) + dims2str(weights_dims) + fuse_activation + dims2str(strides) + dims2str(paddings) + dims2str(dilations) + std::to_string(groups) + suffix; } diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py index b9ef447b56f..c2ebec04a15 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py @@ -83,12 +83,12 @@ class TestConv2dInt8Op(TestConv2dOp): input_residual, self.input_residual_size).astype( self.srctype) * (self.scale_out / self.scale_in_eltwise )) - if self.fuse_relu: + if self.fuse_activation == "relu": output = np.maximum(output_tmp, 0).astype(self.dsttype) else: output = output_tmp.astype(self.dsttype) else: - if self.fuse_relu: + if self.fuse_activation == "relu": output = np.maximum(np.round(output1 - output2), 0).astype(self.dsttype) else: @@ -113,12 +113,12 @@ class TestConv2dInt8Op(TestConv2dOp): input_residual, self.input_residual_size).astype( np.int32) * (self.scale_out / self.scale_in_eltwise )) - if self.fuse_relu: + if self.fuse_activation == "relu": output = np.maximum(output_tmp_res, 0).astype(self.dsttype) else: output = output_tmp_res.astype(self.dsttype) else: - if self.fuse_relu: + if self.fuse_activation == "relu": output = np.maximum(output1_tmp, 0).astype(self.dsttype) else: output = output1_tmp.astype(self.dsttype) @@ -145,7 +145,7 @@ class TestConv2dInt8Op(TestConv2dOp): 'Scale_out': self.scale_out, 'Scale_weights': self.scale_weights, 'Scale_in_eltwise': self.scale_in_eltwise, - 'fuse_relu': self.fuse_relu, + 'fuse_activation': self.fuse_activation, 'fuse_residual_connection': self.fuse_residual } self.outputs = {'Output': output} @@ -178,7 +178,7 @@ class TestConv2dInt8Op(TestConv2dOp): self.dsttype = np.int8 def init_fuse_relu(self): - self.fuse_relu = True + self.fuse_activation = "relu" def init_fuse_residual(self): self.fuse_residual = True @@ -262,11 +262,11 @@ class TestWithInput1x1Filter1x1(TestConv2dInt8Op): self.groups = 3 -def init_data_type_with_fusion(self, input_dt, fuse_relu, fuse_residual): +def init_data_type_with_fusion(self, input_dt, fuse_activation, fuse_residual): self.srctype = input_dt - self.dsttype = np.uint8 if fuse_relu else np.int8 + self.dsttype = np.uint8 if fuse_activation == "relu" else np.int8 - self.fuse_relu = fuse_relu + self.fuse_activation = fuse_activation self.fuse_residual = fuse_residual @@ -277,43 +277,43 @@ def create_test_int8_class(parent): class TestS8U8Case(parent): def init_data_type(self): - init_data_type_with_fusion(self, np.int8, True, False) + init_data_type_with_fusion(self, np.int8, "relu", False) #--------------------test conv2d s8 in and s8 out-------------------- class TestS8S8Case(parent): def init_data_type(self): - init_data_type_with_fusion(self, np.int8, False, False) + init_data_type_with_fusion(self, np.int8, "", False) #--------------------test conv2d u8 in and s8 out-------------------- class TestU8S8Case(parent): def init_data_type(self): - init_data_type_with_fusion(self, np.uint8, False, False) + init_data_type_with_fusion(self, np.uint8, "", False) #--------------------test conv2d u8 in and u8 out without residual fuse-------------------- class TestU8U8Case(parent): def init_data_type(self): - init_data_type_with_fusion(self, np.uint8, True, False) + init_data_type_with_fusion(self, np.uint8, "relu", False) #--------------------test conv2d s8 in and u8 out with residual fuse-------------------- class TestS8U8ResCase(parent): def init_data_type(self): - init_data_type_with_fusion(self, np.int8, True, True) + init_data_type_with_fusion(self, np.int8, "relu", True) #--------------------test conv2d s8 in and s8 out with residual fuse-------------------- class TestS8S8ResCase(parent): def init_data_type(self): - init_data_type_with_fusion(self, np.int8, False, True) + init_data_type_with_fusion(self, np.int8, "", True) #--------------------test conv2d u8 in and s8 out with residual fuse-------------------- class TestU8S8ResCase(parent): def init_data_type(self): - init_data_type_with_fusion(self, np.uint8, False, True) + init_data_type_with_fusion(self, np.uint8, "", True) cls_name_s8u8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "1") cls_name_s8s8 = "{0}_relu_{1}_residual_0".format(parent.__name__, "0") diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py index 6e4f0166121..756d10a9c7d 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py @@ -56,8 +56,9 @@ class TestConv2dMKLDNNOp(TestConv2dOp): def setUp(self): self.fuse_bias = False self.bias_size = None - self.fuse_relu = False - self.fuse_brelu = False + self.fuse_activation = "" + self.fuse_alpha = 0 + self.fuse_beta = 0 self.fuse_brelu_threshold = 6.0 self.fuse_residual_connection = False self.input_residual_size = None @@ -83,18 +84,18 @@ class TestConv2dMKLDNNOp(TestConv2dOp): self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype( input_residual) - if self.fuse_relu: + if self.fuse_activation == "relu": output = np.maximum(output, 0).astype(self.dsttype) - if self.fuse_brelu: - output = np.minimum( - np.maximum(output, 0), - self.fuse_brelu_threshold).astype(self.dsttype) + if self.fuse_activation == "relu6": + output = np.minimum(np.maximum(output, 0), + self.fuse_alpha).astype(self.dsttype) output = output.astype(self.dtype) self.attrs['fuse_bias'] = self.fuse_bias - self.attrs['fuse_relu'] = self.fuse_relu - self.attrs['fuse_brelu'] = self.fuse_brelu + self.attrs['fuse_activation'] = self.fuse_activation + self.attrs['fuse_alpha'] = self.fuse_alpha + self.attrs['fuse_beta'] = self.fuse_beta self.attrs['fuse_brelu_threshold'] = self.fuse_brelu_threshold self.attrs['fuse_residual_connection'] = self.fuse_residual_connection @@ -104,8 +105,8 @@ class TestConv2dMKLDNNOp(TestConv2dOp): class TestWithbreluFusion(TestConv2dMKLDNNOp): def init_test_case(self): TestConv2dMKLDNNOp.init_test_case(self) - self.fuse_brelu = True - self.fuse_brelu_threshold = 6.0 + self.fuse_activation = "relu6" + self.fuse_alpha = 6.0 self.dsttype = np.float32 def test_check_grad(self): diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_transpose_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_transpose_mkldnn_op.py index cc72df51f1e..33f5ea7ad6f 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_transpose_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_transpose_mkldnn_op.py @@ -51,7 +51,9 @@ class TestConv2dTransposeMKLDNNOp(TestConv2dTransposeOp): self.pad = [0, 0] self.fuse_bias = False self.bias_size = None - self.fuse_relu = False + self.fuse_activation = "" + self.fuse_alpha = 0.0 + self.fuse_beta = 0.0 self.stride = [1, 1] self.dilations = [1, 1] self.input_size = [2, 3, 5, 5] # NCHW @@ -71,11 +73,13 @@ class TestConv2dTransposeMKLDNNOp(TestConv2dTransposeOp): self.attrs['fuse_bias'] = self.fuse_bias self.inputs['Bias'] = OpTest.np_dtype_to_fluid_dtype(bias) - if self.fuse_relu: + if self.fuse_activation == "relu": output = np.maximum(output, 0).astype(self.dtype) + output = output.astype(self.dtype) - self.attrs['fuse_bias'] = self.fuse_bias - self.attrs['fuse_relu'] = self.fuse_relu + self.attrs['fuse_activation'] = self.fuse_activation + self.attrs['fuse_alpha'] = self.fuse_alpha + self.attrs['fuse_beta'] = self.fuse_beta self.outputs['Output'] = output -- GitLab