未验证 提交 9ab3c76b 编写于 作者: Z zhangbo9674 提交者: GitHub

fix sync_bn error in fp16 amp-o2 (#40943)

上级 9261dff4
...@@ -171,7 +171,7 @@ def pure_fp16_initialize(models): ...@@ -171,7 +171,7 @@ def pure_fp16_initialize(models):
if (layer._dtype == 'float16') or isinstance( if (layer._dtype == 'float16') or isinstance(
layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D, layer, (paddle.nn.BatchNorm, paddle.nn.BatchNorm1D,
paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D, paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D,
paddle.nn.LayerNorm)): paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm)):
continue continue
layer._to_impl(dtype='float16', include_sublayers=False) layer._to_impl(dtype='float16', include_sublayers=False)
return models return models
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册