diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index f2f4d3fee053a1e5bacd3c2165dba960f3befea4..0c1f288325c815d7ed2ebf4fb499fa411759520a 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -130,6 +130,13 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { z->set_layout(DataLayout::kMKLDNN); z->set_format(x->format()); + + // convert to nchw format to align with native version + using platform::MKLDNNDeviceContext; + auto& dev_ctx = ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + ReorderInput(const_cast(z), ctx.GetPlace(), mkldnn_engine, + z->dims().size() == 4); } else { // Fallback to naive version: const bool are_inputs_in_same_format = x->format() == y->format(); diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py index 738715dd70181988028adff1c50be3a52199c312..04486119cbd39a47bed2163ca3b9eb6827bfbe62 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py @@ -28,7 +28,8 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): self.y = np.random.rand(1, 16).astype(self.dtype) self.out = x * self.y.reshape(1, 16, 1, 1) - self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) + +# self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2) def setUp(self): super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp()