From 0c789ae5fe51c00c9f8355ce9fbb2af9343cfb5f Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Mon, 10 Oct 2022 10:32:00 +0200 Subject: [PATCH] Add fc residual pattern (#46757) * fix fc pattern remove use_bias add residual input switch fix references to pattern * review fixes --- .../framework/ir/graph_pattern_detector.cc | 33 +++++++++++++---- .../framework/ir/graph_pattern_detector.h | 5 ++- .../framework/ir/mkldnn/cpu_quantize_pass.cc | 6 +-- .../fc_elementwise_add_mkldnn_fuse_pass.cc | 6 +-- .../framework/ir/mkldnn/fc_mkldnn_pass.cc | 37 +++++++++++++------ .../framework/ir/mkldnn/fc_mkldnn_pass.h | 1 + 6 files changed, 58 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 0d63ce21211..16c21f4b4e4 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1056,11 +1056,7 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, } } -PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x, - bool with_bias) { - // Create shared nodes. - x->assert_is_op_input("fc", "Input"); - +PDNode *patterns::FCMKLDNN::operator()(bool with_residual_data) { auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc"); // Create variables // Input @@ -1081,8 +1077,31 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x, ->assert_is_op_output("fc", "Out") ->assert_is_only_output_of_op("fc"); - fc_op->LinksFrom({input_var, fc_weight_var, fc_bias_var}) - .LinksTo({fc_out_var}); + std::vector links_from{input_var, fc_weight_var, fc_bias_var}; + if (with_residual_data) { + auto res_fc_var = pattern->NewNode(residual_data_repr()) + ->AsInput() + ->assert_is_op_input("fc") + // assert_is_op_input with two arguments doesn't work + // because ResidualData in FC is set as output with + // SetOutput so we do custom assert output + ->assert_more([&](Node *x) { + for (auto *op : x->outputs) + if (IsNthOutput(x, op, "ResidualData", 0)) + return true; + return false; + }); + links_from.push_back(res_fc_var); + } else { + fc_op->assert_more([&](Node *x) { + if (!HasOutput(x, "ResidualData") || + x->Op()->Output("ResidualData").size() == 0) + return true; + return false; + }); + } + + fc_op->LinksFrom(links_from).LinksTo({fc_out_var}); return fc_out_var; } diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index b2eb740b9ac..aca2a64888c 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -592,12 +592,12 @@ struct FC : public PatternBase { // op: fc // named node: // fc -// w, bias, output +// w, bias, output, residual_data struct FCMKLDNN : public PatternBase { FCMKLDNN(PDPattern* pattern, const std::string& name_scope) : PatternBase(pattern, name_scope, "fc_mkldnn") {} - PDNode* operator()(PDNode* x, bool with_bias); + PDNode* operator()(bool with_residual_data); // declare operator node's name PATTERN_DECL_NODE(fc); @@ -606,6 +606,7 @@ struct FCMKLDNN : public PatternBase { PATTERN_DECL_NODE(weights); PATTERN_DECL_NODE(bias); PATTERN_DECL_NODE(output); + PATTERN_DECL_NODE(residual_data); }; // Embedding diff --git a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc index f40eeea1736..5ec22e2e88a 100644 --- a/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.cc @@ -471,11 +471,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const { GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); patterns::FCMKLDNN fc_pattern{pattern, name_scope_}; - auto* fc_input = gpd.mutable_pattern() - ->NewNode("fc_quantizer/input") - ->AsInput() - ->assert_is_op_input("fc", "Input"); - fc_pattern(fc_input, false); + fc_pattern(false /* with_residual */); int quantize_fc_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, diff --git a/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc index 2046b30ba38..e0de720d049 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.cc @@ -62,11 +62,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC( GraphPatternDetector gpd; auto pattern = gpd.mutable_pattern(); patterns::FCMKLDNN fc_pattern{pattern, name_scope}; - bool fc_has_bias = true; - auto fc_output = fc_pattern( - gpd.mutable_pattern()->NewNode("fc")->AsInput()->assert_is_op_input( - "fc", "Input"), - fc_has_bias); + auto fc_output = fc_pattern(false /* with residual */); patterns::ResidualElementwise elementwise_pattern{ pattern, name_scope, fc_as_x}; diff --git a/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.cc b/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.cc index 6efa9f6b749..a2f8c14d1a2 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.cc +++ b/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/pretty_log.h" namespace paddle { namespace framework { @@ -28,26 +29,26 @@ namespace ir { class Graph; -void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const { - PADDLE_ENFORCE_NOT_NULL(graph, - platform::errors::InvalidArgument( - "Pointer to graph argument should not be NULL.")); - Init("fc_mkldnn_pass", graph); +namespace { +void LogEnabledOps(const int counter, const std::string& details) { + std::string msg_ss{"--- enabled FC MKL-DNN for "}; + msg_ss += counter + " fc ops " + details; + string::PrettyLogDetail(msg_ss.c_str()); +} +} // namespace +void FCMKLDNNPass::ApplyPass(ir::Graph* graph, bool with_residual) const { GraphPatternDetector gpd; - auto* x = gpd.mutable_pattern() - ->NewNode("fc_mkldnn_pass/x") - ->AsInput() - ->assert_is_op_input("fc", "Input"); patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass"); - fc_pattern(x, true /*with bias*/); + fc_pattern(with_residual); int found_fc_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { VLOG(4) << "Handle FC MKL-DNN pass"; if (!(graph->Has("use_mkldnn") && graph->Get("use_mkldnn"))) { - VLOG(3) << "do not perform fc fuse"; + VLOG(3) << "do not enable FC MKL-DNN because it doesn't have use_mkldnn " + "attribute."; return; } GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern); @@ -77,6 +78,20 @@ void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const { gpd(graph, handler); AddStatis(found_fc_count); + + LogEnabledOps(found_fc_count, + (with_residual ? "with residual connection" + : "without residual connection")); +} + +void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL(graph, + platform::errors::InvalidArgument( + "Pointer to graph argument should not be NULL.")); + Init("fc_mkldnn_pass", graph); + + ApplyPass(graph, true); + ApplyPass(graph, false); } } // namespace ir diff --git a/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h b/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h index df02250394a..9367e08e7c7 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h @@ -34,6 +34,7 @@ class FCMKLDNNPass : public FusePassBase { protected: void ApplyImpl(ir::Graph* graph) const; + void ApplyPass(ir::Graph* graph, bool with_residual) const; }; } // namespace ir -- GitLab