提交 d45abc5f 编写于 作者: D d00455729

Asynchronous save checkpoint

上级 c99cc0df
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
"""Checkpoint related classes and functions.""" """Checkpoint related classes and functions."""
import os import os
import shutil
import stat import stat
import time import time
...@@ -86,6 +85,7 @@ class CheckpointConfig: ...@@ -86,6 +85,7 @@ class CheckpointConfig:
Can't be used with keep_checkpoint_max at the same time. Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
Raises: Raises:
ValueError: If the input_param is None or 0. ValueError: If the input_param is None or 0.
...@@ -100,7 +100,8 @@ class CheckpointConfig: ...@@ -100,7 +100,8 @@ class CheckpointConfig:
save_checkpoint_seconds=0, save_checkpoint_seconds=0,
keep_checkpoint_max=5, keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0, keep_checkpoint_per_n_minutes=0,
integrated_save=True): integrated_save=True,
async_save=False):
if not save_checkpoint_steps and not save_checkpoint_seconds and \ if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
...@@ -129,6 +130,7 @@ class CheckpointConfig: ...@@ -129,6 +130,7 @@ class CheckpointConfig:
self._keep_checkpoint_max = 1 self._keep_checkpoint_max = 1
self._integrated_save = check_bool(integrated_save) self._integrated_save = check_bool(integrated_save)
self._async_save = check_bool(async_save)
@property @property
def save_checkpoint_steps(self): def save_checkpoint_steps(self):
...@@ -155,6 +157,11 @@ class CheckpointConfig: ...@@ -155,6 +157,11 @@ class CheckpointConfig:
"""Get the value of _integrated_save.""" """Get the value of _integrated_save."""
return self._integrated_save return self._integrated_save
@property
def async_save(self):
"""Get the value of _async_save."""
return self._async_save
def get_checkpoint_policy(self): def get_checkpoint_policy(self):
"""Get the policy of checkpoint.""" """Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
...@@ -282,8 +289,6 @@ class ModelCheckpoint(Callback): ...@@ -282,8 +289,6 @@ class ModelCheckpoint(Callback):
global _save_dir global _save_dir
_save_dir = self._directory _save_dir = self._directory
cur_file = os.path.join(self._directory, cur_ckpoint_file) cur_file = os.path.join(self._directory, cur_ckpoint_file)
tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt'
gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process)
self._last_time_for_keep = time.time() self._last_time_for_keep = time.time()
self._last_triggered_step = cb_params.cur_step_num self._last_triggered_step = cb_params.cur_step_num
...@@ -291,10 +296,9 @@ class ModelCheckpoint(Callback): ...@@ -291,10 +296,9 @@ class ModelCheckpoint(Callback):
set_cur_net(cb_params.train_network) set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph() cb_params.train_network.exec_checkpoint_graph()
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) _exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
self._config.async_save)
if os.path.exists(gen_file):
shutil.move(gen_file, cur_file)
self._latest_ckpt_file_name = cur_file self._latest_ckpt_file_name = cur_file
@property @property
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""Model and parameters serialization.""" """Model and parameters serialization."""
import os import os
import stat import stat
from threading import Thread, Lock
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
...@@ -40,6 +41,7 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin ...@@ -40,6 +41,7 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
_ckpt_mutex = Lock()
def _special_process_par(par, new_par): def _special_process_par(par, new_par):
""" """
...@@ -101,7 +103,29 @@ def _update_param(param, new_param): ...@@ -101,7 +103,29 @@ def _update_param(param, new_param):
param.set_parameter_data(type(param.data)(new_param.data)) param.set_parameter_data(type(param.data)(new_param.data))
def save_checkpoint(parameter_list, ckpt_file_name): def _exec_save(ckpt_file_name, data_list):
"""Execute save checkpoint into file process."""
checkpoint_list = Checkpoint()
try:
with _ckpt_mutex:
for name, value in data_list.items():
param_value = checkpoint_list.value.add()
param_value.tag = name
param_tensor = param_value.tensor
param_tensor.dims.extend(value[0])
param_tensor.tensor_type = value[1]
param_tensor.tensor_content = value[2].tostring()
with open(ckpt_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpt_file_name, stat.S_IRUSR)
except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
raise RuntimeError(e.__str__())
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
""" """
Saves checkpoint info to a specified file. Saves checkpoint info to a specified file.
...@@ -109,37 +133,37 @@ def save_checkpoint(parameter_list, ckpt_file_name): ...@@ -109,37 +133,37 @@ def save_checkpoint(parameter_list, ckpt_file_name):
parameter_list (list): Parameters list, each element is a dict parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}. like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpt_file_name (str): Checkpoint file name. ckpt_file_name (str): Checkpoint file name.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
Raises: Raises:
RuntimeError: Failed to save the Checkpoint file. RuntimeError: Failed to save the Checkpoint file.
""" """
logger.info("Execute save checkpoint process.") logger.info("Execute save checkpoint process.")
checkpoint_list = Checkpoint()
try: data_list = {}
with _ckpt_mutex:
for param in parameter_list: for param in parameter_list:
param_value = checkpoint_list.value.add() key = param["name"]
param_value.tag = param["name"] data_list[key] = []
param_tensor = param_value.tensor
if isinstance(param["data"], Parameter): if isinstance(param["data"], Parameter):
param["data"].init_data() param["data"].init_data()
param_data = param["data"].asnumpy().reshape(-1) dims = []
param_tensor.tensor_content = param_data.tostring()
param_tensor.tensor_type = str(param["data"].dtype)
if param['data'].shape == (): if param['data'].shape == ():
param_tensor.dims.append(0) dims.append(0)
else: else:
for dim in param['data'].shape: for dim in param['data'].shape:
param_tensor.dims.append(dim) dims.append(dim)
data_list[key].append(dims)
with open(ckpt_file_name, "wb") as f: tensor_type = str(param["data"].dtype)
f.write(checkpoint_list.SerializeToString()) data_list[key].append(tensor_type)
os.chmod(ckpt_file_name, stat.S_IRUSR) data = param["data"].asnumpy().reshape(-1)
data_list[key].append(data)
except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) if async_save:
raise RuntimeError(e.__str__()) thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list))
thr.start()
else:
_exec_save(ckpt_file_name, data_list)
logger.info("Save checkpoint process finish.") logger.info("Save checkpoint process finish.")
...@@ -305,7 +329,7 @@ def _save_graph(network, file_name): ...@@ -305,7 +329,7 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IRUSR) os.chmod(file_name, stat.S_IRUSR)
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False):
""" """
Saves checkpoint for 'ms' backend. Saves checkpoint for 'ms' backend.
...@@ -313,6 +337,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): ...@@ -313,6 +337,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
train_network (Network): The train network for training. train_network (Network): The train network for training.
ckpt_file_name (str): The name of checkpoint file. ckpt_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to integrated save in automatic model parallel scene. integrated_save (bool): Whether to integrated save in automatic model parallel scene.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
""" """
param_dict = {} param_dict = {}
...@@ -336,7 +361,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True): ...@@ -336,7 +361,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
each_param["data"] = param_data each_param["data"] = param_data
param_list.append(each_param) param_list.append(each_param)
save_checkpoint(param_list, ckpt_file_name) save_checkpoint(param_list, ckpt_file_name, async_save)
def _get_merged_param_data(net, param_name, param_data): def _get_merged_param_data(net, param_name, param_data):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册