diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index ae970acc272b077692ce258f1111ed5b8cf6f1cc..5d27f5b60c7115a32aeeca5ec2a6654471c310c7 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -83,7 +83,7 @@ class BatchNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const ExecutionContext &ctx) const override { + const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::ToDataType(ctx.Input("X")->type()); // For float or float16 input tensor, the type of the scale, bias, mean, diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index 2de935d087a4bb902a3e553a580421e3a0665e8f..6ceacc39924a7558e380aaf563aaf234f1bf30a5 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -28,7 +28,7 @@ using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; template -using ScalingParamType = typename CudnnDataType::ScalingParamType; +using BatchNormParamType = typename CudnnDataType::BatchNormParamType; void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, int *N, int *C, int *H, int *W, int *D) { @@ -122,15 +122,16 @@ class BatchNormKernel // alloc memory y->mutable_data(ctx.GetPlace()); - mean_out->mutable_data>(ctx.GetPlace()); - variance_out->mutable_data>(ctx.GetPlace()); - saved_mean->mutable_data>(ctx.GetPlace()); - saved_variance->mutable_data>(ctx.GetPlace()); + mean_out->mutable_data>(ctx.GetPlace()); + variance_out->mutable_data>(ctx.GetPlace()); + saved_mean->mutable_data>(ctx.GetPlace()); + saved_variance->mutable_data>(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); - math::SetConstant> functor; - functor(dev_ctx, saved_mean, static_cast>(0)); - functor(dev_ctx, saved_variance, static_cast>(0)); + math::SetConstant> + functor; + functor(dev_ctx, saved_mean, static_cast>(0)); + functor(dev_ctx, saved_variance, static_cast>(0)); auto handle = dev_ctx.cudnn_handle(); @@ -151,10 +152,10 @@ class BatchNormKernel CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), - bn_param_desc_, scale->template data>(), - bias->template data>(), - est_mean->template data>(), - est_var->template data>(), epsilon)); + bn_param_desc_, scale->template data>(), + bias->template data>(), + est_mean->template data>(), + est_var->template data>(), epsilon)); } else { // Run training mode. // obtain running mean and running inv var, and see if we need to @@ -165,14 +166,15 @@ class BatchNormKernel handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data>(), - bias->template data>(), this_factor, - mean_out->template mutable_data>(ctx.GetPlace()), - variance_out->template mutable_data>( + scale->template data>(), + bias->template data>(), this_factor, + mean_out->template mutable_data>( ctx.GetPlace()), - epsilon, saved_mean->template mutable_data>( + variance_out->template mutable_data>( + ctx.GetPlace()), + epsilon, saved_mean->template mutable_data>( ctx.GetPlace()), - saved_variance->template mutable_data>( + saved_variance->template mutable_data>( ctx.GetPlace()))); } diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 7e001ecc56173db76e8c576e7efd66f41192f292..7c604e14eb245232ed92f53a00b9bde45c2fbaec 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -86,7 +86,8 @@ class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_HALF; // The scaling param type is float for HALF and FLOAT tensors - typedef const float ScalingParamType; + using ScalingParamType = const float; + using BatchNormParamType = float; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v; @@ -101,7 +102,8 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; - typedef const float ScalingParamType; + using ScalingParamType = const float; + using BatchNormParamType = float; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v; @@ -116,7 +118,8 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; - typedef const double ScalingParamType; + using ScalingParamType = const double; + using BatchNormParamType = double; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v;