未验证 提交 cc5d4b1a 编写于 作者: F feng_shuai 提交者: GitHub

Conv relu mkldnn fuse pass (#33664)

上级 79e75bc5
...@@ -49,6 +49,11 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -49,6 +49,11 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse"; VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse";
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
conv_activation_pattern); // Filter conv_activation_pattern); // Filter
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out,
...@@ -97,6 +102,113 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -97,6 +102,113 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_conv_activation_count); AddStatis(found_conv_activation_count);
} }
ConvActivationFusePass::ConvActivationFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsOptional()
.IsTensor()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
// IsStringIn({"EXPLICIT", "SAME", "VALID"}), MobileNetV2 has no this
// attribute
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
// IsStringIn({"NHWC", "NCHW"}) MobileNetV2 has no this attribute
.AddAttr("data_format")
.IsOptional()
.IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End();
AddOpCompat(OpCompat("relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
Conv2DLeakyReLUFusePass::Conv2DLeakyReLUFusePass() {
AddOpCompat(OpCompat("leaky_relu"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
// float, default=0.02
.AddAttr("alpha")
.IsType<float>()
.End();
}
Conv2DReLU6FusePass::Conv2DReLU6FusePass() {
AddOpCompat(OpCompat("relu6"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
// default = 6.0f
.AddAttr("threshold")
.IsType<float>()
.End();
}
Conv2DSwishFusePass::Conv2DSwishFusePass() {
AddOpCompat(OpCompat("swish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
Conv2DHardSwishFusePass::Conv2DHardSwishFusePass() {
AddOpCompat(OpCompat("hard_swish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
// float, optional, default=6.0
.AddAttr("threshold")
.IsOptional()
.IsType<float>()
.End()
// float, optional, default=6.0
.AddAttr("scale")
.IsOptional()
.IsType<float>()
.End()
// float, optional, default=3.0
.AddAttr("offset")
.IsOptional()
.IsType<float>()
.End();
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -31,6 +31,7 @@ class Graph; ...@@ -31,6 +31,7 @@ class Graph;
class ConvActivationFusePass : public FusePassBase { class ConvActivationFusePass : public FusePassBase {
public: public:
ConvActivationFusePass();
virtual ~ConvActivationFusePass() {} virtual ~ConvActivationFusePass() {}
virtual std::string conv_type() const { return "conv2d"; } virtual std::string conv_type() const { return "conv2d"; }
virtual std::string activation_type() const { return "relu"; } virtual std::string activation_type() const { return "relu"; }
...@@ -44,6 +45,7 @@ class ConvActivationFusePass : public FusePassBase { ...@@ -44,6 +45,7 @@ class ConvActivationFusePass : public FusePassBase {
*/ */
class Conv2DLeakyReLUFusePass : public ConvActivationFusePass { class Conv2DLeakyReLUFusePass : public ConvActivationFusePass {
public: public:
Conv2DLeakyReLUFusePass();
std::string activation_type() const { return "leaky_relu"; } std::string activation_type() const { return "leaky_relu"; }
}; };
/* /*
...@@ -51,6 +53,7 @@ class Conv2DLeakyReLUFusePass : public ConvActivationFusePass { ...@@ -51,6 +53,7 @@ class Conv2DLeakyReLUFusePass : public ConvActivationFusePass {
*/ */
class Conv2DReLU6FusePass : public ConvActivationFusePass { class Conv2DReLU6FusePass : public ConvActivationFusePass {
public: public:
Conv2DReLU6FusePass();
std::string activation_type() const { return "relu6"; } std::string activation_type() const { return "relu6"; }
}; };
/* /*
...@@ -58,6 +61,7 @@ class Conv2DReLU6FusePass : public ConvActivationFusePass { ...@@ -58,6 +61,7 @@ class Conv2DReLU6FusePass : public ConvActivationFusePass {
*/ */
class Conv2DSwishFusePass : public ConvActivationFusePass { class Conv2DSwishFusePass : public ConvActivationFusePass {
public: public:
Conv2DSwishFusePass();
std::string activation_type() const { return "swish"; } std::string activation_type() const { return "swish"; }
}; };
/* /*
...@@ -65,6 +69,7 @@ class Conv2DSwishFusePass : public ConvActivationFusePass { ...@@ -65,6 +69,7 @@ class Conv2DSwishFusePass : public ConvActivationFusePass {
*/ */
class Conv2DHardSwishFusePass : public ConvActivationFusePass { class Conv2DHardSwishFusePass : public ConvActivationFusePass {
public: public:
Conv2DHardSwishFusePass();
std::string activation_type() const { return "hard_swish"; } std::string activation_type() const { return "hard_swish"; }
}; };
} // namespace ir } // namespace ir
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
namespace paddle { namespace paddle {
...@@ -30,9 +31,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -30,9 +31,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("name", name); op->SetAttr("name", name);
if (type == "conv2d") { if (type == "conv2d") {
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("groups", 1);
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("data_format", std::string("NCHW"));
op->SetAttr("strides", std::vector<int>({1, 1}));
op->SetAttr("dilations", std::vector<int>({1, 1}));
op->SetAttr("paddings", std::vector<int>({0, 0}));
op->SetInput("Input", {inputs[0]}); op->SetInput("Input", {inputs[0]});
op->SetInput("Filter", {inputs[1]}); op->SetInput("Filter", {inputs[1]});
op->SetInput("Bias", {inputs[2]}); op->SetInput("Bias", {inputs[2]});
op->SetOutput("Output", outputs);
} else if (is_activation) { } else if (is_activation) {
op->SetAttr("use_mkldnn", use_mkldnn); op->SetAttr("use_mkldnn", use_mkldnn);
op->SetInput("X", inputs); op->SetInput("X", inputs);
...@@ -43,8 +51,9 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -43,8 +51,9 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
} else if (type == "swish") { } else if (type == "swish") {
op->SetAttr("beta", 1.0f); op->SetAttr("beta", 1.0f);
} }
}
op->SetOutput("Out", outputs); op->SetOutput("Out", outputs);
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward)); static_cast<int>(OpRole::kForward));
} }
......
...@@ -24,6 +24,18 @@ extra { ...@@ -24,6 +24,18 @@ extra {
name: "op_role" name: "op_role"
type: INT type: INT
} }
attrs {
name: "use_mkldnn"
type: BOOLEAN
}
attrs {
name: "name"
type: STRING
}
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs { attrs {
name: "op_role_var" name: "op_role_var"
type: STRINGS type: STRINGS
......
...@@ -16,6 +16,18 @@ extra { ...@@ -16,6 +16,18 @@ extra {
name: "use_mkldnn" name: "use_mkldnn"
type: BOOLEAN type: BOOLEAN
} }
attrs {
name: "name"
type: STRING
}
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "is_test"
type: BOOLEAN
}
attrs { attrs {
name: "op_role" name: "op_role"
type: INT type: INT
......
...@@ -52,4 +52,8 @@ extra { ...@@ -52,4 +52,8 @@ extra {
name: "is_test" name: "is_test"
type: BOOLEAN type: BOOLEAN
} }
attrs {
name: "name"
type: STRINGS
}
} }
...@@ -6,16 +6,28 @@ def { ...@@ -6,16 +6,28 @@ def {
outputs { outputs {
name: "Out" name: "Out"
} }
attrs {
name: "threshold"
type: FLOAT
}
} }
extra { extra {
attrs { attrs {
name: "threshold" name: "name"
type: STRING
}
attrs {
name: "is_test"
type: FLOAT type: FLOAT
} }
attrs { attrs {
name: "use_mkldnn" name: "use_mkldnn"
type: BOOLEAN type: BOOLEAN
} }
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs { attrs {
name: "op_role" name: "op_role"
type: INT type: INT
......
...@@ -12,6 +12,10 @@ extra { ...@@ -12,6 +12,10 @@ extra {
name: "beta" name: "beta"
type: FLOAT type: FLOAT
} }
attrs {
name: "name"
type: STRING
}
attrs { attrs {
name: "use_mkldnn" name: "use_mkldnn"
type: BOOLEAN type: BOOLEAN
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册