提交 d4606bcb 编写于 作者: Y Yihua Xu

Fix the exception when tensor format is x

test=develop
上级 641313ea
...@@ -134,16 +134,18 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> { ...@@ -134,16 +134,18 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
const bool are_inputs_in_same_format = x->format() == y->format(); const bool are_inputs_in_same_format = x->format() == y->format();
const bool is_x_nchw = x->format() == memory::format::nchw; const bool is_x_nchw = x->format() == memory::format::nchw;
const bool is_x_nc = x->format() == memory::format::nc; const bool is_x_nc = x->format() == memory::format::nc;
const bool is_x_x = x->format() == memory::format::x;
const bool is_y_nchw = y->format() == memory::format::nchw; const bool is_y_nchw = y->format() == memory::format::nchw;
const bool is_y_nc = y->format() == memory::format::nc; const bool is_y_nc = y->format() == memory::format::nc;
const bool is_y_x = y->format() == memory::format::x;
if (!are_inputs_in_same_format) { if (!are_inputs_in_same_format) {
using platform::MKLDNNDeviceContext; using platform::MKLDNNDeviceContext;
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
if (!(is_x_nchw || is_x_nc)) if (!(is_x_nchw || is_x_nc || is_x_x))
ReorderInput<T>(const_cast<Tensor*>(x), ctx.GetPlace(), mkldnn_engine, ReorderInput<T>(const_cast<Tensor*>(x), ctx.GetPlace(), mkldnn_engine,
x->dims().size() == 4); x->dims().size() == 4);
if (!(is_y_nchw || is_y_nc)) if (!(is_y_nchw || is_y_nc || is_y_x))
ReorderInput<T>(const_cast<Tensor*>(y), ctx.GetPlace(), mkldnn_engine, ReorderInput<T>(const_cast<Tensor*>(y), ctx.GetPlace(), mkldnn_engine,
y->dims().size() == 4); y->dims().size() == 4);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册