From 2ebc8f776a44d6d8aec053c4e90f9b20d28b38f5 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 29 Dec 2021 16:03:20 +0800 Subject: [PATCH] [AMP] Add BatchNorm_1D_2D_3D skip for paddle.amp.decorate (#38541) * add bn_1d_2d_3d for fp16 decorate * add unittest --- python/paddle/fluid/dygraph/amp/auto_cast.py | 6 +++-- .../test_imperative_auto_mixed_precision.py | 26 +++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 8f0e0dff2f..15adf4cb6f 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -130,8 +130,10 @@ def pure_fp16_initialize(models): for idx in range(len(models)): for layer in models[idx].sublayers(include_self=True): layer._casted_by_pure_fp16 = True - if (layer._dtype is 'float16') or isinstance(layer, ( - paddle.nn.BatchNorm, paddle.nn.LayerNorm)): + if (layer._dtype is 'float16') or isinstance( + layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D, + paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D, + paddle.nn.LayerNorm)): continue layer._to_impl(dtype='float16', include_sublayers=False) return models diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 2938eabd07..a8ed23f593 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -598,6 +598,32 @@ class TestAmpDecorator(unittest.TestCase): self.assertEqual(optimizers[0]._multi_precision, False) self.assertEqual(optimizers[1]._multi_precision, False) + def test_skip_BatchNorm_Layer_norm(self): + model = paddle.nn.LayerNorm(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm1D(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm2D(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm3D(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + class TestPureFp16SaveLoad(unittest.TestCase): def test_save_dtype_exception(self): -- GitLab