From 1d3b27cae8a7d88db80358a2810279874835fc68 Mon Sep 17 00:00:00 2001 From: ceci3 Date: Mon, 21 Sep 2020 20:22:41 +0800 Subject: [PATCH] add double grad compute for batch norm (#27296) * add double grad compute for batch norm,test=develop * fix unittest, test=develop * remove unuse tensor,test=develop * add format,test=develop * update, test=develop --- paddle/fluid/operators/batch_norm_op.cc | 405 ++++++++++++++- paddle/fluid/operators/batch_norm_op.cu | 44 ++ paddle/fluid/operators/batch_norm_op.h | 61 +++ paddle/fluid/operators/instance_norm_op.cc | 8 +- paddle/fluid/operators/norm_utils.cu.h | 486 ++++++++++++++++++ python/paddle/fluid/layers/nn.py | 4 +- .../unittests/test_imperative_double_grad.py | 2 +- .../tests/unittests/test_norm_nn_grad.py | 62 +++ 8 files changed, 1066 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/operators/norm_utils.cu.h diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index c92f72e653d..dcfe8bb1bb4 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -831,6 +831,401 @@ void BatchNormGradMaker::Apply(GradOpPtr op) const { op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); } +template +void BatchNormDoubleGradMaker::Apply(GradOpPtr op) const { + op->SetType("batch_norm_grad_grad"); + op->SetInput("X", this->Input("X")); + op->SetInput("Scale", this->Input("Scale")); + op->SetInput("SavedMean", this->Input("SavedMean")); + op->SetInput("SavedVariance", this->Input("SavedVariance")); + if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats"))) { + op->SetInput("Variance", this->Input("Variance")); + } + op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); + op->SetInput("DDScale", this->OutputGrad(framework::GradVarName("Scale"))); + op->SetInput("DDBias", this->OutputGrad(framework::GradVarName("Bias"))); + op->SetInput("DY", this->Input(framework::GradVarName("Y"))); + + op->SetAttrMap(this->Attrs()); + op->SetOutput("DX", this->InputGrad("X")); + op->SetOutput("DScale", this->InputGrad("Scale")); + op->SetOutput("DDY", this->InputGrad(framework::GradVarName("Y"))); +} + +void BatchNormDoubleGradOp::InferShape( + framework::InferShapeContext *ctx) const { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "BatchNormDoubleGrad"); + OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", + "BatchNormDoubleGrad"); + OP_INOUT_CHECK(ctx->HasInput("SavedMean"), "Input", "SavedMean", + "BatchNormDoubleGrad"); + OP_INOUT_CHECK(ctx->HasInput("SavedVariance"), "Input", "SavedVariance", + "BatchNormDoubleGrad"); + + const bool use_global_stats = ctx->Attrs().Get("use_global_stats"); + if (use_global_stats) { + OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "VarianceOut", + "BatchNormDoubleGrad"); + } + + OP_INOUT_CHECK(ctx->HasInput("DDX"), "Input", "DDX", "BatchNormDoubleGrad"); + OP_INOUT_CHECK(ctx->HasInput("DY"), "Input", "DY", "BatchNormDoubleGrad"); + + // check output + OP_INOUT_CHECK(ctx->HasOutput("DX"), "Output", "DX", "BatchNormDoubleGrad"); + + const auto x_dims = ctx->GetInputDim("X"); + const int C = x_dims[1]; + if (ctx->HasOutput("DX")) { + ctx->SetOutputDim("DX", x_dims); + } + if (ctx->HasOutput("DScale")) { + ctx->SetOutputDim("DScale", {C}); + } + if (ctx->HasOutput("DDY")) { + ctx->ShareDim("X", "DDY"); + } +} + +framework::OpKernelType BatchNormDoubleGradOp::GetExpectedKernelType( + const framework::ExecutionContext &ctx) const { + const auto *var = ctx.InputVar("DY"); + if (var == nullptr) { + PADDLE_THROW( + platform::errors::NotFound("cannot find gradient variable of Y")); + } + const Tensor *t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } + if (t == nullptr) { + PADDLE_THROW( + platform::errors::InvalidArgument("gradient variable of Y is empty")); + } + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); +} + +template +class BatchNormDoubleGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *X = ctx.Input("X"); + const auto *Scale = ctx.Input("Scale"); + const auto *dY = ctx.Input("DY"); + const auto *Saved_mean = ctx.Input("SavedMean"); + const auto *Saved_variance = ctx.Input("SavedVariance"); + const float epsilon = ctx.Attr("epsilon"); + const bool use_global_stats = ctx.Attr("use_global_stats"); + const bool is_test = ctx.Attr("is_test"); + + PADDLE_ENFORCE_EQ( + is_test, false, + platform::errors::InvalidArgument( + "`is_test = True` CANNOT be used in train program. If " + "you want to use global status in pre_train model, " + "please set `use_global_stats = True`")); + + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + + const auto *ddX = ctx.Input("DDX"); + const auto *ddScale = ctx.Input("DDScale"); + const auto *ddBias = ctx.Input("DDBias"); + + auto *dX = ctx.Output("DX"); + auto *dScale = ctx.Output("DScale"); + auto *ddY = ctx.Output("DDY"); + dX->mutable_data(ctx.GetPlace()); + ddY->mutable_data(ctx.GetPlace()); + + auto &dev_ctx = ctx.template device_context(); + + const auto &x_dims = X->dims(); + const int C = + (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int sample_size = X->numel() / C; + math::SetConstant set_constant; + + const T *mean_data = Saved_mean->data(); + const T *inv_var_data = Saved_variance->data(); + + Tensor inv_var_tensor; + if (use_global_stats) { + const auto *running_variance = ctx.Input("Variance"); + inv_var_tensor.Resize({C}); + + T *running_inv_var_data = inv_var_tensor.mutable_data(ctx.GetPlace()); + EigenVectorArrayMap inv_var_tmp(running_inv_var_data, C); + ConstEigenVectorArrayMap var_arr(running_variance->data(), C); + + inv_var_tmp = (var_arr + epsilon).sqrt().inverse(); + inv_var_data = running_inv_var_data; + } + + // transpose NCHW -> NHWC for easy calculate + Tensor transformed_x(X->type()); + Tensor transformed_dy(dY->type()); + Tensor transformed_ddx(ddX->type()); + + Tensor transformed_dx(dX->type()); + Tensor transformed_ddy(ddY->type()); + if (data_layout == DataLayout::kNCHW && x_dims.size() > 2) { + VLOG(3) << "Transform batchnorm output from NCHW to NHWC"; + // Input Tensor + ResizeToChannelLast(ctx, X, + &transformed_x); + TransToChannelLast(ctx, X, &transformed_x); + ResizeToChannelLast(ctx, dY, + &transformed_dy); + TransToChannelLast(ctx, dY, + &transformed_dy); + ResizeToChannelLast(ctx, ddX, + &transformed_ddx); + TransToChannelLast(ctx, ddX, + &transformed_ddx); + // Output Tensor + ResizeToChannelLast(ctx, dX, + &transformed_dx); + ResizeToChannelLast(ctx, ddY, + &transformed_ddy); + } else { + transformed_x.ShareDataWith(*X); + transformed_dy.ShareDataWith(*dY); + transformed_ddx.ShareDataWith(*ddX); + + transformed_dx.ShareDataWith(*dX); + transformed_ddy.ShareDataWith(*ddY); + } + + ConstEigenArrayMap x_arr(transformed_x.data(), C, sample_size); + ConstEigenVectorArrayMap mean_arr(mean_data, C); + ConstEigenVectorArrayMap inv_var_arr(inv_var_data, C); + + Tensor mean_tile; + mean_tile.Resize({C, sample_size}); + mean_tile.mutable_data(ctx.GetPlace()); + EigenArrayMap mean_tile_data(mean_tile.mutable_data(ctx.GetPlace()), + C, sample_size); + + Tensor inv_var_tile; + inv_var_tile.Resize({C, sample_size}); + inv_var_tile.mutable_data(ctx.GetPlace()); + EigenArrayMap inv_var_tile_data( + inv_var_tile.mutable_data(ctx.GetPlace()), C, sample_size); + + mean_tile_data = mean_arr.replicate(1, sample_size); + inv_var_tile_data = inv_var_arr.replicate(1, sample_size); + + Tensor Scale_data; + if (!Scale) { + Scale_data.mutable_data({C}, ctx.GetPlace()); + set_constant(dev_ctx, &Scale_data, static_cast(1)); + } + ConstEigenVectorArrayMap scale_arr( + Scale ? Scale->data() : Scale_data.data(), C); + + Tensor scale_tile; + scale_tile.Resize({C, sample_size}); + scale_tile.mutable_data(ctx.GetPlace()); + EigenArrayMap scale_tile_data(scale_tile.mutable_data(ctx.GetPlace()), + C, sample_size); + scale_tile_data = scale_arr.replicate(1, sample_size); + + ConstEigenArrayMap dy_arr(transformed_dy.data(), C, sample_size); + ConstEigenArrayMap ddx_arr(transformed_ddx.data(), C, sample_size); + + Tensor x_sub_mean_mul_invstd; + x_sub_mean_mul_invstd.Resize({C, sample_size}); + x_sub_mean_mul_invstd.mutable_data(ctx.GetPlace()); + EigenArrayMap x_sub_mean_mul_invstd_arr( + x_sub_mean_mul_invstd.mutable_data(ctx.GetPlace()), C, sample_size); + x_sub_mean_mul_invstd_arr = (x_arr - mean_tile_data) * inv_var_tile_data; + + if (dX) { + dX->mutable_data(ctx.GetPlace()); + EigenArrayMap dx_arr(transformed_dx.mutable_data(ctx.GetPlace()), C, + sample_size); + dx_arr.setZero(); + if (use_global_stats) { + // math: dx = (ddscale * dy) * inv_var + if (ddScale) { + ConstEigenVectorArrayMap ddscale_arr(ddScale->data(), C); + Tensor ddscale_tile; + ddscale_tile.Resize({C, sample_size}); + EigenArrayMap ddscale_tile_data( + ddscale_tile.mutable_data(ctx.GetPlace()), C, sample_size); + ddscale_tile_data = ddscale_arr.replicate(1, sample_size); + + dx_arr = dy_arr * ddscale_tile_data * inv_var_tile_data; + } + } else { + // math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx, + // axis=(n,h,w)) * + // np.sum(dy, axis=(n,h,w)) - + // np.sum(dy * ddx, axis=(n,h,w)) + 3 * np.mean(dy * (x - + // mean), + // axis=(n,h,w)) * inv_var.pow(2) * + // np.sum(ddx * (x - mean), axis=(n,h,w))) + inv_var.pow(3) / + // NxHxW * + // np.sum(ddx * (x - mean)) * + // (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW * + // np.sum(dy, + // axis=(n,h,w)) * (x - mean) * + // (np.mean(ddx, axis=(n,h,w)) - ddx) + ddr * (dy * inv_var - + // inv_var + // * + // np.mean(dy, axis=(n,h,w)) - + // inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean), + // axis=(n,h,w)))) + + if (ddX) { + dx_arr += + (x_sub_mean_mul_invstd_arr * inv_var_tile_data * + inv_var_tile_data / sample_size) + .colwise() * + (ddx_arr.rowwise().sum() * dy_arr.rowwise().sum() / sample_size - + (dy_arr * ddx_arr).rowwise().sum() + + 3. * (dy_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() * + (ddx_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() / + sample_size); + + dx_arr += (inv_var_tile_data * inv_var_tile_data).colwise() * + (ddx_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() / + sample_size * + (dy_arr.rowwise().sum() / sample_size - dy_arr); + + dx_arr += (inv_var_tile_data * inv_var_tile_data).colwise() * + (dy_arr * x_sub_mean_mul_invstd_arr).rowwise().sum() / + sample_size * + (ddx_arr.rowwise().sum() / sample_size - ddx_arr); + + dx_arr = scale_tile_data * dx_arr; + } + if (ddScale) { + ConstEigenVectorArrayMap ddscale_arr(ddScale->data(), C); + Tensor ddscale_tile; + ddscale_tile.Resize({C, sample_size}); + EigenArrayMap ddscale_tile_data( + ddscale_tile.mutable_data(ctx.GetPlace()), C, sample_size); + ddscale_tile_data = ddscale_arr.replicate(1, sample_size); + + dx_arr += (dy_arr * inv_var_tile_data - + (dy_arr.rowwise().sum().replicate(1, sample_size) / + sample_size) * + inv_var_tile_data - + x_sub_mean_mul_invstd_arr * inv_var_tile_data * + (dy_arr * x_sub_mean_mul_invstd_arr) + .rowwise() + .sum() + .replicate(1, sample_size) / + sample_size) * + ddscale_tile_data; + } + } + if (data_layout == DataLayout::kNCHW) { + VLOG(3) << "Transform batchnorm output from NHWC to NCHW"; + TransToChannelFirst( + ctx, &transformed_dx, dX); + } + } + if (dScale) { + dScale->mutable_data(ctx.GetPlace()); + EigenVectorArrayMap dscale_arr(dScale->mutable_data(ctx.GetPlace()), + C); + dscale_arr.setZero(); + if (use_global_stats) { + // math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var + if (ddX) { + dscale_arr = (ddx_arr * dy_arr * inv_var_tile_data).rowwise().sum(); + } + } else { + // math: dscale = inv_var * (dy - np.mean(dy, axis=(n,h,w) - (x-mean) * + // inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) * + // ddx + if (ddX) { + Tensor first_grad; + first_grad.Resize({C, sample_size}); + EigenArrayMap first_grad_arr( + first_grad.mutable_data(ctx.GetPlace()), C, sample_size); + first_grad_arr.setZero(); + + first_grad_arr += + inv_var_tile_data * + (dy_arr - + dy_arr.rowwise().sum().replicate(1, sample_size) / sample_size - + x_sub_mean_mul_invstd_arr * + (dy_arr * x_sub_mean_mul_invstd_arr) + .rowwise() + .sum() + .replicate(1, sample_size) / + sample_size); + dscale_arr = (first_grad_arr * ddx_arr).rowwise().sum(); + } + } + } + + if (ddY) { + ddY->mutable_data(ctx.GetPlace()); + EigenArrayMap ddy_arr(transformed_ddy.mutable_data(ctx.GetPlace()), + C, sample_size); + ddy_arr.setZero(); + if (use_global_stats) { + // math: ddy = r * ddx * inv_var + if (ddX) { + ddy_arr = scale_tile_data * ddx_arr * inv_var_tile_data; + } + } else { + // math: ddy = (x - mean) * inv_var * ddscale + ddbias + + // scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) * + // np.mean(ddx * (x - mean), axis=(n,h,w))) + if (ddX) { + ddy_arr += + scale_tile_data * inv_var_tile_data * + (ddx_arr - + ddx_arr.rowwise().sum().replicate(1, sample_size) / sample_size - + x_sub_mean_mul_invstd_arr * + (ddx_arr * x_sub_mean_mul_invstd_arr) + .rowwise() + .sum() + .replicate(1, sample_size) / + sample_size); + } + if (ddScale && ddBias) { + ConstEigenVectorArrayMap ddscale_arr(ddScale->data(), C); + Tensor ddscale_tile; + ddscale_tile.Resize({C, sample_size}); + EigenArrayMap ddscale_tile_data( + ddscale_tile.mutable_data(ctx.GetPlace()), C, sample_size); + ddscale_tile_data = ddscale_arr.replicate(1, sample_size); + + ConstEigenVectorArrayMap ddbias_arr(ddBias->data(), C); + Tensor ddbias_tile; + ddbias_tile.Resize({C, sample_size}); + EigenArrayMap ddbias_tile_data( + ddbias_tile.mutable_data(ctx.GetPlace()), C, sample_size); + ddbias_tile_data = ddbias_arr.replicate(1, sample_size); + + ddy_arr += x_sub_mean_mul_invstd_arr * ddscale_tile_data; + ddy_arr += ddbias_tile_data; + } + } + if (data_layout == DataLayout::kNCHW) { + VLOG(3) << "Transform batchnorm output from NHWC to NCHW"; + TransToChannelFirst( + ctx, &transformed_ddy, ddY); + } + } + } +}; + +DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"}); + } // namespace operators } // namespace paddle @@ -839,7 +1234,11 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker, ops::BatchNormOpInferVarType, ops::BatchNormGradMaker, ops::BatchNormGradMaker); -REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp); +REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp, + ops::BatchNormDoubleGradMaker, + ops::BatchNormDoubleGradMaker); +REGISTER_OPERATOR(batch_norm_grad_grad, ops::BatchNormDoubleGradOp, + ops::BatchNormDoubleGradOpInplaceInferer); REGISTER_OP_CPU_KERNEL( batch_norm, ops::BatchNormKernel, @@ -848,3 +1247,7 @@ REGISTER_OP_CPU_KERNEL( batch_norm_grad, ops::BatchNormGradKernel, ops::BatchNormGradKernel); +REGISTER_OP_CPU_KERNEL( + batch_norm_grad_grad, + ops::BatchNormDoubleGradKernel, + ops::BatchNormDoubleGradKernel); diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index be834772679..2d5b395ac68 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/operators/batch_norm_op.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/norm_utils.cu.h" #include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/float16.h" @@ -840,6 +841,45 @@ class BatchNormGradKernel } }; +template +class BatchNormDoubleGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *X = ctx.Input("X"); + const auto *Scale = ctx.Input("Scale"); + const auto *dY = ctx.Input("DY"); + const auto *Saved_mean = ctx.Input("SavedMean"); + const auto *Saved_variance = ctx.Input("SavedVariance"); + const double epsilon = static_cast(ctx.Attr("epsilon")); + const bool use_global_stats = ctx.Attr("use_global_stats"); + const bool is_test = ctx.Attr("is_test"); + + PADDLE_ENFORCE_EQ( + is_test, false, + platform::errors::InvalidArgument( + "`is_test = True` CANNOT be used in train program. If " + "you want to use global status in pre_train model, " + "please set `use_global_stats = True`")); + + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + + const auto *ddX = ctx.Input("DDX"); + const auto *ddScale = ctx.Input("DDScale"); + const auto *ddBias = ctx.Input("DDBias"); + + auto *dX = ctx.Output("DX"); + auto *dScale = ctx.Output("DScale"); + auto *ddY = ctx.Output("DDY"); + + NormDoubleGradFunctor( + ctx, data_layout, X, Scale, dY, Saved_mean, Saved_variance, epsilon, + use_global_stats, ddX, ddScale, ddBias, dX, dScale, ddY); + } +}; + } // namespace operators } // namespace paddle @@ -853,3 +893,7 @@ REGISTER_OP_CUDA_KERNEL( batch_norm_grad, ops::BatchNormGradKernel, ops::BatchNormGradKernel, ops::BatchNormGradKernel); +REGISTER_OP_CUDA_KERNEL( + batch_norm_grad_grad, + ops::BatchNormDoubleGradKernel, + ops::BatchNormDoubleGradKernel); diff --git a/paddle/fluid/operators/batch_norm_op.h b/paddle/fluid/operators/batch_norm_op.h index 9f844b7c078..1440b74290c 100644 --- a/paddle/fluid/operators/batch_norm_op.h +++ b/paddle/fluid/operators/batch_norm_op.h @@ -103,6 +103,42 @@ inline void TransToChannelFirst(const framework::ExecutionContext& context, } } +template +inline void ResizeToChannelLast(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[3]; + in_dims_vec[3] = input->dims()[4]; + in_dims_vec[4] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + + } else if (dim == 2) { + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[3]; + in_dims_vec[3] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } else if (dim == 1) { + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } +} + template inline void TransToChannelLast(const framework::ExecutionContext& context, const Tensor* input, Tensor* transformed_input) { @@ -154,6 +190,16 @@ class BatchNormGradOp : public framework::OperatorWithKernel { const framework::OpKernelType& expected_kernel_type) const override; }; +class BatchNormDoubleGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override; +}; + class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override; @@ -168,6 +214,15 @@ class BatchNormGradMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr op) const override; }; +template +class BatchNormDoubleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override; +}; + class BatchNormOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { protected: @@ -190,5 +245,11 @@ class BatchNormGradKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override; }; +template +class BatchNormDoubleGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override; +}; + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/instance_norm_op.cc b/paddle/fluid/operators/instance_norm_op.cc index f72f7e8b85b..a5b270c1dfe 100644 --- a/paddle/fluid/operators/instance_norm_op.cc +++ b/paddle/fluid/operators/instance_norm_op.cc @@ -595,9 +595,13 @@ class InstanceNormDoubleGradKernel first_grad_arr += inv_var_tile_data * - (dy_arr - dy_arr.colwise().sum() / sample_size - + (dy_arr - + dy_arr.colwise().sum().replicate(sample_size, 1) / sample_size - x_sub_mean_mul_invstd_arr * - (dy_arr * x_sub_mean_mul_invstd_arr).colwise().sum() / + (dy_arr * x_sub_mean_mul_invstd_arr) + .colwise() + .sum() + .replicate(sample_size, 1) / sample_size); first_grad_arr = first_grad_arr * ddx_arr; for (int nc = 0; nc < NxC; ++nc) { diff --git a/paddle/fluid/operators/norm_utils.cu.h b/paddle/fluid/operators/norm_utils.cu.h new file mode 100644 index 00000000000..07333f1ae11 --- /dev/null +++ b/paddle/fluid/operators/norm_utils.cu.h @@ -0,0 +1,486 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include +#include +#include "cub/cub.cuh" +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DataLayout = framework::DataLayout; + +// math: dx = scale * ((x - mean) * inv_var / NxHxW * (np.mean(ddx, +// axis=(n,h,w)) * +// np.sum(dy, axis=(n,h,w)) - +// np.sum(dy * ddx, axis=(n,h,w)) + 3 * np.mean(dy * (x - +// mean), +// axis=(n,h,w)) * inv_var.pow(2) * +// np.sum(ddx * (x - mean), axis=(n,h,w))) + inv_var.pow(3) / +// NxHxW * +// np.sum(ddx * (x - mean)) * +// (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW * +// np.sum(dy, +// axis=(n,h,w)) * (x - mean) * +// (np.mean(ddx, axis=(n,h,w)) - ddx) + ddr * (dy * inv_var - +// inv_var +// * +// np.mean(dy, axis=(n,h,w)) - +// inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean), +// axis=(n,h,w)))) + +template +__global__ void DoubleGradComputeDX(const T *x, const T *mean, + const T *variance, const T *ddx, + const T *dy, const T *scale, + const T *ddscale, const int N, const int C, + const int sample_size, const double epsilon, + T *dx) { + const int outer_size = C; + const int inner_size = N * sample_size; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage dy_storage; + __shared__ typename BlockReduce::TempStorage ddx_storage; + __shared__ typename BlockReduce::TempStorage dy_mul_ddx_storage; + __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage; + __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage; + __shared__ T dy_sum_val; + __shared__ T ddx_sum_val; + __shared__ T dy_mul_ddx_sum_val; + __shared__ T dy_mul_x_sub_mean_sum_val; + __shared__ T ddx_mul_x_sub_mean_sum_val; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + T mean_val = mean[i]; + T var_val = variance[i]; + T dy_sum = 0; + T ddx_sum = 0; + T dy_mul_ddx_sum = 0; + T dy_mul_x_sub_mean_sum = 0; + T ddx_mul_x_sub_mean_sum = 0; + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = + layout == framework::DataLayout::kNCHW + ? (j / sample_size * C + i) * sample_size + j % sample_size + : j * outer_size + i; + T ddx_i = ddx[index]; + T dy_i = dy[index]; + T tmp = x[index] - mean_val; + + dy_sum += dy_i; + ddx_sum += ddx_i; + dy_mul_ddx_sum += (ddx_i * dy_i); + + dy_mul_x_sub_mean_sum += (dy_i * tmp); + ddx_mul_x_sub_mean_sum += (ddx_i * tmp); + } + + dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum()); + ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum()); + dy_mul_ddx_sum = + BlockReduce(dy_mul_ddx_storage).Reduce(dy_mul_ddx_sum, cub::Sum()); + dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage) + .Reduce(dy_mul_x_sub_mean_sum, cub::Sum()); + ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage) + .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum()); + + if (threadIdx.x == 0) { + dy_sum_val = dy_sum; + ddx_sum_val = ddx_sum; + dy_mul_ddx_sum_val = dy_mul_ddx_sum; + dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum; + ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum; + } + __syncthreads(); + + if (ddx != nullptr) { + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = + layout == framework::DataLayout::kNCHW + ? (j / sample_size * C + i) * sample_size + j % sample_size + : j * outer_size + i; + dx[index] += + ((x[index] - mean_val) * var_val * var_val * var_val / inner_size * + (ddx_sum_val * dy_sum_val / inner_size - dy_mul_ddx_sum_val + + 3. * dy_mul_x_sub_mean_sum_val * var_val * + ddx_mul_x_sub_mean_sum_val * var_val / inner_size) + + ddx_mul_x_sub_mean_sum_val * var_val / inner_size * var_val * + var_val * (dy_sum_val / inner_size - dy[index]) + + dy_mul_x_sub_mean_sum_val * var_val / inner_size * var_val * + var_val * (ddx_sum_val / inner_size - ddx[index])) * + scale[i]; + } + } + __syncthreads(); + if (ddscale != nullptr) { + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = + layout == framework::DataLayout::kNCHW + ? (j / sample_size * C + i) * sample_size + j % sample_size + : j * outer_size + i; + dx[index] += (dy[index] * var_val - dy_sum_val / inner_size * var_val - + (x[index] - mean_val) * var_val * + dy_mul_x_sub_mean_sum_val * var_val / inner_size) * + ddscale[i]; + } + } + } +} + +// math: ddy = (x - mean) * inv_var * ddscale + ddbias + +// scale * inv_var * (ddx - (x - mean) * inv_var.pow(2) * +// np.mean(ddx * (x - mean), axis=(n,h,w))) +template +__global__ void DoubleGradComputeDDY(const T *x, const T *mean, + const T *variance, const T *ddscale, + const T *ddbias, const T *ddx, + const T *scale, const int N, const int C, + const int sample_size, + const double epsilon, T *ddy) { + const int outer_size = C; + const int inner_size = N * sample_size; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage ddx_storage; + __shared__ typename BlockReduce::TempStorage ddx_mul_x_sub_mean_storage; + __shared__ T ddx_sum_val; + __shared__ T ddx_mul_x_sub_mean_sum_val; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + T mean_val = mean[i]; + T var_val = variance[i]; + T ddx_sum = 0; + T ddx_mul_x_sub_mean_sum = 0; + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = + layout == framework::DataLayout::kNCHW + ? (j / sample_size * C + i) * sample_size + j % sample_size + : j * outer_size + i; + T ddx_i = ddx[index]; + ddx_sum += ddx_i; + ddx_mul_x_sub_mean_sum += (ddx_i * (x[index] - mean_val)); + } + ddx_sum = BlockReduce(ddx_storage).Reduce(ddx_sum, cub::Sum()); + ddx_mul_x_sub_mean_sum = BlockReduce(ddx_mul_x_sub_mean_storage) + .Reduce(ddx_mul_x_sub_mean_sum, cub::Sum()); + + if (threadIdx.x == 0) { + ddx_sum_val = ddx_sum; + ddx_mul_x_sub_mean_sum_val = ddx_mul_x_sub_mean_sum; + } + __syncthreads(); + + if (ddx != nullptr) { + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = + layout == framework::DataLayout::kNCHW + ? (j / sample_size * C + i) * sample_size + j % sample_size + : j * outer_size + i; + ddy[index] += scale[i] * var_val * + (ddx[index] - ddx_sum_val / inner_size - + (x[index] - mean_val) * var_val * + ddx_mul_x_sub_mean_sum_val * var_val / inner_size); + } + } + __syncthreads(); + if (ddscale != nullptr) { + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = + layout == framework::DataLayout::kNCHW + ? (j / sample_size * C + i) * sample_size + j % sample_size + : j * outer_size + i; + ddy[index] += (x[index] - mean_val) * var_val * ddscale[i]; + } + } + __syncthreads(); + if (ddbias != nullptr) { + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = + layout == framework::DataLayout::kNCHW + ? (j / sample_size * C + i) * sample_size + j % sample_size + : j * outer_size + i; + ddy[index] += ddbias[i]; + } + } + } +} + +// math: dscale = inv_var * (dy - np.mean(dy, axis=(n,h,w) - (x-mean) * +// inv_var.pow(2) * np.mean(dy * (x-mean), axis=(n,h,w)))) * +// ddx +template +__global__ void DoubleGradComputeDScale(const T *x, const T *mean, + const T *variance, const T *ddx, + const T *dy, const int N, const int C, + const int sample_size, + const double epsilon, T *dscale) { + const int outer_size = C; + const int inner_size = N * sample_size; + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage dy_storage; + __shared__ typename BlockReduce::TempStorage dy_mul_x_sub_mean_storage; + __shared__ typename BlockReduce::TempStorage dscale_tmp_storage; + __shared__ T dy_sum_val; + __shared__ T dy_mul_x_sub_mean_sum_val; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + T dy_sum = 0; + T dy_mul_x_sub_mean_sum = 0; + T mean_val = mean[i]; + T var_val = variance[i]; + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = + layout == framework::DataLayout::kNCHW + ? (j / sample_size * C + i) * sample_size + j % sample_size + : j * outer_size + i; + T dy_i = dy[index]; + dy_sum += dy_i; + dy_mul_x_sub_mean_sum += (dy_i * (x[index] - mean_val)); + } + dy_sum = BlockReduce(dy_storage).Reduce(dy_sum, cub::Sum()); + dy_mul_x_sub_mean_sum = BlockReduce(dy_mul_x_sub_mean_storage) + .Reduce(dy_mul_x_sub_mean_sum, cub::Sum()); + + if (threadIdx.x == 0) { + dy_sum_val = dy_sum; + dy_mul_x_sub_mean_sum_val = dy_mul_x_sub_mean_sum; + } + __syncthreads(); + + if (ddx != nullptr) { + T dscale_tmp = 0; + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = + layout == framework::DataLayout::kNCHW + ? (j / sample_size * C + i) * sample_size + j % sample_size + : j * outer_size + i; + dscale_tmp += ddx[index] * var_val * + (dy[index] - dy_sum_val / inner_size - + dy_mul_x_sub_mean_sum_val * (x[index] - mean_val) * + var_val * var_val / inner_size); + } + dscale_tmp = + BlockReduce(dscale_tmp_storage).Reduce(dscale_tmp, cub::Sum()); + + if (threadIdx.x == 0) { + dscale[i] += dscale_tmp; + } + __syncthreads(); + } + } +} + +// math: dscale = np.sum(ddx * dy, axis=(n,h,w)) * inv_var +template +__global__ void DoubleGradComputeDScaleWithGlobal( + const T *ddx, const T *variance, const T *dy, const double epsilon, + const int N, const int C, const int sample_size, T *dscale) { + int outer_size = C; + int inner_size = N * sample_size; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage ddx_mul_dy_storage; + __shared__ T ddx_mul_dy_sum_val; + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + T inv_var_i = 1.0 / sqrt(variance[i] + epsilon); + T ddx_mul_dy_sum = 0; + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = + layout == framework::DataLayout::kNCHW + ? (j / sample_size * C + i) * sample_size + j % sample_size + : j * outer_size + i; + T ddx_i = ddx[index]; + T dy_i = dy[index]; + ddx_mul_dy_sum += (ddx_i * dy_i); + } + ddx_mul_dy_sum = + BlockReduce(ddx_mul_dy_storage).Reduce(ddx_mul_dy_sum, cub::Sum()); + if (threadIdx.x == 0) { + ddx_mul_dy_sum_val = ddx_mul_dy_sum; + } + __syncthreads(); + + if (ddx != nullptr) { + dscale[i] = inv_var_i * ddx_mul_dy_sum_val; + } + } +} + +// math: dx = ddscale * dy * inv_var +// math: ddy = scale * ddx * inv_var +template +__global__ void DoubleGradComputeDataWithGlobal( + const T *dy, const T *scale, const T *variance, const double epsilon, + const int C, const int sample_size, const int num, T *dx) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + if (scale != nullptr) { + for (int i = gid; i < num; i += stride) { + const int c = + layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C; + T inv_var = 1.0 / sqrt(variance[c] + epsilon); + dx[i] = dy[i] * scale[c] * inv_var; + } + } +} + +template +void NormDoubleGradFunctor(const framework::ExecutionContext &ctx, + const DataLayout data_layout, const Tensor *X, + const Tensor *Scale, const Tensor *dY, + const Tensor *Saved_mean, + const Tensor *Saved_variance, const double epsilon, + const bool use_global_stats, const Tensor *ddX, + const Tensor *ddScale, const Tensor *ddBias, + Tensor *dX, Tensor *dScale, Tensor *ddY) { + const T *x_data = X->data(); + const T *dy_data = dY->data(); + const T *ddx_data = (ddX == nullptr ? nullptr : ddX->data()); + + const T *ddscale_data = (ddScale == nullptr ? nullptr : ddScale->data()); + const T *ddbias_data = (ddBias == nullptr ? nullptr : ddBias->data()); + + auto &dev_ctx = ctx.template device_context(); + math::SetConstant set_constant; + + auto &x_dims = X->dims(); + const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] + : x_dims[x_dims.size() - 1]); + const int N = x_dims[0]; + const int num = X->numel(); + const int sample_size = num / N / C; + Tensor scale_tmp; + if (!Scale) { + scale_tmp.mutable_data({C}, ctx.GetPlace()); + set_constant(dev_ctx, &scale_tmp, static_cast(1)); + } + const T *scale_data = Scale ? Scale->data() : scale_tmp.data(); + + const int block = 512; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + int grid = std::min(C, max_blocks); + int grid1 = (num + block - 1) / block; + + const T *mean_data, *variance_data; + if (use_global_stats) { + const auto *running_var = ctx.Input("Variance"); + const auto *running_var_data = running_var->template data(); + variance_data = running_var_data; + } else { + const T *smean_data = Saved_mean->data(); + const T *svariance_data = Saved_variance->data(); + mean_data = smean_data; + variance_data = svariance_data; + } + + if (dX) { + T *dx_data = dX->mutable_data(ctx.GetPlace()); + set_constant(dev_ctx, dX, static_cast(0)); + if (use_global_stats) { + if (data_layout == DataLayout::kNHWC) { + DoubleGradComputeDataWithGlobal< + T, DataLayout::kNHWC><<>>( + dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num, + dx_data); + } else { + DoubleGradComputeDataWithGlobal< + T, DataLayout::kNCHW><<>>( + dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num, + dx_data); + } + } else { + if (data_layout == DataLayout::kNHWC) { + DoubleGradComputeDX< + T, block, DataLayout::kNHWC><<>>( + x_data, mean_data, variance_data, ddx_data, dy_data, scale_data, + ddscale_data, N, C, sample_size, epsilon, dx_data); + } else { + DoubleGradComputeDX< + T, block, DataLayout::kNCHW><<>>( + x_data, mean_data, variance_data, ddx_data, dy_data, scale_data, + ddscale_data, N, C, sample_size, epsilon, dx_data); + } + } + } + if (dScale) { + T *dscale_data = dScale->mutable_data(ctx.GetPlace()); + set_constant(dev_ctx, dScale, static_cast(0)); + if (use_global_stats) { + if (data_layout == DataLayout::kNHWC) { + DoubleGradComputeDScaleWithGlobal< + T, block, DataLayout::kNHWC><<>>( + ddx_data, variance_data, dy_data, epsilon, N, C, sample_size, + dscale_data); + } else { + DoubleGradComputeDScaleWithGlobal< + T, block, DataLayout::kNCHW><<>>( + ddx_data, variance_data, dy_data, epsilon, N, C, sample_size, + dscale_data); + } + } else { + if (data_layout == DataLayout::kNHWC) { + DoubleGradComputeDScale< + T, block, DataLayout::kNHWC><<>>( + x_data, mean_data, variance_data, ddx_data, dy_data, N, C, + sample_size, epsilon, dscale_data); + } else { + DoubleGradComputeDScale< + T, block, DataLayout::kNCHW><<>>( + x_data, mean_data, variance_data, ddx_data, dy_data, N, C, + sample_size, epsilon, dscale_data); + } + } + } + if (ddY) { + T *ddy_data = ddY->mutable_data(ctx.GetPlace()); + set_constant(dev_ctx, ddY, static_cast(0)); + if (use_global_stats) { + if (data_layout == DataLayout::kNHWC) { + DoubleGradComputeDataWithGlobal< + T, DataLayout::kNHWC><<>>( + ddx_data, scale_data, variance_data, epsilon, C, sample_size, num, + ddy_data); + } else { + DoubleGradComputeDataWithGlobal< + T, DataLayout::kNCHW><<>>( + ddx_data, scale_data, variance_data, epsilon, C, sample_size, num, + ddy_data); + } + } else { + if (data_layout == DataLayout::kNHWC) { + DoubleGradComputeDDY< + T, block, DataLayout::kNHWC><<>>( + x_data, mean_data, variance_data, ddscale_data, ddbias_data, + ddx_data, scale_data, N, C, sample_size, epsilon, ddy_data); + } else { + DoubleGradComputeDDY< + T, block, DataLayout::kNCHW><<>>( + x_data, mean_data, variance_data, ddscale_data, ddbias_data, + ddx_data, scale_data, N, C, sample_size, epsilon, ddy_data); + } + } + } +} + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4a750f301a0..3e7d10f8d1a 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -3167,7 +3167,7 @@ def instance_norm(input, param_shape = [channel_num] - if param_attr and bias_attr: + if param_attr != False and bias_attr != False: # create parameter scale = helper.create_parameter( attr=helper.param_attr, @@ -3190,7 +3190,7 @@ def instance_norm(input, instance_norm_out = helper.create_variable_for_type_inference(dtype) inputs = {"X": input} - if param_attr and bias_attr: + if param_attr != False and bias_attr != False: inputs["Scale"] = scale inputs["Bias"] = bias diff --git a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py index 720c9f95c25..39c6fca89cc 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_double_grad.py @@ -346,7 +346,7 @@ class TestRaiseNoDoubleGradOp(TestCase): with fluid.dygraph.guard(): x = fluid.layers.ones(shape=[2, 3, 2, 2], dtype='float32') x.stop_gradient = False - y = paddle.fluid.layers.batch_norm(x) + y = paddle.fluid.layers.group_norm(x, groups=1) dx = fluid.dygraph.grad( outputs=[y], inputs=[x], create_graph=True, diff --git a/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py b/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py index c44ea454271..a89b9fde7f9 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py @@ -68,5 +68,67 @@ class TestInstanceNormDoubleGradCheckWithoutParamBias( [x], z, x_init=x_arr, atol=atol, place=place, eps=eps) +class TestBatchNormDoubleGradCheck(unittest.TestCase): + def setUp(self): + self.init_test() + + def init_test(self): + self.data_layout = 'NCHW' + self.use_global_stats = False + self.shape = [2, 3, 4, 5] + + @prog_scope() + def func(self, place): + prog = fluid.Program() + with fluid.program_guard(prog): + np.random.seed() + dtype = "float32" + eps = 0.005 + atol = 1e-4 + x = layers.create_parameter(dtype=dtype, shape=self.shape, name='x') + z = fluid.layers.batch_norm( + input=x, + data_layout=self.data_layout, + use_global_stats=self.use_global_stats) + x_arr = np.random.uniform(-1, 1, self.shape).astype(dtype) + gradient_checker.double_grad_check( + [x], z, x_init=x_arr, atol=atol, place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestBatchNormDoubleGradCheckCase1(TestBatchNormDoubleGradCheck): + def init_test(self): + self.data_layout = 'NHWC' + self.use_global_stats = False + self.shape = [2, 3, 4, 5] + + +class TestBatchNormDoubleGradCheckCase2(TestBatchNormDoubleGradCheck): + def init_test(self): + self.data_layout = 'NCHW' + self.use_global_stats = True + self.shape = [2, 3, 4, 5] + + +class TestBatchNormDoubleGradCheckCase3(TestBatchNormDoubleGradCheck): + def init_test(self): + self.data_layout = 'NHWC' + self.use_global_stats = True + self.shape = [2, 3, 4, 5] + + +class TestBatchNormDoubleGradCheckCase4(TestBatchNormDoubleGradCheck): + def init_test(self): + self.data_layout = 'NCHW' + self.use_global_stats = False + self.shape = [2, 2, 3, 4, 5] + + if __name__ == "__main__": unittest.main() -- GitLab