提交 5558784c 编写于 作者: Y Yancey1989

Merge branch 'develop' of github.com:PaddlePaddle/Paddle into reset_vars_on_pserver

...@@ -29,39 +29,27 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -29,39 +29,27 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
std::unordered_set<Node*> nodes2delete; std::unordered_set<Node*> nodes2delete;
GraphPatternDetector gpd; GraphPatternDetector gpd;
// BuildFCPattern(gpd.mutable_pattern());
auto* x = gpd.mutable_pattern() auto* x = gpd.mutable_pattern()
->NewNode("fc_fuse/x") ->NewNode("fc_fuse/x")
->AsInput() ->AsInput()
->assert_is_op_input("mul", "X"); ->assert_is_op_input("mul", "X");
patterns::FC(gpd.mutable_pattern(), "fc_fuse", x, true /*with bias*/); patterns::FC fc_pattern(gpd.mutable_pattern(), "fc_fuse");
fc_pattern(x, true /*with bias*/);
#define GET_NODE(id) \
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode("fc_fuse/" #id)), \
"pattern has no Node called %s", #id); \
auto* id = subgraph.at(gpd.pattern().RetrieveNode("fc_fuse/" #id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", "fc_fuse/" #id);
int found_fc_count = 0; int found_fc_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
VLOG(4) << "handle FC fuse"; VLOG(4) << "handle FC fuse";
// Currently, there is no FC op available, so I will just simulate the GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
// scenerio. GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
// FC's fusion is simple, just op fuse, no need to process the GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
// parameters. GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
GET_NODE(x); // x GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
GET_NODE(w); // Y GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
GET_NODE(fc_bias); // bias
GET_NODE(fc_out); // Out
GET_NODE(mul); // MUL op
GET_NODE(elementwise_add); // ELEMENT_ADD op
GET_NODE(mul_out); // tmp
#undef GET_NODE
// Create an FC Node. // Create an FC Node.
OpDesc desc; OpDesc desc;
std::string fc_x_in = x->Name(); std::string fc_x_in = subgraph.at(x)->Name();
std::string fc_Y_in = w->Name(); std::string fc_Y_in = w->Name();
std::string fc_bias_in = fc_bias->Name(); std::string fc_bias_in = fc_bias->Name();
std::string fc_out_out = fc_out->Name(); std::string fc_out_out = fc_out->Name();
...@@ -73,7 +61,8 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl( ...@@ -73,7 +61,8 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out}); GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});
IR_NODE_LINK_TO(x, fc_node); PADDLE_ENFORCE(subgraph.count(x));
IR_NODE_LINK_TO(subgraph.at(x), fc_node);
IR_NODE_LINK_TO(w, fc_node); IR_NODE_LINK_TO(w, fc_node);
IR_NODE_LINK_TO(fc_bias, fc_node); IR_NODE_LINK_TO(fc_bias, fc_node);
IR_NODE_LINK_TO(fc_node, fc_out); IR_NODE_LINK_TO(fc_node, fc_out);
......
...@@ -20,52 +20,43 @@ namespace paddle { ...@@ -20,52 +20,43 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static void BuildPattern(PDPattern* pattern, const std::string& name_scope,
bool with_fc_bias) {
PDNode* x = pattern->NewNode(name_scope, "x")
->assert_is_op_input("mul")
->assert_var_not_persistable();
auto* fc_out = patterns::FC(pattern, name_scope, x, with_fc_bias);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
patterns::GRU(pattern, name_scope, fc_out);
VLOG(3) << "fc_gru pattern \n" << pattern->DotString();
}
static int BuildFusion(Graph* graph, const std::string& name_scope, static int BuildFusion(Graph* graph, const std::string& name_scope,
Scope* scope, bool with_fc_bias) { Scope* scope, bool with_fc_bias) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
BuildPattern(pattern, name_scope, with_fc_bias); // Create pattern.
patterns::FC fc_pattern(pattern, name_scope);
patterns::GRU gru_pattern(pattern, name_scope);
PDNode* x =
pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable();
auto* fc_out = fc_pattern(x, with_fc_bias);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
gru_pattern(fc_out);
// Create New OpDesc // Create New OpDesc
auto gru_creater = [&](int gru, int x, int weight_x, int weight_h, int bias, auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
int hidden, int fc_bias) { Node* bias, Node* hidden, Node* fc_bias) {
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE(x);
GET_NODE(weight_x);
GET_NODE(weight_h);
GET_NODE(bias);
GET_NODE(hidden);
GET_NODE(gru);
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType("fusion_gru"); op_desc.SetType("fusion_gru");
#define NEW_NAME(x) name_scope + "/at." #x ".new" #define NEW_NAME(x) name_scope + "/at." #x ".new"
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()}); #define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN(X, x); SET_IN(X, x);
SET_IN(WeightX, weight_x); SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h); SET_IN(WeightH, weight_h);
if (with_fc_bias) { if (with_fc_bias) {
op_desc.SetInput("Bias", {NEW_NAME(bias) + bias_n->Name()}); op_desc.SetInput("Bias", {NEW_NAME(bias) + bias->Name()});
} else { } else {
SET_IN(Bias, bias); SET_IN(Bias, bias);
} }
#undef SET_IN #undef SET_IN
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
op_desc.SetOutput("Hidden", {hidden_n->Name()}); op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetAttr("is_reverse", gru_n->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse"));
// TODO(TJ): This should be a option for infer // TODO(TJ): This should be a option for infer
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
...@@ -82,14 +73,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -82,14 +73,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
if (with_fc_bias) { if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias // Fusion GRU bias = fcbias + grubias
auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias_n->Name()); auto* fusion_bias_var = scope->Var(NEW_NAME(bias) + bias->Name());
auto* out_bias_tensor = auto* out_bias_tensor =
fusion_bias_var->GetMutable<framework::LoDTensor>(); fusion_bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(fusion_bias_var); PADDLE_ENFORCE(fusion_bias_var);
GET_NODE(fc_bias); auto* gru_bias_var = scope->FindVar(bias->Name());
PADDLE_ENFORCE(fc_bias_n); auto* fc_bias_var = scope->FindVar(fc_bias->Name());
auto* gru_bias_var = scope->FindVar(bias_n->Name());
auto* fc_bias_var = scope->FindVar(fc_bias_n->Name());
PADDLE_ENFORCE(gru_bias_var); PADDLE_ENFORCE(gru_bias_var);
PADDLE_ENFORCE(fc_bias_var); PADDLE_ENFORCE(fc_bias_var);
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>(); const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>();
...@@ -113,11 +102,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -113,11 +102,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef NEW_NAME #undef NEW_NAME
#undef NEW_IMTERMEDIATE_OUT #undef NEW_IMTERMEDIATE_OUT
IR_NODE_LINK_TO(x_n, op); IR_NODE_LINK_TO(x, op);
IR_NODE_LINK_TO(weight_x_n, op); IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h_n, op); IR_NODE_LINK_TO(weight_h, op);
IR_NODE_LINK_TO(bias_n, op); // actually should link to new bias if have IR_NODE_LINK_TO(bias, op); // actually should link to new bias if have
IR_NODE_LINK_TO(op, hidden_n); IR_NODE_LINK_TO(op, hidden);
// h0? // h0?
return op; return op;
}; };
...@@ -125,42 +114,35 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -125,42 +114,35 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
int fusion_count{0}; int fusion_count{0};
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
#define GET_NODE(name__) \ auto* x_n = subgraph.at(x);
std::string name__##key = name_scope + "/" + #name__; \ GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
auto* name__##n = pattern->RetrieveNode(name__##key); \ GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
PADDLE_ENFORCE(name__##n); \ GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
PADDLE_ENFORCE(subgraph.count(name__##n)); \ GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, gru_pattern);
Node* name__##_n = subgraph.at(name__##n); \ GET_IR_NODE_FROM_SUBGRAPH(gru, gru, gru_pattern);
int name__ __attribute__((unused)) = name__##_n->id(); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, gru_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, gru_pattern);
GET_NODE(x);
GET_NODE(w); // fc weight
GET_NODE(mul);
GET_NODE(fc_out);
GET_NODE(Weight);
GET_NODE(gru);
GET_NODE(Bias);
GET_NODE(Hidden);
// nodes need be removed // nodes need be removed
GET_NODE(BatchGate); GET_IR_NODE_FROM_SUBGRAPH(BatchGate, BatchGate, gru_pattern);
GET_NODE(BatchResetHiddenPrev); GET_IR_NODE_FROM_SUBGRAPH(BatchResetHiddenPrev, BatchGate, gru_pattern);
GET_NODE(BatchHidden); GET_IR_NODE_FROM_SUBGRAPH(BatchHidden, BatchGate, gru_pattern);
if (with_fc_bias) { if (with_fc_bias) {
GET_NODE(mul_out); GET_IR_NODE_FROM_SUBGRAPH(mul_out, mul_out, fc_pattern);
GET_NODE(fc_bias); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
GET_NODE(elementwise_add); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
gru_creater(gru, x, w, Weight, Bias, Hidden, fc_bias);
gru_creater(gru, x_n, w, Weight, Bias, Hidden, fc_bias);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul_n, gru_n, elementwise_add_n, fc_bias_n, fc_out_n, mul_out_n, {mul, gru, elementwise_add, fc_bias, fc_out, mul_out, BatchGate,
BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n}); BatchResetHiddenPrev, BatchHidden});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} else { } else {
gru_creater(gru, x, w, Weight, Bias, Hidden, -1); gru_creater(gru, x_n, w, Weight, Bias, Hidden, nullptr);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul_n, gru_n, BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n}); {mul, gru, BatchGate, BatchResetHiddenPrev, BatchHidden});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} }
#undef GET_NODE #undef GET_NODE
......
...@@ -20,45 +20,29 @@ namespace paddle { ...@@ -20,45 +20,29 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static std::string GenNodeName(const std::string& prefix, int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
const std::string& name) {
return prefix + "/" + name;
}
static void BuildPattern(PDPattern* pattern, const std::string& name_scope,
bool with_fc_bias) { bool with_fc_bias) {
PDNode* x = pattern->NewNode(name_scope, "x")
->assert_is_op_input("mul")
->assert_var_not_persistable();
auto* fc_out = patterns::FC(pattern, name_scope, x, with_fc_bias);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
patterns::LSTM(pattern, name_scope, fc_out);
// LOG(INFO) << "\n" << pattern->DotString();
}
static int BuildFusion(Graph* graph, const std::string& name_scope,
Scope* scope, bool with_fc_bias) {
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
BuildPattern(pattern, name_scope, with_fc_bias); // Build pattern
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x"))
->assert_is_op_input("mul")
->assert_var_not_persistable();
patterns::FC fc_pattern(pattern, name_scope);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
auto* fc_out = fc_pattern(x, with_fc_bias)->AsIntermediate();
patterns::LSTM lstm_pattern(pattern, name_scope);
lstm_pattern(fc_out);
// Create New OpDesc // Create New OpDesc
auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h, auto lstm_creator = [&](Node* lstm, Node* input, Node* weight_x,
int bias, int hidden, int cell, int xx, int fc_bias) { Node* weight_h, Node* bias, Node* hidden, Node* cell,
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x); Node* xx, Node* fc_bias) {
GET_NODE(input);
GET_NODE(weight_x);
GET_NODE(weight_h);
GET_NODE(bias);
GET_NODE(hidden);
GET_NODE(cell);
GET_NODE(xx);
GET_NODE(lstm);
OpDesc op_desc; OpDesc op_desc;
op_desc.SetType("fusion_lstm"); op_desc.SetType("fusion_lstm");
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__##_n->Name()}); #define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
SET_IN(X, input); SET_IN(X, input);
SET_IN(WeightX, weight_x); SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h); SET_IN(WeightH, weight_h);
...@@ -71,13 +55,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -71,13 +55,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto* bias_var = scope->Var(new_bias_var); auto* bias_var = scope->Var(new_bias_var);
PADDLE_ENFORCE(bias_var); PADDLE_ENFORCE(bias_var);
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>(); auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
auto* lstm_bias_var = scope->FindVar(bias_n->Name()); auto* lstm_bias_var = scope->FindVar(bias->Name());
PADDLE_ENFORCE(lstm_bias_var); PADDLE_ENFORCE(lstm_bias_var);
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>(); const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
bias_tensor->Resize(lstm_bias_tensor.dims()); bias_tensor->Resize(lstm_bias_tensor.dims());
GET_NODE(fc_bias); auto* fc_bias_var = scope->FindVar(fc_bias->Name());
auto* fc_bias_var = scope->FindVar(fc_bias_n->Name());
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>(); const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
auto* data = bias_tensor->mutable_data<float>(platform::CPUPlace()); auto* data = bias_tensor->mutable_data<float>(platform::CPUPlace());
...@@ -88,31 +71,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -88,31 +71,36 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
} }
op_desc.SetInput("Bias", {new_bias_var}); op_desc.SetInput("Bias", {new_bias_var});
} }
#undef GET_NODE
// Create temp variables. // Create temp variables.
scope->Var(name_scope + "/BatchedInput.new") const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
->GetMutable<framework::LoDTensor>(); const std::string BatchedCellPreAct =
scope->Var(name_scope + "/BatchCellPreAct.new") patterns::UniqueKey("BatchedCellPreAct");
->GetMutable<framework::LoDTensor>(); const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
scope->Var(name_scope + "/BatchedGate.new")
->GetMutable<framework::LoDTensor>(); scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {}); op_desc.SetInput("C0", {});
op_desc.SetOutput("Hidden", {hidden_n->Name()}); op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetOutput("Cell", {cell_n->Name()}); op_desc.SetOutput("Cell", {cell->Name()});
op_desc.SetOutput("XX", {xx_n->Name()}); op_desc.SetOutput("XX", {xx->Name()});
op_desc.SetOutput("BatchedGate", {name_scope + "/BatchedGate.new"}); op_desc.SetOutput("BatchedGate", {BatchedGate});
op_desc.SetOutput("BatchCellPreAct", {name_scope + "/BatchCellPreAct.new"}); op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
op_desc.SetOutput("BatchedInput", {name_scope + "/BatchedInput.new"}); op_desc.SetOutput("BatchedInput", {BatchedInput});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes")); op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
// TODO(TJ): get from attr // TODO(TJ): get from attr
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
#define TMP_NAME(x) "at.new.tmp." #x PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)}) auto* scope = graph->Get<Scope*>(kParamScopeAttr);
#define OP_SET_OUT(x) \
const std::string x = patterns::UniqueKey(#x); \
op_desc.SetOutput(#x, {x}); \
scope->Var(x)->GetMutable<LoDTensor>()
OP_SET_OUT(BatchedCell); OP_SET_OUT(BatchedCell);
OP_SET_OUT(BatchedHidden); OP_SET_OUT(BatchedHidden);
OP_SET_OUT(ReorderedH0); OP_SET_OUT(ReorderedH0);
...@@ -120,22 +108,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -120,22 +108,11 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef OP_SET_OUT #undef OP_SET_OUT
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); IR_NODE_LINK_TO(input, op);
auto* scope = graph->Get<Scope*>(kParamScopeAttr); IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h, op);
#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable<LoDTensor>() IR_NODE_LINK_TO(bias, op);
TMP_NEW(BatchedCell); IR_NODE_LINK_TO(op, hidden);
TMP_NEW(BatchedHidden);
TMP_NEW(ReorderedH0);
TMP_NEW(ReorderedC0);
#undef TMP_NEW
#undef TMP_NAME
IR_NODE_LINK_TO(input_n, op);
IR_NODE_LINK_TO(weight_x_n, op);
IR_NODE_LINK_TO(weight_h_n, op);
IR_NODE_LINK_TO(bias_n, op);
IR_NODE_LINK_TO(op, hidden_n);
return op; return op;
}; };
...@@ -143,39 +120,32 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -143,39 +120,32 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
#define GET_NODE(name__) \
std::string name__##key = name_scope + "/" + #name__; \
auto* name__##n = pattern->RetrieveNode(name__##key); \
PADDLE_ENFORCE(name__##n); \
PADDLE_ENFORCE(subgraph.count(name__##n)); \
Node* name__##_n = subgraph.at(name__##n); \
int name__ __attribute__((unused)) = name__##_n->id();
GET_NODE(x);
GET_NODE(w);
GET_NODE(mul);
GET_NODE(fc_out);
GET_NODE(Weight);
GET_NODE(lstm);
GET_NODE(Bias);
GET_NODE(Hidden);
GET_NODE(Cell);
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
if (with_fc_bias) { if (with_fc_bias) {
GET_NODE(fc_bias); GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
GET_NODE(elementwise_add); GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, fc_bias); GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
fc_bias);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul_n, lstm_n, elementwise_add_n}); {mul, lstm, elementwise_add});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} else { } else {
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, -1); GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
lstm_creator(lstm, subgraph.at(x), w, Weight, Bias, Hidden, Cell, fc_out,
nullptr);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes({mul_n, lstm_n}); std::unordered_set<const Node*> marked_nodes({mul, lstm});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} }
#undef GET_NODE
++fusion_count; ++fusion_count;
}; };
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -106,8 +107,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { ...@@ -106,8 +107,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
for (auto& pdnode : pattern_.nodes()) { for (auto& pdnode : pattern_.nodes()) {
if (!pdnodes2nodes_.count(pdnode.get())) { if (!pdnodes2nodes_.count(pdnode.get())) {
VLOG(4) << pdnode->name() << " can't find matched Node, early stop"; VLOG(4) << pdnode->name() << " can't find matched Node, early stop";
// return false;
return false;
} }
} }
for (auto& item : pdnodes2nodes_) { for (auto& item : pdnodes2nodes_) {
...@@ -517,87 +517,89 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) { ...@@ -517,87 +517,89 @@ bool VarLinksFromOp(Node* node, const std::string& op_type) {
return false; return false;
} }
PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope, PDNode* patterns::FC::operator()(paddle::framework::ir::PDNode* x,
PDNode* x, bool with_bias) { bool with_bias) {
// mul op // Create shared nodes.
auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul"); x->assert_is_op_input("mul", "X");
auto* mul_weight_var = pattern->NewNode(name_scope, "w") auto* mul = pattern->NewNode(mul_repr())->assert_is_op("mul");
auto* mul_w_var = pattern->NewNode(w_repr())
->AsInput() ->AsInput()
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_is_op_input("mul", "Y"); ->assert_is_op_input("mul", "Y");
PDNode* fc_out{nullptr}; auto* mul_out_var =
if (with_bias) { pattern->NewNode(mul_out_repr())->assert_is_op_output("mul");
PDNode* elementwise_add_op{nullptr};
PDNode *mul_out_var{nullptr}, *bias{nullptr}; if (!with_bias) { // not with bias
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add") // Add links.
mul->LinksFrom({x, mul_w_var}).LinksTo({mul_out_var});
return mul_out_var;
} else { // with bias
mul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add");
// Create operators.
auto* elementwise_add = pattern->NewNode(elementwise_add_repr())
->assert_is_op("elementwise_add"); ->assert_is_op("elementwise_add");
// intermediate variable, will be removed in the IR after fuse. // Create variables.
mul_out_var = pattern->NewNode(name_scope, "mul_out") auto* bias = pattern->NewNode(bias_repr())
->AsIntermediate() ->assert_is_op_input("elementwise_add")
->assert_is_only_output_of_op("mul") ->AsInput();
->assert_is_op_input("elementwise_add");
// bias auto* fc_out = pattern->NewNode(Out_repr())
bias = pattern->NewNode(name_scope, "fc_bias")
->AsInput()
->assert_is_op_input("elementwise_add");
// output
fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput() ->AsOutput()
->assert_is_op_output("elementwise_add"); ->assert_is_op_output("elementwise_add");
mul_op->LinksFrom({x, mul_weight_var}).LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out}); mul->LinksFrom({mul_w_var, x}).LinksTo({mul_out_var});
} else { elementwise_add->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput()
->assert_is_op_output("mul");
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out});
}
return fc_out; return fc_out;
}
} }
#define NEW_NODE(op__, arg__, io__) \ PDNode* patterns::LSTM::operator()(PDNode* x) {
auto* arg__ = pattern->NewNode(name_scope, #arg__) \
->assert_is_op_##io__(#op__, #arg__);
PDNode* patterns::LSTM(PDPattern* pattern, const std::string& name_scope,
PDNode* x) {
x->assert_is_op_input("lstm", "Input"); x->assert_is_op_input("lstm", "Input");
auto* lstm_op = pattern->NewNode(name_scope, "lstm")->assert_is_op("lstm"); auto* lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm");
#define NEW_NODE(arg__, io__) \
auto* arg__ = \
pattern->NewNode(arg__##_repr())->assert_is_op_##io__("lstm", #arg__);
// Currently, the H0 and C0 are optional // Currently, the H0 and C0 are optional
// TODO(Superjomn) upgrade the fuse framework to support optional. // TODO(Superjomn) upgrade the fuse framework to support optional.
// NEW_NODE(H0, input); // NEW_NODE(H0, input);
// NEW_NODE(C0, input); // NEW_NODE(C0, input);
NEW_NODE(lstm, Weight, input); NEW_NODE(Weight, input);
NEW_NODE(lstm, Bias, input); NEW_NODE(Bias, input);
NEW_NODE(lstm, Hidden, output); NEW_NODE(Hidden, output);
NEW_NODE(lstm, Cell, output); NEW_NODE(Cell, output);
NEW_NODE(lstm, BatchGate, output); NEW_NODE(BatchGate, output);
NEW_NODE(lstm, BatchCellPreAct, output); NEW_NODE(BatchCellPreAct, output);
#undef NEW_NODE
lstm_op->LinksFrom({x, Weight, Bias}); lstm_op->LinksFrom({x, Weight, Bias});
lstm_op->LinksTo({Hidden, Cell, BatchGate, BatchCellPreAct}); lstm_op->LinksTo({Hidden, Cell, BatchGate, BatchCellPreAct});
return Hidden; return Hidden;
} }
PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope, PDNode* patterns::GRU::operator()(PDNode* x) {
PDNode* x) {
x->assert_is_op_input("gru", "Input"); x->assert_is_op_input("gru", "Input");
auto* gru_op = pattern->NewNode(name_scope, "gru")->assert_is_op("gru"); auto* gru_op = pattern->NewNode(gru_repr())->assert_is_op("gru");
#define NEW_NODE(arg__, io__) \
auto* arg__ = \
pattern->NewNode(arg__##_repr())->assert_is_op_##io__("gru", #arg__);
NEW_NODE(gru, Weight, input); NEW_NODE(Weight, input);
// TODO(Superjomn): upgrade the fuse framework to support optional. // TODO(Superjomn): upgrade the fuse framework to support optional.
// H0 and bias are optional // H0 and bias are optional
NEW_NODE(gru, Bias, input); // also optional NEW_NODE(Bias, input); // also optional
// NEW_NODE(H0, input); // NEW_NODE(H0, input);
NEW_NODE(gru, Hidden, output); NEW_NODE(Hidden, output);
// below are intermediate // below are intermediate
NEW_NODE(gru, BatchGate, output); NEW_NODE(BatchGate, output);
NEW_NODE(gru, BatchResetHiddenPrev, output); NEW_NODE(BatchResetHiddenPrev, output);
NEW_NODE(gru, BatchHidden, output); NEW_NODE(BatchHidden, output);
#undef NEW_NODE
BatchGate->AsIntermediate(); BatchGate->AsIntermediate();
BatchResetHiddenPrev->AsIntermediate(); BatchResetHiddenPrev->AsIntermediate();
...@@ -607,7 +609,6 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope, ...@@ -607,7 +609,6 @@ PDNode* patterns::GRU(PDPattern* pattern, const std::string& name_scope,
gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden}); gru_op->LinksTo({Hidden, BatchGate, BatchResetHiddenPrev, BatchHidden});
return Hidden; return Hidden;
} }
#undef NEW_NODE
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -286,22 +286,148 @@ void GraphSafeRemoveNodes(Graph* graph, ...@@ -286,22 +286,148 @@ void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes); const std::unordered_set<const Node*>& nodes);
// Some pre-defined patterns those can be reused in multiple passes. // Some pre-defined patterns those can be reused in multiple passes.
// The related Fluid Layer or Op should be one pattern here for better reusage
// accross different fusion.
namespace patterns { namespace patterns {
struct KeyCounter {
static KeyCounter& Instance() {
static KeyCounter x;
return x;
}
int IncCounter(const std::string& key) { return dic_[key]++; }
private:
std::unordered_map<std::string, size_t> dic_;
};
// Generate a unique PDNode's name with name_scope and id.
// The format is {name_scope}/{repr}/{id}/{name}
static std::string PDNodeName(const std::string& name_scope,
const std::string& repr, size_t id,
const std::string& name) {
return string::Sprintf("%s/%s/%d/%s", name_scope, repr, id, name);
}
// Generate a unique PDNode's name.
// The format is {name_scope}/{repr}/{id}
static std::string PDNodeName(const std::string& name_scope,
const std::string& repr) {
return string::Sprintf("%s/%s/%d", name_scope, repr,
KeyCounter::Instance().IncCounter(repr));
}
// Generate a unique key. It can be used for a universally unique temporary
// name.
// The format is {repr}/{id}
static std::string UniqueKey(const std::string& repr) {
return string::Sprintf("%s/%d", repr,
KeyCounter::Instance().IncCounter(repr));
}
// Declare a PDNode in a pattern, will create two methods:
// std::string xxx_repr(); return this PDNode's string id.
// PDNode* xxx_n(); return the corresponding PDNode.
#define PATTERN_DECL_NODE(name__) \
std::string name__##_repr() const { \
return PDNodeName(name_scope_, repr_, id_, #name__); \
} \
PDNode* name__##_n() const { return pattern->RetrieveNode(name__##_repr()); }
// Get an ir::Node* from the matched subgraph.
// var: variable.
// arg: the argument declared by PATTERN_DECL_NODE in a pattern definition.
// pat: the pattern object.
#define GET_IR_NODE_FROM_SUBGRAPH(var, arg, pat) \
PADDLE_ENFORCE(subgraph.count(pat.arg##_n()), \
"Node not found for PDNode %s", pat.arg##_repr()); \
Node* var = subgraph.at(pat.arg##_n()); \
PADDLE_ENFORCE(var, "node %s not exists in the sub-graph", #arg)
// The base class of all the patterns.
struct PatternBase {
PatternBase(PDPattern* pattern, const std::string& name_scope,
const std::string& repr)
: pattern(pattern),
name_scope_(name_scope),
repr_(repr),
id_(KeyCounter::Instance().IncCounter(repr)) {}
PDPattern* pattern;
protected:
std::string name_scope_;
std::string repr_;
size_t id_;
};
// FC with bias // FC with bias
// op: mul + elementwise_add // op: mul + elementwise_add
// named nodes: // named nodes:
// mul, elementwise_add // mul, elementwise_add
// w, mul_out, bias, fc_out // w, mul_out, bias, fc_out
PDNode* FC(PDPattern* pattern, const std::string& name_scope, PDNode* x, struct FC : public PatternBase {
bool with_bias); FC(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fc") {}
PDNode* operator()(PDNode* x, bool with_bias);
// declare operator node's name
PATTERN_DECL_NODE(fc);
PATTERN_DECL_NODE(mul);
PATTERN_DECL_NODE(elementwise_add);
// declare variable node's name
PATTERN_DECL_NODE(w);
PATTERN_DECL_NODE(mul_out); // (x,w) -> mul_out
PATTERN_DECL_NODE(bias);
PATTERN_DECL_NODE(Out);
};
struct LSTM : public PatternBase {
LSTM(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "lstm") {}
PDNode* LSTM(PDPattern* pattern, const std::string& name_scope, PDNode* x); PDNode* operator()(PDNode* x);
PDNode* GRU(PDPattern* pattern, const std::string& name_scope, PDNode* x); // Operators
PATTERN_DECL_NODE(lstm);
// Inputs
PATTERN_DECL_NODE(Input);
PATTERN_DECL_NODE(H0);
PATTERN_DECL_NODE(C0);
PATTERN_DECL_NODE(Weight);
PATTERN_DECL_NODE(Bias);
// Outputs
PATTERN_DECL_NODE(Hidden);
PATTERN_DECL_NODE(Cell);
PATTERN_DECL_NODE(BatchGate);
PATTERN_DECL_NODE(BatchCellPreAct);
};
struct GRU : public PatternBase {
GRU(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "lstm") {}
PDNode* operator()(PDNode* x);
// Operators
PATTERN_DECL_NODE(gru);
// Inputs
PATTERN_DECL_NODE(Bias);
PATTERN_DECL_NODE(Weight);
// Outputs
PATTERN_DECL_NODE(BatchGate);
PATTERN_DECL_NODE(BatchResetHiddenPrev);
PATTERN_DECL_NODE(BatchHidden);
PATTERN_DECL_NODE(Hidden);
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other.
#define IR_NODE_LINK_TO(a, b) \ #define IR_NODE_LINK_TO(a, b) \
a->outputs.push_back(b); \ a->outputs.push_back(b); \
b->inputs.push_back(a); b->inputs.push_back(a);
......
...@@ -192,6 +192,8 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( ...@@ -192,6 +192,8 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
auto* id = subgraph.at(pattern.RetrieveNode(#id)); \ auto* id = subgraph.at(pattern.RetrieveNode(#id)); \
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id); PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
int fuse_count{0};
detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph, detector(graph.get(), [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) { Graph* graph) {
VLOG(4) << "get one concat pattern"; VLOG(4) << "get one concat pattern";
...@@ -239,8 +241,12 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl( ...@@ -239,8 +241,12 @@ std::unique_ptr<ir::Graph> SeqConcatFcFusePass::ApplyImpl(
marked_nodes.erase(sequence_expand1_in); marked_nodes.erase(sequence_expand1_in);
marked_nodes.erase(fc_out); marked_nodes.erase(fc_out);
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
++fuse_count;
}); });
AddStatis(fuse_count);
return graph; return graph;
} }
......
...@@ -48,18 +48,18 @@ function (inference_download_and_uncompress install_dir url gz_filename) ...@@ -48,18 +48,18 @@ function (inference_download_and_uncompress install_dir url gz_filename)
message(STATUS "finish downloading ${gz_filename}") message(STATUS "finish downloading ${gz_filename}")
endfunction(inference_download_and_uncompress) endfunction(inference_download_and_uncompress)
set(DITU_RNN_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fmodel.tar.gz") set(RNN1_MODEL_URL "http://paddle-inference-dist.bj.bcebos.com/rnn1%2Fmodel.tar.gz")
set(DITU_RNN_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/ditu_rnn_fluid%2Fdata.txt.tar.gz") set(RNN1_DATA_URL "http://paddle-inference-dist.bj.bcebos.com/rnn1%2Fdata.txt.tar.gz")
set(DITU_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/ditu_rnn" CACHE PATH "Ditu RNN model and data root." FORCE) set(RNN1_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo/rnn1" CACHE PATH "RNN1 model and data root." FORCE)
if (NOT EXISTS ${DITU_INSTALL_DIR} AND WITH_TESTING) if (NOT EXISTS ${RNN1_INSTALL_DIR} AND WITH_TESTING)
inference_download_and_uncompress(${DITU_INSTALL_DIR} ${DITU_RNN_MODEL_URL} "ditu_rnn_fluid%2Fmodel.tar.gz") inference_download_and_uncompress(${RNN1_INSTALL_DIR} ${RNN1_MODEL_URL} "rnn1%2Fmodel.tar.gz")
inference_download_and_uncompress(${DITU_INSTALL_DIR} ${DITU_RNN_DATA_URL} "ditu_rnn_fluid%2Fdata.txt.tar.gz") inference_download_and_uncompress(${RNN1_INSTALL_DIR} ${RNN1_DATA_URL} "rnn1%2Fdata.txt.tar.gz")
endif() endif()
inference_analysis_test(test_analyzer SRCS analyzer_tester.cc inference_analysis_test(test_analyzer SRCS analyzer_tester.cc
EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor EXTRA_DEPS paddle_inference_api paddle_fluid_api ir_pass_manager analysis_predictor
ARGS --infer_ditu_rnn_model=${DITU_INSTALL_DIR}/model ARGS --infer_model=${RNN1_INSTALL_DIR}/model
--infer_ditu_rnn_data=${DITU_INSTALL_DIR}/data.txt) --infer_data=${RNN1_INSTALL_DIR}/data.txt)
inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc) inference_analysis_test(test_data_flow_graph SRCS data_flow_graph_tester.cc)
inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc) inference_analysis_test(test_data_flow_graph_to_fluid_pass SRCS data_flow_graph_to_fluid_pass_tester.cc)
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
#include "paddle/fluid/inference/api/paddle_inference_pass.h" #include "paddle/fluid/inference/api/paddle_inference_pass.h"
#include "paddle/fluid/inference/utils/singleton.h" #include "paddle/fluid/inference/utils/singleton.h"
DEFINE_string(infer_ditu_rnn_model, "", "model path for ditu RNN"); DEFINE_string(infer_model, "", "model path");
DEFINE_string(infer_ditu_rnn_data, "", "data path for ditu RNN"); DEFINE_string(infer_data, "", "data path");
DEFINE_int32(batch_size, 10, "batch size."); DEFINE_int32(batch_size, 10, "batch size.");
DEFINE_int32(repeat, 1, "Running the inference program repeat times."); DEFINE_int32(repeat, 1, "Running the inference program repeat times.");
DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads."); DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads.");
...@@ -223,17 +223,6 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data, ...@@ -223,17 +223,6 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
} // namespace } // namespace
const float ditu_rnn_target_data[] = {
104.711, 11.2431, 1.35422, 0, 0, 0, 0, 0,
27.7039, 1.41486, 7.09526, 0, 0, 0, 0, 0,
7.6481, 6.5324, 56.383, 2.88018, 8.92918, 132.007, 4.27429, 2.02934,
14.1727, 10.7461, 25.0616, 16.0197, 14.4163, 16.9199, 6.75517, 0,
80.0249, 4.77739, 0, 0, 0, 0, 0, 0,
47.5643, 2.67029, 8.76252, 0, 0, 0, 0, 0,
51.8822, 4.4411, 0, 0, 0, 0, 0, 0,
10.7286, 12.0595, 10.6672, 0, 0, 0, 0, 0,
93.5771, 3.84641, 0, 0, 0, 0, 0, 0,
169.426, 0, 0, 0, 0, 0, 0, 0};
void CompareResult(const std::vector<PaddleTensor> &outputs, void CompareResult(const std::vector<PaddleTensor> &outputs,
const std::vector<PaddleTensor> &base_outputs) { const std::vector<PaddleTensor> &base_outputs) {
PADDLE_ENFORCE_GT(outputs.size(), 0); PADDLE_ENFORCE_GT(outputs.size(), 0);
...@@ -255,11 +244,10 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -255,11 +244,10 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
} }
} }
// Test with a really complicate model. // Test with a really complicate model.
void TestDituRNNPrediction(bool use_analysis, bool activate_ir, void TestRNN1Prediction(bool use_analysis, bool activate_ir, int num_threads) {
int num_threads) {
AnalysisConfig config; AnalysisConfig config;
config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__"; config.prog_file = FLAGS_infer_model + "/__model__";
config.param_file = FLAGS_infer_ditu_rnn_model + "/param"; config.param_file = FLAGS_infer_model + "/param";
config.use_gpu = false; config.use_gpu = false;
config.device = 0; config.device = 0;
config.specify_input_name = true; config.specify_input_name = true;
...@@ -267,6 +255,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir, ...@@ -267,6 +255,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
PADDLE_ENFORCE(config.ir_mode == PADDLE_ENFORCE(config.ir_mode ==
AnalysisConfig::IrPassMode::kExclude); // default AnalysisConfig::IrPassMode::kExclude); // default
config.ir_passes.clear(); // Do not exclude any pass. config.ir_passes.clear(); // Do not exclude any pass.
int batch_size = FLAGS_batch_size; int batch_size = FLAGS_batch_size;
int num_times = FLAGS_repeat; int num_times = FLAGS_repeat;
...@@ -276,7 +265,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir, ...@@ -276,7 +265,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>( CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
config); config);
std::vector<PaddleTensor> input_slots; std::vector<PaddleTensor> input_slots;
DataRecord data(FLAGS_infer_ditu_rnn_data, batch_size); DataRecord data(FLAGS_infer_data, batch_size);
// Prepare inputs. // Prepare inputs.
PrepareInputs(&input_slots, &data, batch_size); PrepareInputs(&input_slots, &data, batch_size);
std::vector<PaddleTensor> outputs, base_outputs; std::vector<PaddleTensor> outputs, base_outputs;
...@@ -306,7 +295,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir, ...@@ -306,7 +295,7 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
threads.emplace_back([&, tid]() { threads.emplace_back([&, tid]() {
// Each thread should have local input_slots and outputs. // Each thread should have local input_slots and outputs.
std::vector<PaddleTensor> input_slots; std::vector<PaddleTensor> input_slots;
DataRecord data(FLAGS_infer_ditu_rnn_data, batch_size); DataRecord data(FLAGS_infer_data, batch_size);
PrepareInputs(&input_slots, &data, batch_size); PrepareInputs(&input_slots, &data, batch_size);
std::vector<PaddleTensor> outputs; std::vector<PaddleTensor> outputs;
Timer timer; Timer timer;
...@@ -346,30 +335,29 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir, ...@@ -346,30 +335,29 @@ void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
ASSERT_TRUE(fuse_statis.count("fc_fuse")); ASSERT_TRUE(fuse_statis.count("fc_fuse"));
EXPECT_EQ(fuse_statis.at("fc_fuse"), 1); EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM
EXPECT_EQ(fuse_statis.at("seq_concat_fc_fuse"), 1);
EXPECT_EQ(num_ops, EXPECT_EQ(num_ops,
13); // After graph optimization, only 13 operators exists. 13); // After graph optimization, only 13 operators exists.
} }
} }
// Inference with analysis and IR, easy for profiling independently. // Inference with analysis and IR, easy for profiling independently.
TEST(Analyzer, DituRNN) { TEST(Analyzer, rnn1) { TestRNN1Prediction(true, true, FLAGS_num_threads); }
TestDituRNNPrediction(true, true, FLAGS_num_threads);
}
// Other unit-tests of DituRNN, test different options of use_analysis, // Other unit-tests of RNN1, test different options of use_analysis,
// activate_ir and multi-threads. // activate_ir and multi-threads.
TEST(Analyzer, DituRNN_tests) { TEST(Analyzer, RNN_tests) {
int num_threads[2] = {1, 4}; int num_threads[2] = {1, 4};
for (auto i : num_threads) { for (auto i : num_threads) {
// Directly infer with the original model. // Directly infer with the original model.
TestDituRNNPrediction(false, false, i); TestRNN1Prediction(false, false, i);
// Inference with the original model with the analysis turned on, the // Inference with the original model with the analysis turned on, the
// analysis // analysis
// module will transform the program to a data flow graph. // module will transform the program to a data flow graph.
TestDituRNNPrediction(true, false, i); TestRNN1Prediction(true, false, i);
// Inference with analysis and IR. The IR module will fuse some large // Inference with analysis and IR. The IR module will fuse some large
// kernels. // kernels.
TestDituRNNPrediction(true, true, i); TestRNN1Prediction(true, true, i);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册