From ffa22a5f905daa7b7a4a9e4e6a1e3f17b9fea073 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sun, 18 Mar 2018 20:01:27 -0700 Subject: [PATCH] fix scaling param type --- paddle/fluid/operators/batch_norm_op.cu.cc | 40 ++++++++++++---------- paddle/fluid/platform/cudnn_helper.h | 5 --- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index 5e97678862..2de935d087 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -27,7 +28,7 @@ using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; template -using bn_param_type = CudnnDataType::bn_param_type; +using ScalingParamType = typename CudnnDataType::ScalingParamType; void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, int *N, int *C, int *H, int *W, int *D) { @@ -121,15 +122,15 @@ class BatchNormKernel // alloc memory y->mutable_data(ctx.GetPlace()); - mean_out->mutable_data>(ctx.GetPlace()); - variance_out->mutable_data>(ctx.GetPlace()); - saved_mean->mutable_data>(ctx.GetPlace()); - saved_variance->mutable_data>(ctx.GetPlace()); + mean_out->mutable_data>(ctx.GetPlace()); + variance_out->mutable_data>(ctx.GetPlace()); + saved_mean->mutable_data>(ctx.GetPlace()); + saved_variance->mutable_data>(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); - math::SetConstant> functor; - functor(dev_ctx, saved_mean, static_cast>(0)); - functor(dev_ctx, saved_variance, static_cast>(0)); + math::SetConstant> functor; + functor(dev_ctx, saved_mean, static_cast>(0)); + functor(dev_ctx, saved_variance, static_cast>(0)); auto handle = dev_ctx.cudnn_handle(); @@ -150,10 +151,10 @@ class BatchNormKernel CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), - bn_param_desc_, scale->template data>(), - bias->template data>(), - est_mean->template data>(), - est_var->template data>(), epsilon)); + bn_param_desc_, scale->template data>(), + bias->template data>(), + est_mean->template data>(), + est_var->template data>(), epsilon)); } else { // Run training mode. // obtain running mean and running inv var, and see if we need to @@ -164,13 +165,14 @@ class BatchNormKernel handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data>(), - bias->template data>(), this_factor, - mean_out->template mutable_data>(ctx.GetPlace()), - variance_out->template mutable_data>(ctx.GetPlace()), - epsilon, - saved_mean->template mutable_data>(ctx.GetPlace()), - saved_variance->template mutable_data>( + scale->template data>(), + bias->template data>(), this_factor, + mean_out->template mutable_data>(ctx.GetPlace()), + variance_out->template mutable_data>( + ctx.GetPlace()), + epsilon, saved_mean->template mutable_data>( + ctx.GetPlace()), + saved_variance->template mutable_data>( ctx.GetPlace()))); } diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index a40c366241..7e001ecc56 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -85,9 +85,6 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_HALF; - // cudnn batch norm requires that Scale, Bias, Mean, and Variance - // to be FLOAT tensors when the input x is HALF tensor - static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT; // The scaling param type is float for HALF and FLOAT tensors typedef const float ScalingParamType; static ScalingParamType* kOne() { @@ -104,7 +101,6 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; - static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT; typedef const float ScalingParamType; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; @@ -120,7 +116,6 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; - static const cudnnDataType_t bn_param_type = CUDNN_DATA_DOUBLE; typedef const double ScalingParamType; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; -- GitLab