From dd33d28d3c2c71149fcaf9cd4c7710398b31dd12 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Tue, 6 Jul 2021 14:59:15 +0800 Subject: [PATCH] [pass_enhance] conv_elementwise_add_mkldnn_fuse_pass (#33931) --- .../conv_activation_mkldnn_fuse_pass.cc | 6 +- .../conv_elementwise_add_mkldnn_fuse_pass.cc | 83 +++++++++++-- .../conv_elementwise_add_mkldnn_fuse_pass.h | 10 +- ...elementwise_add_mkldnn_fuse_pass_tester.cc | 116 +++++++++++++----- .../ir/mkldnn/cpu_quantize_squash_pass.cc | 2 + .../operators/compat/elementwise_add.pbtxt | 4 + 6 files changed, 178 insertions(+), 43 deletions(-) diff --git a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc index 79a31e5cdc..aaae505edd 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_activation_mkldnn_fuse_pass.cc @@ -51,7 +51,7 @@ void ConvActivationFusePass::ApplyImpl(ir::Graph* graph) const { VLOG(4) << "handle " + conv_type() + "+" + activation_type() + " fuse"; if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Pass op compat failed."; + LOG(WARNING) << "conv_activation_mkldnn_fuse_pass op compat failed."; return; } GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, @@ -114,6 +114,10 @@ ConvActivationFusePass::ConvActivationFusePass() { .IsOptional() .IsTensor() .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() .AddOutput("Output") .IsTensor() .End() diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc index fa1544f780..bd65ad8e64 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.cc @@ -81,16 +81,72 @@ boost::optional HasAttribute(const Node& op, const std::string& attr) { return boost::none; } +ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() { + AddOpCompat(OpCompat("conv2d")) + .AddInput("Input") + .IsTensor() + .End() + .AddInput("Filter") + .IsTensor() + .End() + .AddInput("Bias") + .IsTensor() + .IsOptional() + .End() + .AddInput("ResidualData") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Output") + .IsTensor() + .End() + .AddAttr("strides") + .IsType>() + .End() + .AddAttr("paddings") + .IsType>() + .End() + .AddAttr("padding_algorithm") + .IsOptional() + .IsStringIn({"EXPLICIT", "SAME", "VALID"}) + .End() + .AddAttr("groups") + .IsNumGE(1) + .End() + .AddAttr("dilations") + .IsType>() + .End() + .AddAttr("data_format") + .IsStringIn({"NCHW", "NHWC", "AnyLayout"}) + .End(); + + AddOpCompat(OpCompat("elementwise_add")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsIntIn({-1, 0}) + .End(); +} + ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::IdentityFuseHandle( const ResidualConnectionMKLDNNFusePass::CanFuseFunc& can_fuse_func, const ResidualConnectionMKLDNNFusePass::IdentityConvFunc& get_node_from_conv_op, const ResidualConnectionMKLDNNFusePass::IdentityElementwiseAddFunc& - get_node_from_elementwise_add_op) + get_node_from_elementwise_add_op, + const ResidualConnectionMKLDNNFusePass* pass) : fusion_stats{std::make_shared(0)}, can_fuse_func{can_fuse_func}, get_node_from_conv_op{get_node_from_conv_op}, - get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {} + get_node_from_elementwise_add_op{get_node_from_elementwise_add_op}, + pass_{pass} {} void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { @@ -102,6 +158,11 @@ void ResidualConnectionMKLDNNFusePass::IdentityFuseHandle::operator()( Node* elementwise_add_op; Node* elementwise_add_identity; Node* elementwise_add_out; + if (!pass_->IsCompat(subgraph, graph)) { + LOG(WARNING) + << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; + return; + } std::tie(conv_op, conv_input, conv_filter, conv_output) = get_node_from_conv_op(subgraph); @@ -133,12 +194,14 @@ ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::ProjectionFuseHandle( const ResidualConnectionMKLDNNFusePass::ProjectionConvFunc& get_node_from_conv_y_op, const ResidualConnectionMKLDNNFusePass::ProjectionElementwiseAddFunc& - get_node_from_elementwise_add_op) + get_node_from_elementwise_add_op, + const ResidualConnectionMKLDNNFusePass* pass) : fusion_stats{std::make_shared(0)}, can_fuse_func{can_fuse_func}, get_node_from_conv_x_op{get_node_from_conv_x_op}, get_node_from_conv_y_op{get_node_from_conv_y_op}, - get_node_from_elementwise_add_op{get_node_from_elementwise_add_op} {} + get_node_from_elementwise_add_op{get_node_from_elementwise_add_op}, + pass_{pass} {} void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { @@ -155,6 +218,12 @@ void ResidualConnectionMKLDNNFusePass::ProjectionFuseHandle::operator()( Node* elementwise_add_op; Node* elementwise_add_out; + if (!pass_->IsCompat(subgraph, graph)) { + LOG(WARNING) + << "conv_elementwise_add_mkldnn_fuse_pass in op compat failed."; + return; + } + std::tie(conv_x_op, conv_x_input, conv_x_filter, conv_x_output) = get_node_from_conv_x_op(subgraph); std::tie(conv_y_op, conv_y_input, conv_y_filter, conv_y_output) = @@ -247,7 +316,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsX( [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { return GetNodesFromConv(conv_pattern, subgraph); }, - get_node_from_elementwise_add); + get_node_from_elementwise_add, this); } GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( @@ -284,7 +353,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseConvAsY( [this, &conv_pattern](const GraphPatternDetector::subgraph_t& subgraph) { return GetNodesFromConv(conv_pattern, subgraph); }, - get_node_from_elementwise_add); + get_node_from_elementwise_add, this); } GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( @@ -325,7 +394,7 @@ GraphWithStats ResidualConnectionMKLDNNFusePass::FuseProjectionConv( &conv_y_pattern](const GraphPatternDetector::subgraph_t& subgraph) { return GetNodesFromConv(conv_y_pattern, subgraph); }, - get_node_from_elementwise_add); + get_node_from_elementwise_add, this); } void ResidualConnectionMKLDNNFusePass::ApplyImpl(graph_ptr graph) const { diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h index 2ba4c80678..5b4f941836 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h @@ -84,7 +84,6 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { auto can_fuse = [this](Node* op1, Node* op2) -> bool { return this->FindFuseOption(*op1, *op2) == FUSE_MKLDNN; }; - auto fuse_handle = HandleType{can_fuse, std::forward(op_funcs)...}; (*gpd)(graph, fuse_handle); @@ -96,7 +95,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { IdentityFuseHandle( const CanFuseFunc& can_fuse_func, const IdentityConvFunc& get_node_from_conv_op, - const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op); + const IdentityElementwiseAddFunc& get_node_from_elementwise_add_op, + const ResidualConnectionMKLDNNFusePass* pass); void operator()(const GraphPatternDetector::subgraph_t& subgraph, Graph* graph); @@ -107,6 +107,7 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { CanFuseFunc can_fuse_func; IdentityConvFunc get_node_from_conv_op; IdentityElementwiseAddFunc get_node_from_elementwise_add_op; + const ResidualConnectionMKLDNNFusePass* pass_; }; struct ProjectionFuseHandle { @@ -114,7 +115,8 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { const CanFuseFunc& can_fuse_func, const ProjectionConvFunc& get_node_from_conv_x_op, const ProjectionConvFunc& get_node_from_conv_y_op, - const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op); + const ProjectionElementwiseAddFunc& get_node_from_elementwise_add_op, + const ResidualConnectionMKLDNNFusePass* pass); void operator()(const GraphPatternDetector::subgraph_t& subgraph, Graph* graph); @@ -126,9 +128,11 @@ class ResidualConnectionMKLDNNFusePass : public FusePassBase { ProjectionConvFunc get_node_from_conv_x_op; ProjectionConvFunc get_node_from_conv_y_op; ProjectionElementwiseAddFunc get_node_from_elementwise_add_op; + const ResidualConnectionMKLDNNFusePass* pass_; }; public: + ResidualConnectionMKLDNNFusePass(); virtual ~ResidualConnectionMKLDNNFusePass() {} protected: diff --git a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc index eafc81cc81..c86c6350a1 100644 --- a/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc +++ b/paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h" #include "paddle/fluid/framework/ir/pass_test_util.h" +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -25,16 +26,67 @@ namespace ir { constexpr int nodes_removed = 3; constexpr int nodes_added = 1; +OpDesc* Create_Op_con2d(ProgramDesc* prog, const std::string& op_type_name, + const std::vector& inputs, + const std::vector& outputs, + const bool use_mkldnn = true) { + auto* op = prog->MutableBlock(0)->AppendOp(); + const std::vector strides({1, 1}); + const std::vector paddings({0, 0}); + const std::vector dilations({1, 1}); + op->SetType(op_type_name); + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("strides", strides); + op->SetAttr("groups", 1); + op->SetAttr("paddings", paddings); + op->SetAttr("padding_algorithm", std::string("EXPLICIT")); + op->SetAttr("dilations", dilations); + op->SetAttr("data_format", std::string("NCHW")); + + for (const auto& input : inputs) { + op->SetInput(input.first, {input.second}); + } + for (const auto& output : outputs) { + op->SetOutput(output.first, {output.second}); + } + + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + return op; +} + +OpDesc* Create_Op_elemntwise_add( + ProgramDesc* prog, const std::string& op_type_name, + const std::vector& inputs, + const std::vector& outputs, + bool use_mkldnn = true) { + auto* op = prog->MutableBlock(0)->AppendOp(); + op->SetType(op_type_name); + op->SetAttr("use_mkldnn", use_mkldnn); + op->SetAttr("axis", -1); + + for (const auto& input : inputs) { + op->SetInput(input.first, {input.second}); + } + for (const auto& output : outputs) { + op->SetOutput(output.first, {output.second}); + } + + op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(OpRole::kForward)); + return op; +} + TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) { auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); - test::CreateOp(&prog, "conv2d", - {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, - {{"Output", "c"}}); - test::CreateOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, - {{"Out", "d"}}); + Create_Op_con2d(&prog, "conv2d", + {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {{"Output", "c"}}); + Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, + {{"Out", "d"}}); test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); Graph graph(prog); @@ -53,17 +105,17 @@ TEST(ConvElementwiseAddMKLDNNFusePass, test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); // right branch - test::CreateOp(&prog, "conv2d", - {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, - {{"Output", "c"}}); + Create_Op_con2d(&prog, "conv2d", + {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {{"Output", "c"}}); // left branch - test::CreateOp(&prog, "conv2d", - {{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}}, - {{"Output", "f"}}); + Create_Op_con2d(&prog, "conv2d", + {{"Input", "a"}, {"Bias", "bias2"}, {"Filter", "weights2"}}, + {{"Output", "f"}}); - test::CreateOp(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}}, - {{"Out", "d"}}); + Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "f"}, {"Y", "c"}}, + {{"Out", "d"}}); test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); Graph graph(prog); @@ -80,10 +132,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass, auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); - test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, - {{"Output", "c"}}); - test::CreateOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, - {{"Out", "d"}}); + Create_Op_con2d(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, + {{"Output", "c"}}); + Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, + {{"Out", "d"}}); test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); Graph graph(prog); @@ -100,12 +152,12 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) { test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"}); test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); - test::CreateOp(&prog, "conv2d", - {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, - {{"Output", "c"}}); + Create_Op_con2d(&prog, "conv2d", + {{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}}, + {{"Output", "c"}}); - test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, - {{"Out", "d"}}); + Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, + {{"Out", "d"}}); test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); Graph graph(prog); @@ -122,10 +174,10 @@ TEST(ConvElementwiseAddMKLDNNFusePass, auto prog = test::BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"}); test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); - test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, - {{"Output", "c"}}); - test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, - {{"Out", "d"}}); + Create_Op_con2d(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, + {{"Output", "c"}}); + Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, + {{"Out", "d"}}); test::CreateOp(&prog, "relu", {{"X", "d"}}, {{"Out", "e"}}); Graph graph(prog); @@ -142,14 +194,14 @@ TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) { test::BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"}); test::CreateOp(&prog, "sigmoid", {{"X", "a"}}, {{"Out", "b"}}); - test::CreateOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, - {{"Output", "c"}}); + Create_Op_con2d(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}}, + {{"Output", "c"}}); - test::CreateOp(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}}, - {{"Output", "e"}}); + Create_Op_con2d(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}}, + {{"Output", "e"}}); - test::CreateOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}}, - {{"Out", "f"}}); + Create_Op_elemntwise_add(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}}, + {{"Out", "f"}}); test::CreateOp(&prog, "relu", {{"X", "f"}}, {{"Out", "g"}}); Graph graph(prog); diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc index b0153ced9c..2483a506a8 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_squash_pass.cc @@ -67,6 +67,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() { .AddAttr("paddings") .End() .AddAttr("padding_algorithm") + .IsOptional() .IsStringIn({"EXPLICIT", "SAME", "VALID"}) .End() .AddAttr("groups") @@ -75,6 +76,7 @@ CPUQuantizeSquashPass::CPUQuantizeSquashPass() { .AddAttr("dilations") .End() .AddAttr("data_format") + .IsOptional() .IsStringIn({"NCHW", "NHWC"}) .End(); } diff --git a/paddle/fluid/operators/compat/elementwise_add.pbtxt b/paddle/fluid/operators/compat/elementwise_add.pbtxt index 6a3d0a9b3a..25da11905d 100644 --- a/paddle/fluid/operators/compat/elementwise_add.pbtxt +++ b/paddle/fluid/operators/compat/elementwise_add.pbtxt @@ -15,6 +15,10 @@ def { } } extra { + attrs { + name: "@ENABLE_CACHE_RUNTIME_CONTEXT@" + type: BOOLEAN + } attrs { name: "out_threshold" type: FLOAT -- GitLab