From 4683de3443eed992108220e18378f8199080ccba Mon Sep 17 00:00:00 2001 From: liuyang_655 Date: Sat, 29 Aug 2020 13:08:28 +0800 Subject: [PATCH] modify save_checkpoint --- mindspore/train/callback/_checkpoint.py | 6 +- mindspore/train/serialization.py | 70 +++++++++----------- model_zoo/official/gnn/gat/train.py | 4 +- model_zoo/official/nlp/tinybert/src/utils.py | 10 +-- tests/ut/python/utils/test_serialize.py | 11 +-- 5 files changed, 49 insertions(+), 52 deletions(-) diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index c9926f7b1..a76c73e11 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -23,7 +23,7 @@ import mindspore.context as context from mindspore import log as logger from mindspore._checkparam import check_bool, check_int_non_negative from mindspore.train._utils import _make_directory -from mindspore.train.serialization import _exec_save_checkpoint, _save_graph +from mindspore.train.serialization import save_checkpoint, _save_graph from ._callback import Callback, set_cur_net @@ -306,8 +306,8 @@ class ModelCheckpoint(Callback): set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() - _exec_save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save, - self._config.async_save) + save_checkpoint(cb_params.train_network, cur_file, self._config.integrated_save, + self._config.async_save) self._latest_ckpt_file_name = cur_file diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index c12e5615d..0b53ad126 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -141,24 +141,52 @@ def _exec_save(ckpt_file_name, data_list): raise RuntimeError(e.__str__()) -def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): +def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False): """ Saves checkpoint info to a specified file. Args: - parameter_list (list): Parameters list, each element is a dictionary - like {"name":xx, "type":xx, "shape":xx, "data":xx}. + save_obj (nn.Cell or list): The train network for training or parameters list(each element is a dictionary, + like {"name":xx, "type":xx, "shape":xx, "data":xx}.) ckpt_file_name (str): Checkpoint file name. + integrated_save (bool): Whether to integrated save in automatic model parallel scene. async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False Raises: + TypeError: If the parameter save_obj is not nn.Cell or list type. RuntimeError: Failed to save the Checkpoint file. """ + + if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list): + raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj))) + logger.info("Execute save checkpoint process.") + if isinstance(save_obj, nn.Cell): + save_obj.init_parameters_data() + param_dict = {} + for _, param in save_obj.parameters_and_names(): + param_dict[param.name] = param + param_list = [] + for (key, value) in param_dict.items(): + each_param = {"name": key} + if isinstance(value.data, Tensor): + param_data = value.data + else: + param_data = Tensor(value.data) + + # in automatic model parallel scenario, some parameters were spliteds to all the devices, + # which should be combined before saving + if integrated_save and key in save_obj.parameter_layout_dict: + param_data = _get_merged_param_data(save_obj, key, param_data) + + each_param["data"] = param_data + param_list.append(each_param) + save_obj = param_list + data_list = {} with _ckpt_mutex: - for param in parameter_list: + for param in save_obj: key = param["name"] data_list[key] = [] if isinstance(param["data"], Parameter): @@ -180,6 +208,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False): thr.start() else: _exec_save(ckpt_file_name, data_list) + logger.info("Save checkpoint process finish.") @@ -354,39 +383,6 @@ def _save_graph(network, file_name): os.chmod(file_name, stat.S_IRUSR) -def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False): - """ - Saves checkpoint for 'ms' backend. - - Args: - train_network (Network): The train network for training. - ckpt_file_name (str): The name of checkpoint file. - integrated_save (bool): Whether to integrated save in automatic model parallel scene. - async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False. - """ - train_network.init_parameters_data() - param_dict = {} - for _, param in train_network.parameters_and_names(): - param_dict[param.name] = param - param_list = [] - for (key, value) in param_dict.items(): - each_param = {"name": key} - if isinstance(value.data, Tensor): - param_data = value.data - else: - param_data = Tensor(value.data) - - # in automatic model parallel scenario, some parameters were spliteds to all the devices, - # which should be combined before saving - if integrated_save and key in train_network.parameter_layout_dict: - param_data = _get_merged_param_data(train_network, key, param_data) - - each_param["data"] = param_data - param_list.append(each_param) - - save_checkpoint(param_list, ckpt_file_name, async_save) - - def _get_merged_param_data(net, param_name, param_data): """ Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map. diff --git a/model_zoo/official/gnn/gat/train.py b/model_zoo/official/gnn/gat/train.py index acfbb05b7..94ac6f069 100644 --- a/model_zoo/official/gnn/gat/train.py +++ b/model_zoo/official/gnn/gat/train.py @@ -18,7 +18,7 @@ import os import numpy as np import mindspore.context as context -from mindspore.train.serialization import _exec_save_checkpoint, load_checkpoint +from mindspore.train.serialization import save_checkpoint, load_checkpoint from src.config import GatConfig from src.dataset import load_and_process @@ -98,7 +98,7 @@ def train(): val_loss_model = eval_loss if os.path.exists("ckpts/gat.ckpt"): os.remove("ckpts/gat.ckpt") - _exec_save_checkpoint(train_net.network, "ckpts/gat.ckpt") + save_checkpoint(train_net.network, "ckpts/gat.ckpt") val_acc_max = np.max((val_acc_max, eval_acc)) val_loss_min = np.min((val_loss_min, eval_loss)) curr_step = 0 diff --git a/model_zoo/official/nlp/tinybert/src/utils.py b/model_zoo/official/nlp/tinybert/src/utils.py index 5e1e77570..84746ae51 100644 --- a/model_zoo/official/nlp/tinybert/src/utils.py +++ b/model_zoo/official/nlp/tinybert/src/utils.py @@ -20,7 +20,7 @@ import numpy as np from mindspore import Tensor from mindspore.common import dtype as mstype from mindspore.train.callback import Callback -from mindspore.train.serialization import _exec_save_checkpoint +from mindspore.train.serialization import save_checkpoint from mindspore.ops import operations as P from mindspore.nn.learning_rate_schedule import LearningRateSchedule, PolynomialDecayLR, WarmUpLR from .assessment_method import Accuracy @@ -53,9 +53,9 @@ class ModelSaveCkpt(Callback): self.save_ckpt_step)) if os.path.exists(path): os.remove(path) - _exec_save_checkpoint(self.network, os.path.join(self.output_dir, - "tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num), - self.save_ckpt_step))) + save_checkpoint(self.network, os.path.join(self.output_dir, + "tiny_bert_{}_{}.ckpt".format(int(saved_ckpt_num), + self.save_ckpt_step))) class LossCallBack(Callback): """ @@ -113,7 +113,7 @@ class EvalCallBack(Callback): eval_model_ckpt_file = "eval_model.ckpt" if os.path.exists(eval_model_ckpt_file): os.remove(eval_model_ckpt_file) - _exec_save_checkpoint(self.network, eval_model_ckpt_file) + save_checkpoint(self.network, eval_model_ckpt_file) class BertLearningRate(LearningRateSchedule): """ diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index dae05e983..5aea787c1 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -31,7 +31,7 @@ from mindspore.nn.optim.momentum import Momentum from mindspore.ops import operations as P from mindspore.train.callback import _CheckpointManager from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \ - _exec_save_checkpoint, export, _save_graph + export, _save_graph from ..ut_filter import non_graph_engine context.set_context(mode=context.GRAPH_MODE, print_file_path="print/print.pb") @@ -95,8 +95,8 @@ def test_save_graph(): os.remove(output_file) -def test_save_checkpoint(): - """ test_save_checkpoint """ +def test_save_checkpoint_for_list(): + """ test save_checkpoint for list""" parameter_list = [] one_param = {} param1 = {} @@ -280,14 +280,15 @@ def test_load_param_into_net(): assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1 -def test_exec_save_checkpoint(): +def test_save_checkpoint_for_network(): + """ test save_checkpoint for network""" net = Net() loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024) loss_net = WithLossCell(net, loss) train_network = TrainOneStepCell(loss_net, opt) - _exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt") + save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt") load_checkpoint("new_ckpt.ckpt") -- GitLab