未验证 提交 61fef975 编写于 作者: L liu zhengxi 提交者: GitHub

Fix fc padding bug during inference fusion (#22860)

* fix fc padding during fusion, test=develop

* fix optim model inference after SaveOptimModel, test=develop
上级 4b40edf3
...@@ -93,6 +93,10 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { ...@@ -93,6 +93,10 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
bool use_gpu = Has("use_gpu") ? Get<bool>("use_gpu") : false; bool use_gpu = Has("use_gpu") ? Get<bool>("use_gpu") : false;
bool use_fc_padding = bool use_fc_padding =
Has("use_fc_padding") ? Get<bool>("use_fc_padding") : true; Has("use_fc_padding") ? Get<bool>("use_fc_padding") : true;
const std::string& w_name = patterns::UniqueKey(w->Name());
VarDesc w_key(w_name);
w_key.SetPersistable(true);
auto* w_node = g->CreateVarNode(&w_key);
if (!use_gpu && use_fc_padding) { if (!use_gpu && use_fc_padding) {
auto* scope = param_scope(); auto* scope = param_scope();
auto* weight = scope->FindVar(w->Name())->GetMutable<LoDTensor>(); auto* weight = scope->FindVar(w->Name())->GetMutable<LoDTensor>();
...@@ -102,20 +106,25 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { ...@@ -102,20 +106,25 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
int w_h = weight_dims[0]; int w_h = weight_dims[0];
int w_w = weight_dims[1]; int w_w = weight_dims[1];
if (w_h % 128 == 0 && w_w % 128 == 0) { if (w_h % 128 == 0 && w_w % 128 == 0) {
auto* w_var = scope->Var(w_name);
auto* w_tensor = w_var->GetMutable<framework::LoDTensor>();
auto* weight_data_tmp = new float[weight_num]; auto* weight_data_tmp = new float[weight_num];
for (int i = 0; i < w_h; i++) { for (int i = 0; i < w_h; i++) {
memcpy(weight_data_tmp + i * w_w, weight_data + i * w_w, memcpy(weight_data_tmp + i * w_w, weight_data + i * w_w,
w_w * sizeof(float)); w_w * sizeof(float));
} }
weight->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4}); w_tensor->Resize(DDim{weight_dims[0] + 4, weight_dims[1] + 4});
auto* weight_data_new = auto* weight_data_new =
weight->mutable_data<float>(platform::CPUPlace()); w_tensor->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < w_h; i++) { for (int i = 0; i < w_h; i++) {
memcpy(weight_data_new + i * (w_w + 4), weight_data_tmp + i * w_w, memcpy(weight_data_new + i * (w_w + 4), weight_data_tmp + i * w_w,
w_w * sizeof(float)); w_w * sizeof(float));
} }
delete[] weight_data_tmp; delete[] weight_data_tmp;
desc.SetInput("W", {w_name});
desc.SetAttr("padding_weights", true); desc.SetAttr("padding_weights", true);
desc.Flush();
} }
} }
...@@ -147,7 +156,12 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { ...@@ -147,7 +156,12 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
} }
IR_NODE_LINK_TO(subgraph.at(x), fc_node); IR_NODE_LINK_TO(subgraph.at(x), fc_node);
IR_NODE_LINK_TO(w, fc_node); if (desc.GetAttrIfExists<bool>("padding_weights")) {
IR_NODE_LINK_TO(w_node, fc_node);
} else {
GraphSafeRemoveNodes(g, {w_node});
IR_NODE_LINK_TO(w, fc_node);
}
IR_NODE_LINK_TO(bias, fc_node); IR_NODE_LINK_TO(bias, fc_node);
if (with_relu) { if (with_relu) {
IR_NODE_LINK_TO(fc_node, relu_out); IR_NODE_LINK_TO(fc_node, relu_out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册