From e870947cfd1f0a2d86d5d422d445c41a99913090 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Sun, 18 Mar 2018 17:58:42 -0700 Subject: [PATCH] fix batch norm fp16 param type --- paddle/fluid/operators/batch_norm_op.cc | 23 +++++++++++ paddle/fluid/operators/batch_norm_op.cu.cc | 38 +++++++++++-------- paddle/fluid/platform/cudnn_helper.h | 5 +++ .../tests/unittests/test_batch_norm_op.py | 31 ++++++++------- 4 files changed, 69 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 215ae229aff..ae970acc272 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel { ctx->SetOutputDim("SavedVariance", {C}); ctx->ShareLoD("X", "Y"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("X")->type()); + // For float or float16 input tensor, the type of the scale, bias, mean, + // and var tensors should both be float. + auto bn_param_type = framework::proto::VarType::FP32; + PADDLE_ENFORCE_EQ(bn_param_type, + framework::ToDataType(ctx.Input("Scale")->type()), + "Scale input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, + framework::ToDataType(ctx.Input("Bias")->type()), + "Bias input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, + framework::ToDataType(ctx.Input("Mean")->type()), + "Mean input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType( + ctx.Input("Variance")->type()), + "Variance input should be of float type"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index f4919398eb9..5e976788624 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -26,6 +26,8 @@ using Tensor = framework::Tensor; using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; +template +using bn_param_type = CudnnDataType::bn_param_type; void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, int *N, int *C, int *H, int *W, int *D) { @@ -104,8 +106,9 @@ class BatchNormKernel CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); + // Note: PERSISTENT not implemented for inference CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( - bn_param_desc_, data_desc_, mode_)); + bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_)); const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); @@ -118,15 +121,15 @@ class BatchNormKernel // alloc memory y->mutable_data(ctx.GetPlace()); - mean_out->mutable_data(ctx.GetPlace()); - variance_out->mutable_data(ctx.GetPlace()); - saved_mean->mutable_data(ctx.GetPlace()); - saved_variance->mutable_data(ctx.GetPlace()); + mean_out->mutable_data>(ctx.GetPlace()); + variance_out->mutable_data>(ctx.GetPlace()); + saved_mean->mutable_data>(ctx.GetPlace()); + saved_variance->mutable_data>(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); - math::SetConstant functor; - functor(dev_ctx, saved_mean, static_cast(0)); - functor(dev_ctx, saved_variance, static_cast(0)); + math::SetConstant> functor; + functor(dev_ctx, saved_mean, static_cast>(0)); + functor(dev_ctx, saved_variance, static_cast>(0)); auto handle = dev_ctx.cudnn_handle(); @@ -147,8 +150,10 @@ class BatchNormKernel CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), - bn_param_desc_, scale->template data(), bias->template data(), - est_mean->template data(), est_var->template data(), epsilon)); + bn_param_desc_, scale->template data>(), + bias->template data>(), + est_mean->template data>(), + est_var->template data>(), epsilon)); } else { // Run training mode. // obtain running mean and running inv var, and see if we need to @@ -159,11 +164,14 @@ class BatchNormKernel handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data(), bias->template data(), this_factor, - mean_out->template mutable_data(ctx.GetPlace()), - variance_out->template mutable_data(ctx.GetPlace()), epsilon, - saved_mean->template mutable_data(ctx.GetPlace()), - saved_variance->template mutable_data(ctx.GetPlace()))); + scale->template data>(), + bias->template data>(), this_factor, + mean_out->template mutable_data>(ctx.GetPlace()), + variance_out->template mutable_data>(ctx.GetPlace()), + epsilon, + saved_mean->template mutable_data>(ctx.GetPlace()), + saved_variance->template mutable_data>( + ctx.GetPlace()))); } // clean when exit. diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 7e001ecc561..a40c3662419 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -85,6 +85,9 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_HALF; + // cudnn batch norm requires that Scale, Bias, Mean, and Variance + // to be FLOAT tensors when the input x is HALF tensor + static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT; // The scaling param type is float for HALF and FLOAT tensors typedef const float ScalingParamType; static ScalingParamType* kOne() { @@ -101,6 +104,7 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; + static const cudnnDataType_t bn_param_type = CUDNN_DATA_FLOAT; typedef const float ScalingParamType; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; @@ -116,6 +120,7 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; + static const cudnnDataType_t bn_param_type = CUDNN_DATA_DOUBLE; typedef const double ScalingParamType; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 91a9d826a0c..261c457708a 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -193,7 +193,7 @@ class TestBatchNormOpInference(OpTest): def __assert_close(self, tensor, np_array, msg, atol=1e-4): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - def check_with_place(place, data_layout, dtype, shape): + def check_with_place(self, place, data_layout, dtype, shape): epsilon = 0.00001 if len(shape) == 2: x_shape = shape @@ -209,11 +209,11 @@ class TestBatchNormOpInference(OpTest): scale_shape = [c] x_val = np.random.random_sample(x_shape).astype(dtype) - scale_val = np.random.random_sample(scale_shape).astype(dtype) - bias_val = np.random.random_sample(scale_shape).astype(dtype) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + bias_val = np.random.random_sample(scale_shape).astype(np.float32) - mean = np.zeros(scale_shape).astype(dtype) - variance = np.ones(scale_shape).astype(dtype) + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, epsilon, data_layout).astype(dtype) @@ -266,9 +266,13 @@ class TestBatchNormOpInference(OpTest): batch_norm_op.run(scope, place) # check inference result - self.__assert_close(y_tensor, y_out, - "inference output are different at " + str(place) + - ", " + data_layout + ", " + str(np.dtype(dtype))) + self.__assert_close( + y_tensor, + y_out, + "inference output are different at " + str(place) + ", " + + data_layout + ", " + str(np.dtype(dtype)) + + str(np.array(y_tensor)) + str(y_out), + atol=2e-2) def test_check_output(self): places = [core.CPUPlace()] @@ -277,8 +281,9 @@ class TestBatchNormOpInference(OpTest): for place in places: for data_format in ["NCHW", "NHWC"]: - check_with_place(place, data_format, self.dtype, [2, 3, 4, 5]) - check_with_place(place, data_format, self.dtype, [2, 3]) + self.check_with_place(place, data_format, self.dtype, + [2, 3, 4, 5]) + self.check_with_place(place, data_format, self.dtype, [2, 3]) class TestFP16BatchNormOpInference(TestBatchNormOpInference): @@ -294,9 +299,9 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference): for place in places: for data_format in ["NCHW", "NHWC"]: - check_output_with_place(place, data_format, self.dtype, - [2, 3, 4, 5]) - check_output_with_place(place, data_format, self.dtype, [2, 3]) + self.check_with_place(place, data_format, self.dtype, + [2, 3, 4, 5]) + self.check_with_place(place, data_format, self.dtype, [2, 3]) class TestBatchNormOpTraining(OpTest): -- GitLab