未验证 提交 0666b858 编写于 作者: G Guoxia Wang 提交者: GitHub

fix bug for fp32 batchnorm_op when using nhwc data_layout (#37020)

上级 cde335a1
...@@ -916,7 +916,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T> ...@@ -916,7 +916,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
Tensor transformed_d_y(d_y->type()); Tensor transformed_d_y(d_y->type());
Tensor transformed_d_x; Tensor transformed_d_x;
if (data_layout == DataLayout::kNHWC && if (data_layout == DataLayout::kNHWC &&
compute_format == DataLayout::kNCHW) { compute_format == DataLayout::kNCHW && x_dims.size() > 2) {
VLOG(3) << "Transform input tensor from NHWC to NCHW."; VLOG(3) << "Transform input tensor from NHWC to NCHW.";
ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, x, ResizeToChannelFirst<platform::CUDADeviceContext, T>(ctx, x,
&transformed_x); &transformed_x);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册