From d93ee063ce6473da487b45feba1899c3efe31744 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C5=82awomir=20Siwek?= Date: Tue, 30 Nov 2021 13:50:20 +0100 Subject: [PATCH] Add new unittests for gIOHW format in conv_transpose_mkldnn_op (#37344) * Add new unittests * Replace I with O channel for filter groups * Undo changes affecting other operators * Fix oneDNN namespace typo * Fix code format error --- .../mkldnn/conv_transpose_mkldnn_op.cc | 10 ++++----- .../mkldnn/test_conv2d_transpose_mkldnn_op.py | 21 +++++++++++++++++++ 2 files changed, 25 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc index 35e35eb4bcb..4a3d1f455bd 100644 --- a/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc @@ -26,14 +26,12 @@ using Tensor = framework::Tensor; using framework::DataLayout; inline dnnl::memory::dims GetWeightsTz(const Tensor* filter, const int groups) { - auto iohw_weights_tz = framework::vectorize(filter->dims()); - auto weights_tz = iohw_weights_tz; - - // IOHW -> OIHW - weights_tz[0] = iohw_weights_tz[1]; - weights_tz[1] = iohw_weights_tz[0]; + auto weights_tz = framework::vectorize(filter->dims()); int g = std::max(groups, 1); + int g_dim = (g > 1) ? 1 : 0; platform::GetGroupConvWeightsTz(weights_tz, g); + // gIOHW -> gOIHW || IOHW -> OIHW + std::swap(weights_tz[g_dim + 0], weights_tz[g_dim + 1]); return weights_tz; } diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_transpose_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_transpose_mkldnn_op.py index 86609f015a2..a36fc28013b 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_transpose_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_transpose_mkldnn_op.py @@ -154,6 +154,27 @@ class TestConv2DTransposeMKLDNNWithDilationsExplicitPad( self.padding_algorithm = "EXPLICIT" +class TestMKLDNNWithGroups(TestConv2DTransposeMKLDNNOp): + def init_test_case(self): + TestConv2DTransposeMKLDNNOp.init_test_case(self) + self.pad = [1, 1] + self.groups = 2 + self.input_size = [2, 4, 5, 5] # NCHW + f_c = self.input_size[1] + self.filter_size = [f_c, 3, 3, 3] + + +class TestMKLDNNWithGroups_NHWC(TestConv2DTransposeMKLDNNOp): + def init_test_case(self): + TestConv2DTransposeMKLDNNOp.init_test_case(self) + self.pad = [1, 1] + self.groups = 2 + self.input_size = [2, 5, 5, 4] # NHWC + f_c = self.input_size[-1] + self.filter_size = [f_c, 3, 3, 3] + self.data_format = 'NHWC' + + if __name__ == '__main__': enable_static() unittest.main() -- GitLab