From 3b632eac465f9a4ea0aa789fa1b634e3f87852ce Mon Sep 17 00:00:00 2001 From: chenzomi Date: Tue, 23 Jun 2020 21:32:35 +0800 Subject: [PATCH] checkpoint add model_type --- mindspore/_checkparam.py | 11 ++++ mindspore/ccsrc/utils/checkpoint.proto | 1 + mindspore/train/callback/_checkpoint.py | 28 +++++++---- mindspore/train/callback/_loss_monitor.py | 6 +-- mindspore/train/serialization.py | 50 ++++++++++++------- model_zoo/lenet/eval.py | 11 ++-- model_zoo/lenet_quant/README.md | 16 +++--- model_zoo/lenet_quant/eval.py | 13 ++--- model_zoo/lenet_quant/eval_quant.py | 16 +++--- model_zoo/lenet_quant/src/lenet.py | 4 +- model_zoo/lenet_quant/src/lenet_fusion.py | 5 +- model_zoo/lenet_quant/train.py | 16 ++++-- model_zoo/lenet_quant/train_quant.py | 23 ++++++--- .../python/predict/test_predict_save_model.py | 2 +- tests/ut/python/utils/test_serialize.py | 24 ++++----- 15 files changed, 136 insertions(+), 90 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 880d26bfa..d5ac7c3e3 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -593,6 +593,17 @@ def check_bool(input_param): raise TypeError("Input type must be bool!") +def check_string(input_param, valid_values): + """String type judgment.""" + if isinstance(input_param, str) and input_param in valid_values: + return input_param + if len(valid_values) == 1: + raise ValueError(f'Input should be str and must be {valid_values[0]},' + f' but got {input_param}.') + raise ValueError(f'Input should be str and must be one of {valid_values},' + f' but got {input_param}.') + + def check_input_format(input_param): """Judge input format.""" if input_param == "NCHW": diff --git a/mindspore/ccsrc/utils/checkpoint.proto b/mindspore/ccsrc/utils/checkpoint.proto index 31c7cd800..7fca399e2 100644 --- a/mindspore/ccsrc/utils/checkpoint.proto +++ b/mindspore/ccsrc/utils/checkpoint.proto @@ -22,6 +22,7 @@ message Checkpoint { required TensorProto tensor = 2; } repeated Value value = 1; + required string model_type = 2; } diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index d185377c8..4e686c414 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -21,17 +21,16 @@ import time import mindspore.context as context from mindspore import log as logger -from mindspore._checkparam import check_bool, check_int_non_negative +from mindspore._checkparam import check_bool, check_string, check_int_non_negative from mindspore.train._utils import _make_directory from mindspore.train.serialization import _exec_save_checkpoint, _save_graph - from ._callback import Callback, set_cur_net + _cur_dir = os.getcwd() _save_dir = _cur_dir - def _check_file_name_prefix(file_name_prefix): """ Check file name valid or not. @@ -87,6 +86,7 @@ class CheckpointConfig: Can't be used with keep_checkpoint_max at the same time. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. + model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal". Raises: ValueError: If the input_param is None or 0. @@ -101,7 +101,8 @@ class CheckpointConfig: save_checkpoint_seconds=0, keep_checkpoint_max=5, keep_checkpoint_per_n_minutes=0, - integrated_save=True): + integrated_save=True, + model_type="normal"): if not save_checkpoint_steps and not save_checkpoint_seconds and \ not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: @@ -115,6 +116,8 @@ class CheckpointConfig: keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) if keep_checkpoint_per_n_minutes: keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) + if model_type: + model_type = check_string(model_type, ["normal", "fusion", "quant"]) self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_seconds = save_checkpoint_seconds @@ -129,6 +132,7 @@ class CheckpointConfig: if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: self._keep_checkpoint_max = 1 + self._model_type = model_type self._integrated_save = check_bool(integrated_save) @property @@ -156,12 +160,18 @@ class CheckpointConfig: """Get the value of _integrated_save.""" return self._integrated_save + @property + def model_type(self): + """Get the value of model_type.""" + return self._model_type + def get_checkpoint_policy(self): """Get the policy of checkpoint.""" checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, 'save_checkpoint_seconds': self._save_checkpoint_seconds, 'keep_checkpoint_max': self._keep_checkpoint_max, - 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes} + 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes, + 'model_type': self._model_type} return checkpoint_policy @@ -226,7 +236,7 @@ class ModelCheckpoint(Callback): graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') _save_graph(cb_params.train_network, graph_file_name) self._graph_saved = True - self._save_ckpt(cb_params) + self._save_ckpt(cb_params, self._config.model_type) def end(self, run_context): """ @@ -237,7 +247,7 @@ class ModelCheckpoint(Callback): """ cb_params = run_context.original_args() _to_save_last_ckpt = True - self._save_ckpt(cb_params, _to_save_last_ckpt) + self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt) from mindspore.parallel._cell_wrapper import destroy_allgather_cell destroy_allgather_cell() @@ -256,7 +266,7 @@ class ModelCheckpoint(Callback): return False - def _save_ckpt(self, cb_params, force_to_save=False): + def _save_ckpt(self, cb_params, model_type, force_to_save=False): """Save checkpoint files.""" if cb_params.cur_step_num == self._last_triggered_step: return @@ -292,7 +302,7 @@ class ModelCheckpoint(Callback): set_cur_net(cb_params.train_network) cb_params.train_network.exec_checkpoint_graph() - _exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) + _exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save) if os.path.exists(gen_file): shutil.move(gen_file, cur_file) diff --git a/mindspore/train/callback/_loss_monitor.py b/mindspore/train/callback/_loss_monitor.py index 22b134287..3f93c6314 100644 --- a/mindspore/train/callback/_loss_monitor.py +++ b/mindspore/train/callback/_loss_monitor.py @@ -76,7 +76,7 @@ class LossMonitor(Callback): step_loss = np.mean(step_loss.asnumpy()) self.losses.append(step_loss) - cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1 if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " @@ -87,7 +87,7 @@ class LossMonitor(Callback): if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " "loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format( - cb_params.cur_epoch_num - 1, cb_params.epoch_num, - cur_step_in_epoch, cb_params.batch_num, + cb_params.cur_epoch_num, cb_params.epoch_num, + cur_step_in_epoch, int(cb_params.batch_num), step_loss, np.mean(self.losses), step_mseconds), flush=True) diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index c39104c6f..ce776d682 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -29,6 +29,7 @@ from mindspore.common.api import _executor from mindspore.common import dtype as mstype from mindspore._checkparam import check_input_data + __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"] tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, @@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} +ModelType = ["normal", "fusion", "quant"] + def _special_process_par(par, new_par): """ @@ -101,20 +104,22 @@ def _update_param(param, new_param): param.set_parameter_data(type(param.data)(new_param.data)) -def save_checkpoint(parameter_list, ckpoint_file_name): +def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"): """ Saves checkpoint info to a specified file. Args: parameter_list (list): Parameters list, each element is a dict like {"name":xx, "type":xx, "shape":xx, "data":xx}. - ckpoint_file_name (str): Checkpoint file name. + ckpt_file_name (str): Checkpoint file name. + model_type (str): The name of model type. Default: "normal". Raises: RuntimeError: Failed to save the Checkpoint file. """ logger.info("Execute save checkpoint process.") checkpoint_list = Checkpoint() + checkpoint_list.model_type = model_type try: for param in parameter_list: @@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name): for dim in param['data'].shape: param_tensor.dims.append(dim) - with open(ckpoint_file_name, "wb") as f: + with open(ckpt_file_name, "wb") as f: f.write(checkpoint_list.SerializeToString()) - os.chmod(ckpoint_file_name, stat.S_IRUSR) + os.chmod(ckpt_file_name, stat.S_IRUSR) except BaseException as e: - logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name) + logger.error("Failed to save the checkpoint file %s.", ckpt_file_name) raise RuntimeError(e.__str__()) logger.info("Save checkpoint process finish.") -def load_checkpoint(ckpoint_file_name, net=None): +def load_checkpoint(ckpt_file_name, model_type="normal", net=None): """ Loads checkpoint info from a specified file. Args: - ckpoint_file_name (str): Checkpoint file name. + ckpt_file_name (str): Checkpoint file name. + model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal". net (Cell): Cell network. Default: None Returns: @@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None): Raises: ValueError: Checkpoint file is incorrect. """ - if not isinstance(ckpoint_file_name, str): - raise ValueError("The ckpoint_file_name must be String.") + if not isinstance(ckpt_file_name, str): + raise ValueError("The ckpt_file_name must be string.") - if not os.path.exists(ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt": + if model_type not in ModelType: + raise ValueError(f"The model_type is not in {ModelType}.") + + if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt": raise ValueError("Please input the correct checkpoint file name.") - if os.path.getsize(ckpoint_file_name) == 0: + if os.path.getsize(ckpt_file_name) == 0: raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.") logger.info("Execute load checkpoint process.") checkpoint_list = Checkpoint() try: - with open(ckpoint_file_name, "rb") as f: + with open(ckpt_file_name, "rb") as f: pb_content = f.read() checkpoint_list.ParseFromString(pb_content) except BaseException as e: - logger.error("Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name) + logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name) raise ValueError(e.__str__()) parameter_dict = {} - + if model_type != checkpoint_list.model_type: + raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format( + checkpoint_list.model_type, model_type)) try: for element in checkpoint_list.value: data = element.tensor.tensor_content @@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None): logger.info("Load checkpoint process finish.") except BaseException as e: - logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name) + logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) raise RuntimeError(e.__str__()) if net: @@ -303,14 +314,15 @@ def _save_graph(network, file_name): os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) -def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True): +def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True): """ Saves checkpoint for 'ms' backend. Args: train_network (Network): The train network for training. - ckpoint_file_name (str): The name of checkpoint file. - integrated_save (bool): Whether to intergrated save in automatic model parallel scene. + ckpt_file_name (str): The name of checkpoint file. + model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal". + integrated_save (bool): Whether to integrated save in automatic model parallel scene. """ param_dict = {} @@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True each_param["data"] = param_data param_list.append(each_param) - save_checkpoint(param_list, ckpoint_file_name) + save_checkpoint(param_list, ckpt_file_name, model_type) def _get_merged_param_data(net, param_name, param_data): diff --git a/model_zoo/lenet/eval.py b/model_zoo/lenet/eval.py index a9842f442..bcd5503c3 100644 --- a/model_zoo/lenet/eval.py +++ b/model_zoo/lenet/eval.py @@ -20,16 +20,14 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt import os import argparse -from src.dataset import create_dataset -from src.config import mnist_cfg as cfg -from src.lenet import LeNet5 import mindspore.nn as nn from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train import Model from mindspore.nn.metrics import Accuracy - +from src.dataset import create_dataset +from src.config import mnist_cfg as cfg +from src.lenet import LeNet5 if __name__ == "__main__": parser = argparse.ArgumentParser(description='MindSpore Lenet Example') @@ -49,9 +47,6 @@ if __name__ == "__main__": net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") repeat_size = cfg.epoch_size net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) print("============== Starting Testing ==============") diff --git a/model_zoo/lenet_quant/README.md b/model_zoo/lenet_quant/README.md index b3bac22c0..c895f68be 100644 --- a/model_zoo/lenet_quant/README.md +++ b/model_zoo/lenet_quant/README.md @@ -128,9 +128,9 @@ After all the following we will get the loss value of each step as following: ```bash >>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234] >>> ... ->>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] ->>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] ->>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] +>>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] +>>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] +>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] ``` To save your time, just run this command. @@ -197,9 +197,9 @@ After all the following we will get the loss value of each step as following: ```bash >>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234] >>> ... ->>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] ->>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] ->>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] +>>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] +>>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] +>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] ``` ### Evaluate quantization aware model @@ -214,8 +214,8 @@ network = LeNet5Fusion(cfg.num_classes) param_dict = load_checkpoint(args.ckpt_path) load_param_into_net(network, param_dict) -# convert funsion netwrok to aware quantizaiton network -network = quant.convert_quant_network(network +# convert funsion netwrok to quantization aware network +network = quant.convert_quant_network(network) ``` To save your time, just run this command. diff --git a/model_zoo/lenet_quant/eval.py b/model_zoo/lenet_quant/eval.py index c1e3a5fd8..d94e77279 100644 --- a/model_zoo/lenet_quant/eval.py +++ b/model_zoo/lenet_quant/eval.py @@ -23,7 +23,6 @@ import argparse import mindspore.nn as nn from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train import Model from mindspore.nn.metrics import Accuracy from src.dataset import create_dataset @@ -47,16 +46,18 @@ if __name__ == "__main__": ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1) step_size = ds_eval.get_dataset_size() + # define fusion network network = LeNet5Fusion(cfg.num_classes) + # define loss net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") - repeat_size = cfg.epoch_size + # define network optimization net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) + + # call back and monitor model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - param_dict = load_checkpoint(args.ckpt_path) + # load check point into network + param_dict = load_checkpoint(args.ckpt_path, network.type) load_param_into_net(network, param_dict) print("============== Starting Testing ==============") diff --git a/model_zoo/lenet_quant/eval_quant.py b/model_zoo/lenet_quant/eval_quant.py index 0ff943f8c..2c2477123 100644 --- a/model_zoo/lenet_quant/eval_quant.py +++ b/model_zoo/lenet_quant/eval_quant.py @@ -23,7 +23,6 @@ import argparse import mindspore.nn as nn from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net -from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train import Model from mindspore.nn.metrics import Accuracy from mindspore.train.quant import quant @@ -48,20 +47,21 @@ if __name__ == "__main__": ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1) step_size = ds_eval.get_dataset_size() - # define funsion network + # define fusion network network = LeNet5Fusion(cfg.num_classes) - # convert funsion netwrok to aware quantizaiton network + # convert fusion netwrok to quantization aware network network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) + # define loss net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + # define network optimization net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) + + # call back and monitor model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) - # load aware quantizaiton network checkpoint - param_dict = load_checkpoint(args.ckpt_path) + # load quantization aware network checkpoint + param_dict = load_checkpoint(args.ckpt_path, model_type="quant") load_param_into_net(network, param_dict) print("============== Starting Testing ==============") diff --git a/model_zoo/lenet_quant/src/lenet.py b/model_zoo/lenet_quant/src/lenet.py index 026f1e8df..1efcf9e7d 100644 --- a/model_zoo/lenet_quant/src/lenet.py +++ b/model_zoo/lenet_quant/src/lenet.py @@ -34,8 +34,8 @@ class LeNet5(nn.Cell): super(LeNet5, self).__init__() self.num_class = num_class - self.conv1 = nn.Conv2d(channel, 6, 5) - self.conv2 = nn.Conv2d(6, 16, 5) + self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') self.fc1 = nn.Dense(16 * 5 * 5, 120) self.fc2 = nn.Dense(120, 84) self.fc3 = nn.Dense(84, self.num_class) diff --git a/model_zoo/lenet_quant/src/lenet_fusion.py b/model_zoo/lenet_quant/src/lenet_fusion.py index 809276a48..88b359350 100644 --- a/model_zoo/lenet_quant/src/lenet_fusion.py +++ b/model_zoo/lenet_quant/src/lenet_fusion.py @@ -32,11 +32,12 @@ class LeNet5(nn.Cell): def __init__(self, num_class=10, channel=1): super(LeNet5, self).__init__() + self.type = "fusion" self.num_class = num_class # change `nn.Conv2d` to `nn.Conv2dBnAct` - self.conv1 = nn.Conv2dBnAct(channel, 6, 5, activation='relu') - self.conv2 = nn.Conv2dBnAct(6, 16, 5, activation='relu') + self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') + self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') # change `nn.Dense` to `nn.DenseBnAct` self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu') diff --git a/model_zoo/lenet_quant/train.py b/model_zoo/lenet_quant/train.py index 6e7a46fb3..b6040776e 100644 --- a/model_zoo/lenet_quant/train.py +++ b/model_zoo/lenet_quant/train.py @@ -46,16 +46,24 @@ if __name__ == "__main__": ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) step_size = ds_train.get_dataset_size() + # define fusion network network = LeNet5Fusion(cfg.num_classes) + # define network loss net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + # define network optimization net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + + # call back and monitor time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) + config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, + keep_checkpoint_max=cfg.keep_checkpoint_max, + model_type=network.type) + ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) + + # define model model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) print("============== Starting Training ==============") - model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], + model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()], dataset_sink_mode=args.dataset_sink_mode) print("============== End Training ==============") diff --git a/model_zoo/lenet_quant/train_quant.py b/model_zoo/lenet_quant/train_quant.py index 3de700af7..eb1f783a7 100644 --- a/model_zoo/lenet_quant/train_quant.py +++ b/model_zoo/lenet_quant/train_quant.py @@ -48,23 +48,30 @@ if __name__ == "__main__": ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) step_size = ds_train.get_dataset_size() - # define funsion network + # define fusion network network = LeNet5Fusion(cfg.num_classes) - # load aware quantizaiton network checkpoint - param_dict = load_checkpoint(args.ckpt_path) + # load quantization aware network checkpoint + param_dict = load_checkpoint(args.ckpt_path, network.type) load_param_into_net(network, param_dict) - # convert funsion netwrok to aware quantizaiton network + # convert fusion network to quantization aware network network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) + # define network loss net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") + # define network optimization net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) + + # call back and monitor time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) - config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, - keep_checkpoint_max=cfg.keep_checkpoint_max) - ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) + config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, + keep_checkpoint_max=cfg.keep_checkpoint_max, + model_type="quant") + ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt) + + # define model model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) print("============== Starting Training ==============") - model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], + model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()], dataset_sink_mode=args.dataset_sink_mode) print("============== End Training ==============") diff --git a/tests/ut/python/predict/test_predict_save_model.py b/tests/ut/python/predict/test_predict_save_model.py index 4f5fe16ad..f57875d07 100644 --- a/tests/ut/python/predict/test_predict_save_model.py +++ b/tests/ut/python/predict/test_predict_save_model.py @@ -85,7 +85,7 @@ if __name__ == '__main__': is_ckpt_exist = os.path.exists(ckpt_file_path) if is_ckpt_exist: - param_dict = load_checkpoint(ckpoint_file_name=ckpt_file_path) + param_dict = load_checkpoint(ckpt_file_name=ckpt_file_path) load_param_into_net(net, param_dict) export(net, input_data, file_name=model_path_name, file_format='LITE') print("test lenet predict success.") diff --git a/tests/ut/python/utils/test_serialize.py b/tests/ut/python/utils/test_serialize.py index 19e9bd72e..c5b458656 100644 --- a/tests/ut/python/utils/test_serialize.py +++ b/tests/ut/python/utils/test_serialize.py @@ -111,19 +111,19 @@ def test_save_checkpoint(): os.chmod('./parameters.ckpt', stat.S_IWRITE) os.remove('./parameters.ckpt') - ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt') - save_checkpoint(parameter_list, ckpoint_file_name) + ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') + save_checkpoint(parameter_list, ckpt_file_name) def test_load_checkpoint_error_filename(): - ckpoint_file_name = 1 + ckpt_file_name = 1 with pytest.raises(ValueError): - load_checkpoint(ckpoint_file_name) + load_checkpoint(ckpt_file_name) def test_load_checkpoint(): - ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt') - par_dict = load_checkpoint(ckpoint_file_name) + ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt') + par_dict = load_checkpoint(ckpt_file_name) assert len(par_dict) == 3 assert par_dict['param_test'].name == 'param_test' @@ -136,17 +136,17 @@ def test_checkpoint_manager(): """ test_checkpoint_manager """ ckp_mgr = _CheckpointManager() - ckpoint_file_name = os.path.join(_cur_dir, './test1.ckpt') - with open(ckpoint_file_name, 'w'): - os.chmod(ckpoint_file_name, stat.S_IWUSR | stat.S_IRUSR) + ckpt_file_name = os.path.join(_cur_dir, './test1.ckpt') + with open(ckpt_file_name, 'w'): + os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR) ckp_mgr.update_ckpoint_filelist(_cur_dir, "test") assert ckp_mgr.ckpoint_num == 1 - ckp_mgr.remove_ckpoint_file(ckpoint_file_name) + ckp_mgr.remove_ckpoint_file(ckpt_file_name) ckp_mgr.update_ckpoint_filelist(_cur_dir, "test") assert ckp_mgr.ckpoint_num == 0 - assert not os.path.exists(ckpoint_file_name) + assert not os.path.exists(ckpt_file_name) another_file_name = os.path.join(_cur_dir, './test2.ckpt') another_file_name = os.path.realpath(another_file_name) @@ -283,7 +283,7 @@ def test_exec_save_checkpoint(): loss_net = WithLossCell(net, loss) train_network = TrainOneStepCell(loss_net, opt) - _exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt") + _exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt") load_checkpoint("new_ckpt.ckpt") -- GitLab