diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 0511eb42a073ac305634110a71a35e501f062132..f07ab5a33b87d7945e5fcdf8f3644f0711ce643b 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -280,12 +280,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { * ('any') which lets a primitive (convolution in this case) choose * the memory format preferred for best performance */ + std::string data_format = ctx.Attr("data_format"); + auto chosen_memory_format = + platform::data_format_to_memory_format(data_format); + auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), memory::format::any); + src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), memory::format::any); + weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto dst_md = platform::MKLDNNMemDesc( - dst_tz, platform::MKLDNNGetDataType(), memory::format::any); + dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); // create a conv primitive descriptor and save it for usage in backward std::shared_ptr conv_pd = @@ -423,16 +427,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { * ('any') which lets a primitive (conv backward in this case) choose * the memory format preferred for best performance */ + std::string data_format = ctx.Attr("data_format"); + auto chosen_memory_format = + platform::data_format_to_memory_format(data_format); + auto src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), memory::format::any); + src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto diff_src_md = platform::MKLDNNMemDesc( - src_tz, platform::MKLDNNGetDataType(), memory::format::any); + src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), memory::format::any); + weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto diff_weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), memory::format::any); + weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto diff_dst_md = platform::MKLDNNMemDesc( - dst_tz, platform::MKLDNNGetDataType(), memory::format::any); + dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); // Retrieve conv_pd from device context auto conv_pd = diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index a8f93e6848a1db1f5aa0ee266a076af2b5d0c964..10a3ad256b17ba41380cdc0377905d03188cbaa3 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -223,7 +223,7 @@ class MKLDNNHandler { static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT const std::string& suffix) { return dims2str(operand_dims) + suffix; - }; + } protected: static std::string dims2str(const mkldnn::memory::dims& operand_dims) { @@ -251,5 +251,17 @@ inline mkldnn::memory::format MKLDNNFormatForSize( return data_format; } +inline mkldnn::memory::format data_format_to_memory_format( + const std::string& data_format) { + switch (framework::StringToDataLayout(data_format)) { + case framework::DataLayout::kNHWC: + return mkldnn::memory::format::nhwc; + case framework::DataLayout::kNCHW: + return mkldnn::memory::format::nchw; + default: + return mkldnn::memory::format::any; + } +} + } // namespace platform } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py index db6be21baaa54d33af9f5c44d1815e4b389eb884..d0de7ad52c8a851c16cbbbf544d479f696dee136 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_mkldnn_op.py @@ -20,16 +20,19 @@ from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride class TestMKLDNN(TestConv2dOp): def init_kernel_type(self): self.use_mkldnn = True + self.data_format = "NCHW" class TestMKLDNNWithPad(TestWithPad): def init_kernel_type(self): self.use_mkldnn = True + self.data_format = "NCHW" class TestMKLDNNWithStride(TestWithStride): def init_kernel_type(self): self.use_mkldnn = True + self.data_format = "NCHW" if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_conv2d_op.py b/python/paddle/fluid/tests/unittests/test_conv2d_op.py index a478649541ba9828e55c4239090d5aee554223ac..f5b034312fd1f060877edf660a6bcf3fb493f7f7 100644 --- a/python/paddle/fluid/tests/unittests/test_conv2d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv2d_op.py @@ -66,6 +66,7 @@ class TestConv2dOp(OpTest): self.op_type = "conv2d" self.use_cudnn = False self.use_mkldnn = False + self.data_format = "AnyLayout" self.dtype = np.float32 self.init_kernel_type() self.init_group() @@ -93,7 +94,8 @@ class TestConv2dOp(OpTest): 'groups': self.groups, 'dilations': self.dilations, 'use_cudnn': self.use_cudnn, - 'use_mkldnn': self.use_mkldnn + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format } self.outputs = {'Output': output}