diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 7c749d9274299a2af3d7cbab98be5b362cabbc6e..79a31e5cdc7b33d7d562708ff381948919758910 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -49,6 +49,11 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse"; + + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass op compat failed."; + return; + } GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, conv_activation_pattern); // Filter GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, @@ -97,6 +102,113 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_conv_activation_count); } +ConvActivationFusePass::ConvActivationFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsOptional() + .IsTensor() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + // IsStringIn({"EXPLICIT", "SAME", "VALID"}), MobileNetV2 has no this + // attribute + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + // IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute + .AddAttr("data_format") + .IsOptional() + .IsStringIn({"NHWC", "NCHW", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} +Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() { + AddOpCompat(OpCompat("leaky_relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + // float, default=0.02 + .AddAttr("alpha") + .IsType() + .End(); +} +Conv2DReLU6FusePass::Conv2DReLU6FusePass() { + AddOpCompat(OpCompat("relu6")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + // default = 6.0f + .AddAttr("threshold") + .IsType() + .End(); +} +Conv2DSwishFusePass::Conv2DSwishFusePass() { + AddOpCompat(OpCompat("swish")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} +Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() { + AddOpCompat(OpCompat("hard_swish")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + // float, optional, default=6.0 + .AddAttr("threshold") + .IsOptional() + .IsType() + .End() + // float, optional, default=6.0 + .AddAttr("scale") + .IsOptional() + .IsType() + .End() + // float, optional, default=3.0 + .AddAttr("offset") + .IsOptional() + .IsType() + .End(); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h index 2df27c420f6ecab56d5067ad0ef4a7f042f68a09..d22773fb41904afa17832224169f5430b94055c6 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h @@ -31,6 +31,7 @@ class Graph; class ConvActivationFusePass : public FusePassBase { public: + ConvActivationFusePass(); virtual ~ConvActivationFusePass() {} virtual std::string conv_type() const { return "conv2d"; } virtual std::string activation_type() const { return "relu"; } @@ -44,6 +45,7 @@ class ConvActivationFusePass : public FusePassBase { */ class Conv2DLeakyReLUFusePass : public ConvActivationFusePass { public: + Conv2DLeakyReLUFusePass(); std::string activation_type() const { return "leaky_relu"; } }; /* @@ -51,6 +53,7 @@ class Conv2DLeakyReLUFusePass : public ConvActivationFusePass { */ class Conv2DReLU6FusePass : public ConvActivationFusePass { public: + Conv2DReLU6FusePass(); std::string activation_type() const { return "relu6"; } }; /* @@ -58,6 +61,7 @@ class Conv2DReLU6FusePass : public ConvActivationFusePass { */ class Conv2DSwishFusePass : public ConvActivationFusePass { public: + Conv2DSwishFusePass(); std::string activation_type() const { return "swish"; } }; /* @@ -65,6 +69,7 @@ class Conv2DSwishFusePass : public ConvActivationFusePass { */ class Conv2DHardSwishFusePass : public ConvActivationFusePass { public: + Conv2DHardSwishFusePass(); std::string activation_type() const { return "hard_swish"; } }; } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc index 55bbad7a8875afc955af03ccecc796efa885e438..453197cda391542f41adcbeab55147b401d242f3 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" #include +#include #include "paddle/fluid/framework/op_proto_maker.h" namespace paddle { @@ -30,9 +31,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetAttr("name", name); if (type == "conv2d") { op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("groups", 1); + op->SetAttr("padding_algorithm", std::string("EXPLICIT")); + op->SetAttr("data_format", std::string("NCHW")); + op->SetAttr("strides", std::vector({1, 1})); + op->SetAttr("dilations", std::vector({1, 1})); + op->SetAttr("paddings", std::vector({0, 0})); op->SetInput("Input", {inputs[0]}); op->SetInput("Filter", {inputs[1]}); op->SetInput("Bias", {inputs[2]}); + op->SetOutput("Output", outputs); } else if (is_activation) { op->SetAttr("use_mkldnn", use_mkldnn); op->SetInput("X", inputs); @@ -43,8 +51,9 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, } else if (type == "swish") { op->SetAttr("beta", 1.0f); } + op->SetOutput("Out", outputs); } - op->SetOutput("Out", outputs); + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), static_cast(OpRole::kForward)); } diff --git a/paddle/fluid/operators/compat/hard_swish.pbtxt b/paddle/fluid/operators/compat/hard_swish.pbtxt index ccf387652ed32569aa35fe6bf7a5d155c2364b98..9951513741a61a8245296fe378b02aced3c17793 100644 --- a/paddle/fluid/operators/compat/hard_swish.pbtxt +++ b/paddle/fluid/operators/compat/hard_swish.pbtxt @@ -24,6 +24,18 @@ extra { name: "op_role" type: INT } + attrs { + name: "use_mkldnn" + type: BOOLEAN + } + attrs { + name: "name" + type: STRING + } + attrs { + name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" + type: BOOLEAN + } attrs { name: "op_role_var" type: STRINGS diff --git a/paddle/fluid/operators/compat/leaky_relu.pbtxt b/paddle/fluid/operators/compat/leaky_relu.pbtxt index 9df2e5916118c534530c0c7d0a12b3dabe0a1cb9..8618b72ca87485480b0f46d3091b32d6bb39611b 100644 --- a/paddle/fluid/operators/compat/leaky_relu.pbtxt +++ b/paddle/fluid/operators/compat/leaky_relu.pbtxt @@ -16,6 +16,18 @@ extra { name: "use_mkldnn" type: BOOLEAN } + attrs { + name: "name" + type: STRING + } + attrs { + name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" + type: BOOLEAN + } + attrs { + name: "is_test" + type: BOOLEAN + } attrs { name: "op_role" type: INT diff --git a/paddle/fluid/operators/compat/relu.pbtxt b/paddle/fluid/operators/compat/relu.pbtxt index 271ed91718cee4ce98c8912023c6281350e778a8..9a184bf03d0a6d68c40d3fa3da3107c0ac03f262 100644 --- a/paddle/fluid/operators/compat/relu.pbtxt +++ b/paddle/fluid/operators/compat/relu.pbtxt @@ -52,4 +52,8 @@ extra { name: "is_test" type: BOOLEAN } + attrs { + name: "name" + type: STRINGS + } } diff --git a/paddle/fluid/operators/compat/relu6.pbtxt b/paddle/fluid/operators/compat/relu6.pbtxt index edd29037324430702ba70e9632d72f01b339b390..340b13020144a83edc4b26fdee8ec33e2c8cbb15 100644 --- a/paddle/fluid/operators/compat/relu6.pbtxt +++ b/paddle/fluid/operators/compat/relu6.pbtxt @@ -6,16 +6,28 @@ def { outputs { name: "Out" } + attrs { + name: "threshold" + type: FLOAT + } } extra { attrs { - name: "threshold" + name: "name" + type: STRING + } + attrs { + name: "is_test" type: FLOAT } attrs { name: "use_mkldnn" type: BOOLEAN } + attrs { + name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" + type: BOOLEAN + } attrs { name: "op_role" type: INT diff --git a/paddle/fluid/operators/compat/swish.pbtxt b/paddle/fluid/operators/compat/swish.pbtxt index 4f5ec127e489794742f88d5589847e598956b981..1dd8e577d9c738f20f7f6fc038019b1cfca133af 100644 --- a/paddle/fluid/operators/compat/swish.pbtxt +++ b/paddle/fluid/operators/compat/swish.pbtxt @@ -12,6 +12,10 @@ extra { name: "beta" type: FLOAT } + attrs { + name: "name" + type: STRING + } attrs { name: "use_mkldnn" type: BOOLEAN