未验证 提交 ae67dcea 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #13366 from luotao1/fusion_lstm_bug

fix fusion_lstm unique_name bug
...@@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
if (with_fc_bias) { if (with_fc_bias) {
// Add FC-bias with LSTM-bias and create a new weight // Add FC-bias with LSTM-bias and create a new weight
PADDLE_ENFORCE(scope); PADDLE_ENFORCE(scope);
const std::string& new_bias_var = name_scope + "_bias.new"; const std::string& new_bias_var = patterns::UniqueKey("NewBias");
auto* bias_var = scope->Var(new_bias_var); auto* bias_var = scope->Var(new_bias_var);
PADDLE_ENFORCE(bias_var); PADDLE_ENFORCE(bias_var);
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>(); auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
...@@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern); GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
...@@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
fc_bias); fc_bias);
// Remove unneeded nodes. // Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes( std::unordered_set<const Node*> marked_nodes(
{mul, lstm, elementwise_add}); {mul, lstm, elementwise_add, fc_bias});
GraphSafeRemoveNodes(graph, marked_nodes); GraphSafeRemoveNodes(graph, marked_nodes);
} else { } else {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern); GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h" #include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -37,13 +38,16 @@ IRPassManager::IRPassManager(const ProgramDesc &program, ...@@ -37,13 +38,16 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
void IRPassManager::Apply(const std::vector<std::string> &passes) { void IRPassManager::Apply(const std::vector<std::string> &passes) {
// Apply all the passes // Apply all the passes
std::string pre_pass; std::string pre_pass;
int pass_num = 0;
for (const std::string &pass_name : passes) { for (const std::string &pass_name : passes) {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass_name); PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass_name);
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
std::string dot_file_path = std::string dot_file_path = std::to_string(pass_num) + "_ir_" +
"ir_" + (pre_pass.empty() ? "origin" : pre_pass) + ".dot"; (pre_pass.empty() ? "origin" : pre_pass) +
".dot";
pass->Set("graph_viz_path", new std::string(std::move(dot_file_path))); pass->Set("graph_viz_path", new std::string(std::move(dot_file_path)));
pass_num++;
} }
graph_ = pass->Apply(std::move(graph_)); graph_ = pass->Apply(std::move(graph_));
pre_pass = pass_name; pre_pass = pass_name;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册