diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index 3d26c2c570858e11771fc27afabdcce5c0fb9443..b4cf9c48df2a804a90c633b5aeaf15a794223595 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -916,7 +916,7 @@ class BatchNormGradKernel Tensor transformed_d_y(d_y->type()); Tensor transformed_d_x; if (data_layout == DataLayout::kNHWC && - compute_format == DataLayout::kNCHW) { + compute_format == DataLayout::kNCHW && x_dims.size() > 2) { VLOG(3) << "Transform input tensor from NHWC to NCHW."; ResizeToChannelFirst(ctx, x, &transformed_x);