提交 097b77c3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3273 Optimized checkpoint save slice tensor

Merge pull request !3273 from changzherui/save_slice_tensor
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Model and parameters serialization.""" """Model and parameters serialization."""
import os import os
import stat import stat
import math
from threading import Thread, Lock from threading import Thread, Lock
import numpy as np import numpy as np
...@@ -42,6 +43,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin ...@@ -42,6 +43,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
_ckpt_mutex = Lock() _ckpt_mutex = Lock()
SLICE_SIZE = 512 * 1024 * 1024
def _special_process_par(par, new_par): def _special_process_par(par, new_par):
""" """
...@@ -105,26 +108,38 @@ def _update_param(param, new_param): ...@@ -105,26 +108,38 @@ def _update_param(param, new_param):
def _exec_save(ckpt_file_name, data_list): def _exec_save(ckpt_file_name, data_list):
"""Execute save checkpoint into file process.""" """Execute save checkpoint into file process."""
checkpoint_list = Checkpoint()
try: try:
with _ckpt_mutex: with _ckpt_mutex:
for name, value in data_list.items(): if os.path.exists(ckpt_file_name):
param_value = checkpoint_list.value.add() os.remove(ckpt_file_name)
param_value.tag = name with open(ckpt_file_name, "ab") as f:
param_tensor = param_value.tensor for name, value in data_list.items():
param_tensor.dims.extend(value[0]) data_size = value[2].nbytes
param_tensor.tensor_type = value[1] if data_size > SLICE_SIZE:
param_tensor.tensor_content = value[2].tostring() slice_count = math.ceil(data_size / SLICE_SIZE)
param_slice_list = np.array_split(value[2], slice_count)
with open(ckpt_file_name, "wb") as f: else:
f.write(checkpoint_list.SerializeToString()) param_slice_list = [value[2]]
os.chmod(ckpt_file_name, stat.S_IRUSR)
for param_slice in param_slice_list:
checkpoint_list = Checkpoint()
param_value = checkpoint_list.value.add()
param_value.tag = name
param_tensor = param_value.tensor
param_tensor.dims.extend(value[0])
param_tensor.tensor_type = value[1]
param_tensor.tensor_content = param_slice.tostring()
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpt_file_name, stat.S_IRUSR)
except BaseException as e: except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
raise RuntimeError(e.__str__()) raise RuntimeError(e.__str__())
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
""" """
Saves checkpoint info to a specified file. Saves checkpoint info to a specified file.
...@@ -206,28 +221,37 @@ def load_checkpoint(ckpt_file_name, net=None): ...@@ -206,28 +221,37 @@ def load_checkpoint(ckpt_file_name, net=None):
parameter_dict = {} parameter_dict = {}
try: try:
element_id = 0
param_data_list = []
for element in checkpoint_list.value: for element in checkpoint_list.value:
data = element.tensor.tensor_content data = element.tensor.tensor_content
data_type = element.tensor.tensor_type data_type = element.tensor.tensor_type
np_type = tensor_to_np_type[data_type] np_type = tensor_to_np_type[data_type]
ms_type = tensor_to_ms_type[data_type] ms_type = tensor_to_ms_type[data_type]
param_data = np.fromstring(data, np_type) element_data = np.frombuffer(data, np_type)
dims = element.tensor.dims param_data_list.append(element_data)
if (element_id == len(checkpoint_list.value) - 1) or \
if dims == [0]: (element.tag != checkpoint_list.value[element_id + 1].tag):
if 'Float' in data_type: param_data = np.concatenate((param_data_list), axis=0)
param_data = float(param_data[0]) param_data_list.clear()
elif 'Int' in data_type: dims = element.tensor.dims
param_data = int(param_data[0])
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) if dims == [0]:
elif dims == [1]: if 'Float' in data_type:
parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) param_data = float(param_data[0])
else: elif 'Int' in data_type:
param_dim = [] param_data = int(param_data[0])
for dim in dims: parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
param_dim.append(dim) elif dims == [1]:
param_value = param_data.reshape(param_dim) parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag)
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag) else:
param_dim = []
for dim in dims:
param_dim.append(dim)
param_value = param_data.reshape(param_dim)
parameter_dict[element.tag] = Parameter(Tensor(param_value, ms_type), name=element.tag)
element_id += 1
logger.info("Load checkpoint process finish.") logger.info("Load checkpoint process finish.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册