From d4606bcb22e7b516541333cb0dfdc375bdb8ac54 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Mon, 24 Dec 2018 14:17:35 +0800 Subject: [PATCH] Fix the exception when tensor format is x test=develop --- .../operators/elementwise/elementwise_mul_mkldnn_op.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc index 4c73a70ed1c..04e8800bbc8 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc @@ -134,16 +134,18 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel { const bool are_inputs_in_same_format = x->format() == y->format(); const bool is_x_nchw = x->format() == memory::format::nchw; const bool is_x_nc = x->format() == memory::format::nc; + const bool is_x_x = x->format() == memory::format::x; const bool is_y_nchw = y->format() == memory::format::nchw; const bool is_y_nc = y->format() == memory::format::nc; + const bool is_y_x = y->format() == memory::format::x; if (!are_inputs_in_same_format) { using platform::MKLDNNDeviceContext; auto& dev_ctx = ctx.template device_context(); const auto& mkldnn_engine = dev_ctx.GetEngine(); - if (!(is_x_nchw || is_x_nc)) + if (!(is_x_nchw || is_x_nc || is_x_x)) ReorderInput(const_cast(x), ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4); - if (!(is_y_nchw || is_y_nc)) + if (!(is_y_nchw || is_y_nc || is_y_x)) ReorderInput(const_cast(y), ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4); } -- GitLab