From 21d95be0dbf9d86b6fd5ce7019ae1708c1e7cfb6 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Thu, 2 Apr 2020 00:09:00 +0800 Subject: [PATCH] Add inplace abn op (#22806) * add inplace_abn_op. test=develop --- .../framework/ir/sync_batch_norm_pass.cc | 10 +- paddle/fluid/framework/unused_var_check.cc | 2 + paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/batch_norm_op.cc | 203 +++++-- paddle/fluid/operators/batch_norm_op.cu | 114 +++- paddle/fluid/operators/inplace_abn_op.cc | 208 +++++++ paddle/fluid/operators/inplace_abn_op.cu | 92 +++ paddle/fluid/operators/inplace_abn_op.h | 117 ++++ paddle/fluid/operators/sync_batch_norm_op.cu | 433 +------------- .../fluid/operators/sync_batch_norm_op.cu.h | 530 ++++++++++++++++++ python/paddle/fluid/layers/nn.py | 215 ++++++- python/paddle/fluid/nets.py | 2 +- .../fluid/tests/unittests/CMakeLists.txt | 2 +- .../tests/unittests/test_inplace_abn_op.py | 189 +++++++ .../fluid/tests/unittests/test_layers.py | 22 + 15 files changed, 1654 insertions(+), 487 deletions(-) create mode 100644 paddle/fluid/operators/inplace_abn_op.cc create mode 100644 paddle/fluid/operators/inplace_abn_op.cu create mode 100644 paddle/fluid/operators/inplace_abn_op.h create mode 100644 paddle/fluid/operators/sync_batch_norm_op.cu.h create mode 100644 python/paddle/fluid/tests/unittests/test_inplace_abn_op.py diff --git a/paddle/fluid/framework/ir/sync_batch_norm_pass.cc b/paddle/fluid/framework/ir/sync_batch_norm_pass.cc index 2077304b969..90c1b23fe4f 100644 --- a/paddle/fluid/framework/ir/sync_batch_norm_pass.cc +++ b/paddle/fluid/framework/ir/sync_batch_norm_pass.cc @@ -24,16 +24,24 @@ namespace ir { class SyncBatchNormPass : public Pass { protected: void ApplyImpl(ir::Graph *graph) const override { - VLOG(3) << "Use synchronous batch norm"; + VLOG(3) << "Use synchronize batch norm"; for (const Node *n : graph->Nodes()) { if (n->IsOp() && n->Op()) { auto *op = n->Op(); + // process synchronize in batch_norm if (op->Type() == "batch_norm") { op->SetType("sync_batch_norm"); } if (op->Type() == "batch_norm_grad") { op->SetType("sync_batch_norm_grad"); } + // process synchronize in inplace_abn + if (op->Type() == "inplace_abn") { + op->SetAttr("use_sync_bn", true); + } + if (op->Type() == "inplace_abn_grad") { + op->SetAttr("use_sync_bn", true); + } } } } diff --git a/paddle/fluid/framework/unused_var_check.cc b/paddle/fluid/framework/unused_var_check.cc index a220c79a088..f9eeaae497a 100644 --- a/paddle/fluid/framework/unused_var_check.cc +++ b/paddle/fluid/framework/unused_var_check.cc @@ -41,6 +41,8 @@ const std::unordered_set op_has_unsed_vars_white_list = { "batch_norm_grad", // 0 "sync_batch_norm", // 0 "sync_batch_norm_grad", // 0 + "inplace_abn", // 0 + "inplace_abn_grad", // 0 "dgc_momentum", // 0 "fake_quantize_range_abs_max", // 0 "rmsprop", // 0 diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index b3e8f41f840..ea0c0b82be4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -59,7 +59,7 @@ if(WITH_COVERAGE OR NOT WITH_AVX OR WIN32) endif() register_operators(EXCLUDES py_func_op warpctc_op dgc_op - sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) + sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) if (WITH_GPU) # warpctc_op needs cudnn 7 above diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 4edc3514651..fbb9cfb75ad 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -82,16 +82,18 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { PADDLE_ENFORCE_GE( x_dims.size(), 2, - "ShapeError: the dimension of input X must greater than or equal to 2." - "But received: the shape of input X = [%s], the dimension of input X =" - "[%d]", - x_dims, x_dims.size()); + platform::errors::InvalidArgument( + "ShapeError: the dimension of input " + "X must greater than or equal to 2. But received: the shape of input " + "X = [%s], the dimension of input X =[%d]", + x_dims, x_dims.size())); PADDLE_ENFORCE_LE( x_dims.size(), 5, - "ShapeError: the dimension of input X must smaller than or equal to 5." - "But received: the shape of input X = [%s], the dimension of input X =" - "[%d]", - x_dims, x_dims.size()); + platform::errors::InvalidArgument( + "ShapeError: the dimension of input X " + "must smaller than or equal to 5. But received: the shape of input X " + "= [%s], the dimension of input X = [%d]", + x_dims, x_dims.size())); const int64_t C = ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) @@ -146,14 +148,18 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( if (input_data_type == framework::proto::VarType::FP64) { bn_param_type = framework::proto::VarType::FP64; } - PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Scale")->type(), - "Scale input should be of float type"); - PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Bias")->type(), - "Bias input should be of float type"); - PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Mean")->type(), - "Mean input should be of float type"); + PADDLE_ENFORCE_EQ( + bn_param_type, ctx.Input("Scale")->type(), + platform::errors::InvalidArgument("Scale input should be of float type")); + PADDLE_ENFORCE_EQ( + bn_param_type, ctx.Input("Bias")->type(), + platform::errors::InvalidArgument("Bias input should be of float type")); + PADDLE_ENFORCE_EQ( + bn_param_type, ctx.Input("Mean")->type(), + platform::errors::InvalidArgument("Mean input should be of float type")); PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Variance")->type(), - "Variance input should be of float type"); + platform::errors::InvalidArgument( + "Variance input should be of float type")); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::LibraryType library = framework::LibraryType::kPlain; @@ -204,8 +210,13 @@ void BatchNormOpMaker::Make() { AddAttr("epsilon", "") .SetDefault(1e-5) .AddCustomChecker([](const float &epsilon) { - PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, - "'epsilon' should be between 0.0 and 0.001."); + PADDLE_ENFORCE_GE( + epsilon, 0.0f, + platform::errors::InvalidArgument( + "'epsilon' should be greater or equal than 0.0.")); + PADDLE_ENFORCE_LE(epsilon, 0.001f, + platform::errors::InvalidArgument( + "'epsilon' should be less or equal than 0.001.")); }); AddAttr("data_layout", "").SetDefault("NCHW"); AddInput("X", "The input tensor"); @@ -259,6 +270,7 @@ void BatchNormOpMaker::Make() { "global mean and variance are also used during train time, " "the BN acts as scaling and shiffting.") .SetDefault(false); + AddComment(R"DOC( Batch Normalization. @@ -290,8 +302,12 @@ class BatchNormKernel const auto *x = ctx.Input("X"); const auto &x_dims = x->dims(); - PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, - "The Input dim size should be between 2 and 5"); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + platform::errors::InvalidArgument( + "The Input X dim size should be larger than 1.")); + PADDLE_ENFORCE_LE(x_dims.size(), 5, + platform::errors::InvalidArgument( + "The Input X dim size should be less than 6.")); const int N = x_dims[0]; const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] @@ -299,6 +315,7 @@ class BatchNormKernel const int sample_size = x->numel() / N / C; auto *y = ctx.Output("Y"); + auto *mean_out = ctx.Output("MeanOut"); auto *variance_out = ctx.Output("VarianceOut"); auto *saved_mean = ctx.Output("SavedMean"); @@ -432,14 +449,18 @@ class BatchNormKernel void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { // check input - PADDLE_ENFORCE(ctx->HasInput("X")); - PADDLE_ENFORCE(ctx->HasInput("Scale"), "Input(scale) should not be null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), - "Input(Y@GRAD) should not be null."); - PADDLE_ENFORCE(ctx->HasInput("SavedMean"), - "Input(SavedMean) should not be null."); - PADDLE_ENFORCE(ctx->HasInput("SavedVariance"), - "Input(SavedVariance) should not be null"); + PADDLE_ENFORCE_EQ( + ctx->HasInput("Scale"), true, + platform::errors::InvalidArgument("Input(scale) should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasInput(framework::GradVarName("Y")), true, + platform::errors::InvalidArgument("Input(Y@GRAD) should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("SavedMean"), true, + platform::errors::InvalidArgument( + "Input(SavedMean) should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("SavedVariance"), true, + platform::errors::InvalidArgument( + "Input(SavedVariance) should not be null")); // check output PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), ""); @@ -456,25 +477,37 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { const bool use_global_stats = ctx->Attrs().Get("use_global_stats"); if (use_global_stats) { - PADDLE_ENFORCE(!ctx->Attrs().Get("use_mkldnn"), - "Using global stats during training is not supported " - "in gradient op kernel of batch_norm_mkldnn_op now."); + PADDLE_ENFORCE_EQ( + !ctx->Attrs().Get("use_mkldnn"), true, + platform::errors::InvalidArgument( + "Using global stats during training is not supported " + "in gradient op kernel of batch_norm_mkldnn_op now.")); } - const auto x_dims = ctx->GetInputDim("X"); - const DataLayout data_layout = framework::StringToDataLayout( - ctx->Attrs().Get("data_layout")); - - const int C = - ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) - ? x_dims[1] - : x_dims[x_dims.size() - 1]); + // batch_norm_grad with inplace takes Y as input, without inplace + // takes X as input. HasInput will throw exception in compile time, + // so only infer shape in run time here. + if (ctx->IsRuntime()) { + PADDLE_ENFORCE_EQ(ctx->HasInput("X") || ctx->HasInput("Y"), true, + platform::errors::InvalidArgument( + "Input(X) and Input(Y) should not be all null.")); + auto input_name = "Y"; + if (ctx->HasInput("X")) input_name = "X"; + const auto x_dims = ctx->GetInputDim(input_name); + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); - ctx->SetOutputDim(framework::GradVarName("X"), x_dims); - // has_scale_grad == has_bias_grad, judge has_scale_grad is enough - if (has_scale_grad) { - ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); - ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); + const int C = + ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) + ? x_dims[1] + : x_dims[x_dims.size() - 1]); + + ctx->SetOutputDim(framework::GradVarName("X"), x_dims); + // has_scale_grad == has_bias_grad, judge has_scale_grad is enough + if (has_scale_grad) { + ctx->SetOutputDim(framework::GradVarName("Scale"), {C}); + ctx->SetOutputDim(framework::GradVarName("Bias"), {C}); + } } } @@ -482,7 +515,8 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( const framework::ExecutionContext &ctx) const { const auto *var = ctx.InputVar(framework::GradVarName("Y")); if (var == nullptr) { - PADDLE_THROW("can't find Y@GRAD"); + PADDLE_THROW( + platform::errors::InvalidArgument("can't find gradient variable of Y")); } const Tensor *t = nullptr; if (var->IsType()) { @@ -491,7 +525,8 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( t = &var->Get(); } if (t == nullptr) { - PADDLE_THROW("can't find Y@GRAD"); + PADDLE_THROW( + platform::errors::InvalidArgument("gradient variable of Y is empty")); } // TODO(pzelazko-intel): enable MKLDNN layout when it's ready @@ -541,9 +576,9 @@ class BatchNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - const auto *x = ctx.Input("X"); const auto *d_y = ctx.Input(framework::GradVarName("Y")); const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); const auto *saved_mean = ctx.Input("SavedMean"); // SavedVariance have been reverted in forward operator const auto *saved_inv_variance = ctx.Input("SavedVariance"); @@ -554,6 +589,30 @@ class BatchNormGradKernel const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + // batch_norm with inplace as false will take X as grad input, which + // is same as cuDNN batch_norm backward calculation, batch_norm + // with inplace as true only take Y as input and X should be calculate + // by inverse operation of batch_norm on Y + const Tensor *x; + bool is_inplace; + if (ctx.HasInput("Y")) { + x = ctx.Input("Y"); + is_inplace = true; + PADDLE_ENFORCE_EQ(d_x, d_y, + platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD not inplace in inplace mode")); + } else { + x = ctx.Input("X"); + is_inplace = false; + PADDLE_ENFORCE_NE(d_x, d_y, + platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD inplaced in non-inplace mode")); + } + PADDLE_ENFORCE_EQ( is_test, false, platform::errors::InvalidArgument( @@ -564,8 +623,12 @@ class BatchNormGradKernel // Get the size for each dimension. // NCHW [batch_size, in_channels, in_height, in_width] const auto &x_dims = x->dims(); - PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, - "The Input dim size should be between 2 and 5"); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + platform::errors::InvalidArgument( + "The Input X dim size should be larger than 1.")); + PADDLE_ENFORCE_LE(x_dims.size(), 5, + platform::errors::InvalidArgument( + "The Input X dim size should be less than 6.")); const int N = x_dims[0]; const int C = (data_layout == DataLayout::kNCHW ? x_dims[1] @@ -573,10 +636,6 @@ class BatchNormGradKernel const int sample_size = x->numel() / N / C; // init output - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_scale = ctx.Output(framework::GradVarName("Scale")); - auto *d_bias = ctx.Output(framework::GradVarName("Bias")); - d_x->mutable_data(ctx.GetPlace()); const T *mean_data = saved_mean->data(); @@ -596,6 +655,7 @@ class BatchNormGradKernel } ConstEigenVectorArrayMap scale_arr(scale->data(), C); + ConstEigenVectorArrayMap bias_arr(bias->data(), C); ConstEigenVectorArrayMap mean_arr(mean_data, C); ConstEigenVectorArrayMap inv_var_arr(inv_var_data, C); @@ -643,13 +703,30 @@ class BatchNormGradKernel dy_sum_arr.setZero(); dy_mul_x_sub_mean_mul_invstd_sum_arr.setZero(); + // inplace calculation + // Y: ((x - est_mean) * (inv_var) * scale + bias + // formula transform ====> + // (x * inv_var * scale) + (bias - est_mean * inv_var * scale) + // X: (y - bias) / scale / (inv_var) + est_mean + // formula transform ====> + // (y - bias) / (scale * inv_var) + est_mean switch (data_layout) { case DataLayout::kNCHW: { + if (is_inplace) { + auto px = *x; + EigenArrayMap x_data(px.mutable_data(ctx.GetPlace()), + sample_size, N * C); + ConstEigenArrayMap y_data(x->data(), sample_size, N * C); + for (int nc = 0; nc < N * C; ++nc) { + x_data.col(nc) = (y_data.col(nc) - bias_arr(nc % C)) / + scale_inv_var_nhw(nc % C) / scale_coefff + + mean_arr(nc % C); + } + } ConstEigenArrayMap x_arr(x->data(), sample_size, N * C); ConstEigenArrayMap d_y_arr(d_y->data(), sample_size, N * C); EigenArrayMap d_x_arr(d_x->mutable_data(ctx.GetPlace()), sample_size, N * C); - d_x_arr.setZero(); for (int nc = 0; nc < N * C; ++nc) { int c = nc % C; @@ -667,7 +744,7 @@ class BatchNormGradKernel if (!use_global_stats) { for (int nc = 0; nc < N * C; ++nc) { int c = nc % C; - d_x_arr.col(nc) += + d_x_arr.col(nc) = scale_inv_var_nhw(c) * (d_y_arr.col(nc) * N * sample_size - dy_sum_arr(c) - (x_arr.col(nc) - mean_arr[c]) * @@ -676,17 +753,27 @@ class BatchNormGradKernel } else { for (int nc = 0; nc < N * C; ++nc) { int c = nc % C; - d_x_arr.col(nc) += scale_inv_var_nhw(c) * d_y_arr.col(nc); + d_x_arr.col(nc) = scale_inv_var_nhw(c) * d_y_arr.col(nc); } } break; } case DataLayout::kNHWC: { + if (is_inplace) { + auto px = *x; + EigenArrayMap x_data(px.mutable_data(ctx.GetPlace()), C, + N * sample_size); + ConstEigenArrayMap y_data(x->data(), C, N * sample_size); + for (int nhw = 0; nhw < N * sample_size; nhw++) { + x_data.col(nhw) = (y_data.col(nhw) - bias_arr) / scale_inv_var_nhw / + scale_coefff + + mean_arr; + } + } ConstEigenArrayMap x_arr(x->data(), C, N * sample_size); ConstEigenArrayMap d_y_arr(d_y->data(), C, N * sample_size); EigenArrayMap d_x_arr(d_x->mutable_data(ctx.GetPlace()), C, N * sample_size); - d_x_arr.setZero(); for (int nhw = 0; nhw < N * sample_size; ++nhw) { dy_sum_arr += d_y_arr.col(nhw); @@ -701,7 +788,7 @@ class BatchNormGradKernel if (!use_global_stats) { for (int nhw = 0; nhw < N * sample_size; ++nhw) { - d_x_arr.col(nhw) += + d_x_arr.col(nhw) = scale_inv_var_nhw * (d_y_arr.col(nhw) * N * sample_size - dy_sum_arr - (x_arr.col(nhw) - mean_arr) * @@ -709,7 +796,7 @@ class BatchNormGradKernel } } else { for (int nhw = 0; nhw < N * sample_size; ++nhw) { - d_x_arr.col(nhw) += scale_inv_var_nhw * d_y_arr.col(nhw); + d_x_arr.col(nhw) = scale_inv_var_nhw * d_y_arr.col(nhw); } } break; diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index 2534b7cf491..99612e4d433 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -40,8 +40,9 @@ class BatchNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument("It must use CUDAPlace.")); double epsilon = static_cast(ctx.Attr("epsilon")); float momentum = ctx.Attr("momentum"); const bool is_test = ctx.Attr("is_test"); @@ -355,6 +356,41 @@ static __global__ void KeBNBackwardData(const T *dy, } } +template +static __global__ void KeBNRestoreData(const framework::DataLayout layout, T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const BatchNormParamType *mean, + const BatchNormParamType *variance, + double epsilon, int C, int M, + const int num, const T *y) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < num; i += stride) { + const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C; + auto y_i = static_cast>(y[i]); + auto x_i = (y_i - bias[c]) / scale[c] / variance[c] + mean[c]; + x[i] = static_cast(x_i); + } +} + +template +class InplaceHelper { + public: + void operator()(const framework::DataLayout layout, T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const BatchNormParamType *mean, + const BatchNormParamType *variance, double epsilon, int C, + int M, const int num, const T *y, int grid2, const int block, + const cudaStream_t &stream) { + PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument( + "X and Y should be inplaced in inplace mode")); + KeBNRestoreData<<>>( + layout, x, scale, bias, mean, variance, epsilon, C, M, num, y); + } +}; + template static __global__ void BNBackwardData(const T *dy, const BatchNormParamType *scale, @@ -417,17 +453,43 @@ class BatchNormGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument("It must use CUDAPlace.")); double epsilon = static_cast(ctx.Attr("epsilon")); const std::string data_layout_str = ctx.Attr("data_layout"); const bool use_global_stats = ctx.Attr("use_global_stats"); const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); - const auto *x = ctx.Input("X"); const auto *d_y = ctx.Input(framework::GradVarName("Y")); const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + // batch_norm with inplace as false will take X as grad input, which + // is same as cuDNN batch_norm backward calculation, batch_norm + // with inplace as true only take Y as input and X should be calculate + // by inverse operation of batch_norm on Y + const Tensor *x; + bool is_inplace; + if (ctx.HasInput("Y")) { + x = ctx.Input("Y"); + is_inplace = true; + PADDLE_ENFORCE_EQ(d_x, d_y, + platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD not inplace in inplace mode")); + } else { + x = ctx.Input("X"); + is_inplace = false; + PADDLE_ENFORCE_NE(d_x, d_y, + platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD inplaced in non-inplace mode")); + } + const bool is_test = ctx.Attr("is_test"); PADDLE_ENFORCE_EQ( is_test, false, @@ -444,11 +506,8 @@ class BatchNormGradKernel ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); // init output - auto *d_x = ctx.Output(framework::GradVarName("X")); - auto *d_scale = ctx.Output(framework::GradVarName("Scale")); - auto *d_bias = ctx.Output(framework::GradVarName("Bias")); - d_x->mutable_data(ctx.GetPlace()); + if (d_scale && d_bias) { d_scale->mutable_data>(ctx.GetPlace()); d_bias->mutable_data>(ctx.GetPlace()); @@ -505,6 +564,8 @@ class BatchNormGradKernel const int max_blocks = std::max(max_threads / block, 1); int grid1 = (num + block - 1) / block; int grid2 = std::min(C, max_blocks); + auto stream = dev_ctx.stream(); + InplaceHelper inplace_functor; if (!use_global_stats) { if ((N * H * W * D) == 1) { @@ -555,6 +616,14 @@ class BatchNormGradKernel const auto *saved_var_data = saved_var->template data>(); + if (is_inplace) { + inplace_functor(compute_format, transformed_x.data(), + scale->template data>(), + bias->template data>(), + saved_mean_data, saved_var_data, epsilon, C, H * W * D, + num, transformed_x.data(), grid2, block, stream); + } + if (d_scale && d_bias) { bool called = false; #if CUDNN_VERSION_MIN(7, 4, 1) @@ -680,30 +749,41 @@ class BatchNormGradKernel const auto *running_var_data = running_var->template data>(); + if (is_inplace) { + auto px = *x; + inplace_functor(data_layout, px.mutable_data(ctx.GetPlace()), + scale->template data>(), + bias->template data>(), + running_mean_data, running_var_data, epsilon, C, + H * W * D, num, x->data(), grid2, block, stream); + } + if (compute_format == DataLayout::kNCHW) { if (d_x) { - KeBNBackwardData<<< - grid1, block, 0, dev_ctx.stream()>>>( + KeBNBackwardData< + T, framework::DataLayout::kNCHW><<>>( d_y->data(), scale->data>(), running_var_data, epsilon, C, H * W, num, d_x->data()); } if (d_scale && d_bias) { - KeBNBackwardScaleBias<<< - grid2, block, 0, dev_ctx.stream()>>>( + KeBNBackwardScaleBias< + T, block, + framework::DataLayout::kNCHW><<>>( d_y->data(), x->data(), running_mean_data, running_var_data, epsilon, N, C, H * W * D, d_scale->data>(), d_bias->data>()); } } else { if (d_x) { - KeBNBackwardData<<< - grid1, block, 0, dev_ctx.stream()>>>( + KeBNBackwardData< + T, framework::DataLayout::kNHWC><<>>( d_y->data(), scale->data>(), running_var_data, epsilon, C, H * W, num, d_x->data()); } if (d_scale && d_bias) { - KeBNBackwardScaleBias<<< - grid2, block, 0, dev_ctx.stream()>>>( + KeBNBackwardScaleBias< + T, block, + framework::DataLayout::kNHWC><<>>( d_y->data(), x->data(), running_mean_data, running_var_data, epsilon, N, C, H * W * D, d_scale->data>(), d_bias->data>()); diff --git a/paddle/fluid/operators/inplace_abn_op.cc b/paddle/fluid/operators/inplace_abn_op.cc new file mode 100644 index 00000000000..0b65699348f --- /dev/null +++ b/paddle/fluid/operators/inplace_abn_op.cc @@ -0,0 +1,208 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/inplace_abn_op.h" +#include +#include +#include +#include "paddle/fluid/framework/framework.pb.h" +#include "paddle/fluid/operators/batch_norm_op.h" + +namespace paddle { +namespace operators { + +class InplaceABNOp : public paddle::operators::BatchNormOp { + public: + using paddle::operators::BatchNormOp::BatchNormOp; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + // By default, the type of the scale, bias, mean, + // and var tensors should both be float. (For float or float16 input tensor) + // or double (For double input tensor). + auto bn_param_type = framework::proto::VarType::FP32; + if (input_data_type == framework::proto::VarType::FP64) { + bn_param_type = framework::proto::VarType::FP64; + } + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Scale")->type(), + platform::errors::InvalidArgument( + "Scale input should be of float type")); + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Bias")->type(), + platform::errors::InvalidArgument( + "Bias input should be of float type")); + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Mean")->type(), + platform::errors::InvalidArgument( + "Mean input should be of float type")); + PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input("Variance")->type(), + platform::errors::InvalidArgument( + "Variance input should be of float type")); + + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library); + } +}; + +class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { + public: + using paddle::operators::BatchNormGradOp::BatchNormGradOp; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + const auto* var = ctx.InputVar(framework::GradVarName("Y")); + auto input_data_type = ctx.Input("Y")->type(); + if (var == nullptr) { + PADDLE_THROW(platform::errors::InvalidArgument( + "can't find gradient variable of Y")); + } + const Tensor* t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } + if (t == nullptr) { + PADDLE_THROW( + platform::errors::InvalidArgument("gradient variable of Y is empty")); + } + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library); + } +}; + +class InplaceABNOpMaker : public paddle::operators::BatchNormOpMaker { + public: + void Make() override { + BatchNormOpMaker::Make(); + AddAttr( + "activation", + "(enum string, default identity, can be identity|elu|leaky-relu) " + "The activation type used for output candidate {h}_t.") + .SetDefault(""); + AddAttr("alpha", + "(float, default 1.0) Only used in inplace-abn kernel," + "the activation type(identity|elu|leakyrelu) would be fused " + "with batch_norm, " + "this is the alpha value for elu|leakyrelu.") + .SetDefault(0.1f); + AddAttr("use_sync_bn", + "(bool, default false) Whether use synchronize batch " + "normalization.") + .SetDefault(false); + } +}; + +template +class InplaceABNOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType(this->ForwardOpType() + "_grad"); + op->SetInput("Y", this->Output("Y")); + op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); + + op->SetInput("Scale", this->Input("Scale")); + op->SetInput("Bias", this->Input("Bias")); + op->SetInput("SavedMean", this->Output("SavedMean")); + op->SetInput("SavedVariance", this->Output("SavedVariance")); + + // used when setting use_global_stats True during training + if (boost::get(this->GetAttr("use_global_stats"))) { + op->SetInput("Mean", this->Output("MeanOut")); + op->SetInput("Variance", this->Output("VarianceOut")); + } + + op->SetAttrMap(this->Attrs()); + + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale")); + op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias")); + } +}; + +template +class InplaceABNKernel + : public paddle::operators::BatchNormKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* y = ctx.Output("Y"); + PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument( + "X and Y not inplaced in inplace mode")); + auto activation = + GetInplaceABNActivationType(ctx.Attr("activation")); + auto& place = *ctx.template device_context().eigen_device(); + BatchNormKernel::Compute(ctx); + + auto cur_y = EigenVector::Flatten(*y); + InplaceABNActivation functor; + functor.Compute(ctx, activation, place, cur_y, cur_y); + } +}; + +template +class InplaceABNGradKernel + : public paddle::operators::BatchNormGradKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* y = ctx.Input("Y"); + auto* d_y = ctx.Input(framework::GradVarName("Y")); + auto* d_x = ctx.Output(framework::GradVarName("X")); + PADDLE_ENFORCE_EQ(d_x, d_y, + platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD not inplaced in inplace mode")); + auto& place = *ctx.template device_context().eigen_device(); + auto activation = + GetInplaceABNActivationType(ctx.Attr("activation")); + + auto py = *y; + auto pd_y = *d_y; + auto cur_y = EigenVector::Flatten(py); + auto cur_dy = EigenVector::Flatten(pd_y); + + InplaceABNActivation functor; + functor.GradCompute(ctx, activation, place, cur_y, cur_y, cur_dy, cur_dy); + + BatchNormGradKernel::Compute(ctx); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(inplace_abn, ops::InplaceABNOp, ops::InplaceABNOpMaker, + ops::BatchNormOpInferVarType, + ops::InplaceABNOpGradMaker, + ops::InplaceABNOpGradMaker) +REGISTER_OPERATOR(inplace_abn_grad, ops::InplaceABNGradOp) + +REGISTER_OP_CPU_KERNEL( + inplace_abn, + ops::InplaceABNKernel, + ops::InplaceABNKernel); +REGISTER_OP_CPU_KERNEL( + inplace_abn_grad, + ops::InplaceABNGradKernel, + ops::InplaceABNGradKernel); diff --git a/paddle/fluid/operators/inplace_abn_op.cu b/paddle/fluid/operators/inplace_abn_op.cu new file mode 100644 index 00000000000..9e12a8291c0 --- /dev/null +++ b/paddle/fluid/operators/inplace_abn_op.cu @@ -0,0 +1,92 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/batch_norm_op.h" +#include "paddle/fluid/operators/inplace_abn_op.h" +#include "paddle/fluid/operators/sync_batch_norm_op.cu.h" + +namespace paddle { +namespace operators { + +template +class InplaceABNKernel + : public paddle::operators::SyncBatchNormKernel, + public paddle::operators::BatchNormKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* y = ctx.Output("Y"); + auto* x = ctx.Input("X"); + PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument( + "X and Y not inplaced in inplace mode")); + auto activation = + GetInplaceABNActivationType(ctx.Attr("activation")); + auto& place = *ctx.template device_context().eigen_device(); + + if (ctx.Attr("use_sync_bn")) { + SyncBatchNormKernel::Compute(ctx); + } else { + BatchNormKernel::Compute(ctx); + } + + auto cur_y = EigenVector::Flatten(*y); + InplaceABNActivation functor; + functor.Compute(ctx, activation, place, cur_y, cur_y); + } +}; + +// Deriving the Gradient for the Backward Pass of Batch Normalization +// https://kevinzakka.github.io/2016/09/14/batch_normalization/ +template +class InplaceABNGradKernel + : public paddle::operators::SyncBatchNormGradKernel, + public paddle::operators::BatchNormGradKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* y = ctx.Input("Y"); + auto* d_y = ctx.Input(framework::GradVarName("Y")); + auto* d_x = ctx.Output(framework::GradVarName("X")); + PADDLE_ENFORCE_EQ(d_x, d_y, + platform::errors::InvalidArgument( + "X@GRAD and Y@GRAD not inplaced in inplace mode")); + auto& place = *ctx.template device_context().eigen_device(); + auto activation = + GetInplaceABNActivationType(ctx.Attr("activation")); + + auto py = *y; + auto pd_y = *d_y; + auto cur_y = EigenVector::Flatten(py); + auto cur_dy = EigenVector::Flatten(pd_y); + + InplaceABNActivation functor; + functor.GradCompute(ctx, activation, place, cur_y, cur_y, cur_dy, cur_dy); + + if (ctx.Attr("use_sync_bn")) { + SyncBatchNormGradKernel::Compute(ctx); + } else { + BatchNormGradKernel::Compute(ctx); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL(inplace_abn, + ops::InplaceABNKernel, + ops::InplaceABNKernel); +REGISTER_OP_CUDA_KERNEL( + inplace_abn_grad, ops::InplaceABNGradKernel, + ops::InplaceABNGradKernel); diff --git a/paddle/fluid/operators/inplace_abn_op.h b/paddle/fluid/operators/inplace_abn_op.h new file mode 100644 index 00000000000..1c90a645bf8 --- /dev/null +++ b/paddle/fluid/operators/inplace_abn_op.h @@ -0,0 +1,117 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +using EigenVector = framework::EigenVector; + +enum InplaceABNActivationType { identity = 0, leakyrelu = 1, elu = 2 }; + +inline InplaceABNActivationType GetInplaceABNActivationType( + const std::string& type) { + if (type == "leaky_relu") { + return InplaceABNActivationType::leakyrelu; + } else if (type == "elu") { + return InplaceABNActivationType::elu; + } else if (type == "identity" || type == "") { + return InplaceABNActivationType::identity; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "unsupported activation type %s for Op(inplace_abn)", type)); + } +} + +template +class InplaceABNActivation { + private: + template + void setAttrs(const framework::ExecutionContext& ctx, Functor* functor) { + auto attrs = functor->GetAttrs(); + for (auto& attr : attrs) { + *attr.second = ctx.Attr(attr.first); + } + } + + template + void compute(const framework::ExecutionContext& ctx, Functor* functor, + Args... args) { + setAttrs(ctx, functor); + (*functor)(args...); + } + + public: + template + void Compute(const framework::ExecutionContext& ctx, const int act_type, + const Device& d, X x, Y y) { + if (act_type == InplaceABNActivationType::identity) { + y.device(d) = x; + } else if (act_type == InplaceABNActivationType::leakyrelu) { + LeakyReluFunctor functor; + compute(ctx, &functor, d, x, y); + } else if (act_type == InplaceABNActivationType::elu) { + ELUFunctor functor; + compute(ctx, &functor, d, x, y); + } else { + PADDLE_THROW( + platform::errors::InvalidArgument("unsupported activation type")); + } + } + + template + void GradCompute(const framework::ExecutionContext& ctx, const int act_type, + const Device& d, X x, Y y, DX dx, DY dy) { + const float alpha = ctx.Attr("alpha"); + + if (act_type == InplaceABNActivationType::identity) { + x.device(d) = y; + dx.device(d) = dy; + } else if (act_type == InplaceABNActivationType::leakyrelu) { + auto temp1 = (y < static_cast(0)).template cast().eval() / + static_cast(alpha); + auto temp2 = (y >= static_cast(0)).template cast().eval(); + x.device(d) = y * (temp1 + temp2).template cast(); + + LeakyReluGradFunctor functor; + compute(ctx, &functor, d, x, y, dy, dx); + } else if (act_type == InplaceABNActivationType::elu) { + auto temp1 = (y >= static_cast(0)).template cast().eval(); + auto temp = (y < static_cast(0)).template cast().eval(); + auto temp2 = (y * temp / static_cast(alpha) + static_cast(1)).log(); + x.device(d) = (y * temp1 + temp2).template cast(); + + ELUGradFunctor functor; + compute(ctx, &functor, d, x, y, dy, dx); + } else { + PADDLE_THROW( + platform::errors::InvalidArgument("unsupported activation type")); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu b/paddle/fluid/operators/sync_batch_norm_op.cu index fb4ae48eb07..d79667cfdcb 100644 --- a/paddle/fluid/operators/sync_batch_norm_op.cu +++ b/paddle/fluid/operators/sync_batch_norm_op.cu @@ -12,113 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -// clang-format off -#include -#include -#include -#include -#include -#include "cub/cub.cuh" -#include "paddle/fluid/framework/data_layout.h" -#include "paddle/fluid/memory/malloc.h" -#include "paddle/fluid/operators/batch_norm_op.h" -#include "paddle/fluid/operators/norm_utils.h" -#include "paddle/fluid/platform/cudnn_helper.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/fluid/platform/nccl_helper.h" +#include "paddle/fluid/operators/sync_batch_norm_op.cu.h" namespace paddle { namespace operators { -using Tensor = framework::Tensor; -using DataLayout = framework::DataLayout; template -using CudnnDataType = platform::CudnnDataType; -template -using BatchNormParamType = typename CudnnDataType::BatchNormParamType; - -template -__global__ void KeLocalStats(const T *x, int N, int M, int C, - BatchNormParamType *mean_var) { - typedef cub::BlockReduce, BlockDim> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - for (int k = blockIdx.x; k < C; k += gridDim.x) { - BatchNormParamType x_sum = 0.; - BatchNormParamType x2_sum = 0.; - for (int i = threadIdx.x; i < N * M; i += BlockDim) { - int id = layout == framework::DataLayout::kNCHW - ? (i / M) * C * M + k * M + i % M - : i * C + k; - auto x_in = static_cast>(x[id]); - x_sum += x_in; - x2_sum += x_in * x_in; - } - __syncthreads(); - auto out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum()); - __syncthreads(); - if (threadIdx.x == 0) { - mean_var[k] = out / (N * M); - } - out = BlockReduce(temp_storage).Reduce(x2_sum, cub::Sum()); - __syncthreads(); - if (threadIdx.x == 0) { - mean_var[k + C] = out / (N * M); - } - } - if (blockIdx.x == 0 && threadIdx.x == 0) { - mean_var[2 * C] = static_cast>(1.0); - } -} - -template -__global__ void KeSyncAndMovingStats( - BatchNormParamType *means, BatchNormParamType *variances, - BatchNormParamType *num_dev, const int C, - const BatchNormParamType momentum, const double epsilon, - BatchNormParamType *sv_mean_data, BatchNormParamType *sv_inv_var_data, - BatchNormParamType *moving_means, - BatchNormParamType *moving_variances) { - // sync stats across multi-devices - int gid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - for (int i = gid; i < C; i += stride) { - auto mean = means[i] / (*num_dev); - auto var = variances[i] / (*num_dev); - var = var - mean * mean; - - // sync stats - sv_mean_data[i] = mean; - sv_inv_var_data[i] = 1.0 / sqrt(var + epsilon); - variances[i] = var; - - // moving stats - moving_means[i] = moving_means[i] * momentum + mean * (1. - momentum); - moving_variances[i] = - moving_variances[i] * momentum + var * (1. - momentum); - } -} - -template -static __global__ void KeNormAffine(const T *x, - const BatchNormParamType *scale, - const BatchNormParamType *bias, - const BatchNormParamType *mean, - const BatchNormParamType *variance, - const double epsilon, const int C, - const int M, const int num, T *y) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - for (int i = gid; i < num; i += stride) { - const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C; - auto x_i = static_cast>(x[i]); - auto y_i = - (x_i - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c]; - y[i] = static_cast(y_i); - } -} - -template -class SyncBatchNormKernel : public framework::OpKernel { +class SyncBatchNormKernel + : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { double epsilon = static_cast(ctx.Attr("epsilon")); @@ -127,331 +28,59 @@ class SyncBatchNormKernel : public framework::OpKernel { const std::string layout_str = ctx.Attr("data_layout"); const DataLayout layout = framework::StringToDataLayout(layout_str); const bool use_global_stats = ctx.Attr("use_global_stats"); - PADDLE_ENFORCE( - !use_global_stats, - "sync_batch_norm doesn't support to set use_global_stats True. ", - "Please use batch_norm in this case."); + PADDLE_ENFORCE_EQ(use_global_stats, false, + platform::errors::InvalidArgument( + "sync_batch_norm doesn't support " + "to set use_global_stats True. Please use batch_norm " + "in this case.")); const auto *x = ctx.Input("X"); - const auto &x_dims = x->dims(); - PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, - "The Input dim size should be between 2 and 5"); - int N, C, H, W, D; - ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); - int x_numel = x->numel(); - - const T *x_d = x->data(); - const auto *s_d = ctx.Input("Scale")->data>(); - const auto *b_d = ctx.Input("Bias")->data>(); - auto *y = ctx.Output("Y"); - T *y_d = y->mutable_data(ctx.GetPlace()); - - const BatchNormParamType *mean_data = nullptr; - const BatchNormParamType *var_data = nullptr; - - auto &dev_ctx = ctx.cuda_device_context(); - auto stream = dev_ctx.stream(); - auto *comm = dev_ctx.nccl_comm(); - const int block = 512; - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - - paddle::memory::AllocationPtr alloc_ptr{nullptr}; - - if (is_test) { - const auto *est_mean = ctx.Input("Mean"); - const auto *est_var = ctx.Input("Variance"); - mean_data = est_mean->data>(); - var_data = est_var->data>(); - } else { - // x, x^2, 1, here 1 is used to calc device num - // device num also can be got from platform::DeviceContextPool - const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); - alloc_ptr = memory::Alloc(dev_ctx, bytes); - - auto *stats = reinterpret_cast *>(alloc_ptr->ptr()); - const int threads = 256; - int grid = std::min(C, (max_threads + threads - 1) / threads); - if (layout == framework::DataLayout::kNCHW) { - KeLocalStats - <<>>(x_d, N, H * W * D, C, stats); - } else { - KeLocalStats - <<>>(x_d, N, H * W * D, C, stats); - } - - // moving mean/variance - auto *mean_out = ctx.Output("MeanOut"); - auto *variance_out = ctx.Output("VarianceOut"); - auto *est_mean_data = - mean_out->mutable_data>(ctx.GetPlace()); - auto *est_var_data = - variance_out->mutable_data>(ctx.GetPlace()); - - auto *saved_mean = ctx.Output("SavedMean"); - auto *saved_inv_variance = ctx.Output("SavedVariance"); - auto *sv_mean_data = - saved_mean->mutable_data>(ctx.GetPlace()); - auto *sv_inv_var_data = - saved_inv_variance->mutable_data>( - ctx.GetPlace()); - - Tensor c_g_st; - auto *c_g_st_d = c_g_st.mutable_data>( - {2 * C + 1}, platform::CPUPlace()); - auto gplace = boost::get(ctx.GetPlace()); - memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0); - int dtype = platform::ToNCCLDataType(mean_out->type()); - // In-place operation - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( - stats, stats, 2 * C + 1, static_cast(dtype), ncclSum, - comm, stream)); + const auto *est_mean = ctx.Input("Mean"); + const auto *est_var = ctx.Input("Variance"); - // Note, Input('Mean')/Input('Variance') share variable with - // Output('MeanOut')/Output('VarianceOut') - KeSyncAndMovingStats<<<(C + block - 1) / block, block, 0, stream>>>( - stats, stats + C, stats + 2 * C, C, momentum, epsilon, sv_mean_data, - sv_inv_var_data, est_mean_data, est_var_data); + // moving mean/variance + auto *mean_out = ctx.Output("MeanOut"); + auto *variance_out = ctx.Output("VarianceOut"); - mean_data = sv_mean_data; - var_data = stats + C; - } + auto *saved_mean = ctx.Output("SavedMean"); + auto *saved_inv_variance = ctx.Output("SavedVariance"); - int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; - if (layout == framework::DataLayout::kNCHW) { - KeNormAffine - <<>>(x_d, s_d, b_d, mean_data, var_data, - epsilon, C, H * W * D, x_numel, y_d); - } else { - KeNormAffine - <<>>(x_d, s_d, b_d, mean_data, var_data, - epsilon, C, H * W * D, x_numel, y_d); - } + SyncBatchNormFunctor( + ctx, layout, x, y, est_mean, est_var, mean_out, variance_out, + saved_mean, saved_inv_variance, epsilon, momentum, is_test, + use_global_stats); } }; -template -__global__ void KeBackwardLocalStats(const T *dy, const T *x, - const BatchNormParamType *means, int N, - int M, int C, - BatchNormParamType *sum_dy_prod) { - typedef cub::BlockReduce, BlockDim> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - for (int k = blockIdx.x; k < C; k += gridDim.x) { - BatchNormParamType sum1 = 0.; - BatchNormParamType sum2 = 0.; - auto mean = means[k]; - for (int i = threadIdx.x; i < N * M; i += blockDim.x) { - int id = layout == framework::DataLayout::kNCHW - ? (i / M) * C * M + k * M + i % M - : i * C + k; - auto g = static_cast>(dy[id]); - sum1 += g; - auto x_i = static_cast>(x[id]); - sum2 += g * (x_i - mean); - } - - __syncthreads(); - auto out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum()); - __syncthreads(); - if (threadIdx.x == 0) { - sum_dy_prod[k] = out; - } - out = BlockReduce(temp_storage).Reduce(sum2, cub::Sum()); - __syncthreads(); - if (threadIdx.x == 0) { - sum_dy_prod[k + C] = out; - } - } - if (blockIdx.x == 0 && threadIdx.x == 0) { - sum_dy_prod[2 * C] = 1.0; - } -} - -template -static __global__ void KeBNBackwardScaleBias( - const T *dy, const T *x, const BatchNormParamType *mean, - const BatchNormParamType *inv_variance, const double epsilon, - const int N, const int C, const int HxW, BatchNormParamType *dscale, - BatchNormParamType *dbias) { - const int outer_size = C; - const int inner_size = N * HxW; - typedef cub::BlockReduce, BlockDim> BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { - BatchNormParamType ds_sum = 0.; - BatchNormParamType db_sum = 0.; - - auto inv_var_i = inv_variance[i]; - auto mean_i = mean[i]; - for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { - const int id = layout == framework::DataLayout::kNCHW - ? ((j / HxW) * C + i) * HxW + (j % HxW) - : j * outer_size + i; - auto x_i = static_cast>(x[id]); - auto dy_i = static_cast>(dy[id]); - ds_sum += dy_i * (x_i - mean_i); - db_sum += dy_i; - } - __syncthreads(); - auto os = BlockReduce(temp_storage).Reduce(ds_sum, cub::Sum()); - __syncthreads(); - auto ob = BlockReduce(temp_storage).Reduce(db_sum, cub::Sum()); - __syncthreads(); - if (threadIdx.x == 0) { - dscale[i] = os * inv_var_i; - dbias[i] = ob; - } - __syncthreads(); - } -} - -template -static __global__ void KeBNBackwardData( - const T *dy, const T *x, const BatchNormParamType *gamma, - const BatchNormParamType *mean, - const BatchNormParamType *inv_variance, - const BatchNormParamType *g_sum_dy, - const BatchNormParamType *g_sum_dy_prod, - const BatchNormParamType *num_dev, const double epsilon, const int C, - const int HxW, const int num, T *dx) { - int gid = blockIdx.x * blockDim.x + threadIdx.x; - int stride = blockDim.x * gridDim.x; - auto scale = static_cast>(C) / num; - auto dev_num = num_dev[0]; - for (int i = gid; i < num; i += stride) { - const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C; - auto inv_var = inv_variance[c]; - auto s_d = gamma[c]; - auto gvar = - -((g_sum_dy_prod[c] / dev_num) * s_d * inv_var * (inv_var * inv_var)); - auto gmean = -((g_sum_dy[c] / dev_num) * s_d * inv_var); - - auto x_i = static_cast>(x[i]); - auto dy_i = static_cast>(dy[i]); - auto dx_i = - dy_i * s_d * inv_var + gmean * scale + gvar * scale * (x_i - mean[c]); - dx[i] = static_cast(dx_i); - } -} - -// Deriving the Gradient for the Backward Pass of Batch Normalization -// https://kevinzakka.github.io/2016/09/14/batch_normalization/ -template -class SyncBatchNormGradKernel : public framework::OpKernel { +template +class SyncBatchNormGradKernel + : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { - PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), - "It must use CUDAPlace."); + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument("It must use CUDAPlace.")); double epsilon = static_cast(ctx.Attr("epsilon")); const std::string layout_str = ctx.Attr("data_layout"); const DataLayout layout = framework::StringToDataLayout(layout_str); - const auto *x = ctx.Input("X"); const auto *d_y = ctx.Input(framework::GradVarName("Y")); const auto *scale = ctx.Input("Scale"); - - const auto &x_dims = x->dims(); - - PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, - "The Input dim size should be between 2 and 5"); - int N, C, H, W, D; - ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); + const auto *bias = ctx.Input("Bias"); // init output auto *d_x = ctx.Output(framework::GradVarName("X")); auto *d_scale = ctx.Output(framework::GradVarName("Scale")); auto *d_bias = ctx.Output(framework::GradVarName("Bias")); - d_x->mutable_data(ctx.GetPlace()); - if (d_scale && d_bias) { - d_scale->mutable_data>(ctx.GetPlace()); - d_bias->mutable_data>(ctx.GetPlace()); - } - PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); - PADDLE_ENFORCE_EQ(scale->dims()[0], C); - - std::vector dims; - std::vector strides; - if (layout == DataLayout::kNCHW) { - dims = {N, C, H, W, D}; - strides = {C * H * W * D, H * W * D, W * D, D, 1}; - } else { - dims = {N, C, H, W, D}; - strides = {H * W * C * D, 1, W * D * C, D * C, C}; - } - - const T *x_d = x->data(); - const T *dy_d = d_y->data(); - - auto &dev_ctx = ctx.cuda_device_context(); - auto stream = dev_ctx.stream(); - auto *comm = dev_ctx.nccl_comm(); - - const auto *saved_mean = - ctx.Input("SavedMean")->data>(); - const auto *saved_inv_var = - ctx.Input("SavedVariance")->data>(); - const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); - auto alloc_ptr = memory::Alloc(dev_ctx, bytes); - auto *stats = reinterpret_cast *>(alloc_ptr->ptr()); - - const int threads = 256; - int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); - int grid = std::min(C, (max_threads + threads - 1) / threads); - int x_numel = x->numel(); - int fsize = H * W * D; - - if (layout == framework::DataLayout::kNCHW) { - KeBackwardLocalStats - <<>>(dy_d, x_d, saved_mean, N, fsize, C, - stats); - } else { - KeBackwardLocalStats - <<>>(dy_d, x_d, saved_mean, N, fsize, C, - stats); - } - int dtype = platform::ToNCCLDataType(scale->type()); - // In-place operation - PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclAllReduce( - stats, stats, 2 * C + 1, static_cast(dtype), ncclSum, - comm, stream)); + const auto *saved_mean = ctx.Input("SavedMean"); + const auto *saved_inv_var = ctx.Input("SavedVariance"); - const int block = 512; - int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; - if (layout == framework::DataLayout::kNCHW) { - if (d_scale && d_bias) { - KeBNBackwardScaleBias - <<>>( - dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, - d_scale->data>(), - d_bias->data>()); - } - if (d_x) { - KeBNBackwardData - <<>>( - dy_d, x_d, scale->data>(), saved_mean, - saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C, - fsize, x->numel(), d_x->data()); - } - } else { - if (d_scale && d_bias) { - KeBNBackwardScaleBias - <<>>( - dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, - d_scale->data>(), - d_bias->data>()); - } - if (d_x) { - KeBNBackwardData - <<>>( - dy_d, x_d, scale->data>(), saved_mean, - saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C, - fsize, x->numel(), d_x->data()); - } - } + SyncBatchNormGradFunctor( + ctx, layout, scale, bias, d_x, d_y, d_scale, d_bias, saved_mean, + saved_inv_var, epsilon); } }; diff --git a/paddle/fluid/operators/sync_batch_norm_op.cu.h b/paddle/fluid/operators/sync_batch_norm_op.cu.h new file mode 100644 index 00000000000..083d22aa2a3 --- /dev/null +++ b/paddle/fluid/operators/sync_batch_norm_op.cu.h @@ -0,0 +1,530 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include "cub/cub.cuh" +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/batch_norm_op.h" +#include "paddle/fluid/operators/norm_utils.h" +#include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/nccl_helper.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DataLayout = framework::DataLayout; +template +using CudnnDataType = platform::CudnnDataType; +template +using BatchNormParamType = typename CudnnDataType::BatchNormParamType; + +template +__global__ void KeLocalStats(const T *x, int N, int M, int C, + BatchNormParamType *mean_var) { + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + for (int k = blockIdx.x; k < C; k += gridDim.x) { + BatchNormParamType x_sum = 0.; + BatchNormParamType x2_sum = 0.; + for (int i = threadIdx.x; i < N * M; i += BlockDim) { + int id = layout == framework::DataLayout::kNCHW + ? (i / M) * C * M + k * M + i % M + : i * C + k; + auto x_in = static_cast>(x[id]); + x_sum += x_in; + x2_sum += x_in * x_in; + } + __syncthreads(); + auto out = BlockReduce(temp_storage).Reduce(x_sum, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + mean_var[k] = out / (N * M); + } + out = BlockReduce(temp_storage).Reduce(x2_sum, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + mean_var[k + C] = out / (N * M); + } + } + if (blockIdx.x == 0 && threadIdx.x == 0) { + mean_var[2 * C] = static_cast>(1.0); + } +} + +template +__global__ void KeSyncAndMovingStats( + BatchNormParamType *means, BatchNormParamType *variances, + BatchNormParamType *num_dev, const int C, + const BatchNormParamType momentum, const double epsilon, + BatchNormParamType *sv_mean_data, BatchNormParamType *sv_inv_var_data, + BatchNormParamType *moving_means, + BatchNormParamType *moving_variances) { + // sync stats across multi-devices + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < C; i += stride) { + auto mean = means[i] / (*num_dev); + auto var = variances[i] / (*num_dev); + var = var - mean * mean; + + // sync stats + sv_mean_data[i] = mean; + sv_inv_var_data[i] = 1.0 / sqrt(var + epsilon); + variances[i] = var; + + // moving stats + moving_means[i] = moving_means[i] * momentum + mean * (1. - momentum); + moving_variances[i] = + moving_variances[i] * momentum + var * (1. - momentum); + } +} + +template +static __global__ void KeNormAffine(const T *x, + const BatchNormParamType *scale, + const BatchNormParamType *bias, + const BatchNormParamType *mean, + const BatchNormParamType *variance, + const double epsilon, const int C, + const int M, const int num, T *y) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < num; i += stride) { + const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C; + auto x_i = static_cast>(x[i]); + auto y_i = + (x_i - mean[c]) / sqrt(variance[c] + epsilon) * scale[c] + bias[c]; + y[i] = static_cast(y_i); + } +} + +template +void SyncBatchNormFunctor(const framework::ExecutionContext &ctx, + const DataLayout layout, const framework::Tensor *x, + framework::Tensor *y, const framework::Tensor *mean, + const framework::Tensor *variance, + framework::Tensor *mean_out, + framework::Tensor *variance_out, + framework::Tensor *saved_mean, + framework::Tensor *saved_variance, double epsilon, + const float momentum, const bool is_test, + const bool use_global_stats + + ) { + const auto &x_dims = x->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), 2, + platform::errors::InvalidArgument( + "The Input dim size should be larger than 1.")); + PADDLE_ENFORCE_LE(x_dims.size(), 5, + platform::errors::InvalidArgument( + "The Input dim size should be less than 6.")); + int N, C, H, W, D; + ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); + int x_numel = x->numel(); + + const T *x_d = x->data(); + const auto *s_d = ctx.Input("Scale")->data>(); + const auto *b_d = ctx.Input("Bias")->data>(); + + T *y_d = y->mutable_data(ctx.GetPlace()); + + const BatchNormParamType *mean_data = nullptr; + const BatchNormParamType *var_data = nullptr; + + auto &dev_ctx = ctx.cuda_device_context(); + auto stream = dev_ctx.stream(); + const int block = 512; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + + paddle::memory::AllocationPtr alloc_ptr{nullptr}; + + if (is_test) { + mean_data = mean->data>(); + var_data = variance->data>(); + } else { + // x, x^2, 1, here 1 is used to calc device num + // device num also can be got from platform::DeviceContextPool + const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); + alloc_ptr = memory::Alloc(dev_ctx, bytes); + + auto *stats = reinterpret_cast *>(alloc_ptr->ptr()); + const int threads = 256; + int grid = std::min(C, (max_threads + threads - 1) / threads); + if (layout == framework::DataLayout::kNCHW) { + KeLocalStats<<>>( + x_d, N, H * W * D, C, stats); + } else { + KeLocalStats<<>>( + x_d, N, H * W * D, C, stats); + } + + Tensor c_g_st; + auto *c_g_st_d = c_g_st.mutable_data>( + {2 * C + 1}, platform::CPUPlace()); + auto gplace = boost::get(ctx.GetPlace()); + memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0); + +#ifndef WIN32 + auto *comm = dev_ctx.nccl_comm(); + if (comm) { + int dtype = platform::ToNCCLDataType(mean_out->type()); + // In-place operation + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclAllReduce(stats, stats, 2 * C + 1, + static_cast(dtype), + ncclSum, comm, stream), + platform::errors::InvalidArgument( + "ncclAllReduce in Op(sync_batch_norm) failed")); + } +#endif + + auto *est_mean_data = + mean_out->mutable_data>(ctx.GetPlace()); + auto *est_var_data = + variance_out->mutable_data>(ctx.GetPlace()); + + auto *sv_mean_data = + saved_mean->mutable_data>(ctx.GetPlace()); + auto *sv_inv_var_data = + saved_variance->mutable_data>(ctx.GetPlace()); + + // Note, Input('Mean')/Input('Variance') share variable with + // Output('MeanOut')/Output('VarianceOut') + KeSyncAndMovingStats<<<(C + block - 1) / block, block, 0, stream>>>( + stats, stats + C, stats + 2 * C, C, momentum, epsilon, sv_mean_data, + sv_inv_var_data, est_mean_data, est_var_data); + + mean_data = sv_mean_data; + var_data = stats + C; + } + + int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; + if (layout == framework::DataLayout::kNCHW) { + KeNormAffine<<>>( + x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel, + y_d); + } else { + KeNormAffine<<>>( + x_d, s_d, b_d, mean_data, var_data, epsilon, C, H * W * D, x_numel, + y_d); + } +} + +template +__global__ void KeBackwardLocalStats(const T *dy, const T *x, + const BatchNormParamType *means, int N, + int M, int C, + BatchNormParamType *sum_dy_prod) { + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + for (int k = blockIdx.x; k < C; k += gridDim.x) { + BatchNormParamType sum1 = 0.; + BatchNormParamType sum2 = 0.; + auto mean = means[k]; + for (int i = threadIdx.x; i < N * M; i += blockDim.x) { + int id = layout == framework::DataLayout::kNCHW + ? (i / M) * C * M + k * M + i % M + : i * C + k; + auto g = static_cast>(dy[id]); + sum1 += g; + auto x_i = static_cast>(x[id]); + sum2 += g * (x_i - mean); + } + + __syncthreads(); + auto out = BlockReduce(temp_storage).Reduce(sum1, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + sum_dy_prod[k] = out; + } + out = BlockReduce(temp_storage).Reduce(sum2, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + sum_dy_prod[k + C] = out; + } + } + if (blockIdx.x == 0 && threadIdx.x == 0) { + sum_dy_prod[2 * C] = 1.0; + } +} + +template +static __global__ void KeBNBackwardScaleBias( + const T *dy, const T *x, const BatchNormParamType *mean, + const BatchNormParamType *inv_variance, const double epsilon, + const int N, const int C, const int HxW, BatchNormParamType *dscale, + BatchNormParamType *dbias) { + const int outer_size = C; + const int inner_size = N * HxW; + typedef cub::BlockReduce, BlockDim> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + BatchNormParamType ds_sum = 0.; + BatchNormParamType db_sum = 0.; + + auto inv_var_i = inv_variance[i]; + auto mean_i = mean[i]; + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int id = layout == framework::DataLayout::kNCHW + ? ((j / HxW) * C + i) * HxW + (j % HxW) + : j * outer_size + i; + auto x_i = static_cast>(x[id]); + auto dy_i = static_cast>(dy[id]); + ds_sum += dy_i * (x_i - mean_i); + db_sum += dy_i; + } + __syncthreads(); + auto os = BlockReduce(temp_storage).Reduce(ds_sum, cub::Sum()); + __syncthreads(); + auto ob = BlockReduce(temp_storage).Reduce(db_sum, cub::Sum()); + __syncthreads(); + if (threadIdx.x == 0) { + dscale[i] = os * inv_var_i; + dbias[i] = ob; + } + __syncthreads(); + } +} + +template +static __global__ void KeBNRestoreData(T *x, const BatchNormParamType *scale, + const BatchNormParamType *bias, + const BatchNormParamType *mean, + const BatchNormParamType *sv_inv, + const double epsilon, int C, int M, + int num, const T *y) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < num; i += stride) { + const int c = layout == framework::DataLayout::kNCHW ? (i / M) % C : i % C; + auto y_i = static_cast>(y[i]); + auto x_i = (y_i - bias[c]) / scale[c] / sv_inv[c] + mean[c]; + x[i] = static_cast(x_i); + } +} + +template +static __global__ void KeBNBackwardData( + const T *dy, const T *x, const BatchNormParamType *gamma, + const BatchNormParamType *mean, + const BatchNormParamType *inv_variance, + const BatchNormParamType *g_sum_dy, + const BatchNormParamType *g_sum_dy_prod, + const BatchNormParamType *num_dev, const double epsilon, const int C, + const int HxW, const int num, T *dx) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + auto scale = static_cast>(C) / num; + auto dev_num = num_dev[0]; + for (int i = gid; i < num; i += stride) { + const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C; + auto inv_var = inv_variance[c]; + auto s_d = gamma[c]; + auto gvar = + -(g_sum_dy_prod[c] / dev_num) * s_d * inv_var * (inv_var * inv_var); + auto gmean = -(g_sum_dy[c] / dev_num) * s_d * inv_var; + + auto x_i = static_cast>(x[i]); + auto dy_i = static_cast>(dy[i]); + auto dx_i = + dy_i * s_d * inv_var + gmean * scale + gvar * scale * (x_i - mean[c]); + dx[i] = static_cast(dx_i); + } +} + +template +void SyncBatchNormGradFunctor( + const framework::ExecutionContext &ctx, const DataLayout layout, + const framework::Tensor *scale, const framework::Tensor *bias, + framework::Tensor *d_x, const framework::Tensor *d_y, + framework::Tensor *d_scale, framework::Tensor *d_bias, + const framework::Tensor *mean, const framework::Tensor *variance, + const double epsilon) { + // sync_batch_norm with inplace as false will take X as grad input, which + // is same as cuDNN batch_norm backward calculation, batch_norm + // with inplace as true only take Y as input and X should be calculate + // by inverse operation of batch_norm on Y + const Tensor *x; + bool is_inplace; + if (ctx.HasInput("Y")) { + x = ctx.Input("Y"); + is_inplace = true; + } else { + x = ctx.Input("X"); + is_inplace = false; + } + + const auto &x_dims = x->dims(); + + PADDLE_ENFORCE_GE(x_dims.size(), 2, + platform::errors::InvalidArgument( + "The Input X dim size should be larger than 1.")); + PADDLE_ENFORCE_LE(x_dims.size(), 5, + platform::errors::InvalidArgument( + "The Input X dim size should be less than 6.")); + + int N, C, H, W, D; + ExtractNCWHD(x_dims, layout, &N, &C, &H, &W, &D); + PADDLE_ENFORCE_EQ(scale->dims()[0], C, + platform::errors::InvalidArgument( + "Expected first dim for input parameter(scale) of " + "OP(sync_batch_norm) be (%d), but given (%d).", + C, scale->dims()[0])); + + d_x->mutable_data(ctx.GetPlace()); + if (d_scale && d_bias) { + d_scale->mutable_data>(ctx.GetPlace()); + d_bias->mutable_data>(ctx.GetPlace()); + } + PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL, + platform::errors::InvalidArgument( + "Expected rank for input parameter(scale) of " + "OP(sync_batch_norm) be (1), but given (%d).", + scale->dims().size())); + + std::vector dims; + std::vector strides; + if (layout == DataLayout::kNCHW) { + dims = {N, C, H, W, D}; + strides = {C * H * W * D, H * W * D, W * D, D, 1}; + } else { + dims = {N, C, H, W, D}; + strides = {H * W * C * D, 1, W * D * C, D * C, C}; + } + const T *x_d = x->data(); + auto px = *x; + const T *dy_d = d_y->data(); + + auto &dev_ctx = ctx.cuda_device_context(); + auto stream = dev_ctx.stream(); + + const auto *saved_mean = mean->data>(); + const auto *saved_inv_var = variance->data>(); + const int bytes = (C * 2 + 1) * sizeof(BatchNormParamType); + auto alloc_ptr = memory::Alloc(dev_ctx, bytes); + auto *stats = reinterpret_cast *>(alloc_ptr->ptr()); + + const int block = 512; + const int threads = 256; + int x_numel = x->numel(); + int fsize = H * W * D; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + int grid = std::min(C, (max_threads + threads - 1) / threads); + int grid2 = (std::min(x_numel, max_threads) + block - 1) / block; + + if (is_inplace) { + if (layout == framework::DataLayout::kNCHW) { + KeBNRestoreData< + T, framework::DataLayout::kNCHW><<>>( + px.mutable_data(ctx.GetPlace()), + scale->data>(), + bias->data>(), saved_mean, saved_inv_var, + epsilon, C, H * W * D, x_numel, x->data()); + } else { + KeBNRestoreData< + T, framework::DataLayout::kNHWC><<>>( + px.mutable_data(ctx.GetPlace()), + scale->data>(), + bias->data>(), saved_mean, saved_inv_var, + epsilon, C, H * W * D, x_numel, x->data()); + } + } + + if (layout == framework::DataLayout::kNCHW) { + KeBackwardLocalStats< + T, threads, framework::DataLayout::kNCHW><<>>( + dy_d, x_d, saved_mean, N, fsize, C, stats); + } else { + KeBackwardLocalStats< + T, threads, framework::DataLayout::kNHWC><<>>( + dy_d, x_d, saved_mean, N, fsize, C, stats); + } + +#ifndef WIN32 + auto *comm = dev_ctx.nccl_comm(); + if (comm) { + int dtype = platform::ToNCCLDataType(scale->type()); + // In-place operation + PADDLE_ENFORCE_CUDA_SUCCESS( + platform::dynload::ncclAllReduce(stats, stats, 2 * C + 1, + static_cast(dtype), + ncclSum, comm, stream), + platform::errors::InvalidArgument( + "ncclAllReduce in Op(sync_batch_norm) failed")); + } +#endif + + if (layout == framework::DataLayout::kNCHW) { + if (d_scale && d_bias) { + KeBNBackwardScaleBias< + T, threads, + framework::DataLayout::kNCHW><<>>( + dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, + d_scale->data>(), + d_bias->data>()); + } + if (d_x) { + KeBNBackwardData< + T, framework::DataLayout::kNCHW><<>>( + dy_d, x_d, scale->data>(), saved_mean, + saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C, fsize, + x->numel(), d_x->data()); + } + } else { + if (d_scale && d_bias) { + KeBNBackwardScaleBias< + T, threads, + framework::DataLayout::kNHWC><<>>( + dy_d, x_d, saved_mean, saved_inv_var, epsilon, N, C, fsize, + d_scale->data>(), + d_bias->data>()); + } + if (d_x) { + KeBNBackwardData< + T, framework::DataLayout::kNHWC><<>>( + dy_d, x_d, scale->data>(), saved_mean, + saved_inv_var, stats, stats + C, stats + 2 * C, epsilon, C, fsize, + x->numel(), d_x->data()); + } + } +} + +template +class SyncBatchNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override; +}; + +// Deriving the Gradient for the Backward Pass of Batch Normalization +// https://kevinzakka.github.io/2016/09/14/batch_normalization/ +template +class SyncBatchNormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index bc00124b354..3f85f89a529 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -50,6 +50,7 @@ __all__ = [ 'adaptive_pool2d', 'adaptive_pool3d', 'batch_norm', + 'inplace_abn', 'instance_norm', 'data_norm', 'conv2d_transpose', @@ -2638,9 +2639,9 @@ def batch_norm(input, If the Initializer of the bias_attr is not set, the bias is initialized zero. Default: None. data_layout (str, optional): Specify the data format of the input, and the data format of the output - will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. - The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: - `[batch_size, input_channels, input_height, input_width]`. + will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. in_place(bool, Default False): Make the input and output of batch norm reuse memory. name(str|None): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. @@ -2657,7 +2658,6 @@ def batch_norm(input, or is_test to true, and the behavior is equivalent. In train mode, when setting use_global_stats True, the global mean and variance are also used during train period. - Returns: A Variable holding Tensor which is the result after applying batch normalization on the input, has same shape and data type with input. @@ -2770,8 +2770,8 @@ def batch_norm(input, reserve_space = helper.create_variable_for_type_inference( dtype=core.VarDesc.VarType.FP16, stop_gradient=True) - batch_norm_out = input if in_place else helper.create_variable_for_type_inference( - dtype) + batch_norm_out = input if in_place else \ + helper.create_variable_for_type_inference(dtype) inputs = { "X": input, @@ -2809,6 +2809,209 @@ def batch_norm(input, return helper.append_activation(batch_norm_out) +def inplace_abn(input, + act=None, + is_test=False, + momentum=0.9, + epsilon=1e-05, + param_attr=None, + bias_attr=None, + data_layout='NCHW', + name=None, + moving_mean_name=None, + moving_variance_name=None, + do_model_average_for_mean_and_var=True, + use_global_stats=False, + act_alpha=1.0): + """ + **In-place Activation Batch Normalization Layer** + + This layer calculates batch normalization and activation with in-place memory. + For batch normalization calculations, see `fluid.layers.batch_norm`. + For in-place activation batch normalization, see `In-Place Activated BatchNorm for + Memory-Optimized Training of DNNs `_ + + `inplace_abn` only support activation type as `None`, `identity`, `leaky_relu`, + `elu` currently. + `inplace_abn` only support data type as `float32`, `float64` currently. + + Note: + if build_strategy.sync_batch_norm=True, the batch_norm in network will use + sync_batch_norm automatically. + `is_test = True` can only be used in test program and inference program, `is_test` CANNOT be set to True in train program, if you want to use global status from pre_train model in train program, please set `use_global_stats = True`. + + Args: + input(Variable): The rank of input variable can be 2, 3, 4, 5. The data type + is float16 or float32 or float64. + act(string, Default None): Activation type, linear|relu|prelu|... + is_test (bool, Default False): A flag indicating whether it is in + test phrase or not. + momentum(float|Variable, Default 0.9): The value used for the moving_mean and + moving_var computation. This should be a float number or a Variable with + shape [1] and data type as float32. The updated formula is: + :math:`moving\_mean = moving\_mean * momentum + new\_mean * (1. - momentum)` + :math:`moving\_var = moving\_var * momentum + new\_var * (1. - momentum)` + Default is 0.9. + epsilon(float, Default 1e-05): A value added to the denominator for + numerical stability. Default is 1e-5. + param_attr(ParamAttr|None): The parameter attribute for Parameter `scale` + of inplace_abn. If it is set to None or one attribute of ParamAttr, inplace_abn + will create ParamAttr as param_attr, the name of scale can be set in ParamAttr. + If the Initializer of the param_attr is not set, the parameter is initialized + with Xavier. Default: None. + bias_attr(ParamAttr|None): The parameter attribute for the bias of inplace_abn. + If it is set to None or one attribute of ParamAttr, inplace_abn + will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr. + If the Initializer of the bias_attr is not set, the bias is initialized zero. + Default: None. + data_layout (str, optional): Specify the data format of the input, and the data format of the output + will be consistent with that of the input. An optional string from: `"NCHW"`, `"NHWC"`. + The default is `"NCHW"`. When it is `"NCHW"`, the data is stored in the order of: + `[batch_size, input_channels, input_height, input_width]`. + name(str|None): For detailed information, please refer to :ref:`api_guide_Name`. + Usually name is no need to set and None by default. + moving_mean_name(str, Default None): The name of moving_mean which store the global Mean. If it + is set to None, inplace_abn will save global mean with a random name, otherwise, inplace_abn + will save global mean with the string. + moving_variance_name(str, Default None): The name of the moving_variance which store the global Variance. + If it is set to None, inplace_abn, will save global variance with a random name, otherwise, inplace_abn + will save global variance with the string. + do_model_average_for_mean_and_var(bool, Default True): Whether parameter mean and variance should do model + average when model average is enabled. + use_global_stats(bool, Default False): Whether to use global mean and + variance. In inference or test mode, set use_global_stats to true + or is_test to true, and the behavior is equivalent. + In train mode, when setting use_global_stats True, the global mean + and variance are also used during train period. + act_alpha(float, Default 1.0): when activation is in ['elu', 'identity', 'leaky_relu'], + inplace activative batch normalization will be used, and alpha parameter for activation + can be given by this parameter. + Returns: + A Variable holding Tensor which is the result after applying batch normalization and activation on the input, + has same shape and data type with input. + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32') + hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w') + hidden2 = fluid.layers.inplace_abn(input=hidden1) + hidden3 = fluid.layers.inplace_abn(input=hidden2, act='leaky_relu', act_alpha=0.2) + + """ + assert act in [None, 'identity', 'leaky_relu', 'elu'], \ + "inplace_abn only support act as None, 'identity', " \ + "'leaky_relu', 'elu' currently" + assert bias_attr is not False, "bias_attr should not be False in inplace_abn." + helper = LayerHelper('inplace_abn', **locals()) + + check_variable_and_dtype(input, 'input', ['float32', 'float64'], + 'inplace_abn') + dtype = helper.input_dtype() + + has_reserve_space = False + if data_layout == 'NHWC': + flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent') + if flag is not None and flag.lower() in ['true', '1']: + has_reserve_space = True + + input_shape = input.shape + if data_layout == 'NCHW': + channel_num = input_shape[1] + else: + if data_layout == 'NHWC': + channel_num = input_shape[-1] + else: + raise ValueError("unsupported data layout:" + data_layout) + + param_shape = [channel_num] + + # create parameter + scale = helper.create_parameter( + attr=helper.param_attr, + shape=param_shape, + dtype=dtype, + default_initializer=Constant(1.0)) + bias = helper.create_parameter( + attr=helper.bias_attr, shape=param_shape, dtype=dtype, is_bias=True) + + mean = helper.create_parameter( + attr=ParamAttr( + name=moving_mean_name, + initializer=Constant(0.0), + trainable=False, + do_model_average=do_model_average_for_mean_and_var), + shape=param_shape, + dtype=dtype) + mean.stop_gradient = True + + variance = helper.create_parameter( + attr=ParamAttr( + name=moving_variance_name, + initializer=Constant(1.0), + trainable=False, + do_model_average=do_model_average_for_mean_and_var), + shape=param_shape, + dtype=dtype) + variance.stop_gradient = True + + # create output + # mean and mean_out share the same memory + mean_out = mean + # variance and variance out share the same memory + variance_out = variance + saved_mean = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + saved_variance = helper.create_variable_for_type_inference( + dtype=dtype, stop_gradient=True) + + reserve_space = None + if has_reserve_space: + reserve_space = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.FP16, stop_gradient=True) + + batch_norm_out = input + + inputs = { + "X": input, + "Scale": scale, + "Bias": bias, + "Mean": mean, + "Variance": variance + } + attrs = { + "epsilon": epsilon, + "is_test": is_test, + "data_layout": data_layout, + "use_mkldnn": False, + "fuse_with_relu": False, + "use_global_stats": use_global_stats, + "activation": act, + "alpha": act_alpha, + } + if isinstance(momentum, Variable): + inputs['MomemtumTensor'] = momentum + else: + attrs['momentum'] = momentum + + outputs = { + "Y": batch_norm_out, + "MeanOut": mean_out, + "VarianceOut": variance_out, + "SavedMean": saved_mean, + "SavedVariance": saved_variance + } + if reserve_space is not None: + outputs["ReserveSpace"] = reserve_space + + helper.append_op( + type="inplace_abn", inputs=inputs, outputs=outputs, attrs=attrs) + + return batch_norm_out + + def instance_norm(input, epsilon=1e-05, param_attr=None, diff --git a/python/paddle/fluid/nets.py b/python/paddle/fluid/nets.py index 118b9d60e3b..21def94ad1e 100644 --- a/python/paddle/fluid/nets.py +++ b/python/paddle/fluid/nets.py @@ -234,7 +234,7 @@ def img_conv_group(input, use_cudnn=use_cudnn) if conv_with_batchnorm[i]: - tmp = layers.batch_norm(input=tmp, act=conv_act, in_place=True) + tmp = layers.batch_norm(input=tmp, act=conv_act) drop_rate = conv_batchnorm_drop_rate[i] if abs(drop_rate) > 1e-5: tmp = layers.dropout(x=tmp, dropout_prob=drop_rate) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 4ac3fff5255..763b04d795e 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -360,7 +360,7 @@ set_tests_properties(test_parallel_executor_test_while_train test_parallel_execu test_fetch_unmerged test_buffer_shared_memory_reuse_pass PROPERTIES LABELS "RUN_TYPE=DIST") -set_tests_properties(test_parallel_executor_crf test_sync_batch_norm_op +set_tests_properties(test_parallel_executor_crf test_sync_batch_norm_op test_inplace_abn_op test_parallel_executor_seresnext_base_gpu test_parallel_executor_seresnext_with_reduce_gpu test_parallel_executor_seresnext_with_fuse_all_reduce_gpu diff --git a/python/paddle/fluid/tests/unittests/test_inplace_abn_op.py b/python/paddle/fluid/tests/unittests/test_inplace_abn_op.py new file mode 100644 index 00000000000..7b92f6f02c6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_inplace_abn_op.py @@ -0,0 +1,189 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import os +import six +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid import compiler +import paddle.fluid.unique_name as unique_name + + +class TestInplaceANBOpTraining(unittest.TestCase): + def setUp(self): + self.dtype = np.float64 + self.N = 4 + self.C = 5 + self.H = 7 + self.W = 9 + self.dshape = [self.N, self.C, self.H, self.W] + + def build_program(self, + place, + layout, + seed, + only_forward=False, + activation="identity", + alpha=1.0, + use_cuda=False, + inplace=False): + main = fluid.Program() + startup = fluid.Program() + main.random_seed = seed + startup.random_seed = seed + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + data = fluid.layers.data( + name='input', + shape=self.dshape, + dtype=self.dtype, + append_batch_size=False, + stop_gradient=False) + if inplace: + bn = fluid.layers.inplace_abn( + data, + act=activation, + param_attr=fluid.ParamAttr(name='bn_scale'), + bias_attr=fluid.ParamAttr(name='bn_bias'), + moving_mean_name='bn_moving_mean', + moving_variance_name='bn_moving_variance', + data_layout=layout, + is_test=only_forward, + act_alpha=alpha) + else: + bn = fluid.layers.batch_norm( + data, + param_attr=fluid.ParamAttr(name='bn_scale'), + bias_attr=fluid.ParamAttr(name='bn_bias'), + moving_mean_name='bn_moving_mean', + moving_variance_name='bn_moving_variance', + data_layout=layout, + is_test=only_forward, + in_place=inplace) + if activation == 'leaky_relu': + bn = fluid.layers.leaky_relu(bn, alpha) + if activation == 'elu': + bn = fluid.layers.elu(bn, alpha) + + # NOTE: in inplace mode input and output of bn + # may have same name, multiply 1. to generate + # a new Variable for fetch + bn = bn * 1. + + sigmoid = fluid.layers.sigmoid(bn) + out = fluid.layers.reduce_sum(sigmoid) + if not only_forward: + sgd_opt = fluid.optimizer.SGD(learning_rate=0.0) + sgd_opt.backward(out) + return main, startup, [out, bn] + + def compare(self, place, layout, only_forward, activation, alpha, use_cuda): + seed = 10 + os.environ['FLAGS_cudnn_deterministic'] = "1" + data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2 + + fetch_outs = [] + fetch_names = [] + for inplace in [False, True]: + main, startup, outs = self.build_program( + place, + layout, + seed, + only_forward, + activation, + alpha, + inplace=inplace) + exe = fluid.Executor(place) + exe.run(startup) + + fetch_name = [v.name for v in outs] + [ + 'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias' + ] + if not only_forward: + others = [ + 'inplace_abn_0.tmp_0' if inplace else 'batch_norm_0.tmp_0', + 'inplace_abn_0.tmp_1' if inplace else 'batch_norm_0.tmp_1', + 'bn_scale@GRAD', + 'bn_bias@GRAD', + 'input@GRAD', + ] + fetch_name += others + for nm in fetch_name: + fv = fluid.framework._get_var(str(nm), program=main) + fv.persistable = True + + build_strategy = fluid.BuildStrategy() + build_strategy.sync_batch_norm = use_cuda and \ + fluid.core.get_cuda_device_count() > 1 + build_strategy.enable_inplace = inplace + exec_strategy = fluid.ExecutionStrategy() + exec_strategy.num_threads = 1 if os.name == 'nt' else 0 + comp_prog1 = compiler.CompiledProgram(main).with_data_parallel( + outs[0].name if not only_forward else None, + build_strategy=build_strategy, + exec_strategy=exec_strategy) + bn_fetches = exe.run(program=comp_prog1, + feed={'input': data}, + fetch_list=fetch_name) + fetch_outs.append(bn_fetches) + fetch_names.append(fetch_name) + + for bn_val, inplace_abn_val, name1, name2 in zip(*(fetch_outs + + fetch_names)): + self.assertTrue( + np.allclose( + bn_val, inplace_abn_val, atol=1e-2), + "Output (" + name1 + ":" + name2 + + ") has diff on {} with {} layout and {} activation. \n".format( + place, layout, activation) + "\nBN " + str(bn_val) + + "\n" + "Inplace ABN " + str(inplace_abn_val)) + + def test_op(self): + use_cudas = [False, True] if core.is_compiled_with_cuda() else [False] + for use_cuda in use_cudas: + place = core.CUDAPlace(0) if use_cuda else core.CPUPlace() + layouts = ["NCHW", "NHWC"] + for layout in layouts: + for activation, alpha in zip([None, 'elu', 'leaky_relu'], + [0., 1., 0.02]): + for infer_only in [True, False]: + self.compare(place, layout, infer_only, activation, + alpha, use_cuda) + + def test_all_branches(self): + seed = 10 + os.environ['FLAGS_cudnn_deterministic'] = "1" + data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2 + use_cudas = [False, True] if core.is_compiled_with_cuda() else [False] + alpha = 0.1 + layouts = ["NCHW", "NHWC"] + for use_cuda in use_cudas: + place = core.CUDAPlace(0) if use_cuda else core.CPUPlace() + for layout in layouts: + for activation in ['identity', 'leaky_relu']: + main, startup, outs = self.build_program( + place, layout, seed, False, activation, alpha, use_cuda, + True) + exe = fluid.Executor(place) + exe.run(startup) + exe.run(program=main, feed={'input': data}) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 77b8b90f9ff..a2f8bc56404 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -2684,6 +2684,28 @@ class TestBook(LayerTest): out = layers.batch_norm(data, momentum=momentum) return (out) + def make_inplace_abn(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + data = self._get_data( + name='data', shape=[32, 128, 128], dtype="float32") + out = layers.inplace_abn(data, act='leaky_relu', act_alpha=0.2) + return (out) + + def make_inplace_abn_momentum_variable(self): + with program_guard(fluid.default_main_program(), + fluid.default_startup_program()): + data = self._get_data( + name='data', shape=[32, 128, 128], dtype="float32") + momentum = self._get_data( + name='momentum', + shape=[1], + dtype='float32', + append_batch_size=False) + out = layers.inplace_abn( + data, momentum=momentum, act='elu', act_alpha=2.0) + return (out) + def make_range(self): with program_guard(fluid.default_main_program(), fluid.default_startup_program()): -- GitLab