From 5e813b53c5359bc175b640b6f50b2945b386706a Mon Sep 17 00:00:00 2001 From: Jie Fang Date: Sun, 1 Dec 2019 10:21:42 +0800 Subject: [PATCH] nhwc optimization for batchnorm (#21090) --- paddle/fluid/framework/grad_op_desc_maker.h | 4 + paddle/fluid/imperative/dygraph_grad_maker.h | 6 + paddle/fluid/operators/batch_norm_op.cc | 60 ++-- paddle/fluid/operators/batch_norm_op.cu | 284 +++++++++++++++--- paddle/fluid/operators/batch_norm_op.h | 99 +++++- paddle/fluid/platform/dynload/cudnn.cc | 4 + paddle/fluid/platform/dynload/cudnn.h | 9 + python/paddle/fluid/layers/nn.py | 34 ++- .../tests/unittests/test_batch_norm_op.py | 38 ++- 9 files changed, 461 insertions(+), 77 deletions(-) diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index 5fda027453c..a98e638ce16 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -141,6 +141,10 @@ class GradOpDescMakerBase { return (fwd_op_.Inputs().count(name) > 0); } + bool HasOutput(const std::string& name) const { + return (fwd_op_.Outputs().count(name) > 0); + } + private: const OpDesc& fwd_op_; const std::unordered_set& no_grad_set_; diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index c2107f0a89f..b1f4bb1b6c9 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -107,6 +107,12 @@ class GradOpBaseMakerBase { return it != var_base_map_in_.end(); } + bool HasOutput(const std::string name) const { + auto it = var_base_map_out_.find(name); + + return it != var_base_map_out_.end(); + } + private: std::vector> GetVarBaseList(const std::string& name, bool is_grad, diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 6870e65ef5a..5e2d02332f2 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -25,27 +25,42 @@ namespace paddle { namespace operators { void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ConvOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Scale"), - "Input(Scale) of ConvOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Bias"), - "Input(Bias) of ConvOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Mean"), - "Input(Mean) of ConvOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Variance"), - "Input(Variance) of ConvOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Y"), - "Output(Y) of ConvOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + platform::errors::InvalidArgument( + "Input(X) of BatchNormOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Scale"), true, + platform::errors::InvalidArgument( + "Input(Scale) of BatchNormOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Bias"), true, + platform::errors::InvalidArgument( + "Input(Bias) of BatchNormOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Mean"), true, + platform::errors::InvalidArgument( + "Input(Mean) of BatchNormOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasInput("Variance"), true, + platform::errors::InvalidArgument( + "Input(Variance) of BatchNormOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true, + platform::errors::InvalidArgument( + "Output(Y) of BatchNormOp should not be null.")); bool is_test = ctx->Attrs().Get("is_test"); if (!is_test) { - PADDLE_ENFORCE(ctx->HasOutput("MeanOut"), - "Output(MeanOut) of ConvOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("VarianceOut"), - "Output(VarianceOut) of ConvOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("SavedMean"), - "Output(SavedMean) of ConvOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("SavedVariance"), - "Output(SavedVariance) of ConvOp should not be null."); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("MeanOut"), true, + platform::errors::InvalidArgument( + "Output(MeanOut) of BatchNormOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("VarianceOut"), true, + platform::errors::InvalidArgument( + "Output(VarianceOut) of BatchNormOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("SavedMean"), true, + platform::errors::InvalidArgument( + "Output(SavedMean) of BatchNormOp should not be null.")); + PADDLE_ENFORCE_EQ( + ctx->HasOutput("SavedVariance"), true, + platform::errors::InvalidArgument( + "Output(SavedVariance) of BatchNormOp should not be null.")); } // make sure Mean/MeanOut and Variance/VarianceOut share memory in Python @@ -200,6 +215,10 @@ void BatchNormOpMaker::Make() { "Variance of the current mini batch, " "will apply to output when training") .AsIntermediate(); + AddOutput("ReserveSpace", + "Reserve GPU space for triggering the new semi-persistent " + "NHWC kernel") + .AsDispensable(); AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); @@ -643,6 +662,9 @@ std::unique_ptr BatchNormGradMaker::Apply() const { op->SetInput("Bias", this->Input("Bias")); op->SetInput("SavedMean", this->Output("SavedMean")); op->SetInput("SavedVariance", this->Output("SavedVariance")); + if (this->HasOutput("ReserveSpace")) { + op->SetInput("ReserveSpace", this->Output("ReserveSpace")); + } // used when setting use_global_stats True during training if (boost::get(this->GetAttr("use_global_stats"))) { diff --git a/paddle/fluid/operators/batch_norm_op.cu b/paddle/fluid/operators/batch_norm_op.cu index 7034fcf0ec7..f8e2d9f393d 100644 --- a/paddle/fluid/operators/batch_norm_op.cu +++ b/paddle/fluid/operators/batch_norm_op.cu @@ -56,12 +56,39 @@ class BatchNormKernel const auto &x_dims = x->dims(); PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, "The Input dim size should be between 2 and 5"); - int N, C, H, W, D; - ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); auto *y = ctx.Output("Y"); y->mutable_data(ctx.GetPlace()); + int N, C, H, W, D; + ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D); + + auto dtype = platform::CudnnDataType::type; + const bool fast_nhwc_batch_norm = + is_test || + (dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent); + + auto compute_format = + fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC + ? DataLayout::kNHWC + : DataLayout::kNCHW; + + Tensor transformed_x(x->type()); + Tensor transformed_y(y->type()); + if (data_layout == DataLayout::kNHWC && + compute_format == DataLayout::kNCHW && x_dims.size() > 2) { + VLOG(3) << "Transform input tensor from NHWC to NCHW."; + ResizeToChannelFirst(ctx, x, + &transformed_x); + TransToChannelFirst(ctx, x, + &transformed_x); + ResizeToChannelFirst(ctx, y, + &transformed_y); + } else { + transformed_x.ShareDataWith(*x); + transformed_y.ShareDataWith(*y); + } + // ------------------- cudnn descriptors --------------------- cudnnTensorDescriptor_t data_desc_; cudnnTensorDescriptor_t bn_param_desc_; @@ -90,7 +117,7 @@ class BatchNormKernel VLOG(3) << "Setting descriptors."; std::vector dims; std::vector strides; - if (data_layout == DataLayout::kNCHW) { + if (compute_format == DataLayout::kNCHW) { dims = {N, C, H, W, D}; strides = {C * H * W * D, H * W * D, W * D, D, 1}; } else { @@ -126,8 +153,9 @@ class BatchNormKernel handle, // Note: PERSISTENT not implemented for inference CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), - CudnnDataType::kZero(), data_desc_, x->template data(), - data_desc_, y->template mutable_data(ctx.GetPlace()), + CudnnDataType::kZero(), data_desc_, + transformed_x.template data(), data_desc_, + transformed_y.template mutable_data(ctx.GetPlace()), bn_param_desc_, scale->template data>(), bias->template data>(), est_mean->template data>(), @@ -167,23 +195,102 @@ class BatchNormKernel } else { double this_factor = 1. - momentum; - CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining( - handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), - data_desc_, x->template data(), data_desc_, - y->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data>(), - bias->template data>(), this_factor, - mean_out->template mutable_data>( - ctx.GetPlace()), - variance_out->template mutable_data>( - ctx.GetPlace()), - epsilon, saved_mean->template mutable_data>( - ctx.GetPlace()), - saved_variance->template mutable_data>( - ctx.GetPlace()))); + bool called = false; +#if CUDNN_VERSION_MIN(7, 4, 1) + if (compute_format == DataLayout::kNHWC) { + called = true; + size_t workspace_size = 0; + size_t reserve_space_size = 0; + void *reserve_space_ptr = nullptr; + void *workspace_ptr = nullptr; + Tensor workspace_tensor; + // Create reserve space and workspace for batch norm. + // Create tensor for each batchnorm op, it will be used in the + // backward. Thus this tensor shouldn't be temp. + auto *reserve_space = ctx.Output("ReserveSpace"); + PADDLE_ENFORCE_NOT_NULL( + reserve_space, + platform::errors::NotFound( + "The argument ReserveSpace of batch_norm op is not found.")); + + // --------------- cudnn batchnorm workspace --------------- + CUDNN_ENFORCE( + platform::dynload:: + cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize( + /*handle=*/handle, + /*mode=*/mode_, + /*bnIps=*/CUDNN_BATCHNORM_OPS_BN, + /*xDesc=*/data_desc_, + /*zDesc=*/nullptr, + /*yDesc=*/data_desc_, + /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, + /*activationDesc=*/nullptr, + /*sizeInBytes=*/&workspace_size)); + + // -------------- cudnn batchnorm reserve space -------------- + CUDNN_ENFORCE( + platform::dynload:: + cudnnGetBatchNormalizationTrainingExReserveSpaceSize( + /*handle=*/handle, + /*mode=*/mode_, + /*bnOps=*/CUDNN_BATCHNORM_OPS_BN, + /*activationDesc=*/nullptr, + /*xDesc=*/data_desc_, + /*sizeInBytes=*/&reserve_space_size)); + + reserve_space_ptr = reserve_space->mutable_data( + ctx.GetPlace(), transformed_x.type(), reserve_space_size); + workspace_ptr = workspace_tensor.mutable_data( + ctx.GetPlace(), transformed_x.type(), workspace_size); + CUDNN_ENFORCE( + platform::dynload::cudnnBatchNormalizationForwardTrainingEx( + handle, mode_, CUDNN_BATCHNORM_OPS_BN, + CudnnDataType::kOne(), CudnnDataType::kZero(), + data_desc_, transformed_x.template data(), nullptr, + nullptr, data_desc_, transformed_y.template data(), + bn_param_desc_, scale->template data>(), + bias->template data>(), this_factor, + mean_out->template mutable_data>( + ctx.GetPlace()), + variance_out->template mutable_data>( + ctx.GetPlace()), + epsilon, + saved_mean->template mutable_data>( + ctx.GetPlace()), + saved_variance->template mutable_data>( + ctx.GetPlace()), + nullptr, workspace_ptr, workspace_size, reserve_space_ptr, + reserve_space_size)); + } +#endif + if (!called) { + CUDNN_ENFORCE( + platform::dynload::cudnnBatchNormalizationForwardTraining( + handle, mode_, CudnnDataType::kOne(), + CudnnDataType::kZero(), data_desc_, + transformed_x.template data(), data_desc_, + transformed_y.template mutable_data(ctx.GetPlace()), + bn_param_desc_, scale->template data>(), + bias->template data>(), this_factor, + mean_out->template mutable_data>( + ctx.GetPlace()), + variance_out->template mutable_data>( + ctx.GetPlace()), + epsilon, + saved_mean->template mutable_data>( + ctx.GetPlace()), + saved_variance->template mutable_data>( + ctx.GetPlace()))); + } } } + if (data_layout == DataLayout::kNHWC && + compute_format == DataLayout::kNCHW && x_dims.size() > 2) { + VLOG(3) << "Transform batchnorm output from NCHW to NHWC"; + TransToChannelLast( + ctx, &transformed_y, y); + } // clean when exit. CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_)); CUDNN_ENFORCE( @@ -337,9 +444,41 @@ class BatchNormGradKernel PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL); PADDLE_ENFORCE_EQ(scale->dims()[0], C); + auto dtype = platform::CudnnDataType::type; + const auto *reserve_space = ctx.Input("ReserveSpace"); + const bool fast_nhwc_batch_norm = + dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent && + reserve_space != nullptr; + auto compute_format = + fast_nhwc_batch_norm && data_layout == DataLayout::kNHWC + ? DataLayout::kNHWC + : DataLayout::kNCHW; + + Tensor transformed_x(x->type()); + Tensor transformed_d_y(d_y->type()); + Tensor transformed_d_x(d_x->type()); + if (data_layout == DataLayout::kNHWC && + compute_format == DataLayout::kNCHW) { + VLOG(3) << "Transform input tensor from NHWC to NCHW."; + ResizeToChannelFirst(ctx, x, + &transformed_x); + TransToChannelFirst(ctx, x, + &transformed_x); + ResizeToChannelFirst(ctx, d_y, + &transformed_d_y); + TransToChannelFirst(ctx, d_y, + &transformed_d_y); + ResizeToChannelFirst(ctx, d_x, + &transformed_d_x); + } else { + transformed_x.ShareDataWith(*x); + transformed_d_y.ShareDataWith(*d_y); + transformed_d_x.ShareDataWith(*d_x); + } + std::vector dims; std::vector strides; - if (data_layout == DataLayout::kNCHW) { + if (compute_format == DataLayout::kNCHW) { dims = {N, C, H, W, D}; strides = {C * H * W * D, H * W * D, W * D, D, 1}; } else { @@ -348,7 +487,7 @@ class BatchNormGradKernel } auto &dev_ctx = ctx.template device_context(); - const int num = x->numel(); + const int num = transformed_x.numel(); const int block = 512; int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); const int max_blocks = std::max(max_threads / block, 1); @@ -404,20 +543,95 @@ class BatchNormGradKernel saved_var->template data>(); if (d_scale && d_bias) { - CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( - dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(), - CudnnDataType::kZero(), CudnnDataType::kOne(), - CudnnDataType::kZero(), data_desc_, x->template data(), - data_desc_, d_y->template data(), data_desc_, - d_x->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data>(), - d_scale->template mutable_data>( - ctx.GetPlace()), - d_bias->template mutable_data>( - ctx.GetPlace()), - epsilon, saved_mean_data, saved_var_data)); + bool called = false; +#if CUDNN_VERSION_MIN(7, 4, 1) + if (compute_format == DataLayout::kNHWC) { + called = true; + size_t workspace_size = 0; + void *workspace_ptr = nullptr; + Tensor workspace_tensor; + auto reserve_space_size = reserve_space->memory_size(); + // --------------- cudnn batchnorm workspace --------------- + CUDNN_ENFORCE(platform::dynload:: + cudnnGetBatchNormalizationBackwardExWorkspaceSize( + /*handle=*/dev_ctx.cudnn_handle(), + /*mode=*/mode_, + /*bnIps=*/CUDNN_BATCHNORM_OPS_BN, + /*xDesc=*/data_desc_, + /*yDesc=*/data_desc_, + /*dyDesc=*/data_desc_, + /*dzDesc=*/nullptr, + /*dxDesc=*/data_desc_, + /*bnScaleBiasMeanVarDesc=*/bn_param_desc_, + /*activationDesc=*/nullptr, + /*sizeInBytes=*/&workspace_size)); + + workspace_ptr = workspace_tensor.mutable_data( + ctx.GetPlace(), transformed_x.type(), workspace_size); + + CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackwardEx( + /*handle=*/dev_ctx.cudnn_handle(), + /*mode=*/mode_, + /*bnOps=*/CUDNN_BATCHNORM_OPS_BN, + /*alphaDataDiff=*/CudnnDataType::kOne(), + /*betaDataDiff=*/CudnnDataType::kZero(), + /*alphaParamDiff=*/CudnnDataType::kOne(), + /*betaParamDiff=*/CudnnDataType::kZero(), + /*xDesc=*/data_desc_, + /*xData=*/transformed_x.template data(), + /*yDesc=*/nullptr, + /*yData=*/nullptr, + /*dyDesc=*/data_desc_, + /*dyData=*/transformed_d_y.template data(), + /*dzDesc=*/nullptr, + /*dzData=*/nullptr, + /*dxDesc=*/data_desc_, + /*dxData=*/transformed_d_x.template mutable_data( + ctx.GetPlace()), + /*dBnScaleBiasDesc=*/bn_param_desc_, + /*bnScaleData=*/scale->template data>(), + /*bnBiasData=*/nullptr, + /*dBnScaleData=*/d_scale + ->template mutable_data>( + ctx.GetPlace()), + /*dBnBiasData=*/d_bias + ->template mutable_data>( + ctx.GetPlace()), + /*epsilon=*/epsilon, + /*savedMean=*/saved_mean_data, + /*savedInvVariance=*/saved_var_data, + /*activationDesc=*/nullptr, + /*workspace=*/workspace_ptr, + /*workSpaceSizeInBytes=*/workspace_size, + /*reserveSpace=*/const_cast( + reserve_space->template data()), + /*reserveSpaceSizeInBytes=*/reserve_space_size)); + } +#endif + if (!called) { + CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward( + dev_ctx.cudnn_handle(), mode_, CudnnDataType::kOne(), + CudnnDataType::kZero(), CudnnDataType::kOne(), + CudnnDataType::kZero(), data_desc_, + transformed_x.template data(), data_desc_, + transformed_d_y.template data(), data_desc_, + transformed_d_x.template mutable_data(ctx.GetPlace()), + bn_param_desc_, scale->template data>(), + d_scale->template mutable_data>( + ctx.GetPlace()), + d_bias->template mutable_data>( + ctx.GetPlace()), + epsilon, saved_mean_data, saved_var_data)); + } + + if (data_layout == DataLayout::kNHWC && + compute_format == DataLayout::kNCHW) { + VLOG(3) << "Transform batchnorm output from NCHW to NHWC"; + TransToChannelLast( + ctx, &transformed_d_x, d_x); + } } else { - if (data_layout == framework::DataLayout::kNCHW) { + if (compute_format == DataLayout::kNCHW) { if (d_x) { BNBackwardData<<< grid2, block, 0, dev_ctx.stream()>>>( @@ -450,7 +664,7 @@ class BatchNormGradKernel const auto *running_var_data = running_var->template data>(); - if (data_layout == framework::DataLayout::kNCHW) { + if (compute_format == DataLayout::kNCHW) { if (d_x) { KeBNBackwardData<<< grid1, block, 0, dev_ctx.stream()>>>( diff --git a/paddle/fluid/operators/batch_norm_op.h b/paddle/fluid/operators/batch_norm_op.h index 9678e2d3f2b..cb02194b6de 100644 --- a/paddle/fluid/operators/batch_norm_op.h +++ b/paddle/fluid/operators/batch_norm_op.h @@ -16,8 +16,10 @@ limitations under the License. */ #include #include #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/norm_utils.h" namespace paddle { @@ -39,24 +41,109 @@ template using ConstEigenVectorArrayMap = Eigen::Map>; +template +inline void ResizeToChannelFirst(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[4]; + in_dims_vec[2] = input->dims()[1]; + in_dims_vec[3] = input->dims()[2]; + in_dims_vec[4] = input->dims()[3]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + + } else if (dim == 2) { + // input + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[3]; + in_dims_vec[2] = input->dims()[1]; + in_dims_vec[3] = input->dims()[2]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } else if (dim == 1) { + transformed_input->Resize(input->dims()); + + auto in_dims_vec = framework::vectorize(input->dims()); + in_dims_vec[1] = input->dims()[2]; + in_dims_vec[2] = input->dims()[1]; + transformed_input->Resize(framework::make_ddim(in_dims_vec)); + transformed_input->mutable_data(context.GetPlace()); + } +} + +template +inline void TransToChannelFirst(const framework::ExecutionContext& context, + const Tensor* input, + Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 4, 1, 2, 3}; + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + + } else if (dim == 2) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 3, 1, 2}; + math::Transpose trans4; + trans4(dev_ctx, *input, transformed_input, axis); + } else if (dim == 1) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 1}; + math::Transpose trans3; + trans3(dev_ctx, *input, transformed_input, axis); + } +} + +template +inline void TransToChannelLast(const framework::ExecutionContext& context, + const Tensor* input, Tensor* transformed_input) { + int dim = input->dims().size() - 2; + if (dim == 3) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 3, 4, 1}; + math::Transpose trans5; + trans5(dev_ctx, *input, transformed_input, axis); + + } else if (dim == 2) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 3, 1}; + math::Transpose trans4; + trans4(dev_ctx, *input, transformed_input, axis); + } else if (dim == 1) { + auto& dev_ctx = context.template device_context(); + std::vector axis{0, 2, 1}; + math::Transpose trans3; + trans3(dev_ctx, *input, transformed_input, axis); + } +} + class BatchNormOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override; + void InferShape(framework::InferShapeContext* ctx) const override; protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override; + const framework::ExecutionContext& ctx) const override; }; class BatchNormGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override; + void InferShape(framework::InferShapeContext* ctx) const override; protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override; + const framework::ExecutionContext& ctx) const override; }; class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { @@ -85,13 +172,13 @@ class BatchNormOpInferVarType template class BatchNormKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override; + void Compute(const framework::ExecutionContext& ctx) const override; }; template class BatchNormGradKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override; + void Compute(const framework::ExecutionContext& ctx) const override; }; } // namespace operators diff --git a/paddle/fluid/platform/dynload/cudnn.cc b/paddle/fluid/platform/dynload/cudnn.cc index 91d9a1ef013..edff8761ee1 100644 --- a/paddle/fluid/platform/dynload/cudnn.cc +++ b/paddle/fluid/platform/dynload/cudnn.cc @@ -46,6 +46,10 @@ CUDNN_DNN_ROUTINE_EACH_R6(DEFINE_WRAP); CUDNN_DNN_ROUTINE_EACH_R7(DEFINE_WRAP); #endif +#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R7 +CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DEFINE_WRAP); +#endif + #ifdef PADDLE_USE_DSO bool HasCUDNN() { std::call_once(cudnn_dso_flag, diff --git a/paddle/fluid/platform/dynload/cudnn.h b/paddle/fluid/platform/dynload/cudnn.h index 0f743801582..bec5ceb1f47 100644 --- a/paddle/fluid/platform/dynload/cudnn.h +++ b/paddle/fluid/platform/dynload/cudnn.h @@ -189,6 +189,15 @@ CUDNN_DNN_ROUTINE_EACH_R6(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) #endif +#if CUDNN_VERSION >= 7401 +#define CUDNN_DNN_ROUTINE_EACH_AFTER_R7(__macro) \ + __macro(cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize); \ + __macro(cudnnBatchNormalizationForwardTrainingEx); \ + __macro(cudnnGetBatchNormalizationBackwardExWorkspaceSize); \ + __macro(cudnnBatchNormalizationBackwardEx); \ + __macro(cudnnGetBatchNormalizationTrainingExReserveSpaceSize); +CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP) +#endif } // namespace dynload } // namespace platform } // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 784a9d54b98..f7ec2453f2c 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2523,6 +2523,13 @@ def batch_norm(input, check_type_and_dtype(input, 'input', Variable, ['float16', 'float32', 'float64'], 'batch_norm') dtype = helper.input_dtype() + + has_reserve_space = False + if data_layout == 'NHWC': + flag = os.environ.get('FLAGS_cudnn_batchnorm_spatial_persistent') + if flag is not None and flag.lower() in ['true', '1']: + has_reserve_space = True + # use fp32 for bn parameter if dtype == core.VarDesc.VarType.FP16: dtype = core.VarDesc.VarType.FP32 @@ -2577,6 +2584,11 @@ def batch_norm(input, saved_variance = helper.create_variable_for_type_inference( dtype=dtype, stop_gradient=True) + reserve_space = None + if has_reserve_space: + reserve_space = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.FP16, stop_gradient=True) + batch_norm_out = input if in_place else helper.create_variable_for_type_inference( dtype) @@ -2599,17 +2611,19 @@ def batch_norm(input, inputs['MomemtumTensor'] = momentum else: attrs['momentum'] = momentum + + outputs = { + "Y": batch_norm_out, + "MeanOut": mean_out, + "VarianceOut": variance_out, + "SavedMean": saved_mean, + "SavedVariance": saved_variance + } + if reserve_space is not None: + outputs["ReserveSpace"] = reserve_space + helper.append_op( - type="batch_norm", - inputs=inputs, - outputs={ - "Y": batch_norm_out, - "MeanOut": mean_out, - "VarianceOut": variance_out, - "SavedMean": saved_mean, - "SavedVariance": saved_variance - }, - attrs=attrs) + type="batch_norm", inputs=inputs, outputs=outputs, attrs=attrs) return helper.append_activation(batch_norm_out) diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 2d9f38acb89..5e3c41441db 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -14,6 +14,7 @@ from __future__ import print_function +import os import unittest import numpy as np import paddle.fluid.core as core @@ -413,16 +414,28 @@ class TestBatchNormOpTraining(unittest.TestCase): inputs['MomentumTensor'] = block.var('momentum_var') else: attrs['momentum'] = momentum + + outputs = { + "Y": block.var('y'), + "MeanOut": block.var('mean'), # share memory + "VarianceOut": block.var('variance'), # share memory + "SavedMean": block.var('saved_mean'), + "SavedVariance": block.var('saved_variance') + } + has_reserve_space = False + if data_format == 'NHWC': + flag = os.environ.get( + 'FLAGS_cudnn_batchnorm_spatial_persistent') + if flag is not None and flag.lower() in ['true', '1']: + has_reserve_space = True + if has_reserve_space: + block.create_var(name="reserve_space", dtype='float16') + outputs["ReserveSpace"] = block.var('reserve_space') + del os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] bn_op = block.append_op( type="batch_norm", inputs=inputs, - outputs={ - "Y": block.var('y'), - "MeanOut": block.var('mean'), # share memory - "VarianceOut": block.var('variance'), # share memory - "SavedMean": block.var('saved_mean'), - "SavedVariance": block.var('saved_variance') - }, + outputs=outputs, attrs=attrs) block.create_var(name='y@GRAD', dtype='float32', shape=y.shape) @@ -479,6 +492,17 @@ class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining): self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD'] +class TestBatchNormOpTrainingCase2(TestBatchNormOpTraining): + def init_test_case(self): + self.use_global_stats = False + self.no_grad_set = set() + self.fetch_list = [ + 'y', 'mean', 'variance', 'saved_mean', 'saved_variance', 'x@GRAD', + 'scale@GRAD', 'bias@GRAD' + ] + os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = "1" + + class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining): def init_test_case(self): self.use_momentum_variable = True -- GitLab