diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index 3912bcd6205fccfa75dedeb7fdc69f389f86df7f..b9e235aed5f3fbcaa3395b756a3c4ee6dcf9cd8a 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -104,10 +104,6 @@ class CheckpointConfig: integrated_save=True, async_save=False): - if not save_checkpoint_steps and not save_checkpoint_seconds and \ - not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: - raise ValueError("The input_param can't be all None or 0") - if save_checkpoint_steps is not None: save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps) if save_checkpoint_seconds is not None: @@ -117,6 +113,10 @@ class CheckpointConfig: if keep_checkpoint_per_n_minutes is not None: keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) + if not save_checkpoint_steps and not save_checkpoint_seconds and \ + not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: + raise ValueError("The input_param can't be all None or 0") + self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_seconds = save_checkpoint_seconds if self._save_checkpoint_steps and self._save_checkpoint_steps > 0: @@ -173,7 +173,6 @@ class CheckpointConfig: return checkpoint_policy - class ModelCheckpoint(Callback): """ The checkpoint callback class. @@ -203,7 +202,7 @@ class ModelCheckpoint(Callback): raise ValueError("Prefix {} for checkpoint file name invalid, " "please check and correct it and then continue.".format(prefix)) - if directory: + if directory is not None: self._directory = _make_directory(directory) else: self._directory = _cur_dir