diff --git a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc index 2b451da7bfa8b089c0f891ce42fbc293b19ac4b1..9dca4d1b29f9f3ef51559383efa3e0a18965ef05 100644 --- a/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc @@ -36,9 +36,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ->assert_var_not_persistable(); patterns::FC fc_pattern(pattern, name_scope); - // fc_out is a tmp var, will be removed after fuse, so marked as intermediate. - auto* fc_out = - fc_pattern(x, with_fc_bias, /* with_relu */ false)->AsIntermediate(); + auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false); patterns::LSTM lstm_pattern(pattern, name_scope); lstm_pattern(fc_out); @@ -58,28 +56,25 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, // Add FC-bias with LSTM-bias and create a new weight PADDLE_ENFORCE_NOT_NULL( scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); - const std::string& new_bias_var = patterns::UniqueKey("NewBias"); - auto* bias_var = scope->Var(new_bias_var); - PADDLE_ENFORCE_NOT_NULL(bias_var, platform::errors::InvalidArgument( - "Bias var ptr cannot be nullptr.")); - auto* bias_tensor = bias_var->GetMutable(); auto* lstm_bias_var = scope->FindVar(bias->Name()); + auto* fc_bias_var = scope->FindVar(fc_bias->Name()); PADDLE_ENFORCE_NOT_NULL(lstm_bias_var, platform::errors::InvalidArgument( "Lstm bias var ptr cannot be nullptr.")); - const auto& lstm_bias_tensor = lstm_bias_var->Get(); - bias_tensor->Resize(lstm_bias_tensor.dims()); - - auto* fc_bias_var = scope->FindVar(fc_bias->Name()); + PADDLE_ENFORCE_NOT_NULL(fc_bias_var, + platform::errors::InvalidArgument( + "FC bias var ptr cannot be nullptr.")); + auto* lstm_bias_tensor = + lstm_bias_var->GetMutable(); const auto& fc_bias_tensor = fc_bias_var->Get(); - auto* data = bias_tensor->mutable_data(platform::CPUPlace()); + auto lstm_bias_data = + lstm_bias_tensor->mutable_data(platform::CPUPlace()); + auto* fc_bias_data = fc_bias_tensor.data(); - for (int i = 0; i < bias_tensor->numel(); i++) { - data[i] = - fc_bias_tensor.data()[i] + lstm_bias_tensor.data()[i]; + for (int i = 0; i < lstm_bias_tensor->numel(); i++) { + lstm_bias_data[i] += fc_bias_data[i]; } - op_desc.SetInput("Bias", {new_bias_var}); } op_desc.SetInput("H0", {}); @@ -114,6 +109,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, IR_NODE_LINK_TO(weight_h, op); IR_NODE_LINK_TO(bias, op); IR_NODE_LINK_TO(op, hidden); + IR_NODE_LINK_TO(op, cell); + IR_NODE_LINK_TO(op, xx); #define IR_NODE(x) \ VarDesc key_##x(x); \