未验证 提交 0c789ae5 编写于 作者: S Sylwester Fraczek 提交者: GitHub

Add fc residual pattern (#46757)

* fix fc pattern

remove use_bias
add residual input switch
fix references to pattern

* review fixes
上级 8a5f17e8
...@@ -1056,11 +1056,7 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x, ...@@ -1056,11 +1056,7 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
} }
} }
PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x, PDNode *patterns::FCMKLDNN::operator()(bool with_residual_data) {
bool with_bias) {
// Create shared nodes.
x->assert_is_op_input("fc", "Input");
auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc"); auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc");
// Create variables // Create variables
// Input // Input
...@@ -1081,8 +1077,31 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x, ...@@ -1081,8 +1077,31 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
->assert_is_op_output("fc", "Out") ->assert_is_op_output("fc", "Out")
->assert_is_only_output_of_op("fc"); ->assert_is_only_output_of_op("fc");
fc_op->LinksFrom({input_var, fc_weight_var, fc_bias_var}) std::vector<PDNode *> links_from{input_var, fc_weight_var, fc_bias_var};
.LinksTo({fc_out_var}); if (with_residual_data) {
auto res_fc_var = pattern->NewNode(residual_data_repr())
->AsInput()
->assert_is_op_input("fc")
// assert_is_op_input with two arguments doesn't work
// because ResidualData in FC is set as output with
// SetOutput so we do custom assert output
->assert_more([&](Node *x) {
for (auto *op : x->outputs)
if (IsNthOutput(x, op, "ResidualData", 0))
return true;
return false;
});
links_from.push_back(res_fc_var);
} else {
fc_op->assert_more([&](Node *x) {
if (!HasOutput(x, "ResidualData") ||
x->Op()->Output("ResidualData").size() == 0)
return true;
return false;
});
}
fc_op->LinksFrom(links_from).LinksTo({fc_out_var});
return fc_out_var; return fc_out_var;
} }
......
...@@ -592,12 +592,12 @@ struct FC : public PatternBase { ...@@ -592,12 +592,12 @@ struct FC : public PatternBase {
// op: fc // op: fc
// named node: // named node:
// fc // fc
// w, bias, output // w, bias, output, residual_data
struct FCMKLDNN : public PatternBase { struct FCMKLDNN : public PatternBase {
FCMKLDNN(PDPattern* pattern, const std::string& name_scope) FCMKLDNN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc_mkldnn") {} : PatternBase(pattern, name_scope, "fc_mkldnn") {}
PDNode* operator()(PDNode* x, bool with_bias); PDNode* operator()(bool with_residual_data);
// declare operator node's name // declare operator node's name
PATTERN_DECL_NODE(fc); PATTERN_DECL_NODE(fc);
...@@ -606,6 +606,7 @@ struct FCMKLDNN : public PatternBase { ...@@ -606,6 +606,7 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE(weights); PATTERN_DECL_NODE(weights);
PATTERN_DECL_NODE(bias); PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(output); PATTERN_DECL_NODE(output);
PATTERN_DECL_NODE(residual_data);
}; };
// Embedding // Embedding
......
...@@ -471,11 +471,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const { ...@@ -471,11 +471,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::FCMKLDNN fc_pattern{pattern, name_scope_}; patterns::FCMKLDNN fc_pattern{pattern, name_scope_};
auto* fc_input = gpd.mutable_pattern() fc_pattern(false /* with_residual */);
->NewNode("fc_quantizer/input")
->AsInput()
->assert_is_op_input("fc", "Input");
fc_pattern(fc_input, false);
int quantize_fc_count = 0; int quantize_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
......
...@@ -62,11 +62,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC( ...@@ -62,11 +62,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC(
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern(); auto pattern = gpd.mutable_pattern();
patterns::FCMKLDNN fc_pattern{pattern, name_scope}; patterns::FCMKLDNN fc_pattern{pattern, name_scope};
bool fc_has_bias = true; auto fc_output = fc_pattern(false /* with residual */);
auto fc_output = fc_pattern(
gpd.mutable_pattern()->NewNode("fc")->AsInput()->assert_is_op_input(
"fc", "Input"),
fc_has_bias);
patterns::ResidualElementwise elementwise_pattern{ patterns::ResidualElementwise elementwise_pattern{
pattern, name_scope, fc_as_x}; pattern, name_scope, fc_as_x};
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h" #include "paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -28,26 +29,26 @@ namespace ir { ...@@ -28,26 +29,26 @@ namespace ir {
class Graph; class Graph;
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const { namespace {
PADDLE_ENFORCE_NOT_NULL(graph, void LogEnabledOps(const int counter, const std::string& details) {
platform::errors::InvalidArgument( std::string msg_ss{"--- enabled FC MKL-DNN for "};
"Pointer to graph argument should not be NULL.")); msg_ss += counter + " fc ops " + details;
Init("fc_mkldnn_pass", graph); string::PrettyLogDetail(msg_ss.c_str());
}
} // namespace
void FCMKLDNNPass::ApplyPass(ir::Graph* graph, bool with_residual) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("fc_mkldnn_pass/x")
->AsInput()
->assert_is_op_input("fc", "Input");
patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass"); patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass");
fc_pattern(x, true /*with bias*/); fc_pattern(with_residual);
int found_fc_count = 0; int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "Handle FC MKL-DNN pass"; VLOG(4) << "Handle FC MKL-DNN pass";
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) { if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) {
VLOG(3) << "do not perform fc fuse"; VLOG(3) << "do not enable FC MKL-DNN because it doesn't have use_mkldnn "
"attribute.";
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);
...@@ -77,6 +78,20 @@ void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const { ...@@ -77,6 +78,20 @@ void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_fc_count); AddStatis(found_fc_count);
LogEnabledOps(found_fc_count,
(with_residual ? "with residual connection"
: "without residual connection"));
}
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
Init("fc_mkldnn_pass", graph);
ApplyPass(graph, true);
ApplyPass(graph, false);
} }
} // namespace ir } // namespace ir
......
...@@ -34,6 +34,7 @@ class FCMKLDNNPass : public FusePassBase { ...@@ -34,6 +34,7 @@ class FCMKLDNNPass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph* graph) const; void ApplyImpl(ir::Graph* graph) const;
void ApplyPass(ir::Graph* graph, bool with_residual) const;
}; };
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册