diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 215ae229aff96d76fc948e19bdb42db319af65dc..5d27f5b60c7115a32aeeca5ec2a6654471c310c7 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -80,6 +80,29 @@ class BatchNormOp : public framework::OperatorWithKernel { ctx->SetOutputDim("SavedVariance", {C}); ctx->ShareLoD("X", "Y"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("X")->type()); + // For float or float16 input tensor, the type of the scale, bias, mean, + // and var tensors should both be float. + auto bn_param_type = framework::proto::VarType::FP32; + PADDLE_ENFORCE_EQ(bn_param_type, + framework::ToDataType(ctx.Input("Scale")->type()), + "Scale input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, + framework::ToDataType(ctx.Input("Bias")->type()), + "Bias input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, + framework::ToDataType(ctx.Input("Mean")->type()), + "Mean input should be of float type"); + PADDLE_ENFORCE_EQ(bn_param_type, framework::ToDataType( + ctx.Input("Variance")->type()), + "Variance input should be of float type"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/batch_norm_op.cu.cc b/paddle/fluid/operators/batch_norm_op.cu.cc index 2d1556efc66826ea9847de8311ccecdee0ea7871..6ceacc39924a7558e380aaf563aaf234f1bf30a5 100644 --- a/paddle/fluid/operators/batch_norm_op.cu.cc +++ b/paddle/fluid/operators/batch_norm_op.cu.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -26,6 +27,8 @@ using Tensor = framework::Tensor; using DataLayout = framework::DataLayout; template using CudnnDataType = platform::CudnnDataType; +template +using BatchNormParamType = typename CudnnDataType::BatchNormParamType; void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout, int *N, int *C, int *H, int *W, int *D) { @@ -104,8 +107,9 @@ class BatchNormKernel CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor( data_desc_, CudnnDataType::type, x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data())); + // Note: PERSISTENT not implemented for inference CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor( - bn_param_desc_, data_desc_, mode_)); + bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_)); const auto *scale = ctx.Input("Scale"); const auto *bias = ctx.Input("Bias"); @@ -118,15 +122,16 @@ class BatchNormKernel // alloc memory y->mutable_data(ctx.GetPlace()); - mean_out->mutable_data(ctx.GetPlace()); - variance_out->mutable_data(ctx.GetPlace()); - saved_mean->mutable_data(ctx.GetPlace()); - saved_variance->mutable_data(ctx.GetPlace()); + mean_out->mutable_data>(ctx.GetPlace()); + variance_out->mutable_data>(ctx.GetPlace()); + saved_mean->mutable_data>(ctx.GetPlace()); + saved_variance->mutable_data>(ctx.GetPlace()); auto &dev_ctx = ctx.template device_context(); - math::SetConstant functor; - functor(dev_ctx, saved_mean, 0); - functor(dev_ctx, saved_variance, 0); + math::SetConstant> + functor; + functor(dev_ctx, saved_mean, static_cast>(0)); + functor(dev_ctx, saved_variance, static_cast>(0)); auto handle = dev_ctx.cudnn_handle(); @@ -147,8 +152,10 @@ class BatchNormKernel CUDNN_BATCHNORM_SPATIAL, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), - bn_param_desc_, scale->template data(), bias->template data(), - est_mean->template data(), est_var->template data(), epsilon)); + bn_param_desc_, scale->template data>(), + bias->template data>(), + est_mean->template data>(), + est_var->template data>(), epsilon)); } else { // Run training mode. // obtain running mean and running inv var, and see if we need to @@ -159,11 +166,16 @@ class BatchNormKernel handle, mode_, CudnnDataType::kOne(), CudnnDataType::kZero(), data_desc_, x->template data(), data_desc_, y->template mutable_data(ctx.GetPlace()), bn_param_desc_, - scale->template data(), bias->template data(), this_factor, - mean_out->template mutable_data(ctx.GetPlace()), - variance_out->template mutable_data(ctx.GetPlace()), epsilon, - saved_mean->template mutable_data(ctx.GetPlace()), - saved_variance->template mutable_data(ctx.GetPlace()))); + scale->template data>(), + bias->template data>(), this_factor, + mean_out->template mutable_data>( + ctx.GetPlace()), + variance_out->template mutable_data>( + ctx.GetPlace()), + epsilon, saved_mean->template mutable_data>( + ctx.GetPlace()), + saved_variance->template mutable_data>( + ctx.GetPlace()))); } // clean when exit. @@ -270,9 +282,9 @@ class BatchNormGradKernel } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( - batch_norm, - ops::BatchNormKernel); + batch_norm, ops::BatchNormKernel, + ops::BatchNormKernel); REGISTER_OP_CUDA_KERNEL( - batch_norm_grad, - ops::BatchNormGradKernel); + batch_norm_grad, ops::BatchNormGradKernel); diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index 17e576a9d5c8f50fbe84b066a93460f03ae6bb08..299a0aed01dfe0448d896738d9fd33319b1b2887 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -278,6 +278,7 @@ void axpy( cblas_daxpy(n, alpha, x, 1, y, 1); } +template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index c6ca2693a053360ce5dc44765acf1520a11cce2c..1e909db5288afccb9dd0be08a45cf3c27048ae6f 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -348,6 +348,7 @@ void axpy( &alpha, x, 1, y, 1)); } +template struct SetConstant; template struct SetConstant; template struct SetConstant; template struct SetConstant; diff --git a/paddle/fluid/platform/cudnn_helper.h b/paddle/fluid/platform/cudnn_helper.h index 7e001ecc56173db76e8c576e7efd66f41192f292..7c604e14eb245232ed92f53a00b9bde45c2fbaec 100644 --- a/paddle/fluid/platform/cudnn_helper.h +++ b/paddle/fluid/platform/cudnn_helper.h @@ -86,7 +86,8 @@ class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_HALF; // The scaling param type is float for HALF and FLOAT tensors - typedef const float ScalingParamType; + using ScalingParamType = const float; + using BatchNormParamType = float; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v; @@ -101,7 +102,8 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_FLOAT; - typedef const float ScalingParamType; + using ScalingParamType = const float; + using BatchNormParamType = float; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v; @@ -116,7 +118,8 @@ template <> class CudnnDataType { public: static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; - typedef const double ScalingParamType; + using ScalingParamType = const double; + using BatchNormParamType = double; static ScalingParamType* kOne() { static ScalingParamType v = 1.0; return &v; diff --git a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py index 80e6fa6df3c21aa19feb571916f11c41ccd6bb10..10aa63e18a6eeaa44e5b12f7532998dca2bc5e9f 100644 --- a/python/paddle/fluid/tests/unittests/test_batch_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_batch_norm_op.py @@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set): return backward_op +def _reference_testing(x, scale, offset, mean, var, epsilon, data_format): + x_shape = x.shape + if len(x_shape) == 2: + if data_format == "NCHW": + x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1)) + else: + x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1])) + + if data_format == "NCHW": + n, c, h, w = x.shape + mean_tile = np.reshape(mean, (1, c, 1, 1)) + mean_tile = np.tile(mean_tile, (n, 1, h, w)) + var_tile = np.reshape(var, (1, c, 1, 1)) + var_tile = np.tile(var_tile, (n, 1, h, w)) + normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon) + scale_tile = np.reshape(scale, (1, c, 1, 1)) + scale_tile = np.tile(scale_tile, (n, 1, h, w)) + offset_tile = np.reshape(offset, (1, c, 1, 1)) + offset_tile = np.reshape(offset_tile, (1, c, 1, 1)) + y = normalized * scale_tile + offset_tile + elif data_format == "NHWC": + normalized = (x - mean) / np.sqrt(var + epsilon) + y = normalized * scale + offset + else: + raise ValueError("Unknown data order.") + + if len(x_shape) == 2: + y = np.reshape(y, x_shape) + return y + + def _reference_training(x, scale, offset, epsilon, data_format): x_shape = x.shape if len(x_shape) == 2: @@ -155,11 +186,159 @@ def set_output_grad(scope, outputs, place, feed_dict=None): __set_tensor__(output, data) -class TestBatchNormOp(OpTest): +class TestBatchNormOpInference(OpTest): + def setUp(self): + self.dtype = np.float32 + def __assert_close(self, tensor, np_array, msg, atol=1e-4): self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) - def test_python(self): + def check_with_place(self, place, data_layout, dtype, shape): + epsilon = 0.00001 + if len(shape) == 2: + x_shape = shape + c = x_shape[1] + else: + n, h, w, c = shape[0], shape[1], shape[2], shape[3] + if data_layout == "NHWC": + x_shape = [n, h, w, c] + elif data_layout == "NCHW": + x_shape = [n, c, h, w] + else: + raise ValueError("Unknown data layout.") + scale_shape = [c] + + x_val = np.random.random_sample(x_shape).astype(dtype) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + bias_val = np.random.random_sample(scale_shape).astype(np.float32) + + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) + + y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, + epsilon, data_layout).astype(dtype) + + scope = core.Scope() + + # create input + x_tensor = create_or_get_tensor(scope, "x_val", + OpTest.np_dtype_to_fluid_dtype(x_val), + place) + scale_tensor = create_or_get_tensor( + scope, "scale_val", + OpTest.np_dtype_to_fluid_dtype(scale_val), place) + bias_tensor = create_or_get_tensor( + scope, "bias_val", OpTest.np_dtype_to_fluid_dtype(bias_val), place) + mean_tensor = create_or_get_tensor(scope, "mean", + OpTest.np_dtype_to_fluid_dtype(mean), + place) + variance_tensor = create_or_get_tensor( + scope, "variance", OpTest.np_dtype_to_fluid_dtype(variance), place) + + # create output + y_tensor = create_or_get_tensor(scope, "y_out", None, place) + saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None, + place) + saved_variance_tensor = create_or_get_tensor(scope, "saved_variance", + None, place) + mean_out_tensor = mean_tensor + variance_out_tensor = variance_tensor + + batch_norm_op = Operator( + "batch_norm", + # inputs + X="x_val", + Scale="scale_val", + Bias="bias_val", + Mean="mean", + Variance="variance", + # outputs + Y="y_out", + MeanOut="mean", + VarianceOut="variance", + SavedMean="saved_mean", + SavedVariance="saved_variance", + # attrs + is_test=True, + data_layout=data_layout, + epsilon=epsilon) + + batch_norm_op.run(scope, place) + + # check inference result + self.__assert_close( + y_tensor, + y_out, + "inference output are different at " + str(place) + ", " + + data_layout + ", " + str(np.dtype(dtype)) + + str(np.array(y_tensor)) + str(y_out), + atol=1e-3) + + def test_check_output(self): + places = [core.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + places.append(core.CUDAPlace(0)) + + for place in places: + for data_format in ["NCHW", "NHWC"]: + self.check_with_place(place, data_format, self.dtype, + [2, 3, 4, 5]) + self.check_with_place(place, data_format, self.dtype, [2, 3]) + + +class TestFP16BatchNormOpInference(TestBatchNormOpInference): + def setUp(self): + self.dtype = np.float16 + + def test_check_output(self): + places = [] + if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"): + place = core.CUDAPlace(0) + if core.is_float16_supported(place): + places.append(place) + + for place in places: + for data_format in ["NCHW", "NHWC"]: + self.check_with_place(place, data_format, self.dtype, + [2, 3, 4, 5]) + self.check_with_place(place, data_format, self.dtype, [2, 3]) + + +class TestBatchNormOpTraining(OpTest): + def __assert_close(self, tensor, np_array, msg, atol=1e-4): + self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg) + + def test_python_testing(self): + data_format = "NHWC" + epsilon = 0.00001 + + n, h, w, c = 2, 3, 4, 5 + x_shape = [n, h, w, c] + scale_shape = [c] + + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + bias_val = np.random.random_sample(scale_shape).astype(np.float32) + + mean = np.zeros(scale_shape).astype(np.float32) + variance = np.ones(scale_shape).astype(np.float32) + + y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance, + epsilon, "NHWC") + + # running N, C, H, W case + # should produce the same results + x_shape2 = [n, c, h, w] + x_val2 = np.transpose(x_val, (0, 3, 1, 2)) + y_out2 = _reference_testing(x_val2, scale_val, bias_val, mean, variance, + epsilon, "NCHW") + + # transfer (N, C, H, W) back to (N, H, W, C) + y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1)) + self.__assert_close(y_out, y_out2_trans, "inference output") + print 'python: NHWC, NCHW, inference checking passed' + + def test_python_training(self): data_format = "NHWC" epsilon = 0.00001 momentum = 0.9 @@ -197,7 +376,7 @@ class TestBatchNormOp(OpTest): # transfer (N, C, H, W) back to (N, H, W, C) y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1)) - self.__assert_close(y_out, y_out2_trans, "batch variance") + self.__assert_close(y_out, y_out2_trans, "batch output") print 'python: NHWC, NCHW, forward checking passed' # test backward now