提交 966f0523 编写于 作者: C changzherui

asyn save checkpoint to file merge to r0.3

上级 e5193176
......@@ -16,7 +16,6 @@
import os
import stat
import shutil
import time
import numpy as np
......@@ -625,8 +624,6 @@ class ModelCheckpoint(Callback):
global _save_dir
_save_dir = self._directory
cur_file = os.path.join(self._directory, cur_ckpoint_file)
tmp_ckpt_file_name_for_cur_process = str(os.getpid()) + "-" + 'parameters.ckpt'
gen_file = os.path.join(_save_dir, tmp_ckpt_file_name_for_cur_process)
self._last_time_for_keep = time.time()
self._last_triggered_step = cb_params.cur_step_num
......@@ -634,10 +631,8 @@ class ModelCheckpoint(Callback):
_set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
_exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save)
if os.path.exists(gen_file):
shutil.move(gen_file, cur_file)
self._latest_ckpt_file_name = cur_file
@property
......
......@@ -84,12 +84,12 @@ class DatasetHelper:
class _DatasetIter:
"""Base iter for dataset help"""
def __init__(self, dataset):
self.loop_size = 1
if not hasattr(dataset, '__loop_size__'):
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
if not hasattr(dataset, '__ME_INITED__'):
if not hasattr(dataset, '__loop_size__'):
self.loop_size = dataset.get_dataset_size()
else:
self.loop_size = dataset.__loop_size__
dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.loop_size)
dataset.__ME_INITED__ = dataset.__TRANSFER_DATASET__.queue_name
......
......@@ -15,6 +15,7 @@
"""Model and parameters serialization."""
import os
import stat
from threading import Thread
import numpy as np
import mindspore.nn as nn
......@@ -96,7 +97,23 @@ def _update_param(param, new_param):
param.set_parameter_data(type(param.data)(new_param.data))
def save_checkpoint(parameter_list, ckpoint_file_name):
def asyn_thread(fun):
def wrapper(*args, **kwargs):
thr = Thread(target=fun, args=args, kwargs=kwargs)
thr.start()
return wrapper
@asyn_thread
def asyn_save_fun(ckpoint_file_name, checkpoint_list):
logger.info("Asynchronous execute save checkpoint into file.")
with open(ckpoint_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpoint_file_name, stat.S_IRUSR)
logger.info("Asynchronous save checkpoint into file process finish.")
def save_checkpoint(parameter_list, ckpoint_file_name, asyn_exec=False):
"""
Saves checkpoint info to a specified file.
......@@ -104,6 +121,7 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpoint_file_name (str): Checkpoint file name.
asyn_exec (bool): Whether asynchronous execute save checkpoint into file.
Raises:
RuntimeError: Failed to save the Checkpoint file.
......@@ -127,10 +145,12 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
else:
for dim in param['data'].shape():
param_tensor.dims.append(dim)
with open(ckpoint_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpoint_file_name, stat.S_IRUSR)
if asyn_exec:
asyn_save_fun(ckpoint_file_name, checkpoint_list)
else:
with open(ckpoint_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpoint_file_name, stat.S_IRUSR)
except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name)
......@@ -298,7 +318,7 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True, asyn_save=False):
"""
Saves checkpoint for 'ms' backend.
......@@ -329,7 +349,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
each_param["data"] = param_data
param_list.append(each_param)
save_checkpoint(param_list, ckpoint_file_name)
save_checkpoint(param_list, ckpoint_file_name, asyn_save)
def _get_merged_param_data(net, param_name, param_data):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册