未验证 提交 478a4e85 编写于 作者: Y Yan Chunwei 提交者: GitHub

refactor ir pattern (#13304)

上级 14242eae
...@@ -29,39 +29,27 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -29,39 +29,27 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unordered_set<Node*> nodes2delete; std::unordered_set<Node*> nodes2delete;
GraphPatternDetector gpd; GraphPatternDetector gpd;
// BuildFCPattern(gpd.mutable_pattern());
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
->NewNode("fc_fuse/x") ->NewNode("fc_fuse/x")
->AsInput() ->AsInput()
->assert_is_op_input("mul", "X"); ->assert_is_op_input("mul", "X");
patterns::FC(gpd.mutable_pattern(), "fc_fuse", x, true /*with bias*/); patterns::FC fc_pattern(gpd.mutable_pattern(), "fc_fuse");
fc_pattern(x, true /*with bias*/);
#define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode("fc_fuse/" #id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetrieveNode("fc_fuse/" #id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", "fc_fuse/" #id);
int found_fc_count = 0; int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle FC fuse"; VLOG(4) << "handle FC fuse";
// Currently, there is no FC op available, so I will just simulate the GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
// scenerio. GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
// FC's fusion is simple, just op fuse, no need to process the GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
// parameters. GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
GET_NODE(x); // x GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_NODE(w); // Y GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
GET_NODE(fc_bias); // bias
GET_NODE(fc_out); // Out
GET_NODE(mul); // MUL op
GET_NODE(elementwise_add); // ELEMENT_ADD op
GET_NODE(mul_out); // tmp
#undef GET_NODE
// Create an FC Node. // Create an FC Node.
OpDesc desc; OpDesc desc;
std::string fc_x_in = x->Name(); std::string fc_x_in = subgraph.at(x)->Name();
std::string fc_Y_in = w->Name(); std::string fc_Y_in = w->Name();
std::string fc_bias_in = fc_bias->Name(); std::string fc_bias_in = fc_bias->Name();
std::string fc_out_out = fc_out->Name(); std::string fc_out_out = fc_out->Name();
...@@ -73,7 +61,8 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -73,7 +61,8 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out}); GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});
IR_NODE_LINK_TO(x, fc_node); PADDLE_ENFORCE(subgraph.count(x));
IR_NODE_LINK_TO(subgraph.at(x), fc_node);
IR_NODE_LINK_TO(w, fc_node); IR_NODE_LINK_TO(w, fc_node);
IR_NODE_LINK_TO(fc_bias, fc_node); IR_NODE_LINK_TO(fc_bias, fc_node);
IR_NODE_LINK_TO(fc_node, fc_out); IR_NODE_LINK_TO(fc_node, fc_out);
......
...@@ -20,52 +20,43 @@ namespace paddle { ...@@ -20,52 +20,43 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static void BuildPattern(PDPattern* pattern, const std::string& name_scope,
bool with_fc_bias) {
PDNode* x = pattern->NewNode(name_scope, "x")
->assert_is_op_input("mul")
->assert_var_not_persistable();
auto* fc_out = patterns::FC(pattern, name_scope, x, with_fc_bias);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
patterns::GRU(pattern, name_scope, fc_out);
VLOG(3) << "fc_gru pattern \n" << pattern->DotString();
}
static int BuildFusion(Graph* graph, const std::string& name_scope, static int BuildFusion(Graph* graph, const std::string& name_scope,
Scope* scope, bool with_fc_bias) { Scope* scope, bool with_fc_bias) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
BuildPattern(pattern, name_scope, with_fc_bias); // Create pattern.
patterns::FC fc_pattern(pattern, name_scope);
patterns::GRU gru_pattern(pattern, name_scope);
PDNode* x =
pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable();
auto* fc_out = fc_pattern(x, with_fc_bias);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
gru_pattern(fc_out);
// Create New OpDesc // Create New OpDesc
auto gru_creater = [&](int gru, int x, int weight_x, int weight_h, int bias, auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
int hidden, int fc_bias) { Node* bias, Node* hidden, Node* fc_bias) {
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE(x);
GET_NODE(weight_x);
GET_NODE(weight_h);
GET_NODE(bias);
GET_NODE(hidden);
GET_NODE(gru);
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType("fusion_gru"); op_desc.SetType("fusion_gru");
#define NEW_NAME(x) name_scope + "/at." #x ".new" #define NEW_NAME(x) name_scope + "/at." #x ".new"
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()}); #define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN(X, x); SET_IN(X, x);
SET_IN(WeightX, weight_x); SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h); SET_IN(WeightH, weight_h);
if (with_fc_bias) { if (with_fc_bias) {
op_desc.SetInput("Bias", {NEW_NAME(bias) + bias_n->Name()}); op_desc.SetInput("Bias", {NEW_NAME(bias) + bias->Name()});
} else { } else {
SET_IN(Bias, bias); SET_IN(Bias, bias);
} }
#undef SET_IN #undef SET_IN
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
op_desc.SetOutput("Hidden", {hidden_n->Name()}); op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetAttr("is_reverse", gru_n->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse"));
// TODO(TJ): This should be a option for infer // TODO(TJ): This should be a option for infer
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
...@@ -82,14 +73,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -82,14 +73,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
if (with_fc_bias) { if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias // Fusion GRU bias = fcbias + grubias
auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias_n->Name()); auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias->Name());
auto* out_bias_tensor = auto* out_bias_tensor =
fusion_bias_var->GetMutable<framework::LoDTensor>(); fusion_bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(fusion_bias_var); PADDLE_ENFORCE(fusion_bias_var);
GET_NODE(fc_bias); auto* gru_bias_var = scope->FindVar(bias->Name());
PADDLE_ENFORCE(fc_bias_n); auto* fc_bias_var = scope->FindVar(fc_bias->Name());
auto* gru_bias_var = scope->FindVar(bias_n->Name());
auto* fc_bias_var = scope->FindVar(fc_bias_n->Name());
PADDLE_ENFORCE(gru_bias_var); PADDLE_ENFORCE(gru_bias_var);
PADDLE_ENFORCE(fc_bias_var); PADDLE_ENFORCE(fc_bias_var);
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>(); const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>();
...@@ -113,11 +102,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -113,11 +102,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef NEW_NAME #undef NEW_NAME
#undef NEW_IMTERMEDIATE_OUT #undef NEW_IMTERMEDIATE_OUT
IR_NODE_LINK_TO(x_n, op); IR_NODE_LINK_TO(x, op);
IR_NODE_LINK_TO(weight_x_n, op); IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h_n, op); IR_NODE_LINK_TO(weight_h, op);
IR_NODE_LINK_TO(bias_n, op); // actually should link to new bias if have IR_NODE_LINK_TO(bias, op); // actually should link to new bias if have
IR_NODE_LINK_TO(op, hidden_n); IR_NODE_LINK_TO(op, hidden);
// h0? // h0?
return op; return op;
}; };
...@@ -125,42 +114,35 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -125,42 +114,35 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
int fusion_count{0}; int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
#define GET_NODE(name__) \ auto* x_n = subgraph.at(x);
std::string name__##key = name_scope + "/" + #name__; \ GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
auto* name__##n = pattern->RetrieveNode(name__##key); \ GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
PADDLE_ENFORCE(name__##n); \ GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
PADDLE_ENFORCE(subgraph.count(name__##n)); \ GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, gru_pattern);
Node* name__##_n = subgraph.at(name__##n); \ GET_IR_NODE_FROM_SUBGRAPH(gru, gru, gru_pattern);
int name__ __attribute__((unused)) = name__##_n->id(); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, gru_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, gru_pattern);
GET_NODE(x);
GET_NODE(w); // fc weight
GET_NODE(mul);
GET_NODE(fc_out);
GET_NODE(Weight);
GET_NODE(gru);
GET_NODE(Bias);
GET_NODE(Hidden);
// nodes need be removed // nodes need be removed
GET_NODE(BatchGate); GET_IR_NODE_FROM_SUBGRAPH(BatchGate, BatchGate, gru_pattern);
GET_NODE(BatchResetHiddenPrev); GET_IR_NODE_FROM_SUBGRAPH(BatchResetHiddenPrev, BatchGate, gru_pattern);
GET_NODE(BatchHidden); GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchGate, gru_pattern);
if (with_fc_bias) { if (with_fc_bias) {
GET_NODE(mul_out); GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
GET_NODE(fc_bias); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_NODE(elementwise_add); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
gru_creater(gru, x, w, Weight, Bias, Hidden, fc_bias);
gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul_n, gru_n, elementwise_add_n, fc_bias_n, fc_out_n, mul_out_n, {mul, gru, elementwise_add, fc_bias, fc_out, mul_out, BatchGate,
BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n}); BatchResetHiddenPrev, BatchHidden});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} else { } else {
gru_creater(gru, x, w, Weight, Bias, Hidden, -1); gru_creater(gru, x_n, w, Weight, Bias, Hidden, nullptr);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul_n, gru_n, BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n}); {mul, gru, BatchGate, BatchResetHiddenPrev, BatchHidden});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} }
#undef GET_NODE #undef GET_NODE
......
...@@ -20,45 +20,29 @@ namespace paddle { ...@@ -20,45 +20,29 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static std::string GenNodeName(const std::string& prefix, int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
const std::string& name) { bool with_fc_bias) {
return prefix + "/" + name; GraphPatternDetector gpd;
} auto* pattern = gpd.mutable_pattern();
static void BuildPattern(PDPattern* pattern, const std::string& name_scope, // Build pattern
bool with_fc_bias) { PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x"))
PDNode* x = pattern->NewNode(name_scope, "x")
->assert_is_op_input("mul") ->assert_is_op_input("mul")
->assert_var_not_persistable(); ->assert_var_not_persistable();
auto* fc_out = patterns::FC(pattern, name_scope, x, with_fc_bias); patterns::FC fc_pattern(pattern, name_scope);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
patterns::LSTM(pattern, name_scope, fc_out);
// LOG(INFO) << "\n" << pattern->DotString();
}
static int BuildFusion(Graph* graph, const std::string& name_scope,
Scope* scope, bool with_fc_bias) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
BuildPattern(pattern, name_scope, with_fc_bias); // fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto* fc_out = fc_pattern(x, with_fc_bias)->AsIntermediate();
patterns::LSTM lstm_pattern(pattern, name_scope);
lstm_pattern(fc_out);
// Create New OpDesc // Create New OpDesc
auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h, auto lstm_creator = [&](Node* lstm, Node* input, Node* weight_x,
int bias, int hidden, int cell, int xx, int fc_bias) { Node* weight_h, Node* bias, Node* hidden, Node* cell,
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x); Node* xx, Node* fc_bias) {
GET_NODE(input);
GET_NODE(weight_x);
GET_NODE(weight_h);
GET_NODE(bias);
GET_NODE(hidden);
GET_NODE(cell);
GET_NODE(xx);
GET_NODE(lstm);
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType("fusion_lstm"); op_desc.SetType("fusion_lstm");
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()}); #define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN(X, input); SET_IN(X, input);
SET_IN(WeightX, weight_x); SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h); SET_IN(WeightH, weight_h);
...@@ -71,13 +55,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -71,13 +55,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* bias_var = scope->Var(new_bias_var); auto* bias_var = scope->Var(new_bias_var);
PADDLE_ENFORCE(bias_var); PADDLE_ENFORCE(bias_var);
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>(); auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
auto* lstm_bias_var = scope->FindVar(bias_n->Name()); auto* lstm_bias_var = scope->FindVar(bias->Name());
PADDLE_ENFORCE(lstm_bias_var); PADDLE_ENFORCE(lstm_bias_var);
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>(); const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
bias_tensor->Resize(lstm_bias_tensor.dims()); bias_tensor->Resize(lstm_bias_tensor.dims());
GET_NODE(fc_bias); auto* fc_bias_var = scope->FindVar(fc_bias->Name());
auto* fc_bias_var = scope->FindVar(fc_bias_n->Name());
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>(); const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
auto* data = bias_tensor->mutable_data<float>(platform::CPUPlace()); auto* data = bias_tensor->mutable_data<float>(platform::CPUPlace());
...@@ -88,31 +71,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -88,31 +71,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
} }
op_desc.SetInput("Bias", {new_bias_var}); op_desc.SetInput("Bias", {new_bias_var});
} }
#undef GET_NODE
// Create temp variables. // Create temp variables.
scope->Var(name_scope + "/BatchedInput.new") const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
->GetMutable<framework::LoDTensor>(); const std::string BatchedCellPreAct =
scope->Var(name_scope + "/BatchCellPreAct.new") patterns::UniqueKey("BatchedCellPreAct");
->GetMutable<framework::LoDTensor>(); const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
scope->Var(name_scope + "/BatchedGate.new")
->GetMutable<framework::LoDTensor>(); scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {}); op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden_n->Name()}); op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetOutput("Cell", {cell_n->Name()}); op_desc.SetOutput("Cell", {cell->Name()});
op_desc.SetOutput("XX", {xx_n->Name()}); op_desc.SetOutput("XX", {xx->Name()});
op_desc.SetOutput("BatchedGate", {name_scope + "/BatchedGate.new"}); op_desc.SetOutput("BatchedGate", {BatchedGate});
op_desc.SetOutput("BatchCellPreAct", {name_scope + "/BatchCellPreAct.new"}); op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
op_desc.SetOutput("BatchedInput", {name_scope + "/BatchedInput.new"}); op_desc.SetOutput("BatchedInput", {BatchedInput});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes")); op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr // TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
#define TMP_NAME(x) "at.new.tmp." #x PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)}) auto* scope = graph->Get<Scope*>(kParamScopeAttr);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden); OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0); OP_SET_OUT(ReorderedH0);
...@@ -120,22 +108,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -120,22 +108,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef OP_SET_OUT #undef OP_SET_OUT
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); IR_NODE_LINK_TO(input, op);
auto* scope = graph->Get<Scope*>(kParamScopeAttr); IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h, op);
#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable<LoDTensor>() IR_NODE_LINK_TO(bias, op);
TMP_NEW(BatchedCell); IR_NODE_LINK_TO(op, hidden);
TMP_NEW(BatchedHidden);
TMP_NEW(ReorderedH0);
TMP_NEW(ReorderedC0);
#undef TMP_NEW
#undef TMP_NAME
IR_NODE_LINK_TO(input_n, op);
IR_NODE_LINK_TO(weight_x_n, op);
IR_NODE_LINK_TO(weight_h_n, op);
IR_NODE_LINK_TO(bias_n, op);
IR_NODE_LINK_TO(op, hidden_n);
return op; return op;
}; };
...@@ -143,39 +120,32 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -143,39 +120,32 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
#define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \
PADDLE_ENFORCE(name__##n); \
PADDLE_ENFORCE(subgraph.count(name__##n)); \
Node* name__##_n = subgraph.at(name__##n); \
int name__ __attribute__((unused)) = name__##_n->id();
GET_NODE(x);
GET_NODE(w);
GET_NODE(mul);
GET_NODE(fc_out);
GET_NODE(Weight);
GET_NODE(lstm);
GET_NODE(Bias);
GET_NODE(Hidden);
GET_NODE(Cell);
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
if (with_fc_bias) { if (with_fc_bias) {
GET_NODE(fc_bias); GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
GET_NODE(elementwise_add); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, fc_bias); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
fc_bias);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul_n, lstm_n, elementwise_add_n}); {mul, lstm, elementwise_add});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} else { } else {
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, -1); GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
nullptr);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes({mul_n, lstm_n}); std::unordered_set<const Node*> marked_nodes({mul, lstm});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} }
#undef GET_NODE
++fusion_count; ++fusion_count;
}; };
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -106,8 +107,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { ...@@ -106,8 +107,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
for (auto& pdnode : pattern_.nodes()) { for (auto& pdnode : pattern_.nodes()) {
if (!pdnodes2nodes_.count(pdnode.get())) { if (!pdnodes2nodes_.count(pdnode.get())) {
VLOG(4) << pdnode->name() << " can't find matched Node, early stop"; VLOG(4) << pdnode->name() << " can't find matched Node, early stop";
// return false;
return false;
} }
} }
for (auto& item : pdnodes2nodes_) { for (auto& item : pdnodes2nodes_) {
...@@ -517,87 +517,89 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) { ...@@ -517,87 +517,89 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
return false; return false;
} }
PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope, PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x,
PDNode* x, bool with_bias) { bool with_bias) {
// mul op // Create shared nodes.
auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul"); x->assert_is_op_input("mul", "X");
auto* mul_weight_var = pattern->NewNode(name_scope, "w") auto* mul = pattern->NewNode(mul_repr())->assert_is_op("mul");
->AsInput()
->assert_is_persistable_var() auto* mul_w_var = pattern->NewNode(w_repr())
->assert_is_op_input("mul", "Y"); ->AsInput()
->assert_is_persistable_var()
PDNode* fc_out{nullptr}; ->assert_is_op_input("mul", "Y");
if (with_bias) {
PDNode* elementwise_add_op{nullptr}; auto* mul_out_var =
PDNode *mul_out_var{nullptr}, *bias{nullptr}; pattern->NewNode(mul_out_repr())->assert_is_op_output("mul");
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add")
->assert_is_op("elementwise_add"); if (!with_bias) { // not with bias
// intermediate variable, will be removed in the IR after fuse. // Add links.
mul_out_var = pattern->NewNode(name_scope, "mul_out") mul->LinksFrom({x, mul_w_var}).LinksTo({mul_out_var});
->AsIntermediate() return mul_out_var;
->assert_is_only_output_of_op("mul")
->assert_is_op_input("elementwise_add"); } else { // with bias
// bias mul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
bias = pattern->NewNode(name_scope, "fc_bias") // Create operators.
->AsInput() auto* elementwise_add = pattern->NewNode(elementwise_add_repr())
->assert_is_op_input("elementwise_add"); ->assert_is_op("elementwise_add");
// output // Create variables.
fc_out = pattern->NewNode(name_scope, "fc_out") auto* bias = pattern->NewNode(bias_repr())
->AsOutput() ->assert_is_op_input("elementwise_add")
->assert_is_op_output("elementwise_add"); ->AsInput();
mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out}); auto* fc_out = pattern->NewNode(Out_repr())
} else { ->AsOutput()
fc_out = pattern->NewNode(name_scope, "fc_out") ->assert_is_op_output("elementwise_add");
->AsOutput()
->assert_is_op_output("mul"); mul->LinksFrom({mul_w_var, x}).LinksTo({mul_out_var});
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out}); elementwise_add->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
return fc_out;
} }
return fc_out;
} }
#define NEW_NODE(op__, arg__, io__) \ PDNode* patterns::LSTM::operator()(PDNode* x) {
auto* arg__ = pattern->NewNode(name_scope, #arg__) \
->assert_is_op_##io__(#op__, #arg__);
PDNode* patterns::LSTM(PDPattern* pattern, const std::string& name_scope,
PDNode* x) {
x->assert_is_op_input("lstm", "Input"); x->assert_is_op_input("lstm", "Input");
auto* lstm_op = pattern->NewNode(name_scope, "lstm")->assert_is_op("lstm"); auto* lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm");
#define NEW_NODE(arg__, io__) \
auto* arg__ = \
pattern->NewNode(arg__##_repr())->assert_is_op_##io__("lstm", #arg__);
// Currently, the H0 and C0 are optional // Currently, the H0 and C0 are optional
// TODO(Superjomn) upgrade the fuse framework to support optional. // TODO(Superjomn) upgrade the fuse framework to support optional.
// NEW_NODE(H0, input); // NEW_NODE(H0, input);
// NEW_NODE(C0, input); // NEW_NODE(C0, input);
NEW_NODE(lstm, Weight, input); NEW_NODE(Weight, input);
NEW_NODE(lstm, Bias, input); NEW_NODE(Bias, input);
NEW_NODE(lstm, Hidden, output); NEW_NODE(Hidden, output);
NEW_NODE(lstm, Cell, output); NEW_NODE(Cell, output);
NEW_NODE(lstm, BatchGate, output); NEW_NODE(BatchGate, output);
NEW_NODE(lstm, BatchCellPreAct, output); NEW_NODE(BatchCellPreAct, output);
#undef NEW_NODE
lstm_op->LinksFrom({x, Weight, Bias}); lstm_op->LinksFrom({x, Weight, Bias});
lstm_op->LinksTo({Hidden, Cell, BatchGate, BatchCellPreAct}); lstm_op->LinksTo({Hidden, Cell, BatchGate, BatchCellPreAct});
return Hidden; return Hidden;
} }
PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope, PDNode* patterns::GRU::operator()(PDNode* x) {
PDNode* x) {
x->assert_is_op_input("gru", "Input"); x->assert_is_op_input("gru", "Input");
auto* gru_op = pattern->NewNode(name_scope, "gru")->assert_is_op("gru"); auto* gru_op = pattern->NewNode(gru_repr())->assert_is_op("gru");
#define NEW_NODE(arg__, io__) \
auto* arg__ = \
pattern->NewNode(arg__##_repr())->assert_is_op_##io__("gru", #arg__);
NEW_NODE(gru, Weight, input); NEW_NODE(Weight, input);
// TODO(Superjomn): upgrade the fuse framework to support optional. // TODO(Superjomn): upgrade the fuse framework to support optional.
// H0 and bias are optional // H0 and bias are optional
NEW_NODE(gru, Bias, input); // also optional NEW_NODE(Bias, input); // also optional
// NEW_NODE(H0, input); // NEW_NODE(H0, input);
NEW_NODE(gru, Hidden, output); NEW_NODE(Hidden, output);
// below are intermediate // below are intermediate
NEW_NODE(gru, BatchGate, output); NEW_NODE(BatchGate, output);
NEW_NODE(gru, BatchResetHiddenPrev, output); NEW_NODE(BatchResetHiddenPrev, output);
NEW_NODE(gru, BatchHidden, output); NEW_NODE(BatchHidden, output);
#undef NEW_NODE
BatchGate->AsIntermediate(); BatchGate->AsIntermediate();
BatchResetHiddenPrev->AsIntermediate(); BatchResetHiddenPrev->AsIntermediate();
...@@ -607,7 +609,6 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope, ...@@ -607,7 +609,6 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden}); gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden});
return Hidden; return Hidden;
} }
#undef NEW_NODE
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -286,22 +286,148 @@ void GraphSafeRemoveNodes(Graph* graph, ...@@ -286,22 +286,148 @@ void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes); const std::unordered_set<const Node*>& nodes);
// Some pre-defined patterns those can be reused in multiple passes. // Some pre-defined patterns those can be reused in multiple passes.
// The related Fluid Layer or Op should be one pattern here for better reusage
// accross different fusion.
namespace patterns { namespace patterns {
struct KeyCounter {
static KeyCounter& Instance() {
static KeyCounter x;
return x;
}
int IncCounter(const std::string& key) { return dic_[key]++; }
private:
std::unordered_map<std::string, size_t> dic_;
};
// Generate a unique PDNode's name with name_scope and id.
// The format is {name_scope}/{repr}/{id}/{name}
static std::string PDNodeName(const std::string& name_scope,
const std::string& repr, size_t id,
const std::string& name) {
return string::Sprintf("%s/%s/%d/%s", name_scope, repr, id, name);
}
// Generate a unique PDNode's name.
// The format is {name_scope}/{repr}/{id}
static std::string PDNodeName(const std::string& name_scope,
const std::string& repr) {
return string::Sprintf("%s/%s/%d", name_scope, repr,
KeyCounter::Instance().IncCounter(repr));
}
// Generate a unique key. It can be used for a universally unique temporary
// name.
// The format is {repr}/{id}
static std::string UniqueKey(const std::string& repr) {
return string::Sprintf("%s/%d", repr,
KeyCounter::Instance().IncCounter(repr));
}
// Declare a PDNode in a pattern, will create two methods:
// std::string xxx_repr(); return this PDNode's string id.
// PDNode* xxx_n(); return the corresponding PDNode.
#define PATTERN_DECL_NODE(name__) \
std::string name__##_repr() const { \
return PDNodeName(name_scope_, repr_, id_, #name__); \
} \
PDNode* name__##_n() const { return pattern->RetrieveNode(name__##_repr()); }
// Get an ir::Node* from the matched subgraph.
// var: variable.
// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition.
// pat: the pattern object.
#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \
PADDLE_ENFORCE(subgraph.count(pat.arg##_n()), \
"Node not found for PDNode %s", pat.arg##_repr()); \
Node* var = subgraph.at(pat.arg##_n()); \
PADDLE_ENFORCE(var, "node %s not exists in the sub-graph", #arg)
// The base class of all the patterns.
struct PatternBase {
PatternBase(PDPattern* pattern, const std::string& name_scope,
const std::string& repr)
: pattern(pattern),
name_scope_(name_scope),
repr_(repr),
id_(KeyCounter::Instance().IncCounter(repr)) {}
PDPattern* pattern;
protected:
std::string name_scope_;
std::string repr_;
size_t id_;
};
// FC with bias // FC with bias
// op: mul + elementwise_add // op: mul + elementwise_add
// named nodes: // named nodes:
// mul, elementwise_add // mul, elementwise_add
// w, mul_out, bias, fc_out // w, mul_out, bias, fc_out
PDNode* FC(PDPattern* pattern, const std::string& name_scope, PDNode* x, struct FC : public PatternBase {
bool with_bias); FC(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc") {}
PDNode* operator()(PDNode* x, bool with_bias);
// declare operator node's name
PATTERN_DECL_NODE(fc);
PATTERN_DECL_NODE(mul);
PATTERN_DECL_NODE(elementwise_add);
// declare variable node's name
PATTERN_DECL_NODE(w);
PATTERN_DECL_NODE(mul_out); // (x,w) -> mul_out
PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(Out);
};
struct LSTM : public PatternBase {
LSTM(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "lstm") {}
PDNode* LSTM(PDPattern* pattern, const std::string& name_scope, PDNode* x); PDNode* operator()(PDNode* x);
PDNode* GRU(PDPattern* pattern, const std::string& name_scope, PDNode* x); // Operators
PATTERN_DECL_NODE(lstm);
// Inputs
PATTERN_DECL_NODE(Input);
PATTERN_DECL_NODE(H0);
PATTERN_DECL_NODE(C0);
PATTERN_DECL_NODE(Weight);
PATTERN_DECL_NODE(Bias);
// Outputs
PATTERN_DECL_NODE(Hidden);
PATTERN_DECL_NODE(Cell);
PATTERN_DECL_NODE(BatchGate);
PATTERN_DECL_NODE(BatchCellPreAct);
};
struct GRU : public PatternBase {
GRU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "lstm") {}
PDNode* operator()(PDNode* x);
// Operators
PATTERN_DECL_NODE(gru);
// Inputs
PATTERN_DECL_NODE(Bias);
PATTERN_DECL_NODE(Weight);
// Outputs
PATTERN_DECL_NODE(BatchGate);
PATTERN_DECL_NODE(BatchResetHiddenPrev);
PATTERN_DECL_NODE(BatchHidden);
PATTERN_DECL_NODE(Hidden);
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other.
#define IR_NODE_LINK_TO(a, b) \ #define IR_NODE_LINK_TO(a, b) \
a->outputs.push_back(b); \ a->outputs.push_back(b); \
b->inputs.push_back(a); b->inputs.push_back(a);
......
...@@ -192,6 +192,8 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( ...@@ -192,6 +192,8 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
int fuse_count{0};
detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph, detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(4) << "get one concat pattern"; VLOG(4) << "get one concat pattern";
...@@ -239,8 +241,12 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( ...@@ -239,8 +241,12 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
marked_nodes.erase(sequence_expand1_in); marked_nodes.erase(sequence_expand1_in);
marked_nodes.erase(fc_out); marked_nodes.erase(fc_out);
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
++fuse_count;
}); });
AddStatis(fuse_count);
return graph; return graph;
} }
......
...@@ -267,6 +267,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir, ...@@ -267,6 +267,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
PADDLE_ENFORCE(config.ir_mode == PADDLE_ENFORCE(config.ir_mode ==
AnalysisConfig::IrPassMode::kExclude); // default AnalysisConfig::IrPassMode::kExclude); // default
config.ir_passes.clear(); // Do not exclude any pass. config.ir_passes.clear(); // Do not exclude any pass.
int batch_size = FLAGS_batch_size; int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat; int num_times = FLAGS_repeat;
...@@ -346,6 +347,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir, ...@@ -346,6 +347,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
ASSERT_TRUE(fuse_statis.count("fc_fuse")); ASSERT_TRUE(fuse_statis.count("fc_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 1); EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM
EXPECT_EQ(fuse_statis.at("seq_concat_fc_fuse"), 1);
EXPECT_EQ(num_ops, EXPECT_EQ(num_ops,
13); // After graph optimization, only 13 operators exists. 13); // After graph optimization, only 13 operators exists.
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册