From 78ab656c5eecda8cd892cbeb0ccd36341499c09e Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Fri, 9 Jul 2021 16:52:59 +0800 Subject: [PATCH] depthwise_conv_mkl_pass (#33936) --- .../ir/mkldnn/depthwise_conv_mkldnn_pass.cc | 45 +++++++++++++++++++ .../ir/mkldnn/depthwise_conv_mkldnn_pass.h | 1 + .../depthwise_conv_mkldnn_pass_tester.cc | 8 +++- 3 files changed, 53 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc index 39f47406a77..039094c2709 100644 --- a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.cc @@ -31,6 +31,47 @@ class Graph; PADDLE_ENFORCE_NOT_NULL( \ id, platform::errors::InvalidArgument("Subgraph has no node %s.", #id)); +DepthwiseConvMKLDNNPass::DepthwiseConvMKLDNNPass() { + AddOpCompat(OpCompat("depthwise_conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsOptional() + .IsTensor() + .End() + .AddInput("ResidualData") + .IsOptional() + .IsTensor() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + // mobilenet-ssd has no "padding_algorithm" + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NHWC", "NCHW", "AnyLayout"}) + .End(); +} + void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); @@ -45,6 +86,10 @@ void DepthwiseConvMKLDNNPass::ApplyImpl(ir::Graph* graph) const { int found_depthwise_conv_mkldnn_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "Pass op compat failed."; + return; + } VLOG(3) << "handle DepthwiseConvMKLDNN fuse"; GET_NODE(depthwise_conv, (*pattern)); depthwise_conv->Op()->SetType("conv2d"); diff --git a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h index 0f4ecc71ad7..06ce5a41b6c 100644 --- a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass.h @@ -24,6 +24,7 @@ class Graph; class DepthwiseConvMKLDNNPass : public FusePassBase { public: + DepthwiseConvMKLDNNPass(); virtual ~DepthwiseConvMKLDNNPass() {} protected: diff --git a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc index c6c72ba33d6..06940b38ea8 100644 --- a/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/depthwise_conv_mkldnn_pass_tester.cc @@ -29,10 +29,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetType(type); op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("name", name); + op->SetAttr("groups", 1); + op->SetAttr("padding_algorithm", std::string("EXPLICIT")); + op->SetAttr("data_format", std::string("NCHW")); + op->SetAttr("strides", std::vector({1, 1})); + op->SetAttr("dilations", std::vector({1, 1})); + op->SetAttr("paddings", std::vector({0, 0})); op->SetInput("Input", {inputs[0]}); op->SetInput("Filter", {inputs[1]}); op->SetInput("Bias", {inputs[2]}); - op->SetOutput("Out", outputs); + op->SetOutput("Output", outputs); } // (a, weights, bias)->depthwise conv mkldnn->b -- GitLab