From cec3cfba8d802b77b3083711d4c41925e462aecd Mon Sep 17 00:00:00 2001 From: liu zhengxi <380185688@qq.com> Date: Tue, 10 Mar 2020 10:35:38 +0800 Subject: [PATCH] Fix fc padding bug during inference fusion (#22860) (#22921) * fix fc padding during fusion, test=develop * fix optim model inference after SaveOptimModel, test=develop --- paddle/fluid/framework/ir/fc_fuse_pass.cc | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/ir/fc_fuse_pass.cc b/paddle/fluid/framework/ir/fc_fuse_pass.cc index ed575272f1..650977c6b2 100644 --- a/paddle/fluid/framework/ir/fc_fuse_pass.cc +++ b/paddle/fluid/framework/ir/fc_fuse_pass.cc @@ -100,22 +100,31 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { bool use_gpu = Has("use_gpu") ? Get("use_gpu") : false; bool use_fc_padding = Has("use_fc_padding") ? Get("use_fc_padding") : true; + const std::string& w_name = patterns::UniqueKey(w->Name()); + VarDesc w_key(w_name); + w_key.SetPersistable(true); + auto* w_node = g->CreateVarNode(&w_key); if (!use_gpu && use_fc_padding) { if (w_h % 128 == 0 && w_w % 128 == 0) { + auto* w_var = scope->Var(w_name); + auto* w_tensor = w_var->GetMutable(); + auto* weight_data_tmp = new float[weight_num]; for (int i = 0; i < w_h; i++) { memcpy(weight_data_tmp + i * w_w, weight_data + i * w_w, w_w * sizeof(float)); } - weight->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4}); + w_tensor->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4}); auto* weight_data_new = - weight->mutable_data(platform::CPUPlace()); + w_tensor->mutable_data(platform::CPUPlace()); for (int i = 0; i < w_h; i++) { memcpy(weight_data_new + i * (w_w + 4), weight_data_tmp + i * w_w, w_w * sizeof(float)); } delete[] weight_data_tmp; + desc.SetInput("W", {w_name}); desc.SetAttr("padding_weights", true); + desc.Flush(); } } @@ -147,7 +156,12 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { } IR_NODE_LINK_TO(subgraph.at(x), fc_node); - IR_NODE_LINK_TO(w, fc_node); + if (desc.GetAttrIfExists("padding_weights")) { + IR_NODE_LINK_TO(w_node, fc_node); + } else { + GraphSafeRemoveNodes(g, {w_node}); + IR_NODE_LINK_TO(w, fc_node); + } IR_NODE_LINK_TO(bias, fc_node); if (with_relu) { IR_NODE_LINK_TO(fc_node, relu_out); -- GitLab