未验证 提交 98899b73 编写于 作者: A Adam 提交者: GitHub

Fix FC + GRU fuse pass (#25687)

上级 0e23dc3a
...@@ -26,15 +26,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -26,15 +26,15 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* pattern = gpd.mutable_pattern(); auto* pattern = gpd.mutable_pattern();
// Create pattern.
patterns::FC fc_pattern(pattern, name_scope);
patterns::GRU gru_pattern(pattern, name_scope);
PDNode* x = PDNode* x =
pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable(); pattern->NewNode(patterns::UniqueKey("x"))->assert_var_not_persistable();
// Create pattern.
patterns::FC fc_pattern(pattern, name_scope);
auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false); auto* fc_out = fc_pattern(x, with_fc_bias, /* with_relu */ false);
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse. fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
patterns::GRU gru_pattern(pattern, name_scope);
gru_pattern(fc_out); gru_pattern(fc_out);
// Create New OpDesc // Create New OpDesc
...@@ -48,17 +48,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -48,17 +48,18 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
SET_IN(X, x); SET_IN(X, x);
SET_IN(WeightX, weight_x); SET_IN(WeightX, weight_x);
SET_IN(WeightH, weight_h); SET_IN(WeightH, weight_h);
if (with_fc_bias) {
op_desc.SetInput("Bias", {NEW_NAME(bias) + bias->Name()});
} else {
SET_IN(Bias, bias); SET_IN(Bias, bias);
}
#undef SET_IN #undef SET_IN
// TODO(grygielski): Add H0 to the pass
op_desc.SetInput("H0", {}); op_desc.SetInput("H0", {});
op_desc.SetOutput("Hidden", {hidden->Name()}); op_desc.SetOutput("Hidden", {hidden->Name()});
op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse")); op_desc.SetAttr("is_reverse", gru->Op()->GetAttr("is_reverse"));
op_desc.SetAttr("origin_mode",
gru->Op()->GetAttrIfExists<bool>("origin_mode"));
// TODO(TJ): This should be a option for infer // TODO(TJ): This should be a option for infer
op_desc.SetAttr("use_seq", true); op_desc.SetAttr("use_seq", true);
op_desc.SetAttr("activation", gru->Op()->GetAttr("activation"));
op_desc.SetAttr("gate_activation", gru->Op()->GetAttr("gate_activation"));
#define SET_IMTERMEDIATE_OUT(key) op_desc.SetOutput(#key, {NEW_NAME(key)}) #define SET_IMTERMEDIATE_OUT(key) op_desc.SetOutput(#key, {NEW_NAME(key)})
SET_IMTERMEDIATE_OUT(ReorderedH0); SET_IMTERMEDIATE_OUT(ReorderedH0);
...@@ -68,35 +69,30 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -68,35 +69,30 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
#undef SET_IMTERMEDIATE_OUT #undef SET_IMTERMEDIATE_OUT
auto* op = graph->CreateOpNode(&op_desc); auto* op = graph->CreateOpNode(&op_desc);
PADDLE_ENFORCE_EQ(graph->Has(kParamScopeAttr), true,
platform::errors::InvalidArgument(
"Graph have no attr kParamScopeAttr."));
auto& scope = graph->Get<Scope>(kParamScopeAttr);
if (with_fc_bias) { if (with_fc_bias) {
// Fusion GRU bias = fcbias + grubias auto* gru_bias_var = scope->FindVar(bias->Name());
auto* fusion_bias_var = scope.Var(NEW_NAME(bias) + bias->Name()); auto* fc_bias_var = scope->FindVar(fc_bias->Name());
auto* out_bias_tensor = PADDLE_ENFORCE_NE(
fusion_bias_var->GetMutable<framework::LoDTensor>(); gru_bias_var, nullptr,
PADDLE_ENFORCE_NOT_NULL( platform::errors::NotFound("GRU bias var has not been found."));
fusion_bias_var, PADDLE_ENFORCE_NE(
platform::errors::InvalidArgument( fc_bias_var, nullptr,
"Fusion bias variable's pointer cannot be nullptr.")); platform::errors::NotFound("FC bias var has not been found."));
auto* gru_bias_var = scope.FindVar(bias->Name());
auto* fc_bias_var = scope.FindVar(fc_bias->Name()); auto* gru_bias_tensor = gru_bias_var->GetMutable<LoDTensor>();
PADDLE_ENFORCE_NOT_NULL(gru_bias_var, auto* fc_bias_tensor = fc_bias_var->GetMutable<LoDTensor>();
platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(
"Gru bias var ptr cannot be nullptr.")); gru_bias_tensor->numel(), fc_bias_tensor->numel(),
PADDLE_ENFORCE_NOT_NULL(fc_bias_var, platform::errors::PreconditionNotMet(
platform::errors::InvalidArgument( "GRU and FC biases have to have equal number of elements."));
"Fc bias var ptr cannot be nullptr."));
const auto& gru_bias_tenosr = gru_bias_var->Get<framework::LoDTensor>(); auto gru_bias_data =
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>(); gru_bias_tensor->mutable_data<float>(platform::CPUPlace());
// new bias = fc bias + gru bias auto* fc_bias_data = fc_bias_tensor->data<float>();
out_bias_tensor->Resize(gru_bias_tenosr.dims());
auto* data = out_bias_tensor->mutable_data<float>(platform::CPUPlace()); // Recompute GRU bias
for (int i = 0; i < out_bias_tensor->numel(); i++) { for (int i = 0; i < gru_bias_tensor->numel(); ++i) {
data[i] = gru_bias_data[i] += fc_bias_data[i];
fc_bias_tensor.data<float>()[i] + gru_bias_tenosr.data<float>()[i];
} }
} }
#undef GET_NODE #undef GET_NODE
...@@ -117,7 +113,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope, ...@@ -117,7 +113,7 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
IR_NODE_LINK_TO(x, op); IR_NODE_LINK_TO(x, op);
IR_NODE_LINK_TO(weight_x, op); IR_NODE_LINK_TO(weight_x, op);
IR_NODE_LINK_TO(weight_h, op); IR_NODE_LINK_TO(weight_h, op);
IR_NODE_LINK_TO(bias, op); // actually should link to new bias if have IR_NODE_LINK_TO(bias, op);
IR_NODE_LINK_TO(op, hidden); IR_NODE_LINK_TO(op, hidden);
// h0? // h0?
return op; return op;
......
...@@ -56,8 +56,6 @@ void SetConfig(AnalysisConfig *cfg) { ...@@ -56,8 +56,6 @@ void SetConfig(AnalysisConfig *cfg) {
cfg->DisableGpu(); cfg->DisableGpu();
cfg->SwitchIrDebug(); cfg->SwitchIrDebug();
cfg->SwitchSpecifyInputNames(false); cfg->SwitchSpecifyInputNames(false);
// TODO(TJ): fix fusion gru
cfg->pass_builder()->DeletePass("fc_gru_fuse_pass");
} }
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
......
...@@ -183,6 +183,10 @@ void FusionGRUOpMaker::Make() { ...@@ -183,6 +183,10 @@ void FusionGRUOpMaker::Make() {
"(bool, default: True) " "(bool, default: True) "
"whether to use seq mode to compute GRU.") "whether to use seq mode to compute GRU.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555")
.SetDefault(false);
AddComment(R"DOC( AddComment(R"DOC(
The Fusion complete GRU Operator. The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU, This operator fuse the fully-connected operator into GRU,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册