提交 8f5a93a0 编写于 作者: J Jacek Czaja 提交者: Tao Luo

- Fix to regression in performance of ResNet-50 training (#21588)

test=develop
上级 9ce0e29d
...@@ -775,8 +775,23 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -775,8 +775,23 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
* ('any') which lets a primitive (conv backward in this case) choose * ('any') which lets a primitive (conv backward in this case) choose
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
auto chosen_memory_format = MKLDNNMemoryFormat::any;
// TODO(jczaja): Once GRAD NHWC is working then format 'any'
// should be used exclusively. But till forward pass enforce
// NCHW for training we need to have NCHW here as well
// to avoid performance degradation in relu_grad and pool2d_grad
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
weights_format = MKLDNNMemoryFormat::any; weights_format = MKLDNNMemoryFormat::any;
// Check the format for user's special output
if (chosen_memory_format != MKLDNNMemoryFormat::any) {
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
}
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册