提交 a9694bd3 编写于 作者: L Leo Zhao

convert output to nchw format to align with native version in avx512 mode

test = develop
resolve #16764
上级 85363848
...@@ -130,6 +130,13 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -130,6 +130,13 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
z->set_layout(DataLayout::kMKLDNN); z->set_layout(DataLayout::kMKLDNN);
z->set_format(x->format()); z->set_format(x->format());
// convert to nchw format to align with native version
using platform::MKLDNNDeviceContext;
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
ReorderInput<T>(const_cast<Tensor*>(z), ctx.GetPlace(), mkldnn_engine,
z->dims().size() == 4);
} else { } else {
// Fallback to naive version: // Fallback to naive version:
const bool are_inputs_in_same_format = x->format() == y->format(); const bool are_inputs_in_same_format = x->format() == y->format();
......
...@@ -28,7 +28,8 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp): ...@@ -28,7 +28,8 @@ class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp):
self.y = np.random.rand(1, 16).astype(self.dtype) self.y = np.random.rand(1, 16).astype(self.dtype)
self.out = x * self.y.reshape(1, 16, 1, 1) self.out = x * self.y.reshape(1, 16, 1, 1)
self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
# self.out = self.out.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
def setUp(self): def setUp(self):
super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp() super(TestElementwiseMulMKLDNNOp_BroadcastNCHW16c, self).setUp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册