未验证 提交 2ebc8f77 编写于 作者: Z zhangbo9674 提交者: GitHub

[AMP] Add BatchNorm_1D_2D_3D skip for paddle.amp.decorate (#38541)

* add bn_1d_2d_3d for fp16 decorate

* add unittest
上级 e3faf345
...@@ -130,8 +130,10 @@ def pure_fp16_initialize(models): ...@@ -130,8 +130,10 @@ def pure_fp16_initialize(models):
for idx in range(len(models)): for idx in range(len(models)):
for layer in models[idx].sublayers(include_self=True): for layer in models[idx].sublayers(include_self=True):
layer._casted_by_pure_fp16 = True layer._casted_by_pure_fp16 = True
if (layer._dtype is 'float16') or isinstance(layer, ( if (layer._dtype is 'float16') or isinstance(
paddle.nn.BatchNorm, paddle.nn.LayerNorm)): layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D,
paddle.nn.LayerNorm)):
continue continue
layer._to_impl(dtype='float16', include_sublayers=False) layer._to_impl(dtype='float16', include_sublayers=False)
return models return models
......
...@@ -598,6 +598,32 @@ class TestAmpDecorator(unittest.TestCase): ...@@ -598,6 +598,32 @@ class TestAmpDecorator(unittest.TestCase):
self.assertEqual(optimizers[0]._multi_precision, False) self.assertEqual(optimizers[0]._multi_precision, False)
self.assertEqual(optimizers[1]._multi_precision, False) self.assertEqual(optimizers[1]._multi_precision, False)
def test_skip_BatchNorm_Layer_norm(self):
model = paddle.nn.LayerNorm(1)
model = paddle.amp.decorate(models=model, level='O2')
for param in model.parameters():
self.assertEqual((param.dtype == paddle.float32), True)
model = paddle.nn.BatchNorm(1)
model = paddle.amp.decorate(models=model, level='O2')
for param in model.parameters():
self.assertEqual((param.dtype == paddle.float32), True)
model = paddle.nn.BatchNorm1D(1)
model = paddle.amp.decorate(models=model, level='O2')
for param in model.parameters():
self.assertEqual((param.dtype == paddle.float32), True)
model = paddle.nn.BatchNorm2D(1)
model = paddle.amp.decorate(models=model, level='O2')
for param in model.parameters():
self.assertEqual((param.dtype == paddle.float32), True)
model = paddle.nn.BatchNorm3D(1)
model = paddle.amp.decorate(models=model, level='O2')
for param in model.parameters():
self.assertEqual((param.dtype == paddle.float32), True)
class TestPureFp16SaveLoad(unittest.TestCase): class TestPureFp16SaveLoad(unittest.TestCase):
def test_save_dtype_exception(self): def test_save_dtype_exception(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册