未验证 提交 56759ff4 编写于 作者: C crystal 提交者: GitHub

optimization batch_norm 2D and NCHW format on CPU (#34585)

上级 a3cc2d0b
......@@ -295,8 +295,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
bool global_stats = test_mode || use_global_stats;
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
......@@ -332,6 +331,12 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
// input dimension is 2 and the format is NCHW. The input can be regarded
// as NHWC format
if (x_dims.size() == 2 && data_layout == DataLayout::kNCHW) {
data_layout = DataLayout::kNHWC;
}
if (!global_stats) {
// saved_xx is use just in this batch of data
EigenVectorArrayMap<T> saved_mean_e(
......@@ -578,8 +583,7 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
const float epsilon = ctx.Attr<float>("epsilon");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
DataLayout data_layout = framework::StringToDataLayout(data_layout_str);
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
......@@ -633,6 +637,12 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
: x_dims[x_dims.size() - 1]);
const int sample_size = x->numel() / N / C;
// input dimension is 2 and the format is NCHW. The input can be regarded as
// NHWC format
if (x_dims.size() == 2 && data_layout == DataLayout::kNCHW) {
data_layout = DataLayout::kNHWC;
}
// init output
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册