diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index cdc6520ad6f4454fd06779f55e8d07bb3ca2b032..6a9c64e3a7f24d7d8f1848a959a0be8ab7544e5e 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -93,6 +93,10 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { bool use_gpu = Has("use_gpu") ? Get("use_gpu") : false; bool use_fc_padding = Has("use_fc_padding") ? Get("use_fc_padding") : true; + const std::string& w_name = patterns::UniqueKey(w->Name()); + VarDesc w_key(w_name); + w_key.SetPersistable(true); + auto* w_node = g->CreateVarNode(&w_key); if (!use_gpu && use_fc_padding) { auto* scope = param_scope(); auto* weight = scope->FindVar(w->Name())->GetMutable(); @@ -102,20 +106,25 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { int w_h = weight_dims[0]; int w_w = weight_dims[1]; if (w_h % 128 == 0 && w_w % 128 == 0) { + auto* w_var = scope->Var(w_name); + auto* w_tensor = w_var->GetMutable(); + auto* weight_data_tmp = new float[weight_num]; for (int i = 0; i < w_h; i++) { memcpy(weight_data_tmp + i * w_w, weight_data + i * w_w, w_w * sizeof(float)); } - weight->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4}); + w_tensor->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4}); auto* weight_data_new = - weight->mutable_data(platform::CPUPlace()); + w_tensor->mutable_data(platform::CPUPlace()); for (int i = 0; i < w_h; i++) { memcpy(weight_data_new + i * (w_w + 4), weight_data_tmp + i * w_w, w_w * sizeof(float)); } delete[] weight_data_tmp; + desc.SetInput("W", {w_name}); desc.SetAttr("padding_weights", true); + desc.Flush(); } } @@ -147,7 +156,12 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { } IR_NODE_LINK_TO(subgraph.at(x), fc_node); - IR_NODE_LINK_TO(w, fc_node); + if (desc.GetAttrIfExists("padding_weights")) { + IR_NODE_LINK_TO(w_node, fc_node); + } else { + GraphSafeRemoveNodes(g, {w_node}); + IR_NODE_LINK_TO(w, fc_node); + } IR_NODE_LINK_TO(bias, fc_node); if (with_relu) { IR_NODE_LINK_TO(fc_node, relu_out);