From cc83c95f357dad4d059993b937dcca2237916fe1 Mon Sep 17 00:00:00 2001 From: Leo Guo <58431564+ZibinGuo@users.noreply.github.com> Date: Thu, 30 Dec 2021 10:15:13 +0800 Subject: [PATCH] Fix the bug of batch_norm and batch_norm_grad op. (#38288) * Fix the bug of batch_norm and batch_norm_grad op. Add the "roi_align" and "roi_align_grad" op in xpu2 op list. * Fix the bug of batch_norm and batch_norm_grad op. Add the "roi_align" and "roi_align_grad" op in xpu2 op list. test=kunlun Co-authored-by: Zibin --- paddle/fluid/operators/batch_norm_op_xpu.cc | 279 ++++++++++++++---- .../fluid/platform/device/xpu/xpu2_op_list.h | 3 + .../unittests/xpu/test_batch_norm_op_xpu.py | 53 ++++ 3 files changed, 275 insertions(+), 60 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op_xpu.cc b/paddle/fluid/operators/batch_norm_op_xpu.cc index 8499d1cdcd6..d232891f3d6 100644 --- a/paddle/fluid/operators/batch_norm_op_xpu.cc +++ b/paddle/fluid/operators/batch_norm_op_xpu.cc @@ -15,6 +15,8 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU #include "paddle/fluid/operators/batch_norm_op.h" +#include +#include namespace paddle { namespace operators { @@ -25,23 +27,25 @@ using DDim = framework::DDim; template class BatchNormXPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { + void Compute(const framework::ExecutionContext &ctx) const override { const auto epsilon = ctx.Attr("epsilon"); - const auto momentum = ctx.Attr("momentum"); + float momentum = ctx.Attr("momentum"); const auto is_test = ctx.Attr("is_test"); const auto use_global_stats = ctx.Attr("use_global_stats"); const auto trainable_stats = ctx.Attr("trainable_statistics"); bool test_mode = is_test && (!trainable_stats); + bool global_stats = test_mode || use_global_stats; - const auto& data_layout_str = ctx.Attr("data_layout"); + const auto &data_layout_str = ctx.Attr("data_layout"); const auto data_layout = framework::StringToDataLayout(data_layout_str); PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW, platform::errors::InvalidArgument( "The 'data_layout' attribute must be NCHW. But " "recevived 'data_layout' is [%s].", data_layout_str)); - const auto* x = ctx.Input("X"); - const auto& x_dims = x->dims(); + + const auto *x = ctx.Input("X"); + const auto &x_dims = x->dims(); PADDLE_ENFORCE_EQ(x_dims.size(), 4, platform::errors::InvalidArgument( "The input tensor X's dimension must equal to 4. But " @@ -51,27 +55,42 @@ class BatchNormXPUKernel : public framework::OpKernel { const int C = x_dims[1]; const int H = x_dims[2]; const int W = x_dims[3]; - const auto* scale = ctx.Input("Scale"); - const auto* bias = ctx.Input("Bias"); - const auto* x_data = x->data(); - const auto* scale_data = scale->data(); - const auto* bias_data = bias->data(); - auto* y = ctx.Output("Y"); - auto* y_data = y->mutable_data(ctx.GetPlace()); - auto& dev_ctx = ctx.template device_context(); + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + const auto *x_data = x->data(); + const auto *scale_data = scale->data(); + const auto *bias_data = bias->data(); + + auto *y = ctx.Output("Y"); + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); + auto *saved_mean = ctx.Output("SavedMean"); + auto *saved_variance = ctx.Output("SavedVariance"); + + // alloc memory + auto *y_data = y->mutable_data(ctx.GetPlace()); + mean_out->mutable_data(ctx.GetPlace()); + variance_out->mutable_data(ctx.GetPlace()); + saved_mean->mutable_data(ctx.GetPlace()); + saved_variance->mutable_data(ctx.GetPlace()); + + auto &dev_ctx = ctx.template device_context(); + if (!global_stats) { - auto* mean_out = ctx.Output("MeanOut"); - auto* variance_out = ctx.Output("VarianceOut"); - auto* saved_mean = ctx.Output("SavedMean"); - auto* saved_variance = ctx.Output("SavedVariance"); - mean_out->mutable_data(ctx.GetPlace()); - variance_out->mutable_data(ctx.GetPlace()); - saved_mean->mutable_data(ctx.GetPlace()); - saved_variance->mutable_data(ctx.GetPlace()); - auto* mean_out_data = mean_out->data(); - auto* variance_out_data = variance_out->data(); - auto* saved_mean_data = saved_mean->data(); - auto* saved_variance_data = saved_variance->data(); + auto *mean_out_data = mean_out->data(); + auto *variance_out_data = variance_out->data(); + auto *saved_mean_data = saved_mean->data(); + auto *saved_variance_data = saved_variance->data(); + + // if MomentumTensor is set, use MomentumTensor value, momentum + // is only used in this training branch + if (ctx.HasInput("MomentumTensor")) { + const auto *mom_tensor = ctx.Input("MomentumTensor"); + Tensor mom_cpu; + TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu); + momentum = mom_tensor->data()[0]; + } + int r = xpu::batch_norm(dev_ctx.x_context(), x_data, y_data, N, C, H, W, epsilon, momentum, scale_data, bias_data, saved_mean_data, saved_variance_data, @@ -81,12 +100,10 @@ class BatchNormXPUKernel : public framework::OpKernel { "The batch_norm XPU API return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); } else { - const auto* mean = ctx.Input("Mean"); - const auto* variance = ctx.Input("Variance"); - const auto* mean_data = mean->data(); - const auto* variance_data = variance->data(); - const auto* x_data = x->data(); - auto* y_data = y->mutable_data(ctx.GetPlace()); + const auto *mean = ctx.Input("Mean"); + const auto *variance = ctx.Input("Variance"); + const auto *mean_data = mean->data(); + const auto *variance_data = variance->data(); int r = xpu::batch_norm_infer(dev_ctx.x_context(), x_data, y_data, N, C, H, W, epsilon, scale_data, bias_data, mean_data, variance_data, true); @@ -99,24 +116,96 @@ class BatchNormXPUKernel : public framework::OpKernel { } }; +template +static int calculate_inv_BN_Y(xpu::Context *ctx, T *x, const T *scale, + const T *bias, const T *mean, const T *variance, + const int N, const int C, const int M, + const T *y) { + PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument( + "X and Y should be inplaced in inplace mode")); + std::vector tensor_shape_vec({N, C, M}); + std::vector array_shape_vec({1, C, 1}); + // y - bias + int r1 = + xpu::broadcast_sub(ctx, bias, y, x, array_shape_vec, tensor_shape_vec); + // (y - bias) / scale + int r2 = xpu::broadcast_div(ctx, scale, x, x, array_shape_vec, + tensor_shape_vec); + // (y - bias) / scale / variance + int r3 = xpu::broadcast_div(ctx, variance, x, x, array_shape_vec, + tensor_shape_vec); + // (y - bias) / scale / variance + mean + int r4 = + xpu::broadcast_add(ctx, mean, x, x, array_shape_vec, tensor_shape_vec); + + return r1 + r2 + r3 + r4; +} + +template +static int calculate_inv_var(xpu::Context *ctx, const T *var, const T epsilon, + const int C, T *epsilon_data, T *inv_var) { + int r1 = constant(ctx, epsilon_data, 1, epsilon); + std::vector tensor_shape_vec({C}); + std::vector array_shape_vec({1}); + int r2 = xpu::broadcast_add(ctx, epsilon_data, var, inv_var, + array_shape_vec, tensor_shape_vec); + int r3 = xpu::rsqrt(ctx, inv_var, inv_var, C); + return r1 + r2 + r3; +} + template class BatchNormGradXPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* x = ctx.Input("X"); - const auto* dy = ctx.Input(framework::GradVarName("Y")); - const auto* scale = ctx.Input("Scale"); - const auto* saved_mean = ctx.Input("SavedMean"); - // SavedVariance have been reverted in forward operator - const auto* saved_inv_variance = ctx.Input("SavedVariance"); - const auto& data_layout_str = ctx.Attr("data_layout"); + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + + const auto &data_layout_str = ctx.Attr("data_layout"); + bool use_global_stats = ctx.Attr("use_global_stats"); + const bool is_test = ctx.Attr("is_test"); + const float epsilon = ctx.Attr("epsilon"); const auto data_layout = framework::StringToDataLayout(data_layout_str); + + // TODO(guozbin): Transform input tensor from NHWC to NCHW PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW, platform::errors::InvalidArgument( "The 'data_layout' attribute must be NCHW. But " "recevived 'data_layout' is [%s].", data_layout_str)); - const auto& x_dims = x->dims(); + + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + use_global_stats = is_test || use_global_stats; + + // batch_norm with inplace as false will take X as grad input, which + // is same as cuDNN batch_norm backward calculation, batch_norm + // with inplace as true only take Y as input and X should be calculate + // by inverse operation of batch_norm on Y + const Tensor *x; + bool is_inplace; + if (ctx.HasInput("Y")) { + x = ctx.Input("Y"); + is_inplace = true; + // if the input of batch norm is stop_gradient, d_x is null. + if (d_x) { + PADDLE_ENFORCE_EQ(d_x, d_y, + platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD not inplace in inplace mode")); + } + } else { + x = ctx.Input("X"); + is_inplace = false; + if (d_x) { + PADDLE_ENFORCE_NE( + d_x, d_y, platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD inplaced in non-inplace mode")); + } + } + + const auto &x_dims = x->dims(); PADDLE_ENFORCE_EQ(x_dims.size(), 4, platform::errors::InvalidArgument( "The input tensor X's dimension must equal to 4. But " @@ -126,26 +215,96 @@ class BatchNormGradXPUKernel : public framework::OpKernel { const int C = x_dims[1]; const int H = x_dims[2]; const int W = x_dims[3]; - const auto* x_data = x->data(); - const auto* dy_data = dy->data(); - const auto* scale_data = scale->data(); - const auto* saved_mean_data = saved_mean->data(); - const auto* saved_inv_variance_data = saved_inv_variance->data(); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dscale = ctx.Output(framework::GradVarName("Scale")); - auto* dbias = ctx.Output(framework::GradVarName("Bias")); - auto* dx_data = dx->mutable_data(ctx.GetPlace()); - auto* dscale_data = dscale->mutable_data(ctx.GetPlace()); - auto* dbias_data = dbias->mutable_data(ctx.GetPlace()); - auto& dev_ctx = ctx.template device_context(); - int r = xpu::batch_norm_grad(dev_ctx.x_context(), x_data, dy_data, - dx_data, N, C, H, W, scale_data, - saved_mean_data, saved_inv_variance_data, - dscale_data, dbias_data, true); - PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( - "XPU API(batch_norm_grad) return " - "wrong value[%d %s]", - r, XPUAPIErrorMsg[r])); + + const auto *x_data = x->data(); + const auto *d_y_data = d_y->data(); + const auto *scale_data = scale->data(); + + // init output + T *d_x_data = nullptr; + T *d_bias_data = nullptr; + T *d_scale_data = nullptr; + if (d_x) { + d_x_data = d_x->mutable_data(ctx.GetPlace()); + } + if (d_scale && d_bias) { + d_scale_data = d_scale->mutable_data(ctx.GetPlace()); + d_bias_data = d_bias->mutable_data(ctx.GetPlace()); + } + + PADDLE_ENFORCE_EQ( + scale->dims().size(), 1UL, + platform::errors::InvalidArgument( + "The size of scale's dimensions must equal to 1. But received: " + "the size of scale's dimensions is [%d], the dimensions of scale " + "is [%s].", + scale->dims().size(), scale->dims())); + PADDLE_ENFORCE_EQ( + scale->dims()[0], C, + platform::errors::InvalidArgument( + "The first dimension of scale must equal to Channels[%d]. But " + "received: the first dimension of scale is [%d]", + C, scale->dims()[0])); + + auto &dev_ctx = ctx.template device_context(); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + + const T *mean_data = nullptr; + const T *inv_var_data = nullptr; + + // TODO(guozibin): hadle the situation case of N * H * W = 1 + if (!use_global_stats) { + const auto *saved_mean = ctx.Input("SavedMean"); + // SavedVariance have been reverted in forward operator + const auto *saved_inv_variance = ctx.Input("SavedVariance"); + mean_data = saved_mean->data(); + inv_var_data = saved_inv_variance->data(); + } else { + const auto *running_mean = ctx.Input("Mean"); + const auto *running_variance = ctx.Input("Variance"); + mean_data = running_mean->data(); + inv_var_data = running_variance->data(); + float *running_inv_var_data = + RAII_GUARD.alloc_l3_or_gm(running_variance->numel()); + float *epsilon_data = RAII_GUARD.alloc_l3_or_gm(1); + int r1 = calculate_inv_var(dev_ctx.x_context(), inv_var_data, epsilon, C, + epsilon_data, running_inv_var_data); + PADDLE_ENFORCE_EQ(r1, XPU_SUCCESS, platform::errors::External( + "XPU API(batch_norm_grad " + "calculate_inv_var function) " + "return wrong value[%d %s]", + r1, XPUAPIErrorMsg[r1])); + inv_var_data = running_inv_var_data; + } + if (is_inplace) { + auto px = *x; + int r2 = calculate_inv_BN_Y( + dev_ctx.x_context(), px.mutable_data(ctx.GetPlace()), + scale->data(), bias->data(), mean_data, inv_var_data, N, + C, H * W, x->data()); + PADDLE_ENFORCE_EQ(r2, XPU_SUCCESS, platform::errors::External( + "XPU API(batch_norm_grad " + "calculate_inv_BN_Y function) " + "return wrong value[%d %s]", + r2, XPUAPIErrorMsg[r2])); + } + if (!d_x) { + d_x_data = RAII_GUARD.alloc_l3_or_gm(x->numel()); + } + if (!d_scale) { + d_scale_data = RAII_GUARD.alloc_l3_or_gm(C); + } + if (!d_bias_data) { + d_bias_data = RAII_GUARD.alloc_l3_or_gm(C); + } + + int r3 = xpu::batch_norm_grad( + dev_ctx.x_context(), x_data, d_y_data, d_x_data, N, C, H, W, scale_data, + mean_data, inv_var_data, d_scale_data, d_bias_data, true); + PADDLE_ENFORCE_EQ(r3, XPU_SUCCESS, platform::errors::External( + "XPU API(batch_norm_grad) return " + "wrong value[%d %s]", + r3, XPUAPIErrorMsg[r3])); } }; diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 3d7739f5a06..c5a140a7681 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -262,6 +262,9 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"roi_align", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"roi_align_grad", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), diff --git a/python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py index 8132a78f696..9cd34c82650 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py @@ -267,5 +267,58 @@ class TestXPUBatchNormOp(unittest.TestCase): outputs[name], outs[id], atol=1e-4), True) +class TestXPUBatchNormOpUseGlobalStats(unittest.TestCase): + def setUp(self): + self.places = [paddle.XPUPlace(0)] + self.init_test() + + ### train mode + def init_test(self): + self.use_global_stats = True + self.trainable_statistics = False + + def test_global_stats(self): + for p in self.places: + with fluid.dygraph.guard(p): + x = paddle.randn([2, 6, 6, 4]) + net1 = paddle.fluid.dygraph.BatchNorm( + 6, + param_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(1.0)), + use_global_stats=self.use_global_stats, + trainable_statistics=self.trainable_statistics) + net2 = paddle.nn.BatchNorm2D( + 6, use_global_stats=self.use_global_stats) + net2.weight = net1.weight + net2.bias = net1.bias + if self.trainable_statistics == True: + net1.training = False + net2.training = False + y1 = net1(x) + y2 = net2(x) + self.assertEqual(np.allclose(y1.numpy(), y2.numpy()), True) + + +class TestXPUBatchNormUseGlobalStatsCase1(TestXPUBatchNormOpUseGlobalStats): + ### test mode + def init_test(self): + self.use_global_stats = False + self.trainable_statistics = True + + +class TestXPUBatchNormUseGlobalStatsCase2(TestXPUBatchNormOpUseGlobalStats): + ### train mode + def init_test(self): + self.use_global_stats = False + self.trainable_statistics = False + + +class TestXPUBatchNormUseGlobalStatsCase3(TestXPUBatchNormOpUseGlobalStats): + ### test mode + def init_test(self): + self.use_global_stats = True + self.trainable_statistics = True + + if __name__ == "__main__": unittest.main() -- GitLab