提交 6f0b2b19 编写于 作者: Z Zhang Ting 提交者: lanxianghit

[cherry-pick] fix the bug of conv_transpose cudnn kernel, test=release/1.6 (#20958) (#20974)

fix the bug of conv_transpose cudnn kernel:cherry-pick #20958
上级 692a04ec
...@@ -72,7 +72,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -72,7 +72,7 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
const T* filter_data = filter->data<T>(); const T* filter_data = filter->data<T>();
const std::string data_layout_str = ctx.Attr<std::string>("data_format"); const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const paddle::operators::DataLayout data_layout = const paddle::operators::DataLayout data_layout =
(data_layout_str == "NCHW" ? DataLayout::kNCHW : DataLayout::kNHWC); (data_layout_str != "NHWC" ? DataLayout::kNCHW : DataLayout::kNHWC);
// if channel_last, transpose to channel_first // if channel_last, transpose to channel_first
Tensor input_transpose; Tensor input_transpose;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册