提交 f4f85831 编写于 作者: Z Zhang Ting 提交者: lanxianghit

fix the bug of conv_transpose cudnn kernel, test=develop (#20958)

fix the bug of conv_transpose cudnn kernel: before version 1.6, the data_format is AnyLayout in inference model. When use version 1.6 and load the model which is saved by previous version, the error occurs.  This is because the cudnn kernel in version 1.6 is not compitable with Anylayout setting.
上级 7695b713
......@@ -72,7 +72,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
const T* filter_data = filter->data<T>();
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const paddle::operators::DataLayout data_layout =
(data_layout_str == "NCHW" ? DataLayout::kNCHW : DataLayout::kNHWC);
(data_layout_str != "NHWC" ? DataLayout::kNCHW : DataLayout::kNHWC);
// if channel_last, transpose to channel_first
Tensor input_transpose;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册