提交 4683de34 编写于 作者: L liuyang_655

modify save_checkpoint

上级 b346f0b3
......@@ -23,7 +23,7 @@ import mindspore.context as context
from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
from mindspore.train.serialization import save_checkpoint, _save_graph
from ._callback import Callback, set_cur_net
......@@ -306,8 +306,8 @@ class ModelCheckpoint(Callback):
set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()
_exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
self._config.async_save)
save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save,
self._config.async_save)
self._latest_ckpt_file_name = cur_file
......
......@@ -141,24 +141,52 @@ def _exec_save(ckpt_file_name, data_list):
raise RuntimeError(e.__str__())
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False):
"""
Saves checkpoint info to a specified file.
Args:
parameter_list (list): Parameters list, each element is a dictionary
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
save_obj (nn.Cell or list): The train network for training or parameters list(each element is a dictionary,
like {"name":xx, "type":xx, "shape":xx, "data":xx}.)
ckpt_file_name (str): Checkpoint file name.
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
Raises:
TypeError: If the parameter save_obj is not nn.Cell or list type.
RuntimeError: Failed to save the Checkpoint file.
"""
if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
logger.info("Execute save checkpoint process.")
if isinstance(save_obj, nn.Cell):
save_obj.init_parameters_data()
param_dict = {}
for _, param in save_obj.parameters_and_names():
param_dict[param.name] = param
param_list = []
for (key, value) in param_dict.items():
each_param = {"name": key}
if isinstance(value.data, Tensor):
param_data = value.data
else:
param_data = Tensor(value.data)
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
# which should be combined before saving
if integrated_save and key in save_obj.parameter_layout_dict:
param_data = _get_merged_param_data(save_obj, key, param_data)
each_param["data"] = param_data
param_list.append(each_param)
save_obj = param_list
data_list = {}
with _ckpt_mutex:
for param in parameter_list:
for param in save_obj:
key = param["name"]
data_list[key] = []
if isinstance(param["data"], Parameter):
......@@ -180,6 +208,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
thr.start()
else:
_exec_save(ckpt_file_name, data_list)
logger.info("Save checkpoint process finish.")
......@@ -354,39 +383,6 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IRUSR)
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False):
"""
Saves checkpoint for 'ms' backend.
Args:
train_network (Network): The train network for training.
ckpt_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
"""
train_network.init_parameters_data()
param_dict = {}
for _, param in train_network.parameters_and_names():
param_dict[param.name] = param
param_list = []
for (key, value) in param_dict.items():
each_param = {"name": key}
if isinstance(value.data, Tensor):
param_data = value.data
else:
param_data = Tensor(value.data)
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
# which should be combined before saving
if integrated_save and key in train_network.parameter_layout_dict:
param_data = _get_merged_param_data(train_network, key, param_data)
each_param["data"] = param_data
param_list.append(each_param)
save_checkpoint(param_list, ckpt_file_name, async_save)
def _get_merged_param_data(net, param_name, param_data):
"""
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
......
......@@ -18,7 +18,7 @@ import os
import numpy as np
import mindspore.context as context
from mindspore.train.serialization import _exec_save_checkpoint, load_checkpoint
from mindspore.train.serialization import save_checkpoint, load_checkpoint
from src.config import GatConfig
from src.dataset import load_and_process
......@@ -98,7 +98,7 @@ def train():
val_loss_model = eval_loss
if os.path.exists("ckpts/gat.ckpt"):
os.remove("ckpts/gat.ckpt")
_exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt")
save_checkpoint(train_net.network, "ckpts/gat.ckpt")
val_acc_max = np.max((val_acc_max, eval_acc))
val_loss_min = np.min((val_loss_min, eval_loss))
curr_step = 0
......
......@@ -20,7 +20,7 @@ import numpy as np
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.train.callback import Callback
from mindspore.train.serialization import _exec_save_checkpoint
from mindspore.train.serialization import save_checkpoint
from mindspore.ops import operations as P
from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR
from .assessment_method import Accuracy
......@@ -53,9 +53,9 @@ class ModelSaveCkpt(Callback):
self.save_ckpt_step))
if os.path.exists(path):
os.remove(path)
_exec_save_checkpoint(self.network, os.path.join(self.output_dir,
"tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
self.save_ckpt_step)))
save_checkpoint(self.network, os.path.join(self.output_dir,
"tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num),
self.save_ckpt_step)))
class LossCallBack(Callback):
"""
......@@ -113,7 +113,7 @@ class EvalCallBack(Callback):
eval_model_ckpt_file = "eval_model.ckpt"
if os.path.exists(eval_model_ckpt_file):
os.remove(eval_model_ckpt_file)
_exec_save_checkpoint(self.network, eval_model_ckpt_file)
save_checkpoint(self.network, eval_model_ckpt_file)
class BertLearningRate(LearningRateSchedule):
"""
......
......@@ -31,7 +31,7 @@ from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.train.callback import _CheckpointManager
from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
_exec_save_checkpoint, export, _save_graph
export, _save_graph
from ..ut_filter import non_graph_engine
context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb")
......@@ -95,8 +95,8 @@ def test_save_graph():
os.remove(output_file)
def test_save_checkpoint():
""" test_save_checkpoint """
def test_save_checkpoint_for_list():
""" test save_checkpoint for list"""
parameter_list = []
one_param = {}
param1 = {}
......@@ -280,14 +280,15 @@ def test_load_param_into_net():
assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1
def test_exec_save_checkpoint():
def test_save_checkpoint_for_network():
""" test save_checkpoint for network"""
net = Net()
loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
loss_net = WithLossCell(net, loss)
train_network = TrainOneStepCell(loss_net, opt)
_exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
load_checkpoint("new_ckpt.ckpt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册