From 9774f9650e6f2f39c0972d968c945c365e75c769 Mon Sep 17 00:00:00 2001 From: Zhangjingyu06 <92561254+Zhangjingyu06@users.noreply.github.com> Date: Thu, 21 Apr 2022 14:16:27 +0800 Subject: [PATCH] modify batch_norm and batch_norm_grad. *test=kunlun (#41976) --- paddle/fluid/operators/batch_norm_op_xpu.cc | 176 +++++++++++--------- 1 file changed, 101 insertions(+), 75 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op_xpu.cc b/paddle/fluid/operators/batch_norm_op_xpu.cc index d6826e8710e..da138fb482e 100644 --- a/paddle/fluid/operators/batch_norm_op_xpu.cc +++ b/paddle/fluid/operators/batch_norm_op_xpu.cc @@ -1,5 +1,4 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -38,15 +37,25 @@ class BatchNormXPUKernel : public framework::OpKernel { bool global_stats = test_mode || use_global_stats; const auto &data_layout_str = ctx.Attr("data_layout"); const auto data_layout = framework::StringToDataLayout(data_layout_str); + PADDLE_ENFORCE_EQ(data_layout_str == "NCHW" || data_layout_str == "NHWC", + true, + platform::errors::InvalidArgument( + "The 'data_layout' attribute must be NCHW or NHWC. " + "But recevived 'data_layout' is [%s].", + data_layout_str)); + const auto *x = ctx.Input("X"); const auto &x_dims = x->dims(); - int temp = x_dims[3]; - temp = (x_dims.size() != 4) ? 1 : temp; - bool is_nchw = (data_layout == DataLayout::kNCHW); - const int N = x_dims[0]; - const int C = is_nchw ? x_dims[1] : temp; - const int H = is_nchw ? x_dims[2] : x_dims[1]; - const int W = is_nchw ? temp : x_dims[2]; + PADDLE_ENFORCE_EQ( + x_dims.size() >= 2 && x_dims.size() <= 5, true, + platform::errors::InvalidArgument( + "The size of input's dimensions should be between 2 and 5" + "But received: the size of input's dimensions is [%d]", + x_dims.size())); + + int N, C, H, W, D; + ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); + const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); const auto *x_data = x->data(); @@ -67,6 +76,7 @@ class BatchNormXPUKernel : public framework::OpKernel { saved_variance->mutable_data(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); + bool is_nchw = data_layout_str == "NCHW"; if (!global_stats) { auto *mean_out_data = mean_out->data(); @@ -83,35 +93,29 @@ class BatchNormXPUKernel : public framework::OpKernel { &mom_cpu); momentum = mom_tensor->data()[0]; } - if (C == 1) { - int r = xpu::batch_norm(dev_ctx.x_context(), x_data, y_data, N, 1, H, - W, epsilon, momentum, scale_data, bias_data, - saved_mean_data, saved_variance_data, - mean_out_data, variance_out_data, true); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "The batch_norm XPU API return wrong value[%d %s]", r, - XPUAPIErrorMsg[r])); - } else { - int r = xpu::batch_norm(dev_ctx.x_context(), x_data, y_data, N, C, H, - W, epsilon, momentum, scale_data, bias_data, - saved_mean_data, saved_variance_data, - mean_out_data, variance_out_data, is_nchw); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "The batch_norm XPU API return wrong value[%d %s]", r, - XPUAPIErrorMsg[r])); - } + + int r = xpu::batch_norm(dev_ctx.x_context(), x_data, y_data, N, C, H, + W, epsilon, momentum, scale_data, bias_data, + saved_mean_data, saved_variance_data, + mean_out_data, variance_out_data, is_nchw); + PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, + platform::errors::External( + "The batch_norm XPU API return wrong value[%d %s]", + r, XPUAPIErrorMsg[r])); } else { + PADDLE_ENFORCE_EQ( + data_layout_str == "NCHW", true, + platform::errors::InvalidArgument( + "The batch_norm_infer 'data_layout' attribute must be NCHW. " + "But recevived 'data_layout' is [%s].", + data_layout_str)); const auto *mean = ctx.Input("Mean"); const auto *variance = ctx.Input("Variance"); const auto *mean_data = mean->data(); const auto *variance_data = variance->data(); int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C, H, W, epsilon, scale_data, bias_data, - mean_data, variance_data, true); + mean_data, variance_data, is_nchw); PADDLE_ENFORCE_EQ( r, xpu::Error_t::SUCCESS, platform::errors::External( @@ -172,6 +176,13 @@ class BatchNormGradXPUKernel : public framework::OpKernel { const float epsilon = ctx.Attr("epsilon"); const auto data_layout = framework::StringToDataLayout(data_layout_str); + PADDLE_ENFORCE_EQ(data_layout_str == "NCHW" || data_layout_str == "NHWC", + true, + platform::errors::InvalidArgument( + "The 'data_layout' attribute must be NCHW or NHWC. " + "But recevived 'data_layout' is [%s].", + data_layout_str)); + auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_scale = ctx.Output(framework::GradVarName("Scale")); auto *d_bias = ctx.Output(framework::GradVarName("Bias")); @@ -204,13 +215,15 @@ class BatchNormGradXPUKernel : public framework::OpKernel { } const auto &x_dims = x->dims(); - int temp = x_dims[3]; - temp = (x_dims.size() != 4) ? 1 : temp; - bool is_nchw = (data_layout == DataLayout::kNCHW); - const int N = x_dims[0]; - const int C = is_nchw ? x_dims[1] : temp; - const int H = is_nchw ? x_dims[2] : x_dims[1]; - const int W = is_nchw ? temp : x_dims[2]; + PADDLE_ENFORCE_EQ( + x_dims.size() >= 2 && x_dims.size() <= 5, true, + platform::errors::InvalidArgument( + "The size of input's dimensions should be between 2 and 5" + "But received: the size of input's dimensions is [%d]", + x_dims.size())); + + int N, C, H, W, D; + ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); const auto *x_data = x->data(); const auto *d_y_data = d_y->data(); @@ -235,42 +248,45 @@ class BatchNormGradXPUKernel : public framework::OpKernel { "the size of scale's dimensions is [%d], the dimensions of scale " "is [%s].", scale->dims().size(), scale->dims())); + PADDLE_ENFORCE_EQ( + scale->dims()[0], C, + platform::errors::InvalidArgument( + "The first dimension of scale must equal to Channels[%d]. But " + "received: the first dimension of scale is [%d]", + C, scale->dims()[0])); auto &dev_ctx = ctx.template device_context(); xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - const T *mean_data = nullptr; - const T *inv_var_data = nullptr; + const auto *batch_mean = ctx.Input("SavedMean"); + const auto *batch_inv_std = ctx.Input("SavedVariance"); + const auto *global_mean = ctx.Input("Mean"); + const auto *global_var = ctx.Input("Variance"); // TODO(guozibin): hadle the situation case of N * H * W = 1 - if (!use_global_stats) { - const auto *saved_mean = ctx.Input("SavedMean"); - // SavedVariance have been reverted in forward operator - const auto *saved_inv_variance = ctx.Input("SavedVariance"); - mean_data = saved_mean->data(); - inv_var_data = saved_inv_variance->data(); - } else { - const auto *running_mean = ctx.Input("Mean"); - const auto *running_variance = ctx.Input("Variance"); - mean_data = running_mean->data(); - inv_var_data = running_variance->data(); - float *running_inv_var_data = - RAII_GUARD.alloc_l3_or_gm(running_variance->numel()); - float *epsilon_data = RAII_GUARD.alloc_l3_or_gm(1); - int r1 = calculate_inv_var(dev_ctx.x_context(), inv_var_data, epsilon, C, - epsilon_data, running_inv_var_data); - PADDLE_ENFORCE_EQ(r1, XPU_SUCCESS, platform::errors::External( - "XPU API(batch_norm_grad " - "calculate_inv_var function) " - "return wrong value[%d %s]", - r1, XPUAPIErrorMsg[r1])); - inv_var_data = running_inv_var_data; - } if (is_inplace) { + float *global_inv_std_data = nullptr; + if (use_global_stats) { + global_inv_std_data = + RAII_GUARD.alloc_l3_or_gm(global_var->numel()); + float *epsilon_data = RAII_GUARD.alloc_l3_or_gm(1); + int r1 = + calculate_inv_var(dev_ctx.x_context(), global_var->data(), + epsilon, C, epsilon_data, global_inv_std_data); + PADDLE_ENFORCE_EQ(r1, XPU_SUCCESS, platform::errors::External( + "XPU API(batch_norm_grad " + "calculate_inv_var function) " + "return wrong value[%d %s]", + r1, XPUAPIErrorMsg[r1])); + } auto px = *x; + auto *inv_std_data = + use_global_stats ? global_inv_std_data : batch_inv_std->data(); + auto mean_data = use_global_stats ? global_mean->data() + : batch_mean->data(); int r2 = calculate_inv_BN_Y( dev_ctx.x_context(), px.mutable_data(ctx.GetPlace()), - scale->data(), bias->data(), mean_data, inv_var_data, N, + scale->data(), bias->data(), mean_data, inv_std_data, N, C, H * W, x->data()); PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS, platform::errors::External( "XPU API(batch_norm_grad " @@ -278,19 +294,29 @@ class BatchNormGradXPUKernel : public framework::OpKernel { "return wrong value[%d %s]", r2, XPUAPIErrorMsg[r2])); } - if (!d_x) { - d_x_data = RAII_GUARD.alloc_l3_or_gm(x->numel()); - } - if (!d_scale) { - d_scale_data = RAII_GUARD.alloc_l3_or_gm(C); - } - if (!d_bias_data) { - d_bias_data = RAII_GUARD.alloc_l3_or_gm(C); - } - int r3 = xpu::batch_norm_grad( - dev_ctx.x_context(), x_data, d_y_data, d_x_data, N, C, H, W, scale_data, - mean_data, inv_var_data, d_scale_data, d_bias_data, is_nchw); + int r3; + bool is_nchw = data_layout_str == "NCHW"; + if (use_global_stats) { + r3 = xpu::batch_norm_grad( + dev_ctx.x_context(), x_data, d_y_data, d_x_data, N, C, H, W, + scale_data, nullptr, nullptr, d_scale_data, d_bias_data, is_nchw, + global_mean->data(), global_var->data(), epsilon); + } else { + if (!d_x) { + d_x_data = RAII_GUARD.alloc_l3_or_gm(x->numel()); + } + if (!d_scale) { + d_scale_data = RAII_GUARD.alloc_l3_or_gm(C); + } + if (!d_bias_data) { + d_bias_data = RAII_GUARD.alloc_l3_or_gm(C); + } + r3 = xpu::batch_norm_grad( + dev_ctx.x_context(), x_data, d_y_data, d_x_data, N, C, H, W, + scale_data, batch_mean->data(), batch_inv_std->data(), + d_scale_data, d_bias_data, is_nchw); + } PADDLE_ENFORCE_EQ(r3, XPU_SUCCESS, platform::errors::External( "XPU API(batch_norm_grad) return " "wrong value[%d %s]", -- GitLab