diff --git a/mindspore/train/callback.py b/mindspore/train/callback.py index e691cfd83738d0c2db5dfd2ac1cfdb97ad45ac9d..572cb746c44c813a4f297ab3d975f9fab6e332b6 100644 --- a/mindspore/train/callback.py +++ b/mindspore/train/callback.py @@ -16,7 +16,6 @@ import os import stat -import shutil import time import numpy as np @@ -625,8 +624,6 @@ class ModelCheckpoint(Callback): global _save_dir _save_dir = self._directory cur_file = os.path.join(self._directory, cur_ckpoint_file) - tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt' - gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process) self._last_time_for_keep = time.time() self._last_triggered_step = cb_params.cur_step_num @@ -634,10 +631,8 @@ class ModelCheckpoint(Callback): _set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() - _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) + _exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save) - if os.path.exists(gen_file): - shutil.move(gen_file, cur_file) self._latest_ckpt_file_name = cur_file @property diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 28d65349ee78c5b7db5d0e64928dc507907fae01..eb119a2907378eb9e9dcc3a448050da7d5288dd8 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -84,12 +84,12 @@ class DatasetHelper: class _DatasetIter: """Base iter for dataset help""" def __init__(self, dataset): - self.loop_size = 1 + if not hasattr(dataset, '__loop_size__'): + self.loop_size = dataset.get_dataset_size() + else: + self.loop_size = dataset.__loop_size__ + if not hasattr(dataset, '__ME_INITED__'): - if not hasattr(dataset, '__loop_size__'): - self.loop_size = dataset.get_dataset_size() - else: - self.loop_size = dataset.__loop_size__ dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size) dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 4e6e67e32bc2bdcf5e955c94161d4e166c06c8aa..30473ba523ed5ea3409be2143f80fc06ede0cc2a 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -15,6 +15,7 @@ """Model and parameters serialization.""" import os import stat +from threading import Thread import numpy as np import mindspore.nn as nn @@ -96,7 +97,23 @@ def _update_param(param, new_param): param.set_parameter_data(type(param.data)(new_param.data)) -def save_checkpoint(parameter_list, ckpoint_file_name): +def asyn_thread(fun): + def wrapper(*args, **kwargs): + thr = Thread(target=fun, args=args, kwargs=kwargs) + thr.start() + return wrapper + + +@asyn_thread +def asyn_save_fun(ckpoint_file_name, checkpoint_list): + logger.info("Asynchronous execute save checkpoint into file.") + with open(ckpoint_file_name, "wb") as f: + f.write(checkpoint_list.SerializeToString()) + os.chmod(ckpoint_file_name, stat.S_IRUSR) + logger.info("Asynchronous save checkpoint into file process finish.") + + +def save_checkpoint(parameter_list, ckpoint_file_name, asyn_exec=False): """ Saves checkpoint info to a specified file. @@ -104,6 +121,7 @@ def save_checkpoint(parameter_list, ckpoint_file_name): parameter_list (list): Parameters list, each element is a dict like {"name":xx, "type":xx, "shape":xx, "data":xx}. ckpoint_file_name (str): Checkpoint file name. + asyn_exec (bool): Whether asynchronous execute save checkpoint into file. Raises: RuntimeError: Failed to save the Checkpoint file. @@ -127,10 +145,12 @@ def save_checkpoint(parameter_list, ckpoint_file_name): else: for dim in param['data'].shape(): param_tensor.dims.append(dim) - - with open(ckpoint_file_name, "wb") as f: - f.write(checkpoint_list.SerializeToString()) - os.chmod(ckpoint_file_name, stat.S_IRUSR) + if asyn_exec: + asyn_save_fun(ckpoint_file_name, checkpoint_list) + else: + with open(ckpoint_file_name, "wb") as f: + f.write(checkpoint_list.SerializeToString()) + os.chmod(ckpoint_file_name, stat.S_IRUSR) except BaseException as e: logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name) @@ -298,7 +318,7 @@ def _save_graph(network, file_name): os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) -def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True): +def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True, asyn_save=False): """ Saves checkpoint for 'ms' backend. @@ -329,7 +349,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True each_param["data"] = param_data param_list.append(each_param) - save_checkpoint(param_list, ckpoint_file_name) + save_checkpoint(param_list, ckpoint_file_name, asyn_save) def _get_merged_param_data(net, param_name, param_data):