From cc5d4b1a2864d1dd84bb29c2f171acea2055eeae Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Thu, 1 Jul 2021 10:31:19 +0800 Subject: [PATCH] Conv relu mkldnn fuse pass (#33664) --- .../conv_activation_mkldnn_fuse_pass.cc | 112 ++++++++++++++++++ .../mkldnn/conv_activation_mkldnn_fuse_pass.h | 5 + ...conv_activation_mkldnn_fuse_pass_tester.cc | 11 +- .../fluid/operators/compat/hard_swish.pbtxt | 12 ++ .../fluid/operators/compat/leaky_relu.pbtxt | 12 ++ paddle/fluid/operators/compat/relu.pbtxt | 4 + paddle/fluid/operators/compat/relu6.pbtxt | 14 ++- paddle/fluid/operators/compat/swish.pbtxt | 4 + 8 files changed, 172 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 7c749d9274..79a31e5cdc 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -49,6 +49,11 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse"; + + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass op compat failed."; + return; + } GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, conv_activation_pattern); // Filter GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, @@ -97,6 +102,113 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_conv_activation_count); } +ConvActivationFusePass::ConvActivationFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsOptional() + .IsTensor() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + // IsStringIn({"EXPLICIT", "SAME", "VALID"}), MobileNetV2 has no this + // attribute + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + // IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute + .AddAttr("data_format") + .IsOptional() + .IsStringIn({"NHWC", "NCHW", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} +Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() { + AddOpCompat(OpCompat("leaky_relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + // float, default=0.02 + .AddAttr("alpha") + .IsType() + .End(); +} +Conv2DReLU6FusePass::Conv2DReLU6FusePass() { + AddOpCompat(OpCompat("relu6")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + // default = 6.0f + .AddAttr("threshold") + .IsType() + .End(); +} +Conv2DSwishFusePass::Conv2DSwishFusePass() { + AddOpCompat(OpCompat("swish")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} +Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() { + AddOpCompat(OpCompat("hard_swish")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + // float, optional, default=6.0 + .AddAttr("threshold") + .IsOptional() + .IsType() + .End() + // float, optional, default=6.0 + .AddAttr("scale") + .IsOptional() + .IsType() + .End() + // float, optional, default=3.0 + .AddAttr("offset") + .IsOptional() + .IsType() + .End(); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h index 2df27c420f..d22773fb41 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h @@ -31,6 +31,7 @@ class Graph; class ConvActivationFusePass : public FusePassBase { public: + ConvActivationFusePass(); virtual ~ConvActivationFusePass() {} virtual std::string conv_type() const { return "conv2d"; } virtual std::string activation_type() const { return "relu"; } @@ -44,6 +45,7 @@ class ConvActivationFusePass : public FusePassBase { */ class Conv2DLeakyReLUFusePass : public ConvActivationFusePass { public: + Conv2DLeakyReLUFusePass(); std::string activation_type() const { return "leaky_relu"; } }; /* @@ -51,6 +53,7 @@ class Conv2DLeakyReLUFusePass : public ConvActivationFusePass { */ class Conv2DReLU6FusePass : public ConvActivationFusePass { public: + Conv2DReLU6FusePass(); std::string activation_type() const { return "relu6"; } }; /* @@ -58,6 +61,7 @@ class Conv2DReLU6FusePass : public ConvActivationFusePass { */ class Conv2DSwishFusePass : public ConvActivationFusePass { public: + Conv2DSwishFusePass(); std::string activation_type() const { return "swish"; } }; /* @@ -65,6 +69,7 @@ class Conv2DSwishFusePass : public ConvActivationFusePass { */ class Conv2DHardSwishFusePass : public ConvActivationFusePass { public: + Conv2DHardSwishFusePass(); std::string activation_type() const { return "hard_swish"; } }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc index 55bbad7a88..453197cda3 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" #include +#include #include "paddle/fluid/framework/op_proto_maker.h" namespace paddle { @@ -30,9 +31,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetAttr("name", name); if (type == "conv2d") { op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("groups", 1); + op->SetAttr("padding_algorithm", std::string("EXPLICIT")); + op->SetAttr("data_format", std::string("NCHW")); + op->SetAttr("strides", std::vector({1, 1})); + op->SetAttr("dilations", std::vector({1, 1})); + op->SetAttr("paddings", std::vector({0, 0})); op->SetInput("Input", {inputs[0]}); op->SetInput("Filter", {inputs[1]}); op->SetInput("Bias", {inputs[2]}); + op->SetOutput("Output", outputs); } else if (is_activation) { op->SetAttr("use_mkldnn", use_mkldnn); op->SetInput("X", inputs); @@ -43,8 +51,9 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, } else if (type == "swish") { op->SetAttr("beta", 1.0f); } + op->SetOutput("Out", outputs); } - op->SetOutput("Out", outputs); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), static_cast(OpRole::kForward)); } diff --git a/paddle/fluid/operators/compat/hard_swish.pbtxt b/paddle/fluid/operators/compat/hard_swish.pbtxt index ccf387652e..9951513741 100644 --- a/paddle/fluid/operators/compat/hard_swish.pbtxt +++ b/paddle/fluid/operators/compat/hard_swish.pbtxt @@ -24,6 +24,18 @@ extra { name: "op_role" type: INT } + attrs { + name: "use_mkldnn" + type: BOOLEAN + } + attrs { + name: "name" + type: STRING + } + attrs { + name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" + type: BOOLEAN + } attrs { name: "op_role_var" type: STRINGS diff --git a/paddle/fluid/operators/compat/leaky_relu.pbtxt b/paddle/fluid/operators/compat/leaky_relu.pbtxt index 9df2e59161..8618b72ca8 100644 --- a/paddle/fluid/operators/compat/leaky_relu.pbtxt +++ b/paddle/fluid/operators/compat/leaky_relu.pbtxt @@ -16,6 +16,18 @@ extra { name: "use_mkldnn" type: BOOLEAN } + attrs { + name: "name" + type: STRING + } + attrs { + name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" + type: BOOLEAN + } + attrs { + name: "is_test" + type: BOOLEAN + } attrs { name: "op_role" type: INT diff --git a/paddle/fluid/operators/compat/relu.pbtxt b/paddle/fluid/operators/compat/relu.pbtxt index 271ed91718..9a184bf03d 100644 --- a/paddle/fluid/operators/compat/relu.pbtxt +++ b/paddle/fluid/operators/compat/relu.pbtxt @@ -52,4 +52,8 @@ extra { name: "is_test" type: BOOLEAN } + attrs { + name: "name" + type: STRINGS + } } diff --git a/paddle/fluid/operators/compat/relu6.pbtxt b/paddle/fluid/operators/compat/relu6.pbtxt index edd2903732..340b130201 100644 --- a/paddle/fluid/operators/compat/relu6.pbtxt +++ b/paddle/fluid/operators/compat/relu6.pbtxt @@ -6,16 +6,28 @@ def { outputs { name: "Out" } + attrs { + name: "threshold" + type: FLOAT + } } extra { attrs { - name: "threshold" + name: "name" + type: STRING + } + attrs { + name: "is_test" type: FLOAT } attrs { name: "use_mkldnn" type: BOOLEAN } + attrs { + name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" + type: BOOLEAN + } attrs { name: "op_role" type: INT diff --git a/paddle/fluid/operators/compat/swish.pbtxt b/paddle/fluid/operators/compat/swish.pbtxt index 4f5ec127e4..1dd8e577d9 100644 --- a/paddle/fluid/operators/compat/swish.pbtxt +++ b/paddle/fluid/operators/compat/swish.pbtxt @@ -12,6 +12,10 @@ extra { name: "beta" type: FLOAT } + attrs { + name: "name" + type: STRING + } attrs { name: "use_mkldnn" type: BOOLEAN -- GitLab