未验证 提交 f6653c71 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Fix to conv2d grad with groups (#27006)

* - Added fix to mobilenet

* - compilation fix

* - Fix to conv2d grad oneDNN with groups

test=develop
上级 8fa3d367
...@@ -1055,7 +1055,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -1055,7 +1055,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
astream.wait(); astream.wait();
filter_grad->set_layout(DataLayout::kMKLDNN); filter_grad->set_layout(DataLayout::kMKLDNN);
filter_grad->set_format(GetMKLDNNFormat(*diff_weights_memory_p)); // in OneDNN groups in convolution are treated as separate dimension
// which is not the case in paddlepaddle
auto filter_fmt = GetMKLDNNFormat(*diff_weights_memory_p);
filter_grad->set_format(platform::MKLDNNFormatForSize(
g > 1 ? weights_tz.size() - 1 : weights_tz.size(), filter_fmt));
} }
if (input_grad) { if (input_grad) {
auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive( auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册