未验证 提交 71e350c5 编写于 作者: A Adam 提交者: GitHub

Fix FC + GRU fuse pass (#25733)

上级 2d7e7759
......@@ -26,15 +26,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FC fc_pattern(pattern, name_scope);
patterns::GRU gru_pattern(pattern, name_scope);
PDNode* x =
pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable();
// Create pattern.
patterns::FC fc_pattern(pattern, name_scope);
auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
patterns::GRU gru_pattern(pattern, name_scope);
gru_pattern(fc_out);
// Create New OpDesc
......@@ -48,17 +48,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
SET_IN(X, x);
SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h);
if (with_fc_bias) {
op_desc.SetInput("Bias", {NEW_NAME(bias) + bias->Name()});
} else {
SET_IN(Bias, bias);
}
SET_IN(Bias, bias);
#undef SET_IN
// TODO(grygielski): Add H0 to the pass
op_desc.SetInput("H0", {});
op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("origin_mode",
gru->Op()->GetAttrIfExists<bool>("origin_mode"));
// TODO(TJ): This should be a option for infer
op_desc.SetAttr("use_seq", true);
op_desc.SetAttr("activation", gru->Op()->GetAttr("activation"));
op_desc.SetAttr("gate_activation", gru->Op()->GetAttr("gate_activation"));
#define SET_IMTERMEDIATE_OUT(key) op_desc.SetOutput(#key, {NEW_NAME(key)})
SET_IMTERMEDIATE_OUT(ReorderedH0);
......@@ -68,26 +69,30 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef SET_IMTERMEDIATE_OUT
auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
auto& scope = graph->Get<Scope>(kParamScopeAttr);
if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias
auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name());
auto* out_bias_tensor =
fusion_bias_var->GetMutable<framework::LoDTensor>();
PADDLE_ENFORCE(fusion_bias_var);
auto* gru_bias_var = scope.FindVar(bias->Name());
auto* fc_bias_var = scope.FindVar(fc_bias->Name());
PADDLE_ENFORCE(gru_bias_var);
PADDLE_ENFORCE(fc_bias_var);
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>();
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
// new bias = fc bias + gru bias
out_bias_tensor->Resize(gru_bias_tenosr.dims());
auto* data = out_bias_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < out_bias_tensor->numel(); i++) {
data[i] =
fc_bias_tensor.data<float>()[i] + gru_bias_tenosr.data<float>()[i];
auto* gru_bias_var = scope->FindVar(bias->Name());
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
PADDLE_ENFORCE_NE(
gru_bias_var, nullptr,
platform::errors::NotFound("GRU bias var has not been found."));
PADDLE_ENFORCE_NE(
fc_bias_var, nullptr,
platform::errors::NotFound("FC bias var has not been found."));
auto* gru_bias_tensor = gru_bias_var->GetMutable<LoDTensor>();
auto* fc_bias_tensor = fc_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_EQ(
gru_bias_tensor->numel(), fc_bias_tensor->numel(),
platform::errors::PreconditionNotMet(
"GRU and FC biases have to have equal number of elements."));
auto gru_bias_data =
gru_bias_tensor->mutable_data<float>(platform::CPUPlace());
auto* fc_bias_data = fc_bias_tensor->data<float>();
// Recompute GRU bias
for (int i = 0; i < gru_bias_tensor->numel(); ++i) {
gru_bias_data[i] += fc_bias_data[i];
}
}
#undef GET_NODE
......@@ -108,7 +113,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
IR_NODE_LINK_TO(x, op);
IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h, op);
IR_NODE_LINK_TO(bias, op); // actually should link to new bias if have
IR_NODE_LINK_TO(bias, op);
IR_NODE_LINK_TO(op, hidden);
// h0?
return op;
......
......@@ -56,8 +56,6 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu();
cfg->SwitchIrDebug();
cfg->SwitchSpecifyInputNames(false);
// TODO(TJ): fix fusion gru
cfg->pass_builder()->DeletePass("fc_gru_fuse_pass");
}
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
......
......@@ -183,6 +183,10 @@ void FusionGRUOpMaker::Make() {
"(bool, default: True) "
"whether to use seq mode to compute GRU.")
.SetDefault(true);
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555")
.SetDefault(false);
AddComment(R"DOC(
The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册