提交 7b5ea706 编写于 作者: H Haihao Shen

Fix the weight md for group convolution

上级 59a90d08
...@@ -566,8 +566,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -566,8 +566,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
src_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( src_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format))); src_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format)));
weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( weights_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<float>(), weights_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format)));
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw)));
dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc( dst_md.reset(new mkldnn::memory::desc(platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format))); dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format)));
mds[2] = src_md; mds[2] = src_md;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册