diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 4277797731d0c9eb0af3fae67ef3f609e146a268..fdb6fecb77d6762f3b021324c7c30708a2f206e9 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -15,6 +15,7 @@ """Model and parameters serialization.""" import os import stat +import math from threading import Thread, Lock import numpy as np @@ -42,6 +43,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} _ckpt_mutex = Lock() +SLICE_SIZE = 512 * 1024 * 1024 + def _special_process_par(par, new_par): """ @@ -105,26 +108,38 @@ def _update_param(param, new_param): def _exec_save(ckpt_file_name, data_list): """Execute save checkpoint into file process.""" - checkpoint_list = Checkpoint() try: with _ckpt_mutex: - for name, value in data_list.items(): - param_value = checkpoint_list.value.add() - param_value.tag = name - param_tensor = param_value.tensor - param_tensor.dims.extend(value[0]) - param_tensor.tensor_type = value[1] - param_tensor.tensor_content = value[2].tostring() - - with open(ckpt_file_name, "wb") as f: - f.write(checkpoint_list.SerializeToString()) - os.chmod(ckpt_file_name, stat.S_IRUSR) + if os.path.exists(ckpt_file_name): + os.remove(ckpt_file_name) + with open(ckpt_file_name, "ab") as f: + for name, value in data_list.items(): + data_size = value[2].nbytes + if data_size > SLICE_SIZE: + slice_count = math.ceil(data_size / SLICE_SIZE) + param_slice_list = np.array_split(value[2], slice_count) + else: + param_slice_list = [value[2]] + + for param_slice in param_slice_list: + checkpoint_list = Checkpoint() + param_value = checkpoint_list.value.add() + param_value.tag = name + param_tensor = param_value.tensor + param_tensor.dims.extend(value[0]) + param_tensor.tensor_type = value[1] + param_tensor.tensor_content = param_slice.tostring() + + f.write(checkpoint_list.SerializeToString()) + + os.chmod(ckpt_file_name, stat.S_IRUSR) except BaseException as e: logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) raise RuntimeError(e.__str__()) + def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): """ Saves checkpoint info to a specified file. @@ -206,28 +221,37 @@ def load_checkpoint(ckpt_file_name, net=None): parameter_dict = {} try: + element_id = 0 + param_data_list = [] for element in checkpoint_list.value: data = element.tensor.tensor_content data_type = element.tensor.tensor_type np_type = tensor_to_np_type[data_type] ms_type = tensor_to_ms_type[data_type] - param_data = np.fromstring(data, np_type) - dims = element.tensor.dims - - if dims == [0]: - if 'Float' in data_type: - param_data = float(param_data[0]) - elif 'Int' in data_type: - param_data = int(param_data[0]) - parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) - elif dims == [1]: - parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) - else: - param_dim = [] - for dim in dims: - param_dim.append(dim) - param_value = param_data.reshape(param_dim) - parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) + element_data = np.frombuffer(data, np_type) + param_data_list.append(element_data) + if (element_id == len(checkpoint_list.value) - 1) or \ + (element.tag != checkpoint_list.value[element_id + 1].tag): + param_data = np.concatenate((param_data_list), axis=0) + param_data_list.clear() + dims = element.tensor.dims + + if dims == [0]: + if 'Float' in data_type: + param_data = float(param_data[0]) + elif 'Int' in data_type: + param_data = int(param_data[0]) + parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) + elif dims == [1]: + parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) + else: + param_dim = [] + for dim in dims: + param_dim.append(dim) + param_value = param_data.reshape(param_dim) + parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) + + element_id += 1 logger.info("Load checkpoint process finish.")