提交 6ffb04b7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!605 fix default parameter set in model

Merge pull request !605 from liubuyu/master
...@@ -99,12 +99,8 @@ class Model: ...@@ -99,12 +99,8 @@ class Model:
self._loss_scale_manager_set = False self._loss_scale_manager_set = False
self._keep_bn_fp32 = True self._keep_bn_fp32 = True
self._check_kwargs(kwargs) self._check_kwargs(kwargs)
if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
if 'loss_scale_manager' in kwargs:
self._loss_scale_manager = kwargs['loss_scale_manager']
self._loss_scale_manager_set = True
self._amp_level = amp_level self._amp_level = amp_level
self._process_amp_args(kwargs)
self._parallel_mode = _get_parallel_mode() self._parallel_mode = _get_parallel_mode()
self._device_number = _get_device_num() self._device_number = _get_device_num()
self._global_rank = _get_global_rank() self._global_rank = _get_global_rank()
...@@ -114,6 +110,15 @@ class Model: ...@@ -114,6 +110,15 @@ class Model:
self._build_eval_network(metrics, eval_network, eval_indexes) self._build_eval_network(metrics, eval_network, eval_indexes)
self._build_predict_network() self._build_predict_network()
def _process_amp_args(self, kwargs):
if self._amp_level == "O0":
self._keep_bn_fp32 = False
if 'keep_batchnorm_fp32' in kwargs:
self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32']
if 'loss_scale_manager' in kwargs:
self._loss_scale_manager = kwargs['loss_scale_manager']
self._loss_scale_manager_set = True
def _check_kwargs(self, kwargs): def _check_kwargs(self, kwargs):
for arg in kwargs: for arg in kwargs:
if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册