From 446d54f5c32d8cf15ad83ba71783f92b19621931 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sun, 18 Mar 2018 20:24:11 -0700 Subject: [PATCH] update --- paddle/fluid/operators/batch_norm_op.cc | 2 +- paddle/fluid/operators/batch_norm_op.cu.cc | 38 ++++++++++++---------- paddle/fluid/platform/cudnn_helper.h | 9 +++-- 3 files changed, 27 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index ae970acc27..5d27f5b60c 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -83,7 +83,7 @@ class BatchNormOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const ExecutionContext &ctx) const override { + const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::ToDataType(ctx.Input("X")->type()); // For float or float16 input tensor, the type of the scale, bias, mean, diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index 2de935d087..6ceacc3992 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -28,7 +28,7 @@ using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; template -using ScalingParamType = typename CudnnDataType::ScalingParamType; +using BatchNormParamType = typename CudnnDataType::BatchNormParamType; void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, int *N, int *C, int *H, int *W, int *D) { @@ -122,15 +122,16 @@ class BatchNormKernel // alloc memory y->mutable_data(ctx.GetPlace()); - mean_out->mutable_data>(ctx.GetPlace()); - variance_out->mutable_data>(ctx.GetPlace()); - saved_mean->mutable_data>(ctx.GetPlace()); - saved_variance->mutable_data>(ctx.GetPlace()); + mean_out->mutable_data>(ctx.GetPlace()); + variance_out->mutable_data>(ctx.GetPlace()); + saved_mean->mutable_data>(ctx.GetPlace()); + saved_variance->mutable_data>(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); - math::SetConstant> functor; - functor(dev_ctx, saved_mean, static_cast>(0)); - functor(dev_ctx, saved_variance, static_cast>(0)); + math::SetConstant> + functor; + functor(dev_ctx, saved_mean, static_cast>(0)); + functor(dev_ctx, saved_variance, static_cast>(0)); auto handle = dev_ctx.cudnn_handle(); @@ -151,10 +152,10 @@ class BatchNormKernel CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), - bn_param_desc_, scale->template data>(), - bias->template data>(), - est_mean->template data>(), - est_var->template data>(), epsilon)); + bn_param_desc_, scale->template data>(), + bias->template data>(), + est_mean->template data>(), + est_var->template data>(), epsilon)); } else { // Run training mode. // obtain running mean and running inv var, and see if we need to @@ -165,14 +166,15 @@ class BatchNormKernel handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data>(), - bias->template data>(), this_factor, - mean_out->template mutable_data>(ctx.GetPlace()), - variance_out->template mutable_data>( + scale->template data>(), + bias->template data>(), this_factor, + mean_out->template mutable_data>( ctx.GetPlace()), - epsilon, saved_mean->template mutable_data>( + variance_out->template mutable_data>( + ctx.GetPlace()), + epsilon, saved_mean->template mutable_data>( ctx.GetPlace()), - saved_variance->template mutable_data>( + saved_variance->template mutable_data>( ctx.GetPlace()))); } diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 7e001ecc56..7c604e14eb 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -86,7 +86,8 @@ class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_HALF; // The scaling param type is float for HALF and FLOAT tensors - typedef const float ScalingParamType; + using ScalingParamType = const float; + using BatchNormParamType = float; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v; @@ -101,7 +102,8 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; - typedef const float ScalingParamType; + using ScalingParamType = const float; + using BatchNormParamType = float; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v; @@ -116,7 +118,8 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; - typedef const double ScalingParamType; + using ScalingParamType = const double; + using BatchNormParamType = double; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v; -- GitLab