From 6595032498d9e8dd28ebe119b7081ace198e0da7 Mon Sep 17 00:00:00 2001 From: Charles-hit <56987902+Charles-hit@users.noreply.github.com> Date: Thu, 13 Jul 2023 15:51:44 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90AMP=20Prim=20OP=E3=80=91support=20inst?= =?UTF-8?q?ance=5Fnorm=20prim=20ops=20for=20fp16=20and=20bf16=20dtype=20(#?= =?UTF-8?q?55368)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [prim]support fp16 for instance_norm and instance_norm_grad * support fp16 and bfp16 dtype for instance_norm prim rules * fix new ir test --------- Co-authored-by: cxxly --- .../composite_backward_api.h | 68 +++-- paddle/phi/infermeta/ternary.cc | 7 + .../incubate/autograd/composite_rules.py | 14 ++ test/legacy_test/test_instance_norm_op_v2.py | 235 ++++++++++-------- 4 files changed, 205 insertions(+), 119 deletions(-) diff --git a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h index b0090fce1f2..07cd1ada4a9 100644 --- a/paddle/fluid/prim/api/composite_backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -1384,40 +1384,80 @@ void instance_norm_grad(const Tensor& x, const int h = x.dims()[2]; const int w = x.dims()[3]; + auto promoted_y_grad = y_grad; + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + promoted_y_grad = cast(y_grad, phi::DataType::FLOAT32); + } + Tensor x_hat; Tensor std_inv; if (scale_grad || x_grad) { - auto mean = reshape(saved_mean, IntArray({n, c, 1, 1})) + auto promoted_x = x; + auto promoted_saved_mean = saved_mean; + auto promoted_saved_var = saved_variance; + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + promoted_x = cast(x, phi::DataType::FLOAT32); + promoted_saved_mean = cast(saved_mean, phi::DataType::FLOAT32); + promoted_saved_var = cast(saved_variance, phi::DataType::FLOAT32); + } + auto mean = reshape(promoted_saved_mean, IntArray({n, c, 1, 1})) .tile(IntArray({1, 1, h, w})); - std_inv = reshape(saved_variance, IntArray({n, c, 1, 1})) + std_inv = reshape(promoted_saved_var, IntArray({n, c, 1, 1})) .tile(IntArray({1, 1, h, w})); - x_hat = (x - mean) * std_inv; + x_hat = (promoted_x - mean) * std_inv; } // x_grad = scale * inv_var * (y_grad - y_grad.mean(2,3) - x_hat * (y_grad * // x_hat).mean((h,w))) if (x_grad) { - auto scale_t = + auto scale_data = reshape(scale.get_ptr() ? scale.get() : full(IntArray({c}), 1., x.dtype()), IntArray({1, c, 1, 1})) .tile(IntArray({n, 1, h, w})); - set_output( - (scale_t * std_inv) * - (y_grad - - y_grad.sum(IntArray({2, 3}), y_grad.dtype(), true) / (h * w) - - (x_hat * - ((y_grad * x_hat).sum(IntArray({2, 3}), y_grad.dtype(), true) / - (h * w)))), - x_grad); + auto promoted_scale = scale_data; + if (scale_data.dtype() == phi::DataType::FLOAT16 || + scale_data.dtype() == phi::DataType::BFLOAT16) { + promoted_scale = cast(scale_data, phi::DataType::FLOAT32); + } + auto result = + (promoted_scale * std_inv) * + (promoted_y_grad - + promoted_y_grad.sum(IntArray({2, 3}), promoted_y_grad.dtype(), true) / + (h * w) - + (x_hat * ((promoted_y_grad * x_hat) + .sum(IntArray({2, 3}), promoted_y_grad.dtype(), true) / + (h * w)))); + if (x.dtype() == phi::DataType::FLOAT16 || + x.dtype() == phi::DataType::BFLOAT16) { + set_output(cast(result, x.dtype()), x_grad); + } else { + set_output(result, x_grad); + } } // scale_grad = x_hat * y_grad.sum(n, h, w) if (scale_grad) { - set_output((y_grad * x_hat).sum(IntArray({0, 2, 3})), scale_grad); + auto result = (promoted_y_grad * x_hat).sum(IntArray({0, 2, 3})); + auto scale_dtype = scale.get_ptr() ? scale.get().dtype() : x.dtype(); + if (scale_dtype == phi::DataType::FLOAT16 || + scale_dtype == phi::DataType::BFLOAT16) { + set_output(cast(result, scale_dtype), scale_grad); + } else { + set_output(result, scale_grad); + } } // d_bias = y_grad.sum(n, h, w) if (bias_grad) { - set_output(y_grad.sum(IntArray({0, 2, 3})), bias_grad); + auto result = promoted_y_grad.sum(IntArray({0, 2, 3})); + auto scale_dtype = scale.get_ptr() ? scale.get().dtype() : x.dtype(); + if (scale_dtype == phi::DataType::FLOAT16 || + scale_dtype == phi::DataType::BFLOAT16) { + set_output(cast(result, scale_dtype), bias_grad); + } else { + set_output(result, bias_grad); + } } } diff --git a/paddle/phi/infermeta/ternary.cc b/paddle/phi/infermeta/ternary.cc index c4a28877651..b3f17ab91b9 100644 --- a/paddle/phi/infermeta/ternary.cc +++ b/paddle/phi/infermeta/ternary.cc @@ -384,11 +384,18 @@ void InstanceNormInferMeta(const MetaTensor& x, y->share_lod(x); y->set_dtype(x.dtype()); y->set_layout(x.layout()); + phi::DataType x_dtype = x.dtype(); + phi::DataType param_type = + (x_dtype == phi::DataType::BFLOAT16 || x_dtype == phi::DataType::FLOAT16) + ? phi::DataType::FLOAT32 + : x_dtype; if (saved_mean) { saved_mean->set_dims({NxC}); + saved_mean->set_dtype(param_type); } if (saved_variance) { saved_variance->set_dims({NxC}); + saved_variance->set_dtype(param_type); } } diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index 7640b9d95a9..caedc31a3c1 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -193,6 +193,16 @@ def instancenorm_composite(x, scale, bias, epsilon): out = (x - mean(x)) / sqrt(var + epsilon)) var = mean((x-mean(x))^2) """ + is_amp = False + from paddle.fluid.data_feeder import convert_dtype + + dtype = convert_dtype(x.dtype) + if dtype in ["float16", "uint16"]: + is_amp = True + x = cast(x, "float32") + scale = cast(scale, "float32") if scale else scale + bias = cast(bias, "float32") if bias else bias + n, c, h, w = x.shape axis = tuple(range(2, len(x.shape))) mean_ = mean(x, axis=axis, keepdim=True) @@ -213,6 +223,10 @@ def instancenorm_composite(x, scale, bias, epsilon): mean_ = reshape(mean_, [-1]) saved_variance = 1 / sqrt_var saved_variance = reshape(saved_variance, [-1]) + + if is_amp: + out = cast(out, dtype) + return out, mean_, saved_variance diff --git a/test/legacy_test/test_instance_norm_op_v2.py b/test/legacy_test/test_instance_norm_op_v2.py index e3b04c0c7fb..8b6745b17bb 100644 --- a/test/legacy_test/test_instance_norm_op_v2.py +++ b/test/legacy_test/test_instance_norm_op_v2.py @@ -12,15 +12,69 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import unittest import numpy as np from eager_op_test import OpTest, convert_float_to_uint16 import paddle -import paddle.nn.functional as F -from paddle import fluid, nn -from paddle.fluid import Program, core, framework, program_guard +from paddle import fluid +from paddle.fluid import Program, core, program_guard + + +def instance_norm_wrapper( + input, weight, bias, epsilon=1e-5, momentum=0.9, data_format='NCHW' +): + if data_format == "AnyLayout": + data_format = "NCDHW" + return paddle.nn.functional.instance_norm( + input, None, None, weight, bias, True, momentum, epsilon, data_format + ) + + +def _reference_instance_norm(x, scale, bias, epsilon): + N, C, H, W = x.shape + mean = np.mean(x, axis=(2, 3), keepdims=True) + variance = np.var(x, axis=(2, 3), keepdims=True) + std = np.sqrt(variance) + epsilon + x_norm = (x - mean) / std + scale = scale.reshape([1, C, 1, 1]) + bias = bias.reshape([1, C, 1, 1]) + x_norm = scale * x_norm + bias + return x_norm, mean.reshape(N * C), std.reshape(N * C) + + +def _reference_instance_norm_grad(x, scale, mean, var): + n, c, h, w = x.shape + d_y = np.ones(x.shape) / (np.prod(x.shape)) + d_bias = np.ones((c,)) / c + + mean_tile = np.reshape(mean, (n, c, 1, 1)) + mean_tile = np.tile(mean_tile, (1, 1, h, w)) + var_tile = np.reshape(var, (n, c, 1, 1)) + var_tile = np.tile(var_tile, (1, 1, h, w)) + + d_scale = np.sum(d_y * (x - mean_tile) * var_tile, axis=(0, 2, 3)) + var_inv = var_tile + scale_tile = np.reshape(scale, (1, c, 1, 1)) + scale_tile = np.tile(scale_tile, (n, 1, h, w)) + + d_x = ( + scale_tile + * var_inv + * ( + d_y + - np.mean(d_y, axis=(2, 3), keepdims=True) + - (x - mean_tile) + * var_inv + * np.mean( + d_y * (x - mean_tile) * var_inv, axis=(2, 3), keepdims=True + ) + ) + ) + + return d_x, d_scale, d_bias class TestInstanceNorm(unittest.TestCase): @@ -86,95 +140,42 @@ class TestInstanceNorm(unittest.TestCase): np.testing.assert_allclose(y1, y2, rtol=1e-05) def test_static(self): - places = [fluid.CPUPlace()] - if core.is_compiled_with_cuda() and core.op_support_gpu( - "instance_norm" - ): - places.append(fluid.CUDAPlace(0)) - for p in places: - exe = fluid.Executor(p) - shape = [4, 10, 16, 16] - - def compute_v1(x_np): - with program_guard(Program(), Program()): - ins = paddle.nn.InstanceNorm2D(shape[1]) - x = paddle.static.data( - name='x', shape=x_np.shape, dtype=x_np.dtype - ) - y = ins(x) - exe.run(fluid.default_startup_program()) - r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] - return r - - def compute_v2(x_np): - with program_guard(Program(), Program()): - ins = paddle.nn.InstanceNorm2D(shape[1]) - x = paddle.static.data( - name='x', shape=x_np.shape, dtype=x_np.dtype - ) - y = ins(x) - exe.run(fluid.default_startup_program()) - r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] - return r - - x = np.random.randn(*shape).astype("float32") - y1 = compute_v1(x) - y2 = compute_v2(x) - np.testing.assert_allclose(y1, y2, rtol=1e-05) - - -def instance_norm_warpper( - input, weight, bias, epsilon=1e-5, momentum=0.9, data_format='NCHW' -): - if data_format == "AnyLayout": - data_format = "NCDHW" - return paddle._C_ops.instance_norm( - input, weight, bias, epsilon, momentum, data_format - ) - - -def _reference_instance_norm(x, scale, bias, epsilon): - N, C, H, W = x.shape - mean = np.mean(x, axis=(2, 3), keepdims=True) - variance = np.var(x, axis=(2, 3), keepdims=True) - std = np.sqrt(variance) + epsilon - x_norm = (x - mean) / std - scale = scale.reshape([1, C, 1, 1]) - bias = bias.reshape([1, C, 1, 1]) - x_norm = scale * x_norm + bias - return x_norm, mean.reshape(N * C), std.reshape(N * C) - - -def _reference_instance_norm_grad(x, scale, mean, var): - n, c, h, w = x.shape - d_y = np.ones(x.shape) / (np.prod(x.shape)) - d_bias = np.ones((c,)) / c - - mean_tile = np.reshape(mean, (n, c, 1, 1)) - mean_tile = np.tile(mean_tile, (1, 1, h, w)) - var_tile = np.reshape(var, (n, c, 1, 1)) - var_tile = np.tile(var_tile, (1, 1, h, w)) - - d_scale = np.sum(d_y * (x - mean_tile) * var_tile, axis=(0, 2, 3)) - var_inv = var_tile - scale_tile = np.reshape(scale, (1, c, 1, 1)) - scale_tile = np.tile(scale_tile, (n, 1, h, w)) - - d_x = ( - scale_tile - * var_inv - * ( - d_y - - np.mean(d_y, axis=(2, 3), keepdims=True) - - (x - mean_tile) - * var_inv - * np.mean( - d_y * (x - mean_tile) * var_inv, axis=(2, 3), keepdims=True - ) - ) - ) - - return d_x, d_scale, d_bias + with paddle.fluid.framework._static_guard(): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda() and core.op_support_gpu( + "instance_norm" + ): + places.append(fluid.CUDAPlace(0)) + for p in places: + exe = fluid.Executor(p) + shape = [4, 10, 16, 16] + + def compute_v1(x_np): + with program_guard(Program(), Program()): + ins = paddle.nn.InstanceNorm2D(shape[1]) + x = paddle.static.data( + name='x', shape=x_np.shape, dtype=x_np.dtype + ) + y = ins(x) + exe.run(fluid.default_startup_program()) + r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] + return r + + def compute_v2(x_np): + with program_guard(Program(), Program()): + ins = paddle.nn.InstanceNorm2D(shape[1]) + x = paddle.static.data( + name='x', shape=x_np.shape, dtype=x_np.dtype + ) + y = ins(x) + exe.run(fluid.default_startup_program()) + r = exe.run(feed={'x': x_np}, fetch_list=[y])[0] + return r + + x = np.random.randn(*shape).astype("float32") + y1 = compute_v1(x) + y2 = compute_v2(x) + np.testing.assert_allclose(y1, y2, rtol=1e-05) class TestInstanceNormFP32OP(OpTest): @@ -182,7 +183,6 @@ class TestInstanceNormFP32OP(OpTest): '''Test instance_norm op with default value''' self.op_type = "instance_norm" self.__class__.op_type = self.op_type - self.python_api = instance_norm_warpper self.data_format = "NCHW" self.eps = 1e-5 self.init_dtype() @@ -204,15 +204,18 @@ class TestInstanceNormFP32OP(OpTest): 'SavedMean': mean, 'SavedVariance': 1.0 / variance, } + self.prim_op_type = "comp" + self.python_api = instance_norm_wrapper + self.public_python_api = instance_norm_wrapper + self.check_prim = ( + False if os.getenv("FLAGS_enable_new_ir_in_executor") else True + ) def test_check_output(self): - self.check_output(atol=self.atol) + self.check_output(atol=self.atol, check_prim=self.check_prim) def test_check_grad(self): - self.check_grad( - ['X', 'Scale', 'Bias'], - 'Y', - ) + self.check_grad(['X', 'Scale', 'Bias'], 'Y', check_prim=self.check_prim) def init_dtype(self): self.dtype = np.float32 @@ -228,6 +231,12 @@ class TestInstanceNormFP32OP(OpTest): def set_err_thre(self): self.atol = 1e-3 + self.fw_comp_rtol = 1e-6 + self.fw_comp_atol = 1e-6 + self.rev_comp_rtol = 1e-4 + self.rev_comp_atol = 1e-4 + self.cinn_rtol = 1e-4 + self.cinn_atol = 1e-4 @unittest.skipIf( @@ -236,6 +245,9 @@ class TestInstanceNormFP32OP(OpTest): "core is not compiled with CUDA or not support the float16", ) class TestInstanceNormFP16OP(TestInstanceNormFP32OP): + def setUp(self): + super().setUp() + def init_dtype(self): self.dtype = np.float16 @@ -245,7 +257,9 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=self.atol) + self.check_output_with_place( + place, atol=self.atol, check_prim=self.check_prim + ) def test_check_grad(self): place = core.CUDAPlace(0) @@ -254,6 +268,7 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP): ['X', 'Scale', 'Bias'], 'Y', max_relative_error=self.max_relative_error, + check_prim=self.check_prim, ) @@ -265,8 +280,10 @@ class TestInstanceNormFP16OP(TestInstanceNormFP32OP): class TestInstanceNormBF16OP(OpTest): def setUp(self): self.op_type = "instance_norm" + self.prim_op_type = "comp" self.__class__.op_type = self.op_type - self.python_api = instance_norm_warpper + self.python_api = instance_norm_wrapper + self.public_python_api = instance_norm_wrapper self.eps = 1e-5 self.data_format = "NCHW" self.dtype = np.uint16 @@ -296,6 +313,9 @@ class TestInstanceNormBF16OP(OpTest): 'momentum': 0.9, 'data_format': self.data_format, } + self.check_prim = ( + False if os.getenv("FLAGS_enable_new_ir_in_executor") else True + ) def init_value(self): np.random.seed(0) @@ -308,7 +328,7 @@ class TestInstanceNormBF16OP(OpTest): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place) + self.check_output_with_place(place, check_prim=self.check_prim) def test_check_grad(self): place = core.CUDAPlace(0) @@ -317,19 +337,22 @@ class TestInstanceNormBF16OP(OpTest): ['X', 'Scale', 'Bias'], 'Y', user_defined_grads=self.user_defined_grads, + check_prim=self.check_prim, ) class PrimNet(paddle.nn.Layer): def __init__(self): super().__init__() - self.conv = nn.Conv2D(2, 4, (3, 3), bias_attr=False) - self.instance_norm = nn.InstanceNorm2D(4) + self.conv = paddle.nn.Conv2D(2, 4, (3, 3), bias_attr=False) + self.instance_norm = paddle.nn.InstanceNorm2D(4) def forward(self, x): y = self.conv(x) out = self.instance_norm(y) - res = F.max_pool2d(out, kernel_size=2, stride=2, padding=0) + res = paddle.nn.functional.max_pool2d( + out, kernel_size=2, stride=2, padding=0 + ) return res @@ -368,7 +391,9 @@ class TestPrimForwardAndBackward(unittest.TestCase): return loss def test_amp_nchw(self): - if not isinstance(framework._current_expected_place(), core.CPUPlace): + if not isinstance( + paddle.fluid.framework._current_expected_place(), core.CPUPlace + ): expected = self.train(False) actual = self.train(True) np.testing.assert_allclose( -- GitLab