未验证 提交 0c789ae5 编写于 作者: S Sylwester Fraczek 提交者: GitHub

Add fc residual pattern (#46757)

* fix fc pattern

remove use_bias
add residual input switch
fix references to pattern

* review fixes
上级 8a5f17e8
......@@ -1056,11 +1056,7 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
}
}
PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
bool with_bias) {
// Create shared nodes.
x->assert_is_op_input("fc", "Input");
PDNode *patterns::FCMKLDNN::operator()(bool with_residual_data) {
auto *fc_op = pattern->NewNode(fc_repr())->assert_is_op("fc");
// Create variables
// Input
......@@ -1081,8 +1077,31 @@ PDNode *patterns::FCMKLDNN::operator()(paddle::framework::ir::PDNode *x,
->assert_is_op_output("fc", "Out")
->assert_is_only_output_of_op("fc");
fc_op->LinksFrom({input_var, fc_weight_var, fc_bias_var})
.LinksTo({fc_out_var});
std::vector<PDNode *> links_from{input_var, fc_weight_var, fc_bias_var};
if (with_residual_data) {
auto res_fc_var = pattern->NewNode(residual_data_repr())
->AsInput()
->assert_is_op_input("fc")
// assert_is_op_input with two arguments doesn't work
// because ResidualData in FC is set as output with
// SetOutput so we do custom assert output
->assert_more([&](Node *x) {
for (auto *op : x->outputs)
if (IsNthOutput(x, op, "ResidualData", 0))
return true;
return false;
});
links_from.push_back(res_fc_var);
} else {
fc_op->assert_more([&](Node *x) {
if (!HasOutput(x, "ResidualData") ||
x->Op()->Output("ResidualData").size() == 0)
return true;
return false;
});
}
fc_op->LinksFrom(links_from).LinksTo({fc_out_var});
return fc_out_var;
}
......
......@@ -592,12 +592,12 @@ struct FC : public PatternBase {
// op: fc
// named node:
// fc
// w, bias, output
// w, bias, output, residual_data
struct FCMKLDNN : public PatternBase {
FCMKLDNN(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc_mkldnn") {}
PDNode* operator()(PDNode* x, bool with_bias);
PDNode* operator()(bool with_residual_data);
// declare operator node's name
PATTERN_DECL_NODE(fc);
......@@ -606,6 +606,7 @@ struct FCMKLDNN : public PatternBase {
PATTERN_DECL_NODE(weights);
PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(output);
PATTERN_DECL_NODE(residual_data);
};
// Embedding
......
......@@ -471,11 +471,7 @@ void CPUQuantizePass::QuantizeFc(Graph* graph) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::FCMKLDNN fc_pattern{pattern, name_scope_};
auto* fc_input = gpd.mutable_pattern()
->NewNode("fc_quantizer/input")
->AsInput()
->assert_is_op_input("fc", "Input");
fc_pattern(fc_input, false);
fc_pattern(false /* with_residual */);
int quantize_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
......
......@@ -62,11 +62,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC(
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::FCMKLDNN fc_pattern{pattern, name_scope};
bool fc_has_bias = true;
auto fc_output = fc_pattern(
gpd.mutable_pattern()->NewNode("fc")->AsInput()->assert_is_op_input(
"fc", "Input"),
fc_has_bias);
auto fc_output = fc_pattern(false /* with residual */);
patterns::ResidualElementwise elementwise_pattern{
pattern, name_scope, fc_as_x};
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/ir/mkldnn/fc_mkldnn_pass.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
......@@ -28,26 +29,26 @@ namespace ir {
class Graph;
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
Init("fc_mkldnn_pass", graph);
namespace {
void LogEnabledOps(const int counter, const std::string& details) {
std::string msg_ss{"--- enabled FC MKL-DNN for "};
msg_ss += counter + " fc ops " + details;
string::PrettyLogDetail(msg_ss.c_str());
}
} // namespace
void FCMKLDNNPass::ApplyPass(ir::Graph* graph, bool with_residual) const {
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("fc_mkldnn_pass/x")
->AsInput()
->assert_is_op_input("fc", "Input");
patterns::FCMKLDNN fc_pattern(gpd.mutable_pattern(), "fc_mkldnn_pass");
fc_pattern(x, true /*with bias*/);
fc_pattern(with_residual);
int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Handle FC MKL-DNN pass";
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) {
VLOG(3) << "do not perform fc fuse";
VLOG(3) << "do not enable FC MKL-DNN because it doesn't have use_mkldnn "
"attribute.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(fc, fc, fc_pattern);
......@@ -77,6 +78,20 @@ void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
gpd(graph, handler);
AddStatis(found_fc_count);
LogEnabledOps(found_fc_count,
(with_residual ? "with residual connection"
: "without residual connection"));
}
void FCMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
Init("fc_mkldnn_pass", graph);
ApplyPass(graph, true);
ApplyPass(graph, false);
}
} // namespace ir
......
......@@ -34,6 +34,7 @@ class FCMKLDNNPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const;
void ApplyPass(ir::Graph* graph, bool with_residual) const;
};
} // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册