From cf3ddd3b29314cb50953b231a6de7eb82563bcab Mon Sep 17 00:00:00 2001 From: TeslaZhao Date: Tue, 22 Jun 2021 10:59:08 +0800 Subject: [PATCH] Pass compat of conv_transpose_bias_mkldnn_fuse_pass (#33708) --- .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc | 102 ++++++++++++++++++ .../ir/mkldnn/conv_bias_mkldnn_fuse_pass.h | 2 + .../conv_bias_mkldnn_fuse_pass_tester.cc | 14 ++- 3 files changed, 117 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc index c804eeb9fc..8d73a35bf0 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.cc @@ -25,6 +25,102 @@ namespace paddle { namespace framework { namespace ir { +ConvBiasFusePass::ConvBiasFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(-1) + .End(); +} + +Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { + AddOpCompat(OpCompat("conv2d_transpose")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("output_padding") + .End() + .AddAttr("output_size") + .IsNumGE(1) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .End() + .AddAttr("strides") + .End() + .AddAttr("paddings") + .End() + .AddAttr("padding_algorithm") + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(-1) + .End(); +} + template LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, BinaryOperation f) { @@ -80,6 +176,12 @@ void ConvBiasFusePass::ApplyImpl(ir::Graph* graph) const { subgraph.count(conv_input), 0, platform::errors::NotFound("Detector did not find conv input.")); + // check compat + if (!IsCompat(subgraph, g)) { + VLOG(3) << "Pass in op compat failed."; + return; + } + // check if fuse can be done and if MKL-DNN should be used FuseOptions fuse_option = FindFuseOption(*conv, *eltwise); if (fuse_option == DO_NOT_FUSE || fuse_option == FUSE_NATIVE) { diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h index 9a83310ebf..20c683c094 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass.h @@ -29,6 +29,7 @@ class Graph; class ConvBiasFusePass : public FusePassBase { public: + ConvBiasFusePass(); virtual ~ConvBiasFusePass() {} virtual std::string type() const { return "conv2d"; } @@ -41,6 +42,7 @@ class ConvBiasFusePass : public FusePassBase { */ class Conv2DTransposeBiasFusePass : public ConvBiasFusePass { public: + Conv2DTransposeBiasFusePass(); std::string type() const override { return "conv2d_transpose"; } }; diff --git a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc index 455350d2f7..80a9ef7eda 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_bias_mkldnn_fuse_pass_tester.cc @@ -31,8 +31,19 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, auto* op = prog->MutableBlock(0)->AppendOp(); op->SetType(type); if (type == "conv2d") { + const std::vector strides({1, 1}); + const std::vector paddings({0, 0}); + const std::vector dilations({1, 1}); op->SetAttr("use_mkldnn", true); op->SetAttr("name", name); + op->SetAttr("strides", strides); + op->SetAttr("groups", 1); + op->SetAttr("paddings", paddings); + op->SetAttr("padding_algorithm", std::string("EXPLICIT")); + op->SetAttr("dilations", dilations); + op->SetAttr("data_format", std::string("NCHW")); + + op->SetOutput("Output", outputs); op->SetInput("Input", {inputs[0]}); op->SetInput("Filter", {inputs[1]}); if (inputs.size() > 2) @@ -41,10 +52,11 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, op->SetInput("Bias", {}); } else if (type == "elementwise_add") { op->SetAttr("use_mkldnn", true); + op->SetAttr("axis", -1); op->SetInput("X", {inputs[0]}); op->SetInput("Y", {inputs[1]}); + op->SetOutput("Out", outputs); } - op->SetOutput("Out", outputs); op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), static_cast(OpRole::kForward)); } -- GitLab