From 52d6b306ecd6ebc7db84b91049387428c26e2411 Mon Sep 17 00:00:00 2001 From: wenbin Date: Thu, 10 Feb 2022 16:19:45 +0800 Subject: [PATCH] mkldnn layout issue fix (#39422) * mkldnn conv fix * definetion --- .../inference/api/details/zero_copy_tensor.cc | 27 +++++++++++++++++++ .../ir/inference/test_conv_bn_fuse_pass.py | 4 ++- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/api/details/zero_copy_tensor.cc b/paddle/fluid/inference/api/details/zero_copy_tensor.cc index 13b07a8e8fb..7fe980bf406 100644 --- a/paddle/fluid/inference/api/details/zero_copy_tensor.cc +++ b/paddle/fluid/inference/api/details/zero_copy_tensor.cc @@ -429,6 +429,33 @@ std::vector Tensor::shape() const { PADDLE_ENFORCE_NOT_NULL( tensor_, paddle::platform::errors::PreconditionNotMet( "Not found tensor called %s in the scope", name_)); +// mkldnn may does layout transform internally, so need to reorder before +// return +#ifdef PADDLE_WITH_MKLDNN + if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) { + paddle::framework::DataLayout out_layout = + paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout(); + // Set default as NCHW in case not specified + out_layout = out_layout == paddle::framework::DataLayout::kAnyLayout + ? paddle::framework::DataLayout::kNCHW + : out_layout; + // In these data layouts, channel dimension is either on 2nd position: nChw + // or + // at last nhwC, so for dim==2 these layouts are the same and nothing should + // be done. Similarly for dim==1 when you have just one possible + // combination. + if (tensor->dims().size() < 3) + return paddle::framework::vectorize(tensor->dims()); + if (out_layout == paddle::framework::DataLayout::kNHWC) { + auto dims = paddle::framework::vectorize(tensor->dims()); + std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); + return dims; + } else { + return paddle::framework::vectorize(tensor->dims()); + } + } +#endif return paddle::framework::vectorize(tensor->dims()); } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bn_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bn_fuse_pass.py index 434b89135be..67e97b0a375 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bn_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_conv_bn_fuse_pass.py @@ -183,7 +183,9 @@ class TestConvBnFusePass(PassAutoScanTest): def add_ignore_pass_case(self): def teller1(program_config, predictor_config): - if program_config.ops[0].attrs['data_format'] == "NHWC": + if program_config.ops[0].attrs[ + 'data_format'] == "NHWC" and not predictor_config.mkldnn_enabled( + ): return True return False -- GitLab