未验证 提交 f4a6c539 编写于 作者: J Jacek Czaja 提交者: GitHub

Conv grad to use set_mem_desc() (#46459)

* - Conv grad changed for MD

* - lint

* - compilation fix

* yet another lint
上级 cdcc0013
...@@ -1048,14 +1048,12 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1048,14 +1048,12 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
{DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); {DNNL_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
astream.wait(); astream.wait();
filter_grad->set_layout(framework::DataLayout::kMKLDNN);
// in OneDNN groups in convolution are treated as separate dimension
// which is not the case in paddlepaddle
auto filter_fmt = platform::GetMKLDNNFormat(*diff_weights_memory_p);
// For convolution with groups convert from blocked to NCHW // For convolution with groups convert from blocked to NCHW
// otherwise there will be problems in next operators working on this data // otherwise there will be problems in next operators working on this data
if (g > 1) { if (g > 1) {
// in OneDNN groups in convolution are treated as separate dimension
// which is not the case in paddlepaddle
dnnl::memory::data_type in_type = framework::ToMKLDNNDataType( dnnl::memory::data_type in_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(filter->dtype())); framework::TransToProtoVarType(filter->dtype()));
// for 3d conv with groups (six dimensional data reorder to goidhw) // for 3d conv with groups (six dimensional data reorder to goidhw)
...@@ -1094,9 +1092,12 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1094,9 +1092,12 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
dnnl::memory::format_tag target_format = dnnl::memory::format_tag target_format =
weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw weights_tz.size() == 6 ? dnnl::memory::format_tag::oidhw
: dnnl::memory::format_tag::oihw; : dnnl::memory::format_tag::oihw;
filter_grad->set_format(target_format); filter_grad->set_mem_desc(
dnnl::memory::desc(phi::vectorize<int64_t>(filter_grad->dims()),
in_type,
target_format));
} else { } else {
filter_grad->set_format(filter_fmt); filter_grad->set_mem_desc(diff_weights_memory_p->get_desc());
} }
} }
if (input_grad) { if (input_grad) {
...@@ -1119,8 +1120,7 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -1119,8 +1120,7 @@ class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
{DNNL_ARG_DIFF_SRC, *diff_src_memory_p}}); {DNNL_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait(); astream.wait();
input_grad->set_layout(framework::DataLayout::kMKLDNN); input_grad->set_mem_desc(diff_src_memory_p->get_desc());
input_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory_p));
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册