未验证 提交 a0666b9d 编写于 作者: J jakpiase 提交者: GitHub

Split op oneDNN AVX2 fix (#33944)

* added checking if md uses blocking format

* minor change

* removed unnecessary line
上级 84e813e3
...@@ -78,9 +78,18 @@ class SplitOp : public framework::OperatorWithKernel { ...@@ -78,9 +78,18 @@ class SplitOp : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), // OneDNN uses blocking format, which cannot be always
framework::DataLayout::kMKLDNN, // supported with reorders, because if blocked dimension is not divisible
framework::LibraryType::kMKLDNN); // by
// 8 or 16(depending on which blocking format is used) submemory cannot be
// created, so in that scenario a fallback is needed
auto tmp_md = dnnl::memory::desc(
framework::vectorize(ctx.Input<Tensor>("X")->dims()),
dnnl::memory::data_type::f32, ctx.Input<Tensor>("X")->format());
if (tmp_md.data.format_desc.blocking.inner_nblks == 0)
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
} }
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册