From d45abc5f54737b9294a771dc10d4152676c512aa Mon Sep 17 00:00:00 2001 From: d00455729 Date: Tue, 14 Jul 2020 20:47:38 +0800 Subject: [PATCH] Asynchronous save checkpoint --- mindspore/train/callback/_checkpoint.py | 18 ++++--- mindspore/train/serialization.py | 69 +++++++++++++++++-------- 2 files changed, 58 insertions(+), 29 deletions(-) diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index e0048ad71..a9389fd39 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -15,7 +15,6 @@ """Checkpoint related classes and functions.""" import os -import shutil import stat import time @@ -86,6 +85,7 @@ class CheckpointConfig: Can't be used with keep_checkpoint_max at the same time. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. + async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False Raises: ValueError: If the input_param is None or 0. @@ -100,7 +100,8 @@ class CheckpointConfig: save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, - integrated_save=True): + integrated_save=True, + async_save=False): if not save_checkpoint_steps and not save_checkpoint_seconds and \ not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: @@ -129,6 +130,7 @@ class CheckpointConfig: self._keep_checkpoint_max = 1 self._integrated_save = check_bool(integrated_save) + self._async_save = check_bool(async_save) @property def save_checkpoint_steps(self): @@ -155,6 +157,11 @@ class CheckpointConfig: """Get the value of _integrated_save.""" return self._integrated_save + @property + def async_save(self): + """Get the value of _async_save.""" + return self._async_save + def get_checkpoint_policy(self): """Get the policy of checkpoint.""" checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, @@ -282,8 +289,6 @@ class ModelCheckpoint(Callback): global _save_dir _save_dir = self._directory cur_file = os.path.join(self._directory, cur_ckpoint_file) - tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt' - gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process) self._last_time_for_keep = time.time() self._last_triggered_step = cb_params.cur_step_num @@ -291,10 +296,9 @@ class ModelCheckpoint(Callback): set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() - _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) + _exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save, + self._config.async_save) - if os.path.exists(gen_file): - shutil.move(gen_file, cur_file) self._latest_ckpt_file_name = cur_file @property diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 381269841..6a4fc36db 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -15,6 +15,7 @@ """Model and parameters serialization.""" import os import stat +from threading import Thread, Lock import numpy as np import mindspore.nn as nn @@ -40,6 +41,7 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} +_ckpt_mutex = Lock() def _special_process_par(par, new_par): """ @@ -101,7 +103,29 @@ def _update_param(param, new_param): param.set_parameter_data(type(param.data)(new_param.data)) -def save_checkpoint(parameter_list, ckpt_file_name): +def _exec_save(ckpt_file_name, data_list): + """Execute save checkpoint into file process.""" + checkpoint_list = Checkpoint() + + try: + with _ckpt_mutex: + for name, value in data_list.items(): + param_value = checkpoint_list.value.add() + param_value.tag = name + param_tensor = param_value.tensor + param_tensor.dims.extend(value[0]) + param_tensor.tensor_type = value[1] + param_tensor.tensor_content = value[2].tostring() + + with open(ckpt_file_name, "wb") as f: + f.write(checkpoint_list.SerializeToString()) + os.chmod(ckpt_file_name, stat.S_IRUSR) + + except BaseException as e: + logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) + raise RuntimeError(e.__str__()) + +def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): """ Saves checkpoint info to a specified file. @@ -109,37 +133,37 @@ def save_checkpoint(parameter_list, ckpt_file_name): parameter_list (list): Parameters list, each element is a dict like {"name":xx, "type":xx, "shape":xx, "data":xx}. ckpt_file_name (str): Checkpoint file name. + async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False Raises: RuntimeError: Failed to save the Checkpoint file. """ logger.info("Execute save checkpoint process.") - checkpoint_list = Checkpoint() - try: + data_list = {} + with _ckpt_mutex: for param in parameter_list: - param_value = checkpoint_list.value.add() - param_value.tag = param["name"] - param_tensor = param_value.tensor + key = param["name"] + data_list[key] = [] if isinstance(param["data"], Parameter): param["data"].init_data() - param_data = param["data"].asnumpy().reshape(-1) - param_tensor.tensor_content = param_data.tostring() - param_tensor.tensor_type = str(param["data"].dtype) - + dims = [] if param['data'].shape == (): - param_tensor.dims.append(0) + dims.append(0) else: for dim in param['data'].shape: - param_tensor.dims.append(dim) - - with open(ckpt_file_name, "wb") as f: - f.write(checkpoint_list.SerializeToString()) - os.chmod(ckpt_file_name, stat.S_IRUSR) - - except BaseException as e: - logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) - raise RuntimeError(e.__str__()) + dims.append(dim) + data_list[key].append(dims) + tensor_type = str(param["data"].dtype) + data_list[key].append(tensor_type) + data = param["data"].asnumpy().reshape(-1) + data_list[key].append(data) + + if async_save: + thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list)) + thr.start() + else: + _exec_save(ckpt_file_name, data_list) logger.info("Save checkpoint process finish.") @@ -305,7 +329,7 @@ def _save_graph(network, file_name): os.chmod(file_name, stat.S_IRUSR) -def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): +def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False): """ Saves checkpoint for 'ms' backend. @@ -313,6 +337,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): train_network (Network): The train network for training. ckpt_file_name (str): The name of checkpoint file. integrated_save (bool): Whether to integrated save in automatic model parallel scene. + async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False. """ param_dict = {} @@ -336,7 +361,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): each_param["data"] = param_data param_list.append(each_param) - save_checkpoint(param_list, ckpt_file_name) + save_checkpoint(param_list, ckpt_file_name, async_save) def _get_merged_param_data(net, param_name, param_data): -- GitLab