From a9694bd3d64efcb1a5642e090d777cd8948aa02c Mon Sep 17 00:00:00 2001 From: Leo Zhao Date: Fri, 12 Apr 2019 16:02:56 +0800 Subject: [PATCH] convert output to nchw format to align with native version in avx512 mode test = develop resolve #16764 --- .../elementwise/mkldnn/elementwise_mul_mkldnn_op.cc | 7 +++++++ .../unittests/mkldnn/test_elementwise_mul_mkldnn_op.py | 3 ++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index f2f4d3fee..0c1f28832 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -130,6 +130,13 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { z->set_layout(DataLayout::kMKLDNN); z->set_format(x->format()); + + // convert to nchw format to align with native version + using platform::MKLDNNDeviceContext; + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + ReorderInput(const_cast(z), ctx.GetPlace(), mkldnn_engine, + z->dims().size() == 4); } else { // Fallback to naive version: const bool are_inputs_in_same_format = x->format() == y->format(); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py index 738715dd7..04486119c 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py @@ -28,7 +28,8 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): self.y = np.random.rand(1, 16).astype(self.dtype) self.out = x * self.y.reshape(1, 16, 1, 1) - self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + +# self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) def setUp(self): super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp() -- GitLab