未验证 提交 56759ff4 编写于 作者: C crystal 提交者: GitHub

optimization batch_norm 2D and NCHW format on CPU (#34585)

上级 a3cc2d0b
...@@ -295,8 +295,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -295,8 +295,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
bool global_stats = test_mode || use_global_stats; bool global_stats = test_mode || use_global_stats;
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout = DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
framework::StringToDataLayout(data_layout_str);
const auto *x = ctx.Input<Tensor>("X"); const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims(); const auto &x_dims = x->dims();
...@@ -332,6 +331,12 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -332,6 +331,12 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
saved_mean->mutable_data<T>(ctx.GetPlace()); saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace()); saved_variance->mutable_data<T>(ctx.GetPlace());
// input dimension is 2 and the format is NCHW. The input can be regarded
// as NHWC format
if (x_dims.size() == 2 && data_layout == DataLayout::kNCHW) {
data_layout = DataLayout::kNHWC;
}
if (!global_stats) { if (!global_stats) {
// saved_xx is use just in this batch of data // saved_xx is use just in this batch of data
EigenVectorArrayMap<T> saved_mean_e( EigenVectorArrayMap<T> saved_mean_e(
...@@ -578,8 +583,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -578,8 +583,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
bool use_global_stats = ctx.Attr<bool>("use_global_stats"); bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const DataLayout data_layout = DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
framework::StringToDataLayout(data_layout_str);
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale")); auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
...@@ -633,6 +637,12 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T> ...@@ -633,6 +637,12 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
const int sample_size = x->numel() / N / C; const int sample_size = x->numel() / N / C;
// input dimension is 2 and the format is NCHW. The input can be regarded as
// NHWC format
if (x_dims.size() == 2 && data_layout == DataLayout::kNCHW) {
data_layout = DataLayout::kNHWC;
}
// init output // init output
if (d_x) { if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace()); d_x->mutable_data<T>(ctx.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册