提交 6ffb04b7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!605 fix default parameter set in model

Merge pull request !605 from liubuyu/master
......@@ -99,12 +99,8 @@ class Model:
self._loss_scale_manager_set = False
self._keep_bn_fp32 = True
self._check_kwargs(kwargs)
if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
if 'loss_scale_manager' in kwargs:
self._loss_scale_manager = kwargs['loss_scale_manager']
self._loss_scale_manager_set = True
self._amp_level = amp_level
self._process_amp_args(kwargs)
self._parallel_mode = _get_parallel_mode()
self._device_number = _get_device_num()
self._global_rank = _get_global_rank()
......@@ -114,6 +110,15 @@ class Model:
self._build_eval_network(metrics, eval_network, eval_indexes)
self._build_predict_network()
def _process_amp_args(self, kwargs):
if self._amp_level == "O0":
self._keep_bn_fp32 = False
if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
if 'loss_scale_manager' in kwargs:
self._loss_scale_manager = kwargs['loss_scale_manager']
self._loss_scale_manager_set = True
def _check_kwargs(self, kwargs):
for arg in kwargs:
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册