From 2281ebf0f3c50a3ba5398632a3e3bc344ca634f2 Mon Sep 17 00:00:00 2001 From: guomingz Date: Wed, 22 May 2019 11:23:55 +0800 Subject: [PATCH] Enable the convolution/relu6(bounded_relu) fusion for FP32 on Intel platform. (#17130) * Relu6 is the bottleneck op for Mobilenet-v2. As the mkldnn supports the conv/relu6 fusion, we implement it fusion via cpass way. Due to the int8 enabling for this fusion will be supported in MKLDNN v0.20, so this PR is focused on the fp32 optimization. Below table shows the benchmark(FPS) which measured on skx-8180(28 cores) Batch size | with fusion | without fusion -- | -- | -- 1 | 214.7 | 53.4 50 | 1219.727 | 137.280 test=develop * Fix the format issue test=develop * Add the missing nolint comments. test=develop * Fix the typos. test=develop * Register the conv_brelu_mkldnn_fuse_pass for the MKLDNN engine. test=develop * Adjust the indentation. test=develop * Add the test_conv_brelu_mkldnn_fuse_pass case. test=develop * Slightly update the code per Baidu comments. Let the parameter definition embedded into the code. That's will make the code easy to understand. test=develop --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + .../framework/ir/graph_pattern_detector.cc | 27 ++++ .../framework/ir/graph_pattern_detector.h | 21 +++ .../ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc | 71 +++++++++ .../ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h | 39 +++++ .../conv_brelu_mkldnn_fuse_pass_tester.cc | 135 ++++++++++++++++++ .../inference/api/paddle_pass_builder.cc | 3 +- paddle/fluid/operators/conv_op.cc | 6 + .../fluid/operators/mkldnn/conv_mkldnn_op.cc | 95 +++++++----- .../mkldnn/conv_transpose_mkldnn_op.cc | 4 +- paddle/fluid/platform/mkldnn_reuse.h | 56 ++++++-- .../unittests/mkldnn/test_conv2d_mkldnn_op.py | 25 ++++ 12 files changed, 435 insertions(+), 49 deletions(-) create mode 100644 paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h create mode 100644 paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 10eac94d797..a00d183a83a 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -85,6 +85,7 @@ if(WITH_MKLDNN) pass_library(depthwise_conv_mkldnn_pass base mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn) pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn) + pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(cpu_quantize_placement_pass base mkldnn) pass_library(cpu_quantize_pass inference mkldnn) @@ -114,6 +115,7 @@ if (WITH_MKLDNN) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) + cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 0dcf064902d..789eee8aa1e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -785,6 +785,33 @@ PDNode *patterns::ConvReLU::operator()( return relu_out_var; } +PDNode *patterns::ConvBReLU::operator()( + paddle::framework::ir::PDNode *conv_input) { + // Create Operators + conv_input->assert_is_op_input("conv2d", "Input"); + auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); + auto *brelu_op = pattern->NewNode(brelu_repr())->assert_is_op("relu6"); + // Create variables + // Filter + auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("conv2d", "Filter"); + // intermediate variable, will be removed in the IR after fuse. + auto *conv_out_var = pattern->NewNode(conv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("conv2d") + ->assert_is_op_input("relu6"); + // output + auto *brelu_out_var = pattern->NewNode(brelu_out_repr()) + ->AsOutput() + ->assert_is_op_output("relu6"); + + conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); + brelu_op->LinksFrom({conv_out_var}).LinksTo({brelu_out_var}); + return brelu_out_var; +} + PDNode *patterns::SeqConvEltAddRelu::operator()( paddle::framework::ir::PDNode *seqconv_input) { // Create Operators diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 907371b56b0..1147f1e8ce0 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -449,6 +449,27 @@ struct ConvReLU : public PatternBase { PATTERN_DECL_NODE(relu_out); }; +// CONV with ReLU6 +// op: conv + relu6 +// named nodes: +// conv_input, conv_weight, +// conv_out, conv, +// relu6_out, relu6 +struct ConvBReLU : public PatternBase { + ConvBReLU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_bounded_relu") {} + + PDNode* operator()(PDNode* conv_input); + + // declare operator node's name + PATTERN_DECL_NODE(conv); + PATTERN_DECL_NODE(brelu); + // declare variable node's name + PATTERN_DECL_NODE(conv_weight); + PATTERN_DECL_NODE(conv_out); + PATTERN_DECL_NODE(brelu_out); +}; + // SEQCONV with Elementwise_Add ReLU // op: seqconv + elementwise_add + relu // named nodes: diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc new file mode 100644 index 00000000000..dd9d4486348 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h" +#include +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void ConvBReLUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init("conv_bounded_relu_mkldnn_fuse", graph); + + GraphPatternDetector gpd; + auto* conv_input = gpd.mutable_pattern() + ->NewNode("conv_bounded_relu_mkldnn_fuse/conv_input") + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvBReLU conv_brelu_pattern(gpd.mutable_pattern(), + "conv_bounded_relu_mkldnn_fuse"); + conv_brelu_pattern(conv_input); + + int found_conv_brelu_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvBoundedReLUFusePass fuse"; + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, + conv_brelu_pattern); // Filter + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_brelu_pattern); // tmp + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_brelu_pattern); // CONV op + GET_IR_NODE_FROM_SUBGRAPH(brelu_out, brelu_out, conv_brelu_pattern); // Out + GET_IR_NODE_FROM_SUBGRAPH(brelu, brelu, conv_brelu_pattern); // ReLU op + + // Transform Conv node into ConvBReLU node. + OpDesc* desc = conv->Op(); + desc->SetOutput("Output", std::vector({brelu_out->Name()})); + desc->SetAttr("fuse_brelu", true); + desc->SetAttr("fuse_brelu_threshold", brelu->Op()->GetAttr("threshold")); + + GraphSafeRemoveNodes(graph, {brelu, conv_out}); + + PADDLE_ENFORCE(subgraph.count(conv_input)); + IR_NODE_LINK_TO(conv, brelu_out); + found_conv_brelu_count++; + }; + + gpd(graph, handler); + + AddStatis(found_conv_brelu_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv_brelu_mkldnn_fuse_pass, + paddle::framework::ir::ConvBReLUFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h new file mode 100644 index 00000000000..c898be69caf --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the CONV and ReLU6 to a ConvReLU6Op. + */ +class ConvBReLUFusePass : public FusePassBase { + public: + virtual ~ConvBReLUFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc new file mode 100644 index 00000000000..5a546bfaeda --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h" + +#include +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, + const std::vector& inputs, + const std::vector& outputs, bool use_mkldnn = false) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + if (type == "conv2d") { + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("name", name); + op->SetInput("Input", {inputs[0]}); + op->SetInput("Filter", {inputs[1]}); + op->SetInput("Bias", {inputs[2]}); + } else if (type == "relu6") { + op->SetAttr("use_mkldnn", use_mkldnn); + if (use_mkldnn) { + op->SetAttr("threshold", 6.0f); + } + op->SetInput("X", inputs); + } + op->SetOutput("Out", outputs); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); +} + +// a->OP0->b +// b->OP1->c +// (c, weights, bias)->conv->f +// (f)->brelu->g +ProgramDesc BuildProgramDesc() { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "c", "weights", "bias", "f", "g", + "h", "weights2", "bias2", "k", "l"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::SELECTED_ROWS); + if (v == "weights" || v == "bias") { + var->SetPersistable(true); + } + } + + SetOp(&prog, "OP0", "op0", std::vector({"a"}), + std::vector({"b"})); + SetOp(&prog, "OP1", "op1", std::vector({"b"}), + std::vector({"c"})); + // conv+brelu, both with MKL-DNN + SetOp(&prog, "conv2d", "conv1", + std::vector({"c", "weights", "bias"}), + std::vector({"f"}), true); + SetOp(&prog, "relu6", "relu1", std::vector({"f"}), + std::vector({"g"}), true); + SetOp(&prog, "OP3", "op3", std::vector({"g"}), + std::vector({"h"})); + // conv+brelu, only one with MKL-DNN + SetOp(&prog, "conv2d", "conv2", + std::vector({"h", "weights2", "bias2"}), + std::vector({"k"}), true); + SetOp(&prog, "relu6", "relu2", std::vector({"k"}), + std::vector({"l"})); + + return prog; +} + +TEST(ConvBReLUFusePass, basic) { + auto prog = BuildProgramDesc(); + + std::unique_ptr graph(new ir::Graph(prog)); + + auto pass = PassRegistry::Instance().Get("conv_brelu_mkldnn_fuse_pass"); + + int original_nodes_num = graph->Nodes().size(); + + graph.reset(pass->Apply(graph.release())); + + int current_nodes_num = graph->Nodes().size(); + + // Remove 3 Nodes: CONV, BRELU, conv_out + // Add 1 Node: ConvBReLU + EXPECT_EQ(original_nodes_num - 2, current_nodes_num); + + // Assert conv_brelu op in newly generated graph + int conv_brelu_count = 0; + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "conv2d") { + auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(boost::get(op->GetAttr("use_mkldnn"))); + // check if only "conv1" convolution is fused + auto op_name = boost::get(op->GetAttr("name")); + if (op_name == "conv1") { + ASSERT_TRUE(op->HasAttr("fuse_brelu")); + ASSERT_TRUE(op->HasAttr("fuse_brelu_threshold")); + + bool fuse_brelu = boost::get(op->GetAttr("fuse_brelu")); + if (fuse_brelu) { + ++conv_brelu_count; + float fuse_brelu_threshold = + boost::get(op->GetAttr("fuse_brelu_threshold")); + EXPECT_EQ(fuse_brelu_threshold, 6.0f); + } + } else if (op_name == "conv2") { + ASSERT_FALSE(op->HasAttr("fuse_brelu")); + } + } + } + EXPECT_EQ(conv_brelu_count, 1); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(conv_brelu_mkldnn_fuse_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 914a07432e6..b8167b0ab32 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -153,7 +153,8 @@ void CpuPassStrategy::EnableMKLDNN() { "conv_bias_mkldnn_fuse_pass", // "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", - "conv_relu_mkldnn_fuse_pass"})) { + "conv_relu_mkldnn_fuse_pass", // + "conv_brelu_mkldnn_fuse_pass"})) { passes_.push_back(pass); } } diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 5b923f8a5eb..a6b8d0c0ace 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -209,6 +209,12 @@ void Conv2DOpMaker::Make() { .SetDefault(false); AddAttr("fuse_relu", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("fuse_brelu", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("fuse_brelu_threshold", + "(float, default false 6.0) Only used in mkldnn kernel") + .SetDefault(6.0f); AddAttr("fuse_residual_connection", "(bool, default false) Only used in mkldnn kernel. Used " "whenever convolution output is as an input to residual " diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index faf518005c8..28db85c3ec0 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -119,9 +119,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector dilations = ctx.Attr>("dilations"); bool fuse_relu = ctx.Attr("fuse_relu"); bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); + bool fuse_brelu = false; + float fuse_brelu_threshold = 6.0; int groups = ctx.Attr("groups"); - bool is_conv3d = strides.size() == 3U; + if (!is_conv3d) { + fuse_brelu = ctx.Attr("fuse_brelu"); + fuse_brelu_threshold = ctx.Attr("fuse_brelu_threshold"); + } // TODO(tpatejko): add support for dilation PADDLE_ENFORCE( is_conv3d @@ -142,8 +147,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // Get unique name for storing MKLDNN primitives const std::string key = platform::ConvMKLDNNHandler::GetHash( - src_tz, weights_tz, strides, paddings, dilations, groups, - ctx.op().Input("Input") + ctx.op().Input("Filter")); + src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations, + groups, ctx.op().Input("Input") + ctx.op().Input("Filter")); std::vector pipeline; @@ -194,11 +199,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz, platform::MKLDNNGetDataType(), memory::format::x); conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn, fwd_prop_kind); + fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold, + fwd_prop_kind); } else { conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, boost::none, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn, fwd_prop_kind); + mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu, + fuse_brelu_threshold, fwd_prop_kind); } // create mkldnn memory from input tensors (data/weights) @@ -317,13 +324,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { int groups = ctx.Attr("groups"); bool fuse_relu = ctx.Attr("fuse_relu"); bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); - + bool fuse_brelu = ctx.Attr("fuse_brelu"); bool force_fp32_output = ctx.Attr("force_fp32_output"); if (fuse_residual_conn) { PADDLE_ENFORCE(force_fp32_output != true, "residual fusion does not support force output with fp32"); } - bool is_conv3d = strides.size() == 3U; // TODO(tpatejko): add support for dilation PADDLE_ENFORCE( @@ -334,6 +340,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { "dilation in convolution is not implemented yet"); PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently"); + PADDLE_ENFORCE(fuse_brelu != true, + "int8 does not support conv/relu6 fusion currently"); const T* input_data = input->data(); @@ -341,15 +349,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector weights_tz = paddle::framework::vectorize2int(filter->dims()); int g = std::max(groups, 1); + GetWeightsTz(weights_tz, g, is_conv3d); std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); - auto dst_dt = fuse_relu ? paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType) - : paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType); + + auto dst_dt = (fuse_relu) ? paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType) + : paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType); if (force_fp32_output) { dst_dt = paddle::framework::ToMKLDNNDataType( @@ -367,8 +377,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { key.reserve(MaxKeyLength); platform::ConvMKLDNNHandler::AppendKey( &key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt, - input->format(), fuse_relu, fuse_residual_conn, + input->format(), fuse_relu, fuse_residual_conn, false /*fuse_brelu*/, ctx.op().Input("Input") + ctx.op().Input("Filter")); + const std::string key_conv_pd = key + "@conv_pd"; bool need_s8_to_u8 = false; @@ -449,22 +460,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, memory::format::x); - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, - strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn, - output_shift_scale, sum_scale, is_test); + + conv_pd = ConvFwdPrimitiveDesc( + src_md, weights_md, bias_md, dst_md, strides, paddings, + mkldnn_engine, fuse_relu, fuse_residual_conn, false /*fuse_brelu*/, + 0.0 /*fuse_brelu_threshold*/, output_shift_scale, sum_scale, + is_test); + } else { - conv_pd = - ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn, - output_shift_scale, sum_scale, is_test); + conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, + paddings, mkldnn_engine, fuse_relu, + fuse_residual_conn, false /*fuse_brelu*/, + 0.0 /*fuse_brelu_threshold*/, + output_shift_scale, sum_scale, is_test); } // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_conv_pd, conv_pd); - handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key)); - // create mkldnn memory from input tensors (data/weights) user_src_memory_p = handler->AcquireSrcMemory(user_src_md, to_void_cast(input_data)); @@ -632,11 +645,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { private: mkldnn::primitive_attr CreatePostOps( bool fuse_relu, bool fuse_residual_conn, - const std::vector output_shift_scale, float sum_scale) const { + const std::vector output_shift_scale, float sum_scale, + bool fuse_brelu, float fuse_brelu_threshold) const { mkldnn::primitive_attr conv_attr; mkldnn::post_ops post_operations; int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; conv_attr.set_output_scales(mask, output_shift_scale); + if (fuse_residual_conn) { post_operations.append_sum(sum_scale); } @@ -647,6 +662,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, negative_slope, placeholder); } + if (fuse_brelu) { + constexpr float scale = 1.0f; + constexpr float placeholder = 0.0f; // beta + post_operations.append_eltwise(scale, + mkldnn::algorithm::eltwise_bounded_relu, + fuse_brelu_threshold, placeholder); + } conv_attr.set_post_ops(post_operations); return conv_attr; } @@ -656,7 +678,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const memory::desc& dst, const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn, + const bool fuse_residual_conn, const bool fuse_brelu, + const float fuse_brelu_threshold, const std::vector output_shift_scale, const float sum_scale, bool is_test) const { memory::dims stride_dims = {strides[0], strides[1]}; @@ -668,9 +691,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto conv_desc = mkldnn::convolution_forward::desc( propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - - mkldnn::primitive_attr conv_attr = CreatePostOps( - fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale); + mkldnn::primitive_attr conv_attr = + CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale, + sum_scale, fuse_brelu, fuse_brelu_threshold); auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( conv_desc, conv_attr, engine); @@ -685,7 +708,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn, + const bool fuse_residual_conn, const bool fuse_brelu, + const float fuse_brelu_threshold, const std::vector output_shift_scale, const float sum_scale, bool is_test) const { memory::dims stride_dims = {strides[0], strides[1]}; @@ -698,8 +722,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { propagation, mkldnn::convolution_direct, src, weights, bias, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - mkldnn::primitive_attr conv_attr = CreatePostOps( - fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale); + mkldnn::primitive_attr conv_attr = + CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale, + sum_scale, fuse_brelu, fuse_brelu_threshold); auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( conv_desc, conv_attr, engine); @@ -762,7 +787,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { GetWeightsTz(weights_tz, g, is_conv3d); std::vector dst_tz = paddle::framework::vectorize2int(output_grad->dims()); - + bool fuse_relu = ctx.Attr("fuse_relu"); + bool fuse_brelu = false; + if (!is_conv3d) { + fuse_brelu = ctx.Attr("fuse_brelu"); + } auto src_format = input->format(); mkldnn::memory::format weights_format = GetWeightsFormat(filter->format(), g, is_conv3d); @@ -771,8 +800,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { // as well as attributes of primitive to be created // This name will be used as key when saving info into device context const std::string key = platform::ConvMKLDNNHandler::GetHash( - src_tz, weights_tz, strides, paddings, dilations, groups, - ctx.op().Input("Input") + ctx.op().Input("Filter")); + src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations, + groups, ctx.op().Input("Input") + ctx.op().Input("Filter")); const std::string key_conv_pd = key + "@conv_pd"; std::vector pipeline; diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index 30d2469eeaf..95494bce5a6 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -166,11 +166,11 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz, platform::MKLDNNGetDataType(), mkldnn::memory::format::x); conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu, false, fwd_prop_kind); + fuse_relu, false, false, 0.0, fwd_prop_kind); } else { conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, boost::none, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, false, fwd_prop_kind); + mkldnn_engine, fuse_relu, false, false, 0.0, fwd_prop_kind); } // create mkldnn memory from input tensors (data/weights) diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index ba3a82b4b07..4011f08cea8 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -212,25 +212,29 @@ class MKLDNNHandler { dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast(output_data))); } - static void AppendKey(std::string* key, - const mkldnn::memory::dims& input_dims, - const mkldnn::memory::dims& weights_dims, - const std::vector& strides, - const std::vector& paddings, - const std::vector& dilations, const int& groups, - const mkldnn::memory::data_type& srcdt, - const mkldnn::memory::format& format, const bool& relu, - const bool& residual, const std::string& suffix) { + static void AppendKey( + std::string* key, const mkldnn::memory::dims& input_dims, + const mkldnn::memory::dims& weights_dims, const std::vector& strides, + const std::vector& paddings, const std::vector& dilations, + const int& groups, const mkldnn::memory::data_type& srcdt, + const mkldnn::memory::format& format, const bool& relu, + const bool& residual, const bool& brelu, const std::string& suffix) { AppendKeyDims(key, input_dims); + AppendKeyDims(key, weights_dims); + AppendKeyVec(key, strides); + AppendKeyVec(key, paddings); + AppendKeyVec(key, dilations); + AppendKey(key, std::to_string(groups)); AppendKey(key, std::to_string(srcdt)); AppendKey(key, std::to_string(format)); AppendKey(key, std::to_string(relu)); AppendKey(key, std::to_string(residual)); + AppendKey(key, std::to_string(brelu)); AppendKey(key, suffix); } @@ -562,8 +566,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { scale_data, mask); } - mkldnn::primitive_attr CreatePostOps(bool fuse_relu, - bool fuse_residual_conn = false) const { + mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn, + bool fuse_brelu, + float fuse_brelu_threshold) const { mkldnn::primitive_attr conv_attr; mkldnn::post_ops post_operations; // Fusion with Elementwise layer relies on adding a sum post-operation with @@ -583,6 +588,14 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, negative_slope, placeholder); } + + if (fuse_brelu) { + constexpr float scale = 1.0f; + constexpr float placeholder = 0.0f; + post_operations.append_eltwise(scale, + mkldnn::algorithm::eltwise_bounded_relu, + fuse_brelu_threshold, placeholder); + } conv_attr.set_post_ops(post_operations); return conv_attr; } @@ -594,6 +607,7 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { const mkldnn::memory::desc& dst, const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, const bool fuse_residual_conn, + const bool fuse_brelu, const float fuse_brelu_threshold, mkldnn::prop_kind fwd_prop_kind) { const std::string key_conv_pd = key_ + "@conv_pd"; @@ -614,8 +628,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { weights, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - mkldnn::primitive_attr conv_attr = - CreatePostOps(fuse_relu, fuse_residual_conn); + mkldnn::primitive_attr conv_attr = CreatePostOps( + fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold); conv_pd_.reset( new typename forward_t::primitive_desc(conv_desc, conv_attr, engine)); @@ -714,6 +728,22 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { return conv_bwd_data_p; } + // Generate keys for storing/retriving primitives for this operator + // TODO(jczaja): Make hashing function more optimial + static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT + mkldnn::memory::dims& weights_dims, // NOLINT + const bool& fuse_relu, // NOLINT + const bool& fuse_brelu, // NOLINT + std::vector& strides, // NOLINT + std::vector& paddings, // NOLINT + std::vector& dilations, // NOLINT + int groups, const std::string& suffix) { + return dims2str(input_dims) + dims2str(weights_dims) + + std::to_string(fuse_relu) + std::to_string(fuse_brelu) + + dims2str(strides) + dims2str(paddings) + dims2str(dilations) + + std::to_string(groups) + suffix; + } + // Generate keys for storing/retriving primitives for this operator // TODO(jczaja): Make hashing function more optimial static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py index 28b670d7ab3..6e4f0166121 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py @@ -57,6 +57,8 @@ class TestConv2dMKLDNNOp(TestConv2dOp): self.fuse_bias = False self.bias_size = None self.fuse_relu = False + self.fuse_brelu = False + self.fuse_brelu_threshold = 6.0 self.fuse_residual_connection = False self.input_residual_size = None TestConv2dOp.setUp(self) @@ -84,15 +86,38 @@ class TestConv2dMKLDNNOp(TestConv2dOp): if self.fuse_relu: output = np.maximum(output, 0).astype(self.dsttype) + if self.fuse_brelu: + output = np.minimum( + np.maximum(output, 0), + self.fuse_brelu_threshold).astype(self.dsttype) output = output.astype(self.dtype) self.attrs['fuse_bias'] = self.fuse_bias self.attrs['fuse_relu'] = self.fuse_relu + self.attrs['fuse_brelu'] = self.fuse_brelu + self.attrs['fuse_brelu_threshold'] = self.fuse_brelu_threshold self.attrs['fuse_residual_connection'] = self.fuse_residual_connection self.outputs['Output'] = output +class TestWithbreluFusion(TestConv2dMKLDNNOp): + def init_test_case(self): + TestConv2dMKLDNNOp.init_test_case(self) + self.fuse_brelu = True + self.fuse_brelu_threshold = 6.0 + self.dsttype = np.float32 + + def test_check_grad(self): + pass + + def test_check_grad_no_filter(self): + pass + + def test_check_grad_no_input(self): + pass + + class TestWithFuse(TestConv2dMKLDNNOp): def init_test_case(self): TestConv2dMKLDNNOp.init_test_case(self) -- GitLab