未验证 提交 efc3b182 编写于 作者: W Wojciech Uss 提交者: GitHub

a fix for the fc_lstm_fuse_pass (#28709)

上级 3b0dd5f6
...@@ -36,9 +36,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -36,9 +36,7 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
->assert_var_not_persistable(); ->assert_var_not_persistable();
patterns::FC fc_pattern(pattern, name_scope); patterns::FC fc_pattern(pattern, name_scope);
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate. auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false);
auto* fc_out =
fc_pattern(x, with_fc_bias, /* with_relu */ false)->AsIntermediate();
patterns::LSTM lstm_pattern(pattern, name_scope); patterns::LSTM lstm_pattern(pattern, name_scope);
lstm_pattern(fc_out); lstm_pattern(fc_out);
...@@ -58,28 +56,25 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -58,28 +56,25 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
// Add FC-bias with LSTM-bias and create a new weight // Add FC-bias with LSTM-bias and create a new weight
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
const std::string& new_bias_var = patterns::UniqueKey("NewBias");
auto* bias_var = scope->Var(new_bias_var);
PADDLE_ENFORCE_NOT_NULL(bias_var, platform::errors::InvalidArgument(
"Bias var ptr cannot be nullptr."));
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
auto* lstm_bias_var = scope->FindVar(bias->Name()); auto* lstm_bias_var = scope->FindVar(bias->Name());
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
PADDLE_ENFORCE_NOT_NULL(lstm_bias_var, PADDLE_ENFORCE_NOT_NULL(lstm_bias_var,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Lstm bias var ptr cannot be nullptr.")); "Lstm bias var ptr cannot be nullptr."));
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>(); PADDLE_ENFORCE_NOT_NULL(fc_bias_var,
bias_tensor->Resize(lstm_bias_tensor.dims()); platform::errors::InvalidArgument(
"FC bias var ptr cannot be nullptr."));
auto* fc_bias_var = scope->FindVar(fc_bias->Name()); auto* lstm_bias_tensor =
lstm_bias_var->GetMutable<framework::LoDTensor>();
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>(); const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
auto* data = bias_tensor->mutable_data<float>(platform::CPUPlace()); auto lstm_bias_data =
lstm_bias_tensor->mutable_data<float>(platform::CPUPlace());
auto* fc_bias_data = fc_bias_tensor.data<float>();
for (int i = 0; i < bias_tensor->numel(); i++) { for (int i = 0; i < lstm_bias_tensor->numel(); i++) {
data[i] = lstm_bias_data[i] += fc_bias_data[i];
fc_bias_tensor.data<float>()[i] + lstm_bias_tensor.data<float>()[i];
} }
op_desc.SetInput("Bias", {new_bias_var});
} }
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
...@@ -114,6 +109,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope, ...@@ -114,6 +109,8 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
IR_NODE_LINK_TO(weight_h, op); IR_NODE_LINK_TO(weight_h, op);
IR_NODE_LINK_TO(bias, op); IR_NODE_LINK_TO(bias, op);
IR_NODE_LINK_TO(op, hidden); IR_NODE_LINK_TO(op, hidden);
IR_NODE_LINK_TO(op, cell);
IR_NODE_LINK_TO(op, xx);
#define IR_NODE(x) \ #define IR_NODE(x) \
VarDesc key_##x(x); \ VarDesc key_##x(x); \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册