提交 768ff072 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4010 modify checkpoint config param check

Merge pull request !4010 from changzherui/mod_ckpt_param
...@@ -104,10 +104,6 @@ class CheckpointConfig: ...@@ -104,10 +104,6 @@ class CheckpointConfig:
integrated_save=True, integrated_save=True,
async_save=False): async_save=False):
if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
raise ValueError("The input_param can't be all None or 0")
if save_checkpoint_steps is not None: if save_checkpoint_steps is not None:
save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps) save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps)
if save_checkpoint_seconds is not None: if save_checkpoint_seconds is not None:
...@@ -117,6 +113,10 @@ class CheckpointConfig: ...@@ -117,6 +113,10 @@ class CheckpointConfig:
if keep_checkpoint_per_n_minutes is not None: if keep_checkpoint_per_n_minutes is not None:
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
raise ValueError("The input_param can't be all None or 0")
self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_steps = save_checkpoint_steps
self._save_checkpoint_seconds = save_checkpoint_seconds self._save_checkpoint_seconds = save_checkpoint_seconds
if self._save_checkpoint_steps and self._save_checkpoint_steps > 0: if self._save_checkpoint_steps and self._save_checkpoint_steps > 0:
...@@ -173,7 +173,6 @@ class CheckpointConfig: ...@@ -173,7 +173,6 @@ class CheckpointConfig:
return checkpoint_policy return checkpoint_policy
class ModelCheckpoint(Callback): class ModelCheckpoint(Callback):
""" """
The checkpoint callback class. The checkpoint callback class.
...@@ -203,7 +202,7 @@ class ModelCheckpoint(Callback): ...@@ -203,7 +202,7 @@ class ModelCheckpoint(Callback):
raise ValueError("Prefix {} for checkpoint file name invalid, " raise ValueError("Prefix {} for checkpoint file name invalid, "
"please check and correct it and then continue.".format(prefix)) "please check and correct it and then continue.".format(prefix))
if directory: if directory is not None:
self._directory = _make_directory(directory) self._directory = _make_directory(directory)
else: else:
self._directory = _cur_dir self._directory = _cur_dir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册