提交 d45abc5f 编写于 作者: D d00455729

Asynchronous save checkpoint

上级 c99cc0df
......@@ -15,7 +15,6 @@
"""Checkpoint related classes and functions."""
import os
import shutil
import stat
import time
......@@ -86,6 +85,7 @@ class CheckpointConfig:
Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
Raises:
ValueError: If the input_param is None or 0.
......@@ -100,7 +100,8 @@ class CheckpointConfig:
save_checkpoint_seconds=0,
keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0,
integrated_save=True):
integrated_save=True,
async_save=False):
if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
......@@ -129,6 +130,7 @@ class CheckpointConfig:
self._keep_checkpoint_max = 1
self._integrated_save = check_bool(integrated_save)
self._async_save = check_bool(async_save)
@property
def save_checkpoint_steps(self):
......@@ -155,6 +157,11 @@ class CheckpointConfig:
"""Get the value of _integrated_save."""
return self._integrated_save
@property
def async_save(self):
"""Get the value of _async_save."""
return self._async_save
def get_checkpoint_policy(self):
"""Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
......@@ -282,8 +289,6 @@ class ModelCheckpoint(Callback):
global _save_dir
_save_dir = self._directory
cur_file = os.path.join(self._directory, cur_ckpoint_file)
tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt'
gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process)
self._last_time_for_keep = time.time()
self._last_triggered_step = cb_params.cur_step_num
......@@ -291,10 +296,9 @@ class ModelCheckpoint(Callback):
set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
_exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
self._config.async_save)
if os.path.exists(gen_file):
shutil.move(gen_file, cur_file)
self._latest_ckpt_file_name = cur_file
@property
......
......@@ -15,6 +15,7 @@
"""Model and parameters serialization."""
import os
import stat
from threading import Thread, Lock
import numpy as np
import mindspore.nn as nn
......@@ -40,6 +41,7 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
_ckpt_mutex = Lock()
def _special_process_par(par, new_par):
"""
......@@ -101,7 +103,29 @@ def _update_param(param, new_param):
param.set_parameter_data(type(param.data)(new_param.data))
def save_checkpoint(parameter_list, ckpt_file_name):
def _exec_save(ckpt_file_name, data_list):
"""Execute save checkpoint into file process."""
checkpoint_list = Checkpoint()
try:
with _ckpt_mutex:
for name, value in data_list.items():
param_value = checkpoint_list.value.add()
param_value.tag = name
param_tensor = param_value.tensor
param_tensor.dims.extend(value[0])
param_tensor.tensor_type = value[1]
param_tensor.tensor_content = value[2].tostring()
with open(ckpt_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpt_file_name, stat.S_IRUSR)
except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
raise RuntimeError(e.__str__())
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
"""
Saves checkpoint info to a specified file.
......@@ -109,37 +133,37 @@ def save_checkpoint(parameter_list, ckpt_file_name):
parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpt_file_name (str): Checkpoint file name.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
Raises:
RuntimeError: Failed to save the Checkpoint file.
"""
logger.info("Execute save checkpoint process.")
checkpoint_list = Checkpoint()
try:
data_list = {}
with _ckpt_mutex:
for param in parameter_list:
param_value = checkpoint_list.value.add()
param_value.tag = param["name"]
param_tensor = param_value.tensor
key = param["name"]
data_list[key] = []
if isinstance(param["data"], Parameter):
param["data"].init_data()
param_data = param["data"].asnumpy().reshape(-1)
param_tensor.tensor_content = param_data.tostring()
param_tensor.tensor_type = str(param["data"].dtype)
dims = []
if param['data'].shape == ():
param_tensor.dims.append(0)
dims.append(0)
else:
for dim in param['data'].shape:
param_tensor.dims.append(dim)
with open(ckpt_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpt_file_name, stat.S_IRUSR)
except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
raise RuntimeError(e.__str__())
dims.append(dim)
data_list[key].append(dims)
tensor_type = str(param["data"].dtype)
data_list[key].append(tensor_type)
data = param["data"].asnumpy().reshape(-1)
data_list[key].append(data)
if async_save:
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list))
thr.start()
else:
_exec_save(ckpt_file_name, data_list)
logger.info("Save checkpoint process finish.")
......@@ -305,7 +329,7 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IRUSR)
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False):
"""
Saves checkpoint for 'ms' backend.
......@@ -313,6 +337,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
train_network (Network): The train network for training.
ckpt_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
"""
param_dict = {}
......@@ -336,7 +361,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
each_param["data"] = param_data
param_list.append(each_param)
save_checkpoint(param_list, ckpt_file_name)
save_checkpoint(param_list, ckpt_file_name, async_save)
def _get_merged_param_data(net, param_name, param_data):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册