提交 4b165ccd 编写于 作者: W Wei Luning

fix doc and eval network build in amp

上级 3cac1bb9
...@@ -133,6 +133,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): ...@@ -133,6 +133,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`. cast_model_type (:class:`mindspore.dtype`): Supports `mstype.float16` or `mstype.float32`.
If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting. If set to `mstype.float16`, use `float16` mode to train. If set, overwrite the level setting.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting.
Only `cast_model_type` is `float16`, `keep_batchnorm_fp32` will take effect.
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If set, overwrite the level setting. scale the loss by LossScaleManager. If set, overwrite the level setting.
""" """
......
...@@ -174,7 +174,7 @@ class Model: ...@@ -174,7 +174,7 @@ class Model:
else: else:
if self._loss_fn is None: if self._loss_fn is None:
raise ValueError("loss_fn can not be None.") raise ValueError("loss_fn can not be None.")
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level in ["O0", "O3"])
self._eval_indexes = [0, 1, 2] self._eval_indexes = [0, 1, 2]
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册