diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index b2cffc3f9063c1fd3b33baa9c740c2402fa00080..be17bf9a03fc1908fccb2bb6a5f32ce49db3353d 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -295,8 +295,7 @@ class BatchNormKernel bool global_stats = test_mode || use_global_stats; const std::string data_layout_str = ctx.Attr("data_layout"); - const DataLayout data_layout = - framework::StringToDataLayout(data_layout_str); + DataLayout data_layout = framework::StringToDataLayout(data_layout_str); const auto *x = ctx.Input("X"); const auto &x_dims = x->dims(); @@ -332,6 +331,12 @@ class BatchNormKernel saved_mean->mutable_data(ctx.GetPlace()); saved_variance->mutable_data(ctx.GetPlace()); + // input dimension is 2 and the format is NCHW. The input can be regarded + // as NHWC format + if (x_dims.size() == 2 && data_layout == DataLayout::kNCHW) { + data_layout = DataLayout::kNHWC; + } + if (!global_stats) { // saved_xx is use just in this batch of data EigenVectorArrayMap saved_mean_e( @@ -578,8 +583,7 @@ class BatchNormGradKernel bool use_global_stats = ctx.Attr("use_global_stats"); const bool is_test = ctx.Attr("is_test"); const float epsilon = ctx.Attr("epsilon"); - const DataLayout data_layout = - framework::StringToDataLayout(data_layout_str); + DataLayout data_layout = framework::StringToDataLayout(data_layout_str); auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_scale = ctx.Output(framework::GradVarName("Scale")); @@ -633,6 +637,12 @@ class BatchNormGradKernel : x_dims[x_dims.size() - 1]); const int sample_size = x->numel() / N / C; + // input dimension is 2 and the format is NCHW. The input can be regarded as + // NHWC format + if (x_dims.size() == 2 && data_layout == DataLayout::kNCHW) { + data_layout = DataLayout::kNHWC; + } + // init output if (d_x) { d_x->mutable_data(ctx.GetPlace());