未验证 提交 93b53f86 编写于 作者: W Wangzheee 提交者: GitHub

[pass_enhance] depthwise_conv_bn_fuse_pas (#33896)

上级 9314743d
......@@ -149,17 +149,21 @@ ConvBNFusePass::ConvBNFusePass() {
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
......@@ -169,6 +173,7 @@ ConvBNFusePass::ConvBNFusePass() {
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
......@@ -205,6 +210,10 @@ ConvBNFusePass::ConvBNFusePass() {
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("ReserveSpace")
.IsTensor()
.IsOptional()
.End()
.AddAttr("epsilon")
.IsNumLE(0.001f)
.IsNumGE(0.0f)
......@@ -375,17 +384,21 @@ ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
......@@ -395,6 +408,7 @@ ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
......@@ -431,6 +445,10 @@ ConvEltwiseAddBNFusePass::ConvEltwiseAddBNFusePass() {
.AddOutput("Y")
.IsTensor()
.End()
.AddOutput("ReserveSpace")
.IsTensor()
.IsOptional()
.End()
.AddAttr("epsilon")
.IsNumLE(0.001f)
.IsNumGE(0.0f)
......@@ -575,31 +593,85 @@ ConvTransposeBNFusePass::ConvTransposeBNFusePass() {
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
AddOpCompat(OpCompat("conv2d_transpose"))
DepthwiseConvBNFusePass::DepthwiseConvBNFusePass() {
AddOpCompat(OpCompat("depthwise_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
......@@ -607,23 +679,31 @@ ConvTransposeEltwiseAddBNFusePass::ConvTransposeEltwiseAddBNFusePass() {
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
......
......@@ -17,8 +17,6 @@
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
......@@ -27,12 +25,10 @@ namespace ir {
/*
* Fuse the Conv and BatchNorm to a ConvBNMKLDNNOp.
*/
class Graph;
class ConvBNFusePass : public FusePassBase {
public:
ConvBNFusePass();
virtual ~ConvBNFusePass() {}
virtual std::string conv_type() const { return "conv2d"; }
protected:
......@@ -43,7 +39,6 @@ class ConvBNFusePass : public FusePassBase {
class ConvEltwiseAddBNFusePass : public FusePassBase {
public:
ConvEltwiseAddBNFusePass();
virtual ~ConvEltwiseAddBNFusePass() {}
virtual std::string conv_type() const { return "conv2d"; }
protected:
......@@ -54,19 +49,18 @@ class ConvEltwiseAddBNFusePass : public FusePassBase {
class ConvTransposeBNFusePass : public ConvBNFusePass {
public:
ConvTransposeBNFusePass();
virtual ~ConvTransposeBNFusePass() {}
std::string conv_type() const { return "conv2d_transpose"; }
};
class ConvTransposeEltwiseAddBNFusePass : public ConvEltwiseAddBNFusePass {
public:
ConvTransposeEltwiseAddBNFusePass();
virtual ~ConvTransposeEltwiseAddBNFusePass() {}
std::string conv_type() const { return "conv2d_transpose"; }
};
class DepthwiseConvBNFusePass : public ConvBNFusePass {
public:
DepthwiseConvBNFusePass();
std::string conv_type() const { return "depthwise_conv2d"; }
};
......
......@@ -196,7 +196,7 @@ FCElementwiseLayerNormFusePass::FCElementwiseLayerNormFusePass() {
.IsTensor()
.End()
.AddAttr("axis")
.IsNumEQ(-1)
.IsIntIn({-1, 0})
.End();
}
......
......@@ -84,15 +84,18 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsNumGE(1)
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
......@@ -110,7 +113,7 @@ Conv2DTransposeBiasFusePass::Conv2DTransposeBiasFusePass() {
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC"})
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
......
......@@ -200,14 +200,12 @@ QuantDequantFusePass::QuantDequantFusePass() {
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("paddings")
.AddAttr("output_size")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.IsOptional()
.End()
.AddAttr("groups")
......@@ -216,6 +214,15 @@ QuantDequantFusePass::QuantDequantFusePass() {
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
......
type: "depthwise_conv2d"
def {
inputs {
name: "Input"
}
inputs {
name: "Filter"
}
inputs {
name: "Bias"
}
inputs {
name: "ResidualData"
}
outputs {
name: "Output"
}
attrs {
name: "strides"
type: INTS
}
attrs {
name: "paddings"
type: INTS
}
attrs {
name: "padding_algorithm"
type: STRING
}
attrs {
name: "groups"
type: INT
}
attrs {
name: "dilations"
type: INTS
}
attrs {
name: "data_format"
type: STRING
}
}
extra {
attrs {
name: "Input_scale"
type: FLOAT
}
attrs {
name: "quantization_type"
type: STRING
}
attrs {
name: "bit_length"
type: INT
}
attrs {
name: "out_threshold"
type: FLOAT
}
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "skip_quant"
type: BOOLEAN
}
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "name"
type: STRING
}
attrs {
name: "use_cudnn"
type: BOOLEAN
}
attrs {
name: "fuse_relu_before_depthwise_conv"
type: BOOLEAN
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "use_quantizer"
type: BOOLEAN
}
attrs {
name: "mkldnn_data_type"
type: STRING
}
attrs {
name: "fuse_relu"
type: BOOLEAN
}
attrs {
name: "fuse_brelu"
type: BOOLEAN
}
attrs {
name: "fuse_brelu_threshold"
type: FLOAT
}
attrs {
name: "fuse_activation"
type: STRING
}
attrs {
name: "fuse_alpha"
type: FLOAT
}
attrs {
name: "fuse_beta"
type: FLOAT
}
attrs {
name: "use_addto"
type: BOOLEAN
}
attrs {
name: "fuse_residual_connection"
type: BOOLEAN
}
attrs {
name: "Scale_in"
type: FLOAT
}
attrs {
name: "Scale_out"
type: FLOAT
}
attrs {
name: "Scale_in_eltwise"
type: FLOAT
}
attrs {
name: "Scale_weights"
type: FLOATS
}
attrs {
name: "force_fp32_output"
type: BOOLEAN
}
attrs {
name: "workspace_size_MB"
type: INT
}
attrs {
name: "exhaustive_search"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
}
attrs {
name: "op_role_var"
type: STRINGS
}
attrs {
name: "op_namescope"
type: STRING
}
attrs {
name: "op_callstack"
type: STRINGS
}
attrs {
name: "op_device"
type: STRING
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册