diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 10eac94d797a92a2dc4db55087b20cca5c1618ba..a00d183a83af386709c4231498b0e3471b42d794 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -85,6 +85,7 @@ if(WITH_MKLDNN) pass_library(depthwise_conv_mkldnn_pass base mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference mkldnn) pass_library(conv_relu_mkldnn_fuse_pass inference mkldnn) + pass_library(conv_brelu_mkldnn_fuse_pass inference mkldnn) pass_library(conv_elementwise_add_mkldnn_fuse_pass inference mkldnn) pass_library(cpu_quantize_placement_pass base mkldnn) pass_library(cpu_quantize_pass inference mkldnn) @@ -114,6 +115,7 @@ if (WITH_MKLDNN) cc_test(test_depthwise_conv_mkldnn_pass SRCS mkldnn/depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass) cc_test(test_conv_bias_mkldnn_fuse_pass SRCS mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc DEPS conv_bias_mkldnn_fuse_pass naive_executor) cc_test(test_conv_relu_mkldnn_fuse_pass SRCS mkldnn/conv_relu_mkldnn_fuse_pass_tester.cc DEPS conv_relu_mkldnn_fuse_pass) + cc_test(test_conv_brelu_mkldnn_fuse_pass SRCS mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc DEPS conv_brelu_mkldnn_fuse_pass) cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass) cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 0dcf064902d1c1c6cb034421cedea0387b6e0505..789eee8aa1e47ea164e3a6ba70ea85955eece37a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -785,6 +785,33 @@ PDNode *patterns::ConvReLU::operator()( return relu_out_var; } +PDNode *patterns::ConvBReLU::operator()( + paddle::framework::ir::PDNode *conv_input) { + // Create Operators + conv_input->assert_is_op_input("conv2d", "Input"); + auto *conv_op = pattern->NewNode(conv_repr())->assert_is_op("conv2d"); + auto *brelu_op = pattern->NewNode(brelu_repr())->assert_is_op("relu6"); + // Create variables + // Filter + auto *conv_weight_var = pattern->NewNode(conv_weight_repr()) + ->AsInput() + ->assert_is_persistable_var() + ->assert_is_op_input("conv2d", "Filter"); + // intermediate variable, will be removed in the IR after fuse. + auto *conv_out_var = pattern->NewNode(conv_out_repr()) + ->AsIntermediate() + ->assert_is_only_output_of_op("conv2d") + ->assert_is_op_input("relu6"); + // output + auto *brelu_out_var = pattern->NewNode(brelu_out_repr()) + ->AsOutput() + ->assert_is_op_output("relu6"); + + conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var}); + brelu_op->LinksFrom({conv_out_var}).LinksTo({brelu_out_var}); + return brelu_out_var; +} + PDNode *patterns::SeqConvEltAddRelu::operator()( paddle::framework::ir::PDNode *seqconv_input) { // Create Operators diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 907371b56b06dcd66297adedea6c17b61d9b5e38..1147f1e8ce00294cd0e7886e257c9d7a41ca289c 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -449,6 +449,27 @@ struct ConvReLU : public PatternBase { PATTERN_DECL_NODE(relu_out); }; +// CONV with ReLU6 +// op: conv + relu6 +// named nodes: +// conv_input, conv_weight, +// conv_out, conv, +// relu6_out, relu6 +struct ConvBReLU : public PatternBase { + ConvBReLU(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "conv_bounded_relu") {} + + PDNode* operator()(PDNode* conv_input); + + // declare operator node's name + PATTERN_DECL_NODE(conv); + PATTERN_DECL_NODE(brelu); + // declare variable node's name + PATTERN_DECL_NODE(conv_weight); + PATTERN_DECL_NODE(conv_out); + PATTERN_DECL_NODE(brelu_out); +}; + // SEQCONV with Elementwise_Add ReLU // op: seqconv + elementwise_add + relu // named nodes: diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..dd9d448634806377b5f62b045f2ff59f65529780 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.cc @@ -0,0 +1,71 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h" +#include +#include +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void ConvBReLUFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE(graph); + FusePassBase::Init("conv_bounded_relu_mkldnn_fuse", graph); + + GraphPatternDetector gpd; + auto* conv_input = gpd.mutable_pattern() + ->NewNode("conv_bounded_relu_mkldnn_fuse/conv_input") + ->AsInput() + ->assert_is_op_input("conv2d", "Input"); + patterns::ConvBReLU conv_brelu_pattern(gpd.mutable_pattern(), + "conv_bounded_relu_mkldnn_fuse"); + conv_brelu_pattern(conv_input); + + int found_conv_brelu_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "handle ConvBoundedReLUFusePass fuse"; + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, + conv_brelu_pattern); // Filter + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_brelu_pattern); // tmp + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_brelu_pattern); // CONV op + GET_IR_NODE_FROM_SUBGRAPH(brelu_out, brelu_out, conv_brelu_pattern); // Out + GET_IR_NODE_FROM_SUBGRAPH(brelu, brelu, conv_brelu_pattern); // ReLU op + + // Transform Conv node into ConvBReLU node. + OpDesc* desc = conv->Op(); + desc->SetOutput("Output", std::vector({brelu_out->Name()})); + desc->SetAttr("fuse_brelu", true); + desc->SetAttr("fuse_brelu_threshold", brelu->Op()->GetAttr("threshold")); + + GraphSafeRemoveNodes(graph, {brelu, conv_out}); + + PADDLE_ENFORCE(subgraph.count(conv_input)); + IR_NODE_LINK_TO(conv, brelu_out); + found_conv_brelu_count++; + }; + + gpd(graph, handler); + + AddStatis(found_conv_brelu_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(conv_brelu_mkldnn_fuse_pass, + paddle::framework::ir::ConvBReLUFusePass); diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..c898be69caf049d2de14f13714036a8f45508f98 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the CONV and ReLU6 to a ConvReLU6Op. + */ +class ConvBReLUFusePass : public FusePassBase { + public: + virtual ~ConvBReLUFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc new file mode 100644 index 0000000000000000000000000000000000000000..5a546bfaedadf4d7038a0636098936c2ffd7ed72 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass_tester.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/mkldnn/conv_brelu_mkldnn_fuse_pass.h" + +#include +#include "paddle/fluid/framework/op_proto_maker.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, + const std::vector& inputs, + const std::vector& outputs, bool use_mkldnn = false) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(type); + if (type == "conv2d") { + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("name", name); + op->SetInput("Input", {inputs[0]}); + op->SetInput("Filter", {inputs[1]}); + op->SetInput("Bias", {inputs[2]}); + } else if (type == "relu6") { + op->SetAttr("use_mkldnn", use_mkldnn); + if (use_mkldnn) { + op->SetAttr("threshold", 6.0f); + } + op->SetInput("X", inputs); + } + op->SetOutput("Out", outputs); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); +} + +// a->OP0->b +// b->OP1->c +// (c, weights, bias)->conv->f +// (f)->brelu->g +ProgramDesc BuildProgramDesc() { + ProgramDesc prog; + for (auto& v : + std::vector({"a", "b", "c", "weights", "bias", "f", "g", + "h", "weights2", "bias2", "k", "l"})) { + auto* var = prog.MutableBlock(0)->Var(v); + var->SetType(proto::VarType::SELECTED_ROWS); + if (v == "weights" || v == "bias") { + var->SetPersistable(true); + } + } + + SetOp(&prog, "OP0", "op0", std::vector({"a"}), + std::vector({"b"})); + SetOp(&prog, "OP1", "op1", std::vector({"b"}), + std::vector({"c"})); + // conv+brelu, both with MKL-DNN + SetOp(&prog, "conv2d", "conv1", + std::vector({"c", "weights", "bias"}), + std::vector({"f"}), true); + SetOp(&prog, "relu6", "relu1", std::vector({"f"}), + std::vector({"g"}), true); + SetOp(&prog, "OP3", "op3", std::vector({"g"}), + std::vector({"h"})); + // conv+brelu, only one with MKL-DNN + SetOp(&prog, "conv2d", "conv2", + std::vector({"h", "weights2", "bias2"}), + std::vector({"k"}), true); + SetOp(&prog, "relu6", "relu2", std::vector({"k"}), + std::vector({"l"})); + + return prog; +} + +TEST(ConvBReLUFusePass, basic) { + auto prog = BuildProgramDesc(); + + std::unique_ptr graph(new ir::Graph(prog)); + + auto pass = PassRegistry::Instance().Get("conv_brelu_mkldnn_fuse_pass"); + + int original_nodes_num = graph->Nodes().size(); + + graph.reset(pass->Apply(graph.release())); + + int current_nodes_num = graph->Nodes().size(); + + // Remove 3 Nodes: CONV, BRELU, conv_out + // Add 1 Node: ConvBReLU + EXPECT_EQ(original_nodes_num - 2, current_nodes_num); + + // Assert conv_brelu op in newly generated graph + int conv_brelu_count = 0; + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "conv2d") { + auto* op = node->Op(); + ASSERT_TRUE(op->HasAttr("use_mkldnn")); + EXPECT_TRUE(boost::get(op->GetAttr("use_mkldnn"))); + // check if only "conv1" convolution is fused + auto op_name = boost::get(op->GetAttr("name")); + if (op_name == "conv1") { + ASSERT_TRUE(op->HasAttr("fuse_brelu")); + ASSERT_TRUE(op->HasAttr("fuse_brelu_threshold")); + + bool fuse_brelu = boost::get(op->GetAttr("fuse_brelu")); + if (fuse_brelu) { + ++conv_brelu_count; + float fuse_brelu_threshold = + boost::get(op->GetAttr("fuse_brelu_threshold")); + EXPECT_EQ(fuse_brelu_threshold, 6.0f); + } + } else if (op_name == "conv2") { + ASSERT_FALSE(op->HasAttr("fuse_brelu")); + } + } + } + EXPECT_EQ(conv_brelu_count, 1); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(conv_brelu_mkldnn_fuse_pass); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 914a07432e682326d9f22d9317792bcf2351e49e..b8167b0ab324c82d545f8ee3bdbf700804d05ce4 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -153,7 +153,8 @@ void CpuPassStrategy::EnableMKLDNN() { "conv_bias_mkldnn_fuse_pass", // "conv3d_bias_mkldnn_fuse_pass", // "conv_elementwise_add_mkldnn_fuse_pass", - "conv_relu_mkldnn_fuse_pass"})) { + "conv_relu_mkldnn_fuse_pass", // + "conv_brelu_mkldnn_fuse_pass"})) { passes_.push_back(pass); } } diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 5b923f8a5eb58cfdf5809c677dfc915a68c64aae..a6b8d0c0ace140969f1c5611044c07571b3d421a 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -209,6 +209,12 @@ void Conv2DOpMaker::Make() { .SetDefault(false); AddAttr("fuse_relu", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("fuse_brelu", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("fuse_brelu_threshold", + "(float, default false 6.0) Only used in mkldnn kernel") + .SetDefault(6.0f); AddAttr("fuse_residual_connection", "(bool, default false) Only used in mkldnn kernel. Used " "whenever convolution output is as an input to residual " diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index faf518005c8cb0958dd5b0bbfc5c6fc4b3c2b582..28db85c3ec0b7cc138164fabea097a58af67caeb 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -119,9 +119,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector dilations = ctx.Attr>("dilations"); bool fuse_relu = ctx.Attr("fuse_relu"); bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); + bool fuse_brelu = false; + float fuse_brelu_threshold = 6.0; int groups = ctx.Attr("groups"); - bool is_conv3d = strides.size() == 3U; + if (!is_conv3d) { + fuse_brelu = ctx.Attr("fuse_brelu"); + fuse_brelu_threshold = ctx.Attr("fuse_brelu_threshold"); + } // TODO(tpatejko): add support for dilation PADDLE_ENFORCE( is_conv3d @@ -142,8 +147,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { // Get unique name for storing MKLDNN primitives const std::string key = platform::ConvMKLDNNHandler::GetHash( - src_tz, weights_tz, strides, paddings, dilations, groups, - ctx.op().Input("Input") + ctx.op().Input("Filter")); + src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations, + groups, ctx.op().Input("Input") + ctx.op().Input("Filter")); std::vector pipeline; @@ -194,11 +199,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz, platform::MKLDNNGetDataType(), memory::format::x); conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn, fwd_prop_kind); + fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold, + fwd_prop_kind); } else { conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, boost::none, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn, fwd_prop_kind); + mkldnn_engine, fuse_relu, fuse_residual_conn, fuse_brelu, + fuse_brelu_threshold, fwd_prop_kind); } // create mkldnn memory from input tensors (data/weights) @@ -317,13 +324,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { int groups = ctx.Attr("groups"); bool fuse_relu = ctx.Attr("fuse_relu"); bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); - + bool fuse_brelu = ctx.Attr("fuse_brelu"); bool force_fp32_output = ctx.Attr("force_fp32_output"); if (fuse_residual_conn) { PADDLE_ENFORCE(force_fp32_output != true, "residual fusion does not support force output with fp32"); } - bool is_conv3d = strides.size() == 3U; // TODO(tpatejko): add support for dilation PADDLE_ENFORCE( @@ -334,6 +340,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { "dilation in convolution is not implemented yet"); PADDLE_ENFORCE(is_conv3d != true, "int8 does not support conv3d currently"); + PADDLE_ENFORCE(fuse_brelu != true, + "int8 does not support conv/relu6 fusion currently"); const T* input_data = input->data(); @@ -341,15 +349,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector weights_tz = paddle::framework::vectorize2int(filter->dims()); int g = std::max(groups, 1); + GetWeightsTz(weights_tz, g, is_conv3d); std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); mkldnn::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(input->type()); - auto dst_dt = fuse_relu ? paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType) - : paddle::framework::ToMKLDNNDataType( - framework::DataTypeTrait::DataType); + + auto dst_dt = (fuse_relu) ? paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType) + : paddle::framework::ToMKLDNNDataType( + framework::DataTypeTrait::DataType); if (force_fp32_output) { dst_dt = paddle::framework::ToMKLDNNDataType( @@ -367,8 +377,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { key.reserve(MaxKeyLength); platform::ConvMKLDNNHandler::AppendKey( &key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt, - input->format(), fuse_relu, fuse_residual_conn, + input->format(), fuse_relu, fuse_residual_conn, false /*fuse_brelu*/, ctx.op().Input("Input") + ctx.op().Input("Filter")); + const std::string key_conv_pd = key + "@conv_pd"; bool need_s8_to_u8 = false; @@ -449,22 +460,24 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz = paddle::framework::vectorize2int(bias->dims()); auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, memory::format::x); - conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, - strides, paddings, mkldnn_engine, - fuse_relu, fuse_residual_conn, - output_shift_scale, sum_scale, is_test); + + conv_pd = ConvFwdPrimitiveDesc( + src_md, weights_md, bias_md, dst_md, strides, paddings, + mkldnn_engine, fuse_relu, fuse_residual_conn, false /*fuse_brelu*/, + 0.0 /*fuse_brelu_threshold*/, output_shift_scale, sum_scale, + is_test); + } else { - conv_pd = - ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, fuse_residual_conn, - output_shift_scale, sum_scale, is_test); + conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, + paddings, mkldnn_engine, fuse_relu, + fuse_residual_conn, false /*fuse_brelu*/, + 0.0 /*fuse_brelu_threshold*/, + output_shift_scale, sum_scale, is_test); } // Save conv_pd/src_memory/weights_memory for backward pass dev_ctx.SetBlob(key_conv_pd, conv_pd); - handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key)); - // create mkldnn memory from input tensors (data/weights) user_src_memory_p = handler->AcquireSrcMemory(user_src_md, to_void_cast(input_data)); @@ -632,11 +645,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { private: mkldnn::primitive_attr CreatePostOps( bool fuse_relu, bool fuse_residual_conn, - const std::vector output_shift_scale, float sum_scale) const { + const std::vector output_shift_scale, float sum_scale, + bool fuse_brelu, float fuse_brelu_threshold) const { mkldnn::primitive_attr conv_attr; mkldnn::post_ops post_operations; int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0; conv_attr.set_output_scales(mask, output_shift_scale); + if (fuse_residual_conn) { post_operations.append_sum(sum_scale); } @@ -647,6 +662,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, negative_slope, placeholder); } + if (fuse_brelu) { + constexpr float scale = 1.0f; + constexpr float placeholder = 0.0f; // beta + post_operations.append_eltwise(scale, + mkldnn::algorithm::eltwise_bounded_relu, + fuse_brelu_threshold, placeholder); + } conv_attr.set_post_ops(post_operations); return conv_attr; } @@ -656,7 +678,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const memory::desc& dst, const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn, + const bool fuse_residual_conn, const bool fuse_brelu, + const float fuse_brelu_threshold, const std::vector output_shift_scale, const float sum_scale, bool is_test) const { memory::dims stride_dims = {strides[0], strides[1]}; @@ -668,9 +691,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto conv_desc = mkldnn::convolution_forward::desc( propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - - mkldnn::primitive_attr conv_attr = CreatePostOps( - fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale); + mkldnn::primitive_attr conv_attr = + CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale, + sum_scale, fuse_brelu, fuse_brelu_threshold); auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( conv_desc, conv_attr, engine); @@ -685,7 +708,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, - const bool fuse_residual_conn, + const bool fuse_residual_conn, const bool fuse_brelu, + const float fuse_brelu_threshold, const std::vector output_shift_scale, const float sum_scale, bool is_test) const { memory::dims stride_dims = {strides[0], strides[1]}; @@ -698,8 +722,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { propagation, mkldnn::convolution_direct, src, weights, bias, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - mkldnn::primitive_attr conv_attr = CreatePostOps( - fuse_relu, fuse_residual_conn, output_shift_scale, sum_scale); + mkldnn::primitive_attr conv_attr = + CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale, + sum_scale, fuse_brelu, fuse_brelu_threshold); auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( conv_desc, conv_attr, engine); @@ -762,7 +787,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { GetWeightsTz(weights_tz, g, is_conv3d); std::vector dst_tz = paddle::framework::vectorize2int(output_grad->dims()); - + bool fuse_relu = ctx.Attr("fuse_relu"); + bool fuse_brelu = false; + if (!is_conv3d) { + fuse_brelu = ctx.Attr("fuse_brelu"); + } auto src_format = input->format(); mkldnn::memory::format weights_format = GetWeightsFormat(filter->format(), g, is_conv3d); @@ -771,8 +800,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { // as well as attributes of primitive to be created // This name will be used as key when saving info into device context const std::string key = platform::ConvMKLDNNHandler::GetHash( - src_tz, weights_tz, strides, paddings, dilations, groups, - ctx.op().Input("Input") + ctx.op().Input("Filter")); + src_tz, weights_tz, fuse_relu, fuse_brelu, strides, paddings, dilations, + groups, ctx.op().Input("Input") + ctx.op().Input("Filter")); const std::string key_conv_pd = key + "@conv_pd"; std::vector pipeline; diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index 30d2469eeaf6938f1f93730b8b645ca2cfe97364..95494bce5a667142a5a850d3f0f44013fd8dd1b1 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -166,11 +166,11 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel { bias_tz, platform::MKLDNNGetDataType(), mkldnn::memory::format::x); conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, - fuse_relu, false, fwd_prop_kind); + fuse_relu, false, false, 0.0, fwd_prop_kind); } else { conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( src_md, weights_md, boost::none, dst_md, strides, paddings, - mkldnn_engine, fuse_relu, false, fwd_prop_kind); + mkldnn_engine, fuse_relu, false, false, 0.0, fwd_prop_kind); } // create mkldnn memory from input tensors (data/weights) diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index ba3a82b4b07f4dcb3f0037e398c146ab167d7b57..4011f08cea8c49559f7411048c68648dfb84291a 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -212,25 +212,29 @@ class MKLDNNHandler { dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast(output_data))); } - static void AppendKey(std::string* key, - const mkldnn::memory::dims& input_dims, - const mkldnn::memory::dims& weights_dims, - const std::vector& strides, - const std::vector& paddings, - const std::vector& dilations, const int& groups, - const mkldnn::memory::data_type& srcdt, - const mkldnn::memory::format& format, const bool& relu, - const bool& residual, const std::string& suffix) { + static void AppendKey( + std::string* key, const mkldnn::memory::dims& input_dims, + const mkldnn::memory::dims& weights_dims, const std::vector& strides, + const std::vector& paddings, const std::vector& dilations, + const int& groups, const mkldnn::memory::data_type& srcdt, + const mkldnn::memory::format& format, const bool& relu, + const bool& residual, const bool& brelu, const std::string& suffix) { AppendKeyDims(key, input_dims); + AppendKeyDims(key, weights_dims); + AppendKeyVec(key, strides); + AppendKeyVec(key, paddings); + AppendKeyVec(key, dilations); + AppendKey(key, std::to_string(groups)); AppendKey(key, std::to_string(srcdt)); AppendKey(key, std::to_string(format)); AppendKey(key, std::to_string(relu)); AppendKey(key, std::to_string(residual)); + AppendKey(key, std::to_string(brelu)); AppendKey(key, suffix); } @@ -562,8 +566,9 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { scale_data, mask); } - mkldnn::primitive_attr CreatePostOps(bool fuse_relu, - bool fuse_residual_conn = false) const { + mkldnn::primitive_attr CreatePostOps(bool fuse_relu, bool fuse_residual_conn, + bool fuse_brelu, + float fuse_brelu_threshold) const { mkldnn::primitive_attr conv_attr; mkldnn::post_ops post_operations; // Fusion with Elementwise layer relies on adding a sum post-operation with @@ -583,6 +588,14 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu, negative_slope, placeholder); } + + if (fuse_brelu) { + constexpr float scale = 1.0f; + constexpr float placeholder = 0.0f; + post_operations.append_eltwise(scale, + mkldnn::algorithm::eltwise_bounded_relu, + fuse_brelu_threshold, placeholder); + } conv_attr.set_post_ops(post_operations); return conv_attr; } @@ -594,6 +607,7 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { const mkldnn::memory::desc& dst, const std::vector& strides, const std::vector& paddings, const mkldnn::engine& engine, const bool fuse_relu, const bool fuse_residual_conn, + const bool fuse_brelu, const float fuse_brelu_threshold, mkldnn::prop_kind fwd_prop_kind) { const std::string key_conv_pd = key_ + "@conv_pd"; @@ -614,8 +628,8 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { weights, dst, stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero); - mkldnn::primitive_attr conv_attr = - CreatePostOps(fuse_relu, fuse_residual_conn); + mkldnn::primitive_attr conv_attr = CreatePostOps( + fuse_relu, fuse_residual_conn, fuse_brelu, fuse_brelu_threshold); conv_pd_.reset( new typename forward_t::primitive_desc(conv_desc, conv_attr, engine)); @@ -714,6 +728,22 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { return conv_bwd_data_p; } + // Generate keys for storing/retriving primitives for this operator + // TODO(jczaja): Make hashing function more optimial + static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT + mkldnn::memory::dims& weights_dims, // NOLINT + const bool& fuse_relu, // NOLINT + const bool& fuse_brelu, // NOLINT + std::vector& strides, // NOLINT + std::vector& paddings, // NOLINT + std::vector& dilations, // NOLINT + int groups, const std::string& suffix) { + return dims2str(input_dims) + dims2str(weights_dims) + + std::to_string(fuse_relu) + std::to_string(fuse_brelu) + + dims2str(strides) + dims2str(paddings) + dims2str(dilations) + + std::to_string(groups) + suffix; + } + // Generate keys for storing/retriving primitives for this operator // TODO(jczaja): Make hashing function more optimial static std::string GetHash(mkldnn::memory::dims& input_dims, // NOLINT diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py index 28b670d7ab3267a03157b7e617504eb9a35656aa..6e4f0166121a6478399973d2c7a3aa7e1cb5506c 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_mkldnn_op.py @@ -57,6 +57,8 @@ class TestConv2dMKLDNNOp(TestConv2dOp): self.fuse_bias = False self.bias_size = None self.fuse_relu = False + self.fuse_brelu = False + self.fuse_brelu_threshold = 6.0 self.fuse_residual_connection = False self.input_residual_size = None TestConv2dOp.setUp(self) @@ -84,15 +86,38 @@ class TestConv2dMKLDNNOp(TestConv2dOp): if self.fuse_relu: output = np.maximum(output, 0).astype(self.dsttype) + if self.fuse_brelu: + output = np.minimum( + np.maximum(output, 0), + self.fuse_brelu_threshold).astype(self.dsttype) output = output.astype(self.dtype) self.attrs['fuse_bias'] = self.fuse_bias self.attrs['fuse_relu'] = self.fuse_relu + self.attrs['fuse_brelu'] = self.fuse_brelu + self.attrs['fuse_brelu_threshold'] = self.fuse_brelu_threshold self.attrs['fuse_residual_connection'] = self.fuse_residual_connection self.outputs['Output'] = output +class TestWithbreluFusion(TestConv2dMKLDNNOp): + def init_test_case(self): + TestConv2dMKLDNNOp.init_test_case(self) + self.fuse_brelu = True + self.fuse_brelu_threshold = 6.0 + self.dsttype = np.float32 + + def test_check_grad(self): + pass + + def test_check_grad_no_filter(self): + pass + + def test_check_grad_no_input(self): + pass + + class TestWithFuse(TestConv2dMKLDNNOp): def init_test_case(self): TestConv2dMKLDNNOp.init_test_case(self)