diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index 10bf11e3a775af8d84dbf48e4b1b0409e7777f64..0bffec03728d8285977b86e60ecc2dea602fa91c 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1210,12 +1210,17 @@ void batch_norm_grad(const Tensor& x, Tensor x_data = x; Tensor out_grad_data = out_grad; - if (x.dtype() == phi::DataType::FLOAT16) { + + bool need_cast = x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16; + if (need_cast) { x_data = cast(x, phi::DataType::FLOAT32); } - if (out_grad.dtype() == phi::DataType::FLOAT16) { + if (out_grad.dtype() == phi::DataType::FLOAT16 || + out_grad.dtype() == phi::DataType::BFLOAT16) { out_grad_data = cast(out_grad, phi::DataType::FLOAT32); } + auto x_dims = x_data.dims(); const int C = (data_layout_ == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); @@ -1278,7 +1283,7 @@ void batch_norm_grad(const Tensor& x, if (use_global_stats) { auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad; auto nchw_x_grad = transpose(nhwc_x_grad, nhwc_to_nchw_dim); - if (x.dtype() == phi::DataType::FLOAT16) { + if (need_cast) { nchw_x_grad = cast(nchw_x_grad, x.dtype()); } set_output(nchw_x_grad, x_grad); @@ -1291,7 +1296,7 @@ void batch_norm_grad(const Tensor& x, auto x_grad_data = part1 * part2; auto nchw_x_grad = transpose(x_grad_data, nhwc_to_nchw_dim); - if (x.dtype() == phi::DataType::FLOAT16) { + if (need_cast) { nchw_x_grad = cast(nchw_x_grad, x.dtype()); } set_output(nchw_x_grad, x_grad); @@ -1314,7 +1319,7 @@ void batch_norm_grad(const Tensor& x, out_grad_data * (x_data - mean_data), reduce_axis, dtype, false); if (use_global_stats) { auto x_grad_data = scale * rsqrt_var * out_grad_data; - if (x.dtype() == phi::DataType::FLOAT16) { + if (need_cast) { x_grad_data = cast(x_grad_data, x.dtype()); } set_output(x_grad_data, x_grad); @@ -1328,7 +1333,7 @@ void batch_norm_grad(const Tensor& x, out_grad_data - mean_temp1 - (x_data - mean_data) * mean_temp2; auto x_grad_data = part1 * part2; - if (x.dtype() == phi::DataType::FLOAT16) { + if (need_cast) { x_grad_data = cast(x_grad_data, x.dtype()); } set_output(x_grad_data, x_grad); diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index fa036dc1cc309b25a0a57364ae0152bb51b502a0..231cdf9161a1f8896201acc455b25d76c10e0b31 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -79,9 +79,12 @@ def composite_batchnorm( is_amp = False from paddle.fluid.data_feeder import convert_dtype - if convert_dtype(x.dtype) == "float16": + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: is_amp = True x = cast(x, "float32") + scale = cast(scale, "float32") if scale else scale + bias = cast(bias, "float32") if bias else bias feature_axis = ( 1 if data_layout in ('NC', 'NCL', 'NCHW', 'NCHWD') else len(x.shape) - 1 @@ -124,7 +127,7 @@ def composite_batchnorm( else: y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape) if is_amp: - y = cast(y, "float16") + y = cast(y, dtype) # add op assign to detach tensor in void unsafe change outside the rule. batch_mean_ = assign(batch_mean) diff --git a/python/paddle/nn/functional/norm.py b/python/paddle/nn/functional/norm.py index e1c3d985a68db74de32d4a2a285907f5951d6931..1a556a7c5106e64354a4529aa037f1c3a66d96be 100644 --- a/python/paddle/nn/functional/norm.py +++ b/python/paddle/nn/functional/norm.py @@ -240,7 +240,9 @@ def batch_norm( from paddle.fluid.data_feeder import convert_dtype param_dtype = ( - x.dtype if convert_dtype(x.dtype) != 'float16' else 'float32' + x.dtype + if convert_dtype(x.dtype) not in ['float16', 'uint16'] + else 'float32' ) saved_mean = helper.create_variable_for_type_inference( dtype=param_dtype, stop_gradient=True diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index 525cac1c5e22c70ef52a90c16c6d81a86fd2f62b..bd8ec526ea9440cd32c32740308e4ea4c303a9c0 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -953,6 +953,7 @@ if(WITH_NV_JETSON) set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 1200) set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 1200) set_tests_properties(test_norm_op PROPERTIES TIMEOUT 1200) + set_tests_properties(test_batch_norm_op_prim PROPERTIES TIMEOUT 1500) set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 1500) set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 1500) else() @@ -961,6 +962,7 @@ else() set_tests_properties(test_conv3d_transpose_op PROPERTIES TIMEOUT 120) set_tests_properties(test_conv3d_op PROPERTIES TIMEOUT 120) set_tests_properties(test_norm_op PROPERTIES TIMEOUT 120) + set_tests_properties(test_batch_norm_op_prim PROPERTIES TIMEOUT 250) set_tests_properties(test_layer_norm_op PROPERTIES TIMEOUT 250) set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150) endif() diff --git a/test/legacy_test/eager_op_test.py b/test/legacy_test/eager_op_test.py index e534879e8053912fd70145e071a354b961fa4eb2..b4ba945b1f5726e649184d6a80a2e766ddf173ae 100644 --- a/test/legacy_test/eager_op_test.py +++ b/test/legacy_test/eager_op_test.py @@ -1624,6 +1624,7 @@ class OpTest(unittest.TestCase): equal_nan=False, check_dygraph=True, check_prim=False, + only_check_prim=False, inplace_atol=None, check_cinn=False, ): @@ -2033,6 +2034,8 @@ class OpTest(unittest.TestCase): # Support operators which are not in the NO_FP64_CHECK_GRAD_OP_LIST list can be test prim with fp32 self.__class__.check_prim = True self.__class__.op_type = self.op_type + if only_check_prim: + return static_checker = StaticChecker(self, self.outputs) static_checker.check() @@ -2150,6 +2153,7 @@ class OpTest(unittest.TestCase): check_prim=False, inplace_atol=None, check_cinn=False, + only_check_prim=False, ): self.__class__.op_type = self.op_type if self.is_mkldnn_op(): @@ -2171,9 +2175,12 @@ class OpTest(unittest.TestCase): equal_nan, check_dygraph=check_dygraph, check_prim=check_prim, + only_check_prim=only_check_prim, inplace_atol=inplace_atol, check_cinn=check_cinn, ) + if not res and only_check_prim: + continue if check_dygraph: outs, dygraph_dygraph_outs, fetch_list = res else: diff --git a/test/legacy_test/test_batch_norm_op_prim.py b/test/legacy_test/test_batch_norm_op_prim.py new file mode 100644 index 0000000000000000000000000000000000000000..17148541bd443485c7f5f5be3bb361eca79cb83d --- /dev/null +++ b/test/legacy_test/test_batch_norm_op_prim.py @@ -0,0 +1,521 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from eager_op_test import ( + OpTest, + _set_use_system_allocator, + convert_float_to_uint16, +) + +import paddle +import paddle.nn.functional as F +from paddle.fluid import core + +paddle.enable_static() + +np.random.seed(123) +paddle.seed(123) + +_set_use_system_allocator(True) + + +def batch_norm_wrapper( + x, + running_mean, + running_variance, + weight, + bias, + is_test, + momentum, + epsilon, + data_format, + use_global_stats, +): + y = F.batch_norm( + x, + running_mean, + running_variance, + weight, + bias, + training=not is_test, + momentum=momentum, + epsilon=epsilon, + data_format=data_format, + use_global_stats=use_global_stats, + ) + z = F.relu(y) + return z + + +class TestBatchNormOp(OpTest): + def setUp(self): + self.python_api = batch_norm_wrapper + self.public_python_api = batch_norm_wrapper + self.op_type = "batch_norm" + self.prim_op_type = "comp" + self.python_out_sig = ["Y"] + self.initConfig() + self.initTestCase() + + def test_check_output(self): + if self.dtype not in ("uint16", "float16"): + self.check_output_with_place( + core.CPUPlace(), + no_check_set=None, + check_prim=True, + only_check_prim=True, + ) + if paddle.is_compiled_with_cuda(): + self.check_output_with_place( + core.CUDAPlace(0), + no_check_set=None, + check_prim=True, + only_check_prim=True, + ) + + def test_check_grad_x(self): + if self.dtype not in ("uint16", "float16"): + self.check_grad_with_place( + core.CPUPlace(), + ["X"], + ['Y'], + user_defined_grad_outputs=self.out_grad, + check_prim=True, + only_check_prim=True, + ) + elif self.data_format == "NCHW" and paddle.is_compiled_with_cuda(): + # origin batch_norm cuda kernel differ in nhwc x_grad whether to calculate scale_grad and bias_grad + self.check_grad_with_place( + core.CUDAPlace(0), + ["X"], + ['Y'], + user_defined_grad_outputs=self.out_grad, + check_prim=True, + only_check_prim=True, + ) + + def test_check_grad_scale_bias(self): + self.enable_cinn = False + self.rev_comp_atol = 1e-3 + self.rev_comp_rtol = 1e-3 + if self.dtype not in ("uint16", "float16"): + self.check_grad_with_place( + core.CPUPlace(), + ["X", "Scale", "Bias"], + ['Y'], + user_defined_grad_outputs=self.out_grad, + check_prim=True, + only_check_prim=True, + ) + if paddle.is_compiled_with_cuda(): + self.check_grad_with_place( + core.CUDAPlace(0), + ["X", "Scale", "Bias"], + ['Y'], + user_defined_grad_outputs=self.out_grad, + check_prim=True, + only_check_prim=True, + ) + # restore init config + self.initConfig() + + def initConfig(self): + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + self.fw_comp_atol = 1e-5 + self.fw_comp_rtol = 1e-5 + + self.cinn_atol = 1e-5 + self.cinn_rtol = 1e-5 + + self.dtype = "float32" + self.shape = [16, 24, 16, 8] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = None + + def initTestCase(self): + if ( + self.dtype in ("uint16", "float16") + and not paddle.is_compiled_with_cuda() + ): + self.__class__.op_type = self.op_type + self.__class__.no_need_check_grad = True + return + np.random.seed(123) + + self.C = self.shape[1] if self.data_format == "NCHW" else self.shape[-1] + if self.dtype == "uint16": + x = convert_float_to_uint16( + np.random.random(self.shape).astype("float32") + ) + else: + x = np.random.random(self.shape).astype(self.dtype) + + self.var_dtype = ( + "float32" if self.dtype in ["float16", "uint16"] else self.dtype + ) + weight = np.random.random(self.C).astype(self.var_dtype) + bias = np.random.random(self.C).astype(self.var_dtype) + running_mean = np.random.random(self.C).astype(self.var_dtype) + running_var = np.random.random(self.C).astype(self.var_dtype) + if self.dtype == "uint16": + self.out_grad = [ + convert_float_to_uint16( + np.random.random(self.shape).astype("float32") + ) + ] + else: + self.out_grad = [np.random.random(self.shape).astype(self.dtype)] + self.inputs = { + "X": x, + "Scale": weight, + "Bias": bias, + "Mean": running_mean, + "Variance": running_var, + } + + if self.use_global_stats is None: + self.use_global_stats = not self.training + trainable_statistics = False + else: + trainable_statistics = not self.use_global_stats + + self.attrs = { + "momentum": self.momentum, + "epsilon": self.epsilon, + "is_test": not self.training, + "data_layout": self.data_format, + "use_global_stats": self.use_global_stats, + "trainable_statistics": trainable_statistics, + } + + paddle.disable_static() + + ( + y, + running_mean, + running_var, + saved_mean, + saved_variance, + _, + ) = paddle._C_ops.batch_norm( + paddle.to_tensor(x), + paddle.to_tensor(running_mean), + paddle.to_tensor(running_var), + paddle.to_tensor(weight), + paddle.to_tensor(bias), + not self.training, + self.momentum, + self.epsilon, + self.data_format, + self.use_global_stats, + trainable_statistics, + ) + if self.dtype == "uint16": + y = convert_float_to_uint16(y) + paddle.enable_static() + self.outputs = { + "Y": y, + "MeanOut": running_mean, + "VarianceOut": running_var, + "SavedMean": saved_mean, + "SavedVariance": saved_variance, + } + + +class TestBatchNormOpNCHWTestMode(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-5 + self.fw_comp_rtol = 1e-5 + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + self.dtype = "float32" + self.shape = [16, 16, 16, 8] + self.training = False + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = True + + +class TestBatchNormOpNCHWFp64(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-11 + self.fw_comp_rtol = 1e-11 + self.rev_comp_atol = 1e-11 + self.rev_comp_rtol = 1e-11 + self.dtype = "float64" + self.shape = [16, 16, 16, 8] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = None + + +class TestBatchNormOpNCHWTestModeFp64(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-15 + self.fw_comp_rtol = 1e-15 + self.rev_comp_atol = 1e-15 + self.rev_comp_rtol = 1e-15 + self.dtype = "float64" + self.shape = [16, 16, 16, 8] + self.training = False + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = None + + +class TestBatchNormOpNCHWFp16(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-3 + self.fw_comp_rtol = 1e-3 + self.rev_comp_atol = 1e-3 + self.rev_comp_rtol = 1e-3 + self.dtype = "float16" + self.shape = [16, 16, 16, 8] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = None + + +class TestBatchNormOpNCHWTestModeFp16(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-3 + self.fw_comp_rtol = 1e-3 + self.rev_comp_atol = 1e-3 + self.rev_comp_rtol = 1e-3 + self.dtype = "float16" + self.shape = [16, 16, 16, 8] + self.training = False + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = None + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestBatchNormOpNCHWbf16(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-3 + self.fw_comp_rtol = 1e-3 + self.rev_comp_atol = 1e-3 + self.rev_comp_rtol = 1e-3 + self.cinn_atol = 1e-3 + self.cinn_rtol = 1e-3 + self.dtype = "uint16" + self.shape = [16, 16, 16, 8] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = None + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestBatchNormOpNCHWTestModebf16(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-3 + self.fw_comp_rtol = 1e-3 + self.rev_comp_atol = 1e-3 + self.rev_comp_rtol = 1e-3 + self.cinn_atol = 1e-3 + self.cinn_rtol = 1e-3 + self.dtype = "uint16" + self.shape = [16, 16, 16, 8] + self.training = False + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = None + + +class TestBatchNormOpNHWC(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-5 + self.fw_comp_rtol = 1e-5 + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + self.dtype = "float32" + self.shape = [16, 16, 16, 8] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NHWC" + self.use_global_stats = None + + +class TestBatchNormOpNHWCFp64(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-11 + self.fw_comp_rtol = 1e-11 + self.rev_comp_atol = 1e-11 + self.rev_comp_rtol = 1e-11 + self.dtype = "float64" + self.shape = [16, 16, 16, 8] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NHWC" + self.use_global_stats = None + + +class TestBatchNormOpNHWCFp16(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-3 + self.fw_comp_rtol = 1e-3 + self.rev_comp_atol = 1e-3 + self.rev_comp_rtol = 1e-3 + self.dtype = "float16" + self.shape = [16, 16, 16, 8] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NHWC" + self.use_global_stats = None + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support the bfloat16", +) +class TestBatchNormOpNHWCbf16(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-3 + self.fw_comp_rtol = 1e-3 + self.rev_comp_atol = 1e-3 + self.rev_comp_rtol = 1e-3 + self.cinn_atol = 1e-3 + self.cinn_rtol = 1e-3 + self.dtype = "uint16" + self.shape = [16, 16, 16, 8] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NHWC" + self.use_global_stats = None + + +class TestBatchNormOpNCHWShape2(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-5 + self.fw_comp_rtol = 1e-5 + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + self.dtype = "float32" + self.shape = [4, 8, 16, 32] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = None + + +class TestBatchNormOpNCHWMomentum2(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-5 + self.fw_comp_rtol = 1e-5 + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + self.dtype = "float32" + self.shape = [16, 16, 16, 8] + self.training = True + self.momentum = 0.9 + self.epsilon = 1e-05 + self.data_format = "NCHW" + self.use_global_stats = None + + +class TestBatchNormOpNCHWEps2(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-5 + self.fw_comp_rtol = 1e-5 + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + self.dtype = "float32" + self.shape = [16, 16, 16, 8] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-06 + self.data_format = "NCHW" + self.use_global_stats = None + + +class TestBatchNormOpNHWCShape2(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-5 + self.fw_comp_rtol = 1e-5 + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + self.dtype = "float32" + self.shape = [4, 8, 16, 32] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-05 + self.data_format = "NHWC" + self.use_global_stats = None + + +class TestBatchNormOpNHWCMomentum2(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-5 + self.fw_comp_rtol = 1e-5 + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + self.dtype = "float32" + self.shape = [16, 16, 16, 8] + self.training = True + self.momentum = 0.9 + self.epsilon = 1e-05 + self.data_format = "NHWC" + self.use_global_stats = None + + +class TestBatchNormOpNHWCEps2(TestBatchNormOp): + def initConfig(self): + self.fw_comp_atol = 1e-5 + self.fw_comp_rtol = 1e-5 + self.rev_comp_atol = 1e-5 + self.rev_comp_rtol = 1e-5 + self.dtype = "float32" + self.shape = [16, 16, 16, 8] + self.training = True + self.momentum = 0.1 + self.epsilon = 1e-06 + self.data_format = "NHWC" + self.use_global_stats = None + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()