未验证 提交 cc5d4b1a 编写于 作者: F feng_shuai 提交者: GitHub

Conv relu mkldnn fuse pass (#33664)

上级 79e75bc5
......@@ -49,6 +49,11 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse";
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_activation_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out,
......@@ -97,6 +102,113 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_conv_activation_count);
}
ConvActivationFusePass::ConvActivationFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.IsTensor()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
// IsStringIn({"EXPLICIT", "SAME", "VALID"}), MobileNetV2 has no this
// attribute
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
// IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute
.AddAttr("data_format")
.IsOptional()
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() {
AddOpCompat(OpCompat("leaky_relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
// float, default=0.02
.AddAttr("alpha")
.IsType<float>()
.End();
}
Conv2DReLU6FusePass::Conv2DReLU6FusePass() {
AddOpCompat(OpCompat("relu6"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
// default = 6.0f
.AddAttr("threshold")
.IsType<float>()
.End();
}
Conv2DSwishFusePass::Conv2DSwishFusePass() {
AddOpCompat(OpCompat("swish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
AddOpCompat(OpCompat("hard_swish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
// float, optional, default=6.0
.AddAttr("threshold")
.IsOptional()
.IsType<float>()
.End()
// float, optional, default=6.0
.AddAttr("scale")
.IsOptional()
.IsType<float>()
.End()
// float, optional, default=3.0
.AddAttr("offset")
.IsOptional()
.IsType<float>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -31,6 +31,7 @@ class Graph;
class ConvActivationFusePass : public FusePassBase {
public:
ConvActivationFusePass();
virtual ~ConvActivationFusePass() {}
virtual std::string conv_type() const { return "conv2d"; }
virtual std::string activation_type() const { return "relu"; }
......@@ -44,6 +45,7 @@ class ConvActivationFusePass : public FusePassBase {
*/
class Conv2DLeakyReLUFusePass : public ConvActivationFusePass {
public:
Conv2DLeakyReLUFusePass();
std::string activation_type() const { return "leaky_relu"; }
};
/*
......@@ -51,6 +53,7 @@ class Conv2DLeakyReLUFusePass : public ConvActivationFusePass {
*/
class Conv2DReLU6FusePass : public ConvActivationFusePass {
public:
Conv2DReLU6FusePass();
std::string activation_type() const { return "relu6"; }
};
/*
......@@ -58,6 +61,7 @@ class Conv2DReLU6FusePass : public ConvActivationFusePass {
*/
class Conv2DSwishFusePass : public ConvActivationFusePass {
public:
Conv2DSwishFusePass();
std::string activation_type() const { return "swish"; }
};
/*
......@@ -65,6 +69,7 @@ class Conv2DSwishFusePass : public ConvActivationFusePass {
*/
class Conv2DHardSwishFusePass : public ConvActivationFusePass {
public:
Conv2DHardSwishFusePass();
std::string activation_type() const { return "hard_swish"; }
};
} // namespace ir
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle {
......@@ -30,9 +31,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("name", name);
if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("groups", 1);
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("data_format", std::string("NCHW"));
op->SetAttr("strides", std::vector<int>({1, 1}));
op->SetAttr("dilations", std::vector<int>({1, 1}));
op->SetAttr("paddings", std::vector<int>({0, 0}));
op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]});
op->SetOutput("Output", outputs);
} else if (is_activation) {
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs);
......@@ -43,8 +51,9 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
} else if (type == "swish") {
op->SetAttr("beta", 1.0f);
}
op->SetOutput("Out", outputs);
}
op->SetOutput("Out", outputs);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
}
......
......@@ -24,6 +24,18 @@ extra {
name: "op_role"
type: INT
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "name"
type: STRING
}
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "op_role_var"
type: STRINGS
......
......@@ -16,6 +16,18 @@ extra {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "name"
type: STRING
}
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
......
......@@ -52,4 +52,8 @@ extra {
name: "is_test"
type: BOOLEAN
}
attrs {
name: "name"
type: STRINGS
}
}
......@@ -6,16 +6,28 @@ def {
outputs {
name: "Out"
}
attrs {
name: "threshold"
type: FLOAT
}
}
extra {
attrs {
name: "threshold"
name: "name"
type: STRING
}
attrs {
name: "is_test"
type: FLOAT
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "op_role"
type: INT
......
......@@ -12,6 +12,10 @@ extra {
name: "beta"
type: FLOAT
}
attrs {
name: "name"
type: STRING
}
attrs {
name: "use_mkldnn"
type: BOOLEAN
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册