From 98899b73d280583c90f08cb6a4eda407e9770f0e Mon Sep 17 00:00:00 2001 From: Adam <38704900+grygielski@users.noreply.github.com> Date: Fri, 24 Jul 2020 06:38:48 +0200 Subject: [PATCH] Fix FC + GRU fuse pass (#25687) --- paddle/fluid/framework/ir/fc_gru_fuse_pass.cc | 72 +++++++++---------- .../tests/api/analyzer_vis_tester.cc | 2 - paddle/fluid/operators/fused/fusion_gru_op.cc | 4 ++ 3 files changed, 38 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc index 08dd0302b4..a2185cdc55 100644 --- a/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_gru_fuse_pass.cc @@ -26,15 +26,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, GraphPatternDetector gpd; auto* pattern = gpd.mutable_pattern(); - // Create pattern. - patterns::FC fc_pattern(pattern, name_scope); - patterns::GRU gru_pattern(pattern, name_scope); - PDNode* x = pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable(); + // Create pattern. + patterns::FC fc_pattern(pattern, name_scope); auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false); fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse. + + patterns::GRU gru_pattern(pattern, name_scope); gru_pattern(fc_out); // Create New OpDesc @@ -48,17 +48,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, SET_IN(X, x); SET_IN(WeightX, weight_x); SET_IN(WeightH, weight_h); - if (with_fc_bias) { - op_desc.SetInput("Bias", {NEW_NAME(bias) + bias->Name()}); - } else { - SET_IN(Bias, bias); - } + SET_IN(Bias, bias); #undef SET_IN + // TODO(grygielski): Add H0 to the pass op_desc.SetInput("H0", {}); op_desc.SetOutput("Hidden", {hidden->Name()}); op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse")); + op_desc.SetAttr("origin_mode", + gru->Op()->GetAttrIfExists("origin_mode")); // TODO(TJ): This should be a option for infer op_desc.SetAttr("use_seq", true); + op_desc.SetAttr("activation", gru->Op()->GetAttr("activation")); + op_desc.SetAttr("gate_activation", gru->Op()->GetAttr("gate_activation")); #define SET_IMTERMEDIATE_OUT(key) op_desc.SetOutput(#key, {NEW_NAME(key)}) SET_IMTERMEDIATE_OUT(ReorderedH0); @@ -68,35 +69,30 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, #undef SET_IMTERMEDIATE_OUT auto* op = graph->CreateOpNode(&op_desc); - PADDLE_ENFORCE_EQ(graph->Has(kParamScopeAttr), true, - platform::errors::InvalidArgument( - "Graph have no attr kParamScopeAttr.")); - auto& scope = graph->Get(kParamScopeAttr); if (with_fc_bias) { - // Fusion GRU bias = fcbias + grubias - auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name()); - auto* out_bias_tensor = - fusion_bias_var->GetMutable(); - PADDLE_ENFORCE_NOT_NULL( - fusion_bias_var, - platform::errors::InvalidArgument( - "Fusion bias variable's pointer cannot be nullptr.")); - auto* gru_bias_var = scope.FindVar(bias->Name()); - auto* fc_bias_var = scope.FindVar(fc_bias->Name()); - PADDLE_ENFORCE_NOT_NULL(gru_bias_var, - platform::errors::InvalidArgument( - "Gru bias var ptr cannot be nullptr.")); - PADDLE_ENFORCE_NOT_NULL(fc_bias_var, - platform::errors::InvalidArgument( - "Fc bias var ptr cannot be nullptr.")); - const auto& gru_bias_tenosr = gru_bias_var->Get(); - const auto& fc_bias_tensor = fc_bias_var->Get(); - // new bias = fc bias + gru bias - out_bias_tensor->Resize(gru_bias_tenosr.dims()); - auto* data = out_bias_tensor->mutable_data(platform::CPUPlace()); - for (int i = 0; i < out_bias_tensor->numel(); i++) { - data[i] = - fc_bias_tensor.data()[i] + gru_bias_tenosr.data()[i]; + auto* gru_bias_var = scope->FindVar(bias->Name()); + auto* fc_bias_var = scope->FindVar(fc_bias->Name()); + PADDLE_ENFORCE_NE( + gru_bias_var, nullptr, + platform::errors::NotFound("GRU bias var has not been found.")); + PADDLE_ENFORCE_NE( + fc_bias_var, nullptr, + platform::errors::NotFound("FC bias var has not been found.")); + + auto* gru_bias_tensor = gru_bias_var->GetMutable(); + auto* fc_bias_tensor = fc_bias_var->GetMutable(); + PADDLE_ENFORCE_EQ( + gru_bias_tensor->numel(), fc_bias_tensor->numel(), + platform::errors::PreconditionNotMet( + "GRU and FC biases have to have equal number of elements.")); + + auto gru_bias_data = + gru_bias_tensor->mutable_data(platform::CPUPlace()); + auto* fc_bias_data = fc_bias_tensor->data(); + + // Recompute GRU bias + for (int i = 0; i < gru_bias_tensor->numel(); ++i) { + gru_bias_data[i] += fc_bias_data[i]; } } #undef GET_NODE @@ -117,7 +113,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, IR_NODE_LINK_TO(x, op); IR_NODE_LINK_TO(weight_x, op); IR_NODE_LINK_TO(weight_h, op); - IR_NODE_LINK_TO(bias, op); // actually should link to new bias if have + IR_NODE_LINK_TO(bias, op); IR_NODE_LINK_TO(op, hidden); // h0? return op; diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc index 5f65229ecd..65755b7b15 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc @@ -56,8 +56,6 @@ void SetConfig(AnalysisConfig *cfg) { cfg->DisableGpu(); cfg->SwitchIrDebug(); cfg->SwitchSpecifyInputNames(false); - // TODO(TJ): fix fusion gru - cfg->pass_builder()->DeletePass("fc_gru_fuse_pass"); } void SetInput(std::vector> *inputs) { diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 32eeae9a01..f6c8316e2e 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -183,6 +183,10 @@ void FusionGRUOpMaker::Make() { "(bool, default: True) " "whether to use seq mode to compute GRU.") .SetDefault(true); + AddAttr("origin_mode", + "bool" + "use origin mode in article https://arxiv.org/abs/1412.3555") + .SetDefault(false); AddComment(R"DOC( The Fusion complete GRU Operator. This operator fuse the fully-connected operator into GRU, -- GitLab