提交 b12322ce 编写于 作者: L luotao1

fix fusion_lstm unique_name bug

上级 62a98210
......@@ -51,7 +51,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
if (with_fc_bias) {
// Add FC-bias with LSTM-bias and create a new weight
PADDLE_ENFORCE(scope);
const std::string& new_bias_var = name_scope + "_bias.new";
const std::string& new_bias_var = patterns::UniqueKey("NewBias");
auto* bias_var = scope->Var(new_bias_var);
PADDLE_ENFORCE(bias_var);
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
......@@ -120,7 +120,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
......@@ -136,7 +135,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
fc_bias);
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes(
{mul, lstm, elementwise_add});
{mul, lstm, elementwise_add, fc_bias});
GraphSafeRemoveNodes(graph, marked_nodes);
} else {
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/inference/analysis/ir_pass_manager.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
......@@ -37,13 +38,16 @@ IRPassManager::IRPassManager(const ProgramDesc &program,
void IRPassManager::Apply(const std::vector<std::string> &passes) {
// Apply all the passes
std::string pre_pass;
int pass_num = 0;
for (const std::string &pass_name : passes) {
PrettyLogEndl(Style::H2(), "--- Running IR pass [%s]", pass_name);
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
if (pass_name == "graph_viz_pass") {
std::string dot_file_path =
"ir_" + (pre_pass.empty() ? "origin" : pre_pass) + ".dot";
std::string dot_file_path = std::to_string(pass_num) + "_ir_" +
(pre_pass.empty() ? "origin" : pre_pass) +
".dot";
pass->Set("graph_viz_path", new std::string(std::move(dot_file_path)));
pass_num++;
}
graph_ = pass->Apply(std::move(graph_));
pre_pass = pass_name;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册