From 33edb62a7752fb1a134bcb7d09ae931bce28937e Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 1 Jul 2021 20:36:38 +0800 Subject: [PATCH] pass_enhance_conv_concat_relu_mkldnn (#33867) --- .../conv_concat_relu_mkldnn_fuse_pass.cc | 62 ++++++++++++++++++- .../conv_concat_relu_mkldnn_fuse_pass.h | 5 +- 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc index c4d7a12037..5fbfef08b7 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.cc @@ -23,7 +23,67 @@ namespace paddle { namespace framework { namespace ir { -class Graph; +ConvConcatReLUFusePass::ConvConcatReLUFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("concat")) + .AddInput("X") // Input("X"): vector + .End() + .AddInput("AxisTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumGE(0) + .End(); + + AddOpCompat(OpCompat("relu")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End(); +} void ConvConcatReLUFusePass::FindConcatWithConvs( ir::Graph* graph, diff --git a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h index f1faa84f3d..af372dbf97 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_concat_relu_mkldnn_fuse_pass.h @@ -18,9 +18,6 @@ #include #include "paddle/fluid/framework/ir/fuse_pass_base.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/ir/graph_pattern_detector.h" -#include "paddle/fluid/framework/ir/pass.h" namespace paddle { namespace framework { @@ -31,10 +28,10 @@ namespace ir { * to a: * (multi ConvReLU) -> Concat -> next_op. */ -class Graph; class ConvConcatReLUFusePass : public FusePassBase { public: + ConvConcatReLUFusePass(); virtual ~ConvConcatReLUFusePass() {} protected: -- GitLab