未验证 提交 597b7305 编写于 作者: Y Yan Chunwei 提交者: GitHub

refine/fc lstm fusion link (#13158)

上级 1e7ccf9f
...@@ -13,37 +13,37 @@ ...@@ -13,37 +13,37 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( std::string GenNodeName(const std::string& prefix, const std::string& name) {
std::unique_ptr<ir::Graph> graph) const { return prefix + "/" + name;
GraphPatternDetector gpd; }
auto* pattern = gpd.mutable_pattern();
std::unordered_set<int> fused_ops({// first lstm
13, 15, 16,
// second lstm
23, 25, 26});
pattern->NewNode([&](Node* x) { return fused_ops.count(x->id()); },
"any_node");
std::unordered_set<Node*> marked_nodes; void BuildPattern(PDPattern* pattern, const std::string& name_scope,
bool with_fc_bias) {
PDNode* x = pattern->NewNode(name_scope, "x")
->assert_is_op_input("mul")
->assert_var_not_persistable();
auto* fc_out = patterns::FC(pattern, name_scope, x, with_fc_bias);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
patterns::LSTM(pattern, name_scope, fc_out);
// LOG(INFO) << "\n" << pattern->DotString();
}
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
Graph* g) { bool with_fc_bias) {
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node")); BuildPattern(pattern, name_scope, with_fc_bias);
marked_nodes.insert(id);
};
gpd(graph.get(), handler);
// Create New OpDesc // Create New OpDesc
auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h, auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h,
int bias, int hidden, int cell, int xx) { int bias, int hidden, int cell, int xx, int fc_bias) {
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x); #define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
GET_NODE(input); GET_NODE(input);
GET_NODE(weight_x); GET_NODE(weight_x);
...@@ -61,12 +61,33 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -61,12 +61,33 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
SET_IN(WeightX, weight_x); SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h); SET_IN(WeightH, weight_h);
SET_IN(Bias, bias); SET_IN(Bias, bias);
#undef GET_NODE
#undef SET_IN #undef SET_IN
if (with_fc_bias) {
// Add FC-bias with LSTM-bias and create a new weight
PADDLE_ENFORCE(scope);
const std::string& new_bias_var = name_scope + "_bias.new";
auto* bias_var = scope->Var(new_bias_var);
PADDLE_ENFORCE(bias_var);
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
auto* lstm_bias_var = scope->FindVar(bias_n->Name());
PADDLE_ENFORCE(lstm_bias_var);
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
bias_tensor->Resize(lstm_bias_tensor.dims());
GET_NODE(fc_bias);
auto* fc_bias_var = scope->FindVar(fc_bias_n->Name());
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
auto* data = bias_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < bias_tensor->numel(); i++) {
data[i] =
fc_bias_tensor.data<float>()[i] + lstm_bias_tensor.data<float>()[i];
}
op_desc.SetInput("Bias", {new_bias_var});
}
VLOG(4) << "hidden_n: " << hidden_n->Name(); #undef GET_NODE
VLOG(4) << "cell: " << cell_n->Name();
VLOG(4) << "xx: " << xx_n->Name();
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
op_desc.SetInput("C0", {}); op_desc.SetInput("C0", {});
...@@ -76,7 +97,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -76,7 +97,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"}); op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"});
op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"}); op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("use_peepholes", false); op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes"));
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
#define LINK_TO(a, b) \ #define LINK_TO(a, b) \
...@@ -89,33 +110,71 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -89,33 +110,71 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
LINK_TO(op, hidden_n); LINK_TO(op, hidden_n);
#undef LINK_TO #undef LINK_TO
return op; return op;
}; };
lstm_creator(16, 12, 14, 18, 17, 22, 21, 19); int fusion_count{0};
lstm_creator(26, 12, 24, 28, 27, 32, 31, 29);
// remove all the nodes auto fc_no_bias_handler = [&](
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
for (auto* node : marked_nodes) { #define GET_NODE(name__) \
graph->RemoveNode(const_cast<Node*>(node)); std::string name__##key = name_scope + "/" + #name__; \
} auto* name__##n = pattern->RetrieveNode(name__##key); \
PADDLE_ENFORCE(name__##n); \
PADDLE_ENFORCE(subgraph.count(name__##n)); \
Node* name__##_n = subgraph.at(name__##n); \
int name__ __attribute__((unused)) = name__##_n->id();
for (auto* node : graph->Nodes()) { GET_NODE(x);
for (auto it = node->inputs.begin(); it != node->inputs.end();) { GET_NODE(w);
if (marked_nodes.count(*it)) { GET_NODE(mul);
it = const_cast<Node*>(node)->inputs.erase(it); GET_NODE(fc_out);
} else GET_NODE(Weight);
it++; GET_NODE(lstm);
} GET_NODE(Bias);
for (auto it = node->outputs.begin(); it != node->outputs.end();) { GET_NODE(Hidden);
if (marked_nodes.count(*it)) { GET_NODE(Cell);
it = const_cast<Node*>(node)->outputs.erase(it);
} else if (with_fc_bias) {
it++; GET_NODE(fc_bias);
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, fc_bias);
} else {
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, -1);
} }
} #undef GET_NODE
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes({mul_n, lstm_n});
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, fc_no_bias_handler);
return fusion_count;
}
std::unique_ptr<ir::Graph> MulLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
false /*with_fc_bias*/);
AddStatis(fusion_count);
return graph;
}
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
std::unique_ptr<ir::Graph> graph) const {
FusePassBase::Init(name_scope_, graph.get());
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
true /*with_fc_bias*/);
AddStatis(fusion_count);
return graph; return graph;
} }
...@@ -123,4 +182,5 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl( ...@@ -123,4 +182,5 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(mul_lstm_fuse_pass, paddle::framework::ir::MulLstmFusePass);
REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass); REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass);
...@@ -12,20 +12,34 @@ ...@@ -12,20 +12,34 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class FCLstmFusePass : public Pass { // The MulLstmFusePass and MulLstmFusePass will fuse to the same FusionLstm op.
// Just FC without bias
class FCLstmFusePass : public FusePassBase {
public: public:
virtual ~FCLstmFusePass() {} virtual ~FCLstmFusePass() {}
protected: protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const; std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"fc_lstm_fuse"};
};
class MulLstmFusePass : public FusePassBase {
public:
virtual ~MulLstmFusePass() {}
protected:
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
const std::string name_scope_{"fc_nobias_lstm_fuse"};
}; };
} // namespace ir } // namespace ir
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_traits.h" #include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -71,7 +72,11 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) { ...@@ -71,7 +72,11 @@ void PDPattern::AddEdge(PDNode* a, PDNode* b) {
void GraphPatternDetector::operator()(Graph* graph, void GraphPatternDetector::operator()(Graph* graph,
GraphPatternDetector::handle_t handler) { GraphPatternDetector::handle_t handler) {
if (!MarkPDNodesInGraph(*graph)) return; if (!MarkPDNodesInGraph(*graph)) {
LOG(INFO) << "Mark failed";
return;
}
auto subgraphs = DetectPatterns(); auto subgraphs = DetectPatterns();
UniquePatterns(&subgraphs); UniquePatterns(&subgraphs);
RemoveOverlappedMatch(&subgraphs); RemoveOverlappedMatch(&subgraphs);
...@@ -87,7 +92,7 @@ void GraphPatternDetector::operator()(Graph* graph, ...@@ -87,7 +92,7 @@ void GraphPatternDetector::operator()(Graph* graph,
} }
bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
VLOG(4) << "mark pdnodes in graph"; VLOG(3) << "mark pdnodes in graph";
if (graph.Nodes().empty()) return false; if (graph.Nodes().empty()) return false;
for (auto& node : GraphTraits::DFS(graph)) { for (auto& node : GraphTraits::DFS(graph)) {
...@@ -107,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) { ...@@ -107,6 +112,7 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
} }
} }
VLOG(3) << pdnodes2nodes_.size() << " nodes marked"; VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
return !pdnodes2nodes_.empty(); return !pdnodes2nodes_.empty();
} }
...@@ -357,7 +363,9 @@ PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type, ...@@ -357,7 +363,9 @@ PDNode* PDNode::assert_is_op_nth_input(const std::string& op_type,
assert_is_op_input(op_type); assert_is_op_input(op_type);
asserts_.emplace_back([=](Node* x) { asserts_.emplace_back([=](Node* x) {
for (auto* op : x->outputs) { for (auto* op : x->outputs) {
if (IsNthInput(x, op, argument, nth)) return true; if (op->IsOp() && op->Op()->Type() == op_type &&
IsNthInput(x, op, argument, nth))
return true;
} }
return false; return false;
}); });
...@@ -368,7 +376,9 @@ PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type, ...@@ -368,7 +376,9 @@ PDNode* PDNode::assert_is_op_nth_output(const std::string& op_type,
assert_is_var(); assert_is_var();
asserts_.emplace_back([=](Node* x) { asserts_.emplace_back([=](Node* x) {
for (auto* op : x->inputs) { for (auto* op : x->inputs) {
if (IsNthOutput(x, op, argument, nth)) return true; if (op->IsOp() && op->Op()->Type() == op_type &&
IsNthOutput(x, op, argument, nth))
return true;
} }
return false; return false;
}); });
...@@ -412,6 +422,12 @@ PDNode* PDNode::assert_is_op_output(const std::string& op_type) { ...@@ -412,6 +422,12 @@ PDNode* PDNode::assert_is_op_output(const std::string& op_type) {
}); });
return this; return this;
} }
PDNode* PDNode::assert_is_op_output(const std::string& op_type,
const std::string& argument) {
assert_is_var();
assert_is_op_nth_output(op_type, argument, 0);
return this;
}
PDNode* PDNode::assert_is_op_input(const std::string& op_type) { PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
assert_is_var(); assert_is_var();
asserts_.emplace_back([=](Node* x) { asserts_.emplace_back([=](Node* x) {
...@@ -424,6 +440,12 @@ PDNode* PDNode::assert_is_op_input(const std::string& op_type) { ...@@ -424,6 +440,12 @@ PDNode* PDNode::assert_is_op_input(const std::string& op_type) {
}); });
return this; return this;
} }
PDNode* PDNode::assert_is_op_input(const std::string& op_type,
const std::string& argument) {
assert_is_var();
assert_is_op_nth_input(op_type, argument, 0);
return this;
}
PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) { PDNode* PDNode::assert_op_has_n_inputs(const std::string& op_type, size_t n) {
assert_is_op(op_type); assert_is_op(op_type);
asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; }); asserts_.emplace_back([=](Node* x) { return x->inputs.size() == n; });
...@@ -439,6 +461,128 @@ PDNode* PDNode::assert_more(PDNode::teller_t&& teller) { ...@@ -439,6 +461,128 @@ PDNode* PDNode::assert_more(PDNode::teller_t&& teller) {
return this; return this;
} }
bool VarLinksToOp(Node* node, const std::string& op_type) {
for (auto* out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->Op()->Input(argument).size() <= nth) return false;
return var->Name() == op->Op()->Input(argument)[nth];
}
bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->Op()->Output(argument).size() <= nth) return false;
return var->Name() == op->Op()->Output(argument)[nth];
}
void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes) {
for (auto* node : nodes) {
graph->RemoveNode(const_cast<Node*>(node));
}
for (auto* node : graph->Nodes()) {
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it);
} else
it++;
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->outputs.erase(it);
} else
it++;
}
}
}
bool VarLinksFromOp(Node* node, const std::string& op_type) {
for (auto* out : node->inputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
PDNode* patterns::FC(PDPattern* pattern, const std::string& name_scope,
PDNode* x, bool with_bias) {
// Create Operators
PDNode* elementwise_add_op{nullptr};
auto* mul_op = pattern->NewNode(name_scope, "mul")->assert_is_op("mul");
if (with_bias) {
elementwise_add_op = pattern->NewNode(name_scope, "elementwise_add")
->assert_is_op("elementwise_add");
}
// Create variables
// w
auto* mul_weight_var = pattern->NewNode(name_scope, "w")
->AsInput()
->assert_is_persistable_var()
->assert_is_op_nth_input("mul", "Y", 0);
PDNode* mul_out_var{nullptr};
if (with_bias) {
// intermediate variable, will be removed in the IR after fuse.
mul_out_var = pattern->NewNode(name_scope, "mul_out")
->AsIntermediate()
->assert_is_only_output_of_op("mul")
->assert_is_op_input("elementwise_add");
}
PDNode *bias{nullptr}, *fc_out{nullptr};
if (with_bias) {
// bias
bias = pattern->NewNode(name_scope, "fc_bias")
->assert_is_op_input("elementwise_add")
->AsInput();
// output
fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput()
->assert_is_op_output("elementwise_add");
} else {
fc_out = pattern->NewNode(name_scope, "fc_out")
->AsOutput()
->assert_is_op_output("mul");
}
if (with_bias) {
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({mul_out_var});
elementwise_add_op->LinksFrom({mul_out_var, bias}).LinksTo({fc_out});
} else {
mul_op->LinksFrom({mul_weight_var, x}).LinksTo({fc_out});
}
return fc_out;
}
PDNode* patterns::LSTM(PDPattern* pattern, const std::string& name_scope,
PDNode* x) {
x->assert_is_op_input("lstm", "Input");
auto* lstm_op = pattern->NewNode(name_scope, "lstm")->assert_is_op("lstm");
#define NEW_NODE(arg__, io__) \
auto* arg__ = pattern->NewNode(name_scope, #arg__) \
->assert_is_op_##io__("lstm", #arg__);
// Currently, the H0 and C0 are optional
// TODO(Superjomn) upgrade the fuse framework to support optional.
// NEW_NODE(H0, input);
// NEW_NODE(C0, input);
NEW_NODE(Weight, input);
NEW_NODE(Bias, input);
NEW_NODE(Hidden, output);
NEW_NODE(Cell, output);
NEW_NODE(BatchGate, output);
NEW_NODE(BatchCellPreAct, output);
lstm_op->LinksFrom({x, Weight, Bias});
lstm_op->LinksTo({Hidden, Cell, BatchGate, BatchCellPreAct});
return Hidden;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -95,7 +95,11 @@ struct PDNode { ...@@ -95,7 +95,11 @@ struct PDNode {
PDNode* assert_var_not_persistable(); PDNode* assert_var_not_persistable();
PDNode* assert_is_persistable_var(); PDNode* assert_is_persistable_var();
PDNode* assert_is_op_output(const std::string& op_type); PDNode* assert_is_op_output(const std::string& op_type);
PDNode* assert_is_op_output(const std::string& op_type,
const std::string& argument);
PDNode* assert_is_op_input(const std::string& op_type); PDNode* assert_is_op_input(const std::string& op_type);
PDNode* assert_is_op_input(const std::string& op_type,
const std::string& argument);
PDNode* assert_is_op_nth_input(const std::string& op_type, PDNode* assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth); const std::string& argument, int nth);
PDNode* assert_is_op_nth_output(const std::string& op_type, PDNode* assert_is_op_nth_output(const std::string& op_type,
...@@ -167,6 +171,9 @@ class PDPattern { ...@@ -167,6 +171,9 @@ class PDPattern {
PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID()); PDNode* NewNode(PDNode::teller_t&& teller, const std::string& name = NewID());
PDNode* NewNode(const std::string& name = NewID()); PDNode* NewNode(const std::string& name = NewID());
PDNode* NewNode(const std::string& prefix, const std::string& name) {
return NewNode(prefix + "/" + name);
}
PDNode* RetrieveNode(const std::string& id) const; PDNode* RetrieveNode(const std::string& id) const;
const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; } const std::vector<std::unique_ptr<PDNode>>& nodes() const { return nodes_; }
...@@ -257,64 +264,36 @@ class GraphPatternDetector { ...@@ -257,64 +264,36 @@ class GraphPatternDetector {
// some helper methods. // some helper methods.
// Op's input. // Tell if a var links to an Op
static bool VarLinksToOp(Node* node, const std::string& op_type) { bool VarLinksToOp(Node* node, const std::string& op_type);
for (auto* out : node->outputs) {
if (out->IsOp() && out->Op()->Type() == op_type) { // Tell if an op links to a var
return true; bool VarLinksFromOp(Node* node, const std::string& op_type);
}
}
return false;
}
// Op's output.
static bool VarLinksFromOp(Node* node, const std::string& op_type) {
for (auto* out : node->inputs) {
if (out->IsOp() && out->Op()->Type() == op_type) {
return true;
}
}
return false;
}
// Check whether a var node is a op node's nth input. // Check whether a var node is a op node's nth input.
static bool IsNthInput(Node* var, Node* op, const std::string& argument, bool IsNthInput(Node* var, Node* op, const std::string& argument, size_t nth);
size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->inputs.size() <= nth) return false;
return var->Name() == op->Op()->Input(argument)[nth];
}
static bool IsNthOutput(Node* var, Node* op, const std::string& argument,
size_t nth) {
PADDLE_ENFORCE(var->IsVar());
PADDLE_ENFORCE(op->IsOp());
if (op->inputs.size() <= nth) return false;
return var->Name() == op->Op()->Output(argument)[nth];
}
static void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes) {
for (auto* node : nodes) {
graph->RemoveNode(const_cast<Node*>(node));
}
for (auto* node : graph->Nodes()) { // Tell whether a var node is a op node's nth output.
for (auto it = node->inputs.begin(); it != node->inputs.end();) { bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth);
if (nodes.count(*it)) {
it = const_cast<Node*>(node)->inputs.erase(it); // Graph safely remove some nodes, will automatically clean up the edges.
} else void GraphSafeRemoveNodes(Graph* graph,
it++; const std::unordered_set<const Node*>& nodes);
}
for (auto it = node->outputs.begin(); it != node->outputs.end();) { // Some pre-defined patterns those can be reused in multiple passes.
if (nodes.count(*it)) { namespace patterns {
it = const_cast<Node*>(node)->outputs.erase(it);
} else // FC with bias
it++; // op: mul + elementwise_add
} // named nodes:
} // mul, elementwise_add
} // w, mul_out, bias, fc_out
PDNode* FC(PDPattern* pattern, const std::string& name_scope, PDNode* x,
bool with_bias);
PDNode* LSTM(PDPattern* pattern, const std::string& name_scope, PDNode* x);
} // namespace patterns
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -42,6 +42,13 @@ class GraphVizPass : public Pass { ...@@ -42,6 +42,13 @@ class GraphVizPass : public Pass {
marked_nodes_t ConsumeMarkedNodes(Graph* graph) const; marked_nodes_t ConsumeMarkedNodes(Graph* graph) const;
}; };
static GraphVizPass::marked_nodes_t& GetMarkedNodes(Graph* graph) {
if (!graph->Has(kGraphvizMarkedNodeAttr)) {
graph->Set(kGraphvizMarkedNodeAttr, new GraphVizPass::marked_nodes_t);
}
return graph->Get<GraphVizPass::marked_nodes_t>(kGraphvizMarkedNodeAttr);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -109,6 +109,7 @@ void Analyzer::Run(Argument* argument) { ...@@ -109,6 +109,7 @@ void Analyzer::Run(Argument* argument) {
"infer_clean_graph_pass", "graph_viz_pass", // "infer_clean_graph_pass", "graph_viz_pass", //
"attention_lstm_fuse_pass", "graph_viz_pass", // "attention_lstm_fuse_pass", "graph_viz_pass", //
"fc_lstm_fuse_pass", "graph_viz_pass", // "fc_lstm_fuse_pass", "graph_viz_pass", //
"mul_lstm_fuse_pass", "graph_viz_pass", //
"seq_concat_fc_fuse_pass", "graph_viz_pass", // "seq_concat_fc_fuse_pass", "graph_viz_pass", //
"fc_fuse_pass", "graph_viz_pass" // "fc_fuse_pass", "graph_viz_pass" //
......
...@@ -329,6 +329,7 @@ void TestDituRNNPrediction(const std::string &model_path, ...@@ -329,6 +329,7 @@ void TestDituRNNPrediction(const std::string &model_path,
ASSERT_TRUE(fuse_statis.count("fc")); ASSERT_TRUE(fuse_statis.count("fc"));
EXPECT_EQ(fuse_statis.at("fc"), 1); EXPECT_EQ(fuse_statis.at("fc"), 1);
EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 1);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册