未验证 提交 93b53f86 编写于 作者: W Wangzheee 提交者: GitHub

[pass_enhance] depthwise_conv_bn_fuse_pas (#33896)

上级 9314743d
...@@ -149,17 +149,21 @@ ConvBNFusePass::ConvBNFusePass() { ...@@ -149,17 +149,21 @@ ConvBNFusePass::ConvBNFusePass() {
.IsTensor() .IsTensor()
.End() .End()
.AddInput("Bias") .AddInput("Bias")
.IsTensor()
.IsOptional() .IsOptional()
.End() .End()
.AddInput("ResidualData") .AddInput("ResidualData")
.IsTensor()
.IsOptional() .IsOptional()
.End() .End()
.AddOutput("Output") .AddOutput("Output")
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("strides") .AddAttr("strides")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("paddings") .AddAttr("paddings")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsOptional() .IsOptional()
...@@ -169,6 +173,7 @@ ConvBNFusePass::ConvBNFusePass() { ...@@ -169,6 +173,7 @@ ConvBNFusePass::ConvBNFusePass() {
.IsNumGE(1) .IsNumGE(1)
.End() .End()
.AddAttr("dilations") .AddAttr("dilations")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
...@@ -205,6 +210,10 @@ ConvBNFusePass::ConvBNFusePass() { ...@@ -205,6 +210,10 @@ ConvBNFusePass::ConvBNFusePass() {
.AddOutput("Y") .AddOutput("Y")
.IsTensor() .IsTensor()
.End() .End()
.AddOutput("ReserveSpace")
.IsTensor()
.IsOptional()
.End()
.AddAttr("epsilon") .AddAttr("epsilon")
.IsNumLE(0.001f) .IsNumLE(0.001f)
.IsNumGE(0.0f) .IsNumGE(0.0f)
...@@ -375,17 +384,21 @@ ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() { ...@@ -375,17 +384,21 @@ ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
.IsTensor() .IsTensor()
.End() .End()
.AddInput("Bias") .AddInput("Bias")
.IsTensor()
.IsOptional() .IsOptional()
.End() .End()
.AddInput("ResidualData") .AddInput("ResidualData")
.IsTensor()
.IsOptional() .IsOptional()
.End() .End()
.AddOutput("Output") .AddOutput("Output")
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("strides") .AddAttr("strides")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("paddings") .AddAttr("paddings")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
...@@ -395,6 +408,7 @@ ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() { ...@@ -395,6 +408,7 @@ ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
.IsNumGE(1) .IsNumGE(1)
.End() .End()
.AddAttr("dilations") .AddAttr("dilations")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
...@@ -431,6 +445,10 @@ ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() { ...@@ -431,6 +445,10 @@ ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
.AddOutput("Y") .AddOutput("Y")
.IsTensor() .IsTensor()
.End() .End()
.AddOutput("ReserveSpace")
.IsTensor()
.IsOptional()
.End()
.AddAttr("epsilon") .AddAttr("epsilon")
.IsNumLE(0.001f) .IsNumLE(0.001f)
.IsNumGE(0.0f) .IsNumGE(0.0f)
...@@ -575,31 +593,85 @@ ConvTransposeBNFusePass::ConvTransposeBNFusePass() { ...@@ -575,31 +593,85 @@ ConvTransposeBNFusePass::ConvTransposeBNFusePass() {
.IsTensor() .IsTensor()
.End() .End()
.AddInput("Bias") .AddInput("Bias")
.IsTensor()
.IsOptional() .IsOptional()
.End() .End()
.AddOutput("Output") .AddOutput("Output")
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides") .AddAttr("strides")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("paddings") .AddAttr("paddings")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional() .IsOptional()
.End() .End()
.AddAttr("groups") .AddAttr("groups")
.IsNumGE(1) .IsNumGE(1)
.End() .End()
.AddAttr("dilations") .AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
} }
ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() { DepthwiseConvBNFusePass::DepthwiseConvBNFusePass() {
AddOpCompat(OpCompat("conv2d_transpose")) AddOpCompat(OpCompat("depthwise_conv2d"))
.AddInput("Input") .AddInput("Input")
.IsTensor() .IsTensor()
.End() .End()
...@@ -607,23 +679,31 @@ ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() { ...@@ -607,23 +679,31 @@ ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
.IsTensor() .IsTensor()
.End() .End()
.AddInput("Bias") .AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional() .IsOptional()
.End() .End()
.AddOutput("Output") .AddOutput("Output")
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("strides") .AddAttr("strides")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("paddings") .AddAttr("paddings")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("padding_algorithm") .AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional() .IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("groups") .AddAttr("groups")
.IsNumGE(1) .IsNumGE(1)
.End() .End()
.AddAttr("dilations") .AddAttr("dilations")
.IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
#include <string> #include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -27,12 +25,10 @@ namespace ir { ...@@ -27,12 +25,10 @@ namespace ir {
/* /*
* Fuse the Conv and BatchNorm to a ConvBNMKLDNNOp. * Fuse the Conv and BatchNorm to a ConvBNMKLDNNOp.
*/ */
class Graph;
class ConvBNFusePass : public FusePassBase { class ConvBNFusePass : public FusePassBase {
public: public:
ConvBNFusePass(); ConvBNFusePass();
virtual ~ConvBNFusePass() {}
virtual std::string conv_type() const { return "conv2d"; } virtual std::string conv_type() const { return "conv2d"; }
protected: protected:
...@@ -43,7 +39,6 @@ class ConvBNFusePass : public FusePassBase { ...@@ -43,7 +39,6 @@ class ConvBNFusePass : public FusePassBase {
class ConvEltwiseAddBNFusePass : public FusePassBase { class ConvEltwiseAddBNFusePass : public FusePassBase {
public: public:
ConvEltwiseAddBNFusePass(); ConvEltwiseAddBNFusePass();
virtual ~ConvEltwiseAddBNFusePass() {}
virtual std::string conv_type() const { return "conv2d"; } virtual std::string conv_type() const { return "conv2d"; }
protected: protected:
...@@ -54,19 +49,18 @@ class ConvEltwiseAddBNFusePass : public FusePassBase { ...@@ -54,19 +49,18 @@ class ConvEltwiseAddBNFusePass : public FusePassBase {
class ConvTransposeBNFusePass : public ConvBNFusePass { class ConvTransposeBNFusePass : public ConvBNFusePass {
public: public:
ConvTransposeBNFusePass(); ConvTransposeBNFusePass();
virtual ~ConvTransposeBNFusePass() {}
std::string conv_type() const { return "conv2d_transpose"; } std::string conv_type() const { return "conv2d_transpose"; }
}; };
class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass { class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
public: public:
ConvTransposeEltwiseAddBNFusePass(); ConvTransposeEltwiseAddBNFusePass();
virtual ~ConvTransposeEltwiseAddBNFusePass() {}
std::string conv_type() const { return "conv2d_transpose"; } std::string conv_type() const { return "conv2d_transpose"; }
}; };
class DepthwiseConvBNFusePass : public ConvBNFusePass { class DepthwiseConvBNFusePass : public ConvBNFusePass {
public: public:
DepthwiseConvBNFusePass();
std::string conv_type() const { return "depthwise_conv2d"; } std::string conv_type() const { return "depthwise_conv2d"; }
}; };
......
...@@ -196,7 +196,7 @@ FCElementwiseLayerNormFusePass::FCElementwiseLayerNormFusePass() { ...@@ -196,7 +196,7 @@ FCElementwiseLayerNormFusePass::FCElementwiseLayerNormFusePass() {
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("axis") .AddAttr("axis")
.IsNumEQ(-1) .IsIntIn({-1, 0})
.End(); .End();
} }
......
...@@ -84,15 +84,18 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { ...@@ -84,15 +84,18 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.End() .End()
.AddInput("Bias") .AddInput("Bias")
.IsTensor() .IsTensor()
.IsOptional()
.End() .End()
.AddOutput("Output") .AddOutput("Output")
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("output_padding") .AddAttr("output_padding")
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.IsOptional()
.End() .End()
.AddAttr("output_size") .AddAttr("output_size")
.IsNumGE(1) .IsType<std::vector<int>>()
.IsOptional()
.End() .End()
.AddAttr("groups") .AddAttr("groups")
.IsNumGE(1) .IsNumGE(1)
...@@ -110,7 +113,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() { ...@@ -110,7 +113,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsStringIn({"EXPLICIT", "SAME", "VALID"}) .IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
} }
......
...@@ -200,14 +200,12 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -200,14 +200,12 @@ QuantDequantFusePass::QuantDequantFusePass() {
.AddOutput("Output") .AddOutput("Output")
.IsTensor() .IsTensor()
.End() .End()
.AddAttr("strides") .AddAttr("output_padding")
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.IsOptional()
.End() .End()
.AddAttr("paddings") .AddAttr("output_size")
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional() .IsOptional()
.End() .End()
.AddAttr("groups") .AddAttr("groups")
...@@ -216,6 +214,15 @@ QuantDequantFusePass::QuantDequantFusePass() { ...@@ -216,6 +214,15 @@ QuantDequantFusePass::QuantDequantFusePass() {
.AddAttr("dilations") .AddAttr("dilations")
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"}) .IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End(); .End();
......
type: "depthwise_conv2d"
def {
inputs {
name: "Input"
}
inputs {
name: "Filter"
}
inputs {
name: "Bias"
}
inputs {
name: "ResidualData"
}
outputs {
name: "Output"
}
attrs {
name: "strides"
type: INTS
}
attrs {
name: "paddings"
type: INTS
}
attrs {
name: "padding_algorithm"
type: STRING
}
attrs {
name: "groups"
type: INT
}
attrs {
name: "dilations"
type: INTS
}
attrs {
name: "data_format"
type: STRING
}
}
extra {
attrs {
name: "Input_scale"
type: FLOAT
}
attrs {
name: "quantization_type"
type: STRING
}
attrs {
name: "bit_length"
type: INT
}
attrs {
name: "out_threshold"
type: FLOAT
}
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "skip_quant"
type: BOOLEAN
}
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "name"
type: STRING
}
attrs {
name: "use_cudnn"
type: BOOLEAN
}
attrs {
name: "fuse_relu_before_depthwise_conv"
type: BOOLEAN
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "use_quantizer"
type: BOOLEAN
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "fuse_relu"
type: BOOLEAN
}
attrs {
name: "fuse_brelu"
type: BOOLEAN
}
attrs {
name: "fuse_brelu_threshold"
type: FLOAT
}
attrs {
name: "fuse_activation"
type: STRING
}
attrs {
name: "fuse_alpha"
type: FLOAT
}
attrs {
name: "fuse_beta"
type: FLOAT
}
attrs {
name: "use_addto"
type: BOOLEAN
}
attrs {
name: "fuse_residual_connection"
type: BOOLEAN
}
attrs {
name: "Scale_in"
type: FLOAT
}
attrs {
name: "Scale_out"
type: FLOAT
}
attrs {
name: "Scale_in_eltwise"
type: FLOAT
}
attrs {
name: "Scale_weights"
type: FLOATS
}
attrs {
name: "force_fp32_output"
type: BOOLEAN
}
attrs {
name: "workspace_size_MB"
type: INT
}
attrs {
name: "exhaustive_search"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册