提交 e4298d7d 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2326 fix the problem of BatchNorm config failure at Amp O3 level and some unexpected indent

Merge pull request !2326 from liangzelang/master
......@@ -127,7 +127,8 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
- O2: Cast network to float16, keep batchnorm and `loss_fn` (if set) run in float32,
using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
O2 is recommended on GPU, O3 is recommemded on Ascend.
O2 is recommended on GPU, O3 is recommended on Ascend.
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
......
......@@ -61,6 +61,7 @@ class Model:
- O0: Do not change.
- O2: Cast network to float16, keep batchnorm run in float32, using dynamic loss scale.
- O3: Cast network to float16, with additional property 'keep_batchnorm_fp32=False'.
O2 is recommended on GPU, O3 is recommended on Ascend.
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
......@@ -115,7 +116,7 @@ class Model:
self._build_predict_network()
def _process_amp_args(self, kwargs):
if self._amp_level == "O0":
if self._amp_level in ["O0", "O3"]:
self._keep_bn_fp32 = False
if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册