diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 8f0e0dff2fcfb3e4431e8e1bd8216ebd6920635e..15adf4cb6faaf61c0601d7a313dd203690289463 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -130,8 +130,10 @@ def pure_fp16_initialize(models): for idx in range(len(models)): for layer in models[idx].sublayers(include_self=True): layer._casted_by_pure_fp16 = True - if (layer._dtype is 'float16') or isinstance(layer, ( - paddle.nn.BatchNorm, paddle.nn.LayerNorm)): + if (layer._dtype is 'float16') or isinstance( + layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D, + paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D, + paddle.nn.LayerNorm)): continue layer._to_impl(dtype='float16', include_sublayers=False) return models diff --git a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py index 2938eabd07b9c840a54cfdeff3feeec6212b931f..a8ed23f5938c069171734b307ac87ea0367bfe62 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_auto_mixed_precision.py @@ -598,6 +598,32 @@ class TestAmpDecorator(unittest.TestCase): self.assertEqual(optimizers[0]._multi_precision, False) self.assertEqual(optimizers[1]._multi_precision, False) + def test_skip_BatchNorm_Layer_norm(self): + model = paddle.nn.LayerNorm(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm1D(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm2D(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + + model = paddle.nn.BatchNorm3D(1) + model = paddle.amp.decorate(models=model, level='O2') + for param in model.parameters(): + self.assertEqual((param.dtype == paddle.float32), True) + class TestPureFp16SaveLoad(unittest.TestCase): def test_save_dtype_exception(self):