diff --git a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc index 1755b0f2082071b272f0d74b9d4e3298dcfb4bb9..79551b6d59a2cb4a501e20bce3d72e50a0416bb9 100644 --- a/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc @@ -170,6 +170,9 @@ class FCPrimitiveFactory { // In case of 2 dims, we set the only possible format, nc if (dim_num == 2) { out->set_format(MKLDNNMemoryFormat::nc); + out->set_mem_desc({phi::vectorize(out->dims()), + platform::MKLDNNGetDataType(), + out->format()}); // In case of 3 dims, we generate a format that is based on number // of output dims and the layout of input format (nchw or nhwc). } else if (dim_num == 3) { @@ -185,9 +188,6 @@ class FCPrimitiveFactory { } else { out->set_format(in_format); } - out->set_mem_desc({phi::vectorize(out->dims()), - platform::MKLDNNGetDataType(), - out->format()}); } void UpdateDataPointers(const ExecutionContext& ctx,