diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 13b07a8e8fb7be3b494cbf4240e36aa77ef3380a..7fe980bf40641865a7145f9d53e56455e11b9104 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -429,6 +429,33 @@ std::vector Tensor::shape() const { PADDLE_ENFORCE_NOT_NULL( tensor_, paddle::platform::errors::PreconditionNotMet( "Not found tensor called %s in the scope", name_)); +// mkldnn may does layout transform internally, so need to reorder before +// return +#ifdef PADDLE_WITH_MKLDNN + if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) { + paddle::framework::DataLayout out_layout = + paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout(); + // Set default as NCHW in case not specified + out_layout = out_layout == paddle::framework::DataLayout::kAnyLayout + ? paddle::framework::DataLayout::kNCHW + : out_layout; + // In these data layouts, channel dimension is either on 2nd position: nChw + // or + // at last nhwC, so for dim==2 these layouts are the same and nothing should + // be done. Similarly for dim==1 when you have just one possible + // combination. + if (tensor->dims().size() < 3) + return paddle::framework::vectorize(tensor->dims()); + if (out_layout == paddle::framework::DataLayout::kNHWC) { + auto dims = paddle::framework::vectorize(tensor->dims()); + std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); + return dims; + } else { + return paddle::framework::vectorize(tensor->dims()); + } + } +#endif return paddle::framework::vectorize(tensor->dims()); } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bn_fuse_pass.py index 434b89135be76cacc3d9b08d19d99b75a0e7bb9b..67e97b0a3752ec319ec1bc2d1daa20e19db586d6 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bn_fuse_pass.py @@ -183,7 +183,9 @@ class TestConvBnFusePass(PassAutoScanTest): def add_ignore_pass_case(self): def teller1(program_config, predictor_config): - if program_config.ops[0].attrs['data_format'] == "NHWC": + if program_config.ops[0].attrs[ + 'data_format'] == "NHWC" and not predictor_config.mkldnn_enabled( + ): return True return False