未验证 提交 dd33d28d 编写于 作者: W Wangzheee 提交者: GitHub

[pass_enhance] conv_elementwise_add_mkldnn_fuse_pass (#33931)

上级 ae74c404
......@@ -51,7 +51,7 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const {
VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse";
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass op compat failed.";
LOG(WARNING) << "conv_activation_mkldnn_fuse_pass op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
......@@ -114,6 +114,10 @@ ConvActivationFusePass::ConvActivationFusePass() {
.IsOptional()
.IsTensor()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
......
......@@ -81,16 +81,72 @@ boost::optional<T> HasAttribute(const Node& op, const std::string& attr) {
return boost::none;
}
ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 0})
.End();
}
ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle(
const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func,
const ResidualConnectionMKLDNNFusePass::IdentityConvFunc&
get_node_from_conv_op,
const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc&
get_node_from_elementwise_add_op)
get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass)
: fusion_stats{std::make_shared<int>(0)},
can_fuse_func{can_fuse_func},
get_node_from_conv_op{get_node_from_conv_op},
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {}
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
pass_{pass} {}
void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
......@@ -102,6 +158,11 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()(
Node* elementwise_add_op;
Node* elementwise_add_identity;
Node* elementwise_add_out;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
std::tie(conv_op, conv_input, conv_filter, conv_output) =
get_node_from_conv_op(subgraph);
......@@ -133,12 +194,14 @@ ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle(
const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc&
get_node_from_conv_y_op,
const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc&
get_node_from_elementwise_add_op)
get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass)
: fusion_stats{std::make_shared<int>(0)},
can_fuse_func{can_fuse_func},
get_node_from_conv_x_op{get_node_from_conv_x_op},
get_node_from_conv_y_op{get_node_from_conv_y_op},
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {}
get_node_from_elementwise_add_op{get_node_from_elementwise_add_op},
pass_{pass} {}
void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
......@@ -155,6 +218,12 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()(
Node* elementwise_add_op;
Node* elementwise_add_out;
if (!pass_->IsCompat(subgraph, graph)) {
LOG(WARNING)
<< "conv_elementwise_add_mkldnn_fuse_pass in op compat failed.";
return;
}
std::tie(conv_x_op, conv_x_input, conv_x_filter, conv_x_output) =
get_node_from_conv_x_op(subgraph);
std::tie(conv_y_op, conv_y_input, conv_y_filter, conv_y_output) =
......@@ -247,7 +316,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX(
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_pattern, subgraph);
},
get_node_from_elementwise_add);
get_node_from_elementwise_add, this);
}
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
......@@ -284,7 +353,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY(
[this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_pattern, subgraph);
},
get_node_from_elementwise_add);
get_node_from_elementwise_add, this);
}
GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
......@@ -325,7 +394,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv(
&conv_y_pattern](const GraphPatternDetector::subgraph_t& subgraph) {
return GetNodesFromConv(conv_y_pattern, subgraph);
},
get_node_from_elementwise_add);
get_node_from_elementwise_add, this);
}
void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const {
......
......@@ -84,7 +84,6 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
auto can_fuse = [this](Node* op1, Node* op2) -> bool {
return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN;
};
auto fuse_handle = HandleType{can_fuse, std::forward<OpFuncs>(op_funcs)...};
(*gpd)(graph, fuse_handle);
......@@ -96,7 +95,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
IdentityFuseHandle(
const CanFuseFunc& can_fuse_func,
const IdentityConvFunc& get_node_from_conv_op,
const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op);
const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass);
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph);
......@@ -107,6 +107,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
CanFuseFunc can_fuse_func;
IdentityConvFunc get_node_from_conv_op;
IdentityElementwiseAddFunc get_node_from_elementwise_add_op;
const ResidualConnectionMKLDNNFusePass* pass_;
};
struct ProjectionFuseHandle {
......@@ -114,7 +115,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
const CanFuseFunc& can_fuse_func,
const ProjectionConvFunc& get_node_from_conv_x_op,
const ProjectionConvFunc& get_node_from_conv_y_op,
const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op);
const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op,
const ResidualConnectionMKLDNNFusePass* pass);
void operator()(const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph);
......@@ -126,9 +128,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase {
ProjectionConvFunc get_node_from_conv_x_op;
ProjectionConvFunc get_node_from_conv_y_op;
ProjectionElementwiseAddFunc get_node_from_elementwise_add_op;
const ResidualConnectionMKLDNNFusePass* pass_;
};
public:
ResidualConnectionMKLDNNFusePass();
virtual ~ResidualConnectionMKLDNNFusePass() {}
protected:
......
......@@ -16,6 +16,7 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_test_util.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......@@ -25,16 +26,67 @@ namespace ir {
constexpr int nodes_removed = 3;
constexpr int nodes_added = 1;
OpDesc* Create_Op_con2d(ProgramDesc* prog, const std::string& op_type_name,
const std::vector<test::InOutVarNamePair>& inputs,
const std::vector<test::InOutVarNamePair>& outputs,
const bool use_mkldnn = true) {
auto* op = prog->MutableBlock(0)->AppendOp();
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({0, 0});
const std::vector<int> dilations({1, 1});
op->SetType(op_type_name);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("strides", strides);
op->SetAttr("groups", 1);
op->SetAttr("paddings", paddings);
op->SetAttr("padding_algorithm", std::string("EXPLICIT"));
op->SetAttr("dilations", dilations);
op->SetAttr("data_format", std::string("NCHW"));
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
for (const auto& output : outputs) {
op->SetOutput(output.first, {output.second});
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return op;
}
OpDesc* Create_Op_elemntwise_add(
ProgramDesc* prog, const std::string& op_type_name,
const std::vector<test::InOutVarNamePair>& inputs,
const std::vector<test::InOutVarNamePair>& outputs,
bool use_mkldnn = true) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(op_type_name);
op->SetAttr("use_mkldnn", use_mkldnn);
op->SetAttr("axis", -1);
for (const auto& input : inputs) {
op->SetInput(input.first, {input.second});
}
for (const auto& output : outputs) {
op->SetOutput(output.first, {output.second});
}
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
return op;
}
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) {
auto prog =
test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{{"Output", "c"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}},
{{"Out", "d"}});
Create_Op_con2d(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{{"Output", "c"}});
Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}},
{{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog);
......@@ -53,17 +105,17 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
// right branch
test::CreateOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{{"Output", "c"}});
Create_Op_con2d(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{{"Output", "c"}});
// left branch
test::CreateOp(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}},
{{"Output", "f"}});
Create_Op_con2d(&prog, "conv2d",
{{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}},
{{"Output", "f"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}},
{{"Out", "d"}});
Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}},
{{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog);
......@@ -80,10 +132,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}},
{{"Out", "d"}});
Create_Op_con2d(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}});
Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}},
{{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog);
......@@ -100,12 +152,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{{"Output", "c"}});
Create_Op_con2d(&prog, "conv2d",
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
{{"Output", "c"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}},
{{"Out", "d"}});
Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}},
{{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog);
......@@ -122,10 +174,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass,
auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}},
{{"Out", "d"}});
Create_Op_con2d(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}});
Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}},
{{"Out", "d"}});
test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}});
Graph graph(prog);
......@@ -142,14 +194,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
test::BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"});
test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}});
test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}});
Create_Op_con2d(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
{{"Output", "c"}});
test::CreateOp(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}},
{{"Output", "e"}});
Create_Op_con2d(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}},
{{"Output", "e"}});
test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}},
{{"Out", "f"}});
Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}},
{{"Out", "f"}});
test::CreateOp(&prog, "relu", {{"X", "f"}}, {{"Out", "g"}});
Graph graph(prog);
......
......@@ -67,6 +67,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
.AddAttr("paddings")
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
......@@ -75,6 +76,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() {
.AddAttr("dilations")
.End()
.AddAttr("data_format")
.IsOptional()
.IsStringIn({"NCHW", "NHWC"})
.End();
}
......
......@@ -15,6 +15,10 @@ def {
}
}
extra {
attrs {
name: "@ENABLE_CACHE_RUNTIME_CONTEXT@"
type: BOOLEAN
}
attrs {
name: "out_threshold"
type: FLOAT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册