提交 087779b7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2517 checkpoint add model_type

Merge pull request !2517 from chenzhongming/quant
...@@ -593,6 +593,17 @@ def check_bool(input_param): ...@@ -593,6 +593,17 @@ def check_bool(input_param):
raise TypeError("Input type must be bool!") raise TypeError("Input type must be bool!")
def check_string(input_param, valid_values):
"""String type judgment."""
if isinstance(input_param, str) and input_param in valid_values:
return input_param
if len(valid_values) == 1:
raise ValueError(f'Input should be str and must be {valid_values[0]},'
f' but got {input_param}.')
raise ValueError(f'Input should be str and must be one of {valid_values},'
f' but got {input_param}.')
def check_input_format(input_param): def check_input_format(input_param):
"""Judge input format.""" """Judge input format."""
if input_param == "NCHW": if input_param == "NCHW":
......
...@@ -22,6 +22,7 @@ message Checkpoint { ...@@ -22,6 +22,7 @@ message Checkpoint {
required TensorProto tensor = 2; required TensorProto tensor = 2;
} }
repeated Value value = 1; repeated Value value = 1;
required string model_type = 2;
} }
......
...@@ -21,17 +21,16 @@ import time ...@@ -21,17 +21,16 @@ import time
import mindspore.context as context import mindspore.context as context
from mindspore import log as logger from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative from mindspore._checkparam import check_bool, check_string, check_int_non_negative
from mindspore.train._utils import _make_directory from mindspore.train._utils import _make_directory
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
from ._callback import Callback, set_cur_net from ._callback import Callback, set_cur_net
_cur_dir = os.getcwd() _cur_dir = os.getcwd()
_save_dir = _cur_dir _save_dir = _cur_dir
def _check_file_name_prefix(file_name_prefix): def _check_file_name_prefix(file_name_prefix):
""" """
Check file name valid or not. Check file name valid or not.
...@@ -87,6 +86,7 @@ class CheckpointConfig: ...@@ -87,6 +86,7 @@ class CheckpointConfig:
Can't be used with keep_checkpoint_max at the same time. Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True. integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel. Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".
Raises: Raises:
ValueError: If the input_param is None or 0. ValueError: If the input_param is None or 0.
...@@ -101,7 +101,8 @@ class CheckpointConfig: ...@@ -101,7 +101,8 @@ class CheckpointConfig:
save_checkpoint_seconds=0, save_checkpoint_seconds=0,
keep_checkpoint_max=5, keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0, keep_checkpoint_per_n_minutes=0,
integrated_save=True): integrated_save=True,
model_type="normal"):
if not save_checkpoint_steps and not save_checkpoint_seconds and \ if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
...@@ -115,6 +116,8 @@ class CheckpointConfig: ...@@ -115,6 +116,8 @@ class CheckpointConfig:
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
if keep_checkpoint_per_n_minutes: if keep_checkpoint_per_n_minutes:
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
if model_type:
model_type = check_string(model_type, ["normal", "fusion", "quant"])
self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_steps = save_checkpoint_steps
self._save_checkpoint_seconds = save_checkpoint_seconds self._save_checkpoint_seconds = save_checkpoint_seconds
...@@ -129,6 +132,7 @@ class CheckpointConfig: ...@@ -129,6 +132,7 @@ class CheckpointConfig:
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0: if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
self._keep_checkpoint_max = 1 self._keep_checkpoint_max = 1
self._model_type = model_type
self._integrated_save = check_bool(integrated_save) self._integrated_save = check_bool(integrated_save)
@property @property
...@@ -156,12 +160,18 @@ class CheckpointConfig: ...@@ -156,12 +160,18 @@ class CheckpointConfig:
"""Get the value of _integrated_save.""" """Get the value of _integrated_save."""
return self._integrated_save return self._integrated_save
@property
def model_type(self):
"""Get the value of model_type."""
return self._model_type
def get_checkpoint_policy(self): def get_checkpoint_policy(self):
"""Get the policy of checkpoint.""" """Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps, checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
'save_checkpoint_seconds': self._save_checkpoint_seconds, 'save_checkpoint_seconds': self._save_checkpoint_seconds,
'keep_checkpoint_max': self._keep_checkpoint_max, 'keep_checkpoint_max': self._keep_checkpoint_max,
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes} 'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes,
'model_type': self._model_type}
return checkpoint_policy return checkpoint_policy
...@@ -226,7 +236,7 @@ class ModelCheckpoint(Callback): ...@@ -226,7 +236,7 @@ class ModelCheckpoint(Callback):
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta') graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
_save_graph(cb_params.train_network, graph_file_name) _save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True self._graph_saved = True
self._save_ckpt(cb_params) self._save_ckpt(cb_params, self._config.model_type)
def end(self, run_context): def end(self, run_context):
""" """
...@@ -237,7 +247,7 @@ class ModelCheckpoint(Callback): ...@@ -237,7 +247,7 @@ class ModelCheckpoint(Callback):
""" """
cb_params = run_context.original_args() cb_params = run_context.original_args()
_to_save_last_ckpt = True _to_save_last_ckpt = True
self._save_ckpt(cb_params, _to_save_last_ckpt) self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt)
from mindspore.parallel._cell_wrapper import destroy_allgather_cell from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell() destroy_allgather_cell()
...@@ -256,7 +266,7 @@ class ModelCheckpoint(Callback): ...@@ -256,7 +266,7 @@ class ModelCheckpoint(Callback):
return False return False
def _save_ckpt(self, cb_params, force_to_save=False): def _save_ckpt(self, cb_params, model_type, force_to_save=False):
"""Save checkpoint files.""" """Save checkpoint files."""
if cb_params.cur_step_num == self._last_triggered_step: if cb_params.cur_step_num == self._last_triggered_step:
return return
...@@ -292,7 +302,7 @@ class ModelCheckpoint(Callback): ...@@ -292,7 +302,7 @@ class ModelCheckpoint(Callback):
set_cur_net(cb_params.train_network) set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph() cb_params.train_network.exec_checkpoint_graph()
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save) _exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save)
if os.path.exists(gen_file): if os.path.exists(gen_file):
shutil.move(gen_file, cur_file) shutil.move(gen_file, cur_file)
......
...@@ -76,7 +76,7 @@ class LossMonitor(Callback): ...@@ -76,7 +76,7 @@ class LossMonitor(Callback):
step_loss = np.mean(step_loss.asnumpy()) step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss) self.losses.append(step_loss)
cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1
if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)): if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)):
raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. " raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
...@@ -87,7 +87,7 @@ class LossMonitor(Callback): ...@@ -87,7 +87,7 @@ class LossMonitor(Callback):
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0: if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], " print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format( "loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format(
cb_params.cur_epoch_num - 1, cb_params.epoch_num, cb_params.cur_epoch_num, cb_params.epoch_num,
cur_step_in_epoch, int(cb_params.batch_num), cur_step_in_epoch, int(cb_params.batch_num),
step_loss, np.mean(self.losses), step_loss, np.mean(self.losses),
step_mseconds), flush=True) step_mseconds), flush=True)
...@@ -29,6 +29,7 @@ from mindspore.common.api import _executor ...@@ -29,6 +29,7 @@ from mindspore.common.api import _executor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data from mindspore._checkparam import check_input_data
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"] __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
...@@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin ...@@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64, "Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_} "Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
ModelType = ["normal", "fusion", "quant"]
def _special_process_par(par, new_par): def _special_process_par(par, new_par):
""" """
...@@ -101,20 +104,22 @@ def _update_param(param, new_param): ...@@ -101,20 +104,22 @@ def _update_param(param, new_param):
param.set_parameter_data(type(param.data)(new_param.data)) param.set_parameter_data(type(param.data)(new_param.data))
def save_checkpoint(parameter_list, ckpoint_file_name): def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
""" """
Saves checkpoint info to a specified file. Saves checkpoint info to a specified file.
Args: Args:
parameter_list (list): Parameters list, each element is a dict parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}. like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpoint_file_name (str): Checkpoint file name. ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type. Default: "normal".
Raises: Raises:
RuntimeError: Failed to save the Checkpoint file. RuntimeError: Failed to save the Checkpoint file.
""" """
logger.info("Execute save checkpoint process.") logger.info("Execute save checkpoint process.")
checkpoint_list = Checkpoint() checkpoint_list = Checkpoint()
checkpoint_list.model_type = model_type
try: try:
for param in parameter_list: for param in parameter_list:
...@@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name): ...@@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
for dim in param['data'].shape: for dim in param['data'].shape:
param_tensor.dims.append(dim) param_tensor.dims.append(dim)
with open(ckpoint_file_name, "wb") as f: with open(ckpt_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString()) f.write(checkpoint_list.SerializeToString())
os.chmod(ckpoint_file_name, stat.S_IRUSR) os.chmod(ckpt_file_name, stat.S_IRUSR)
except BaseException as e: except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name) logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
raise RuntimeError(e.__str__()) raise RuntimeError(e.__str__())
logger.info("Save checkpoint process finish.") logger.info("Save checkpoint process finish.")
def load_checkpoint(ckpoint_file_name, net=None): def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
""" """
Loads checkpoint info from a specified file. Loads checkpoint info from a specified file.
Args: Args:
ckpoint_file_name (str): Checkpoint file name. ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
net (Cell): Cell network. Default: None net (Cell): Cell network. Default: None
Returns: Returns:
...@@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None): ...@@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None):
Raises: Raises:
ValueError: Checkpoint file is incorrect. ValueError: Checkpoint file is incorrect.
""" """
if not isinstance(ckpoint_file_name, str): if not isinstance(ckpt_file_name, str):
raise ValueError("The ckpoint_file_name must be String.") raise ValueError("The ckpt_file_name must be string.")
if not os.path.exists(ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt": if model_type not in ModelType:
raise ValueError(f"The model_type is not in {ModelType}.")
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
raise ValueError("Please input the correct checkpoint file name.") raise ValueError("Please input the correct checkpoint file name.")
if os.path.getsize(ckpoint_file_name) == 0: if os.path.getsize(ckpt_file_name) == 0:
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.") raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
logger.info("Execute load checkpoint process.") logger.info("Execute load checkpoint process.")
checkpoint_list = Checkpoint() checkpoint_list = Checkpoint()
try: try:
with open(ckpoint_file_name, "rb") as f: with open(ckpt_file_name, "rb") as f:
pb_content = f.read() pb_content = f.read()
checkpoint_list.ParseFromString(pb_content) checkpoint_list.ParseFromString(pb_content)
except BaseException as e: except BaseException as e:
logger.error("Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name) logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__()) raise ValueError(e.__str__())
parameter_dict = {} parameter_dict = {}
if model_type != checkpoint_list.model_type:
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
checkpoint_list.model_type, model_type))
try: try:
for element in checkpoint_list.value: for element in checkpoint_list.value:
data = element.tensor.tensor_content data = element.tensor.tensor_content
...@@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None): ...@@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None):
logger.info("Load checkpoint process finish.") logger.info("Load checkpoint process finish.")
except BaseException as e: except BaseException as e:
logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name) logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__()) raise RuntimeError(e.__str__())
if net: if net:
...@@ -303,14 +314,15 @@ def _save_graph(network, file_name): ...@@ -303,14 +314,15 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True): def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True):
""" """
Saves checkpoint for 'ms' backend. Saves checkpoint for 'ms' backend.
Args: Args:
train_network (Network): The train network for training. train_network (Network): The train network for training.
ckpoint_file_name (str): The name of checkpoint file. ckpt_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
""" """
param_dict = {} param_dict = {}
...@@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True ...@@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
each_param["data"] = param_data each_param["data"] = param_data
param_list.append(each_param) param_list.append(each_param)
save_checkpoint(param_list, ckpoint_file_name) save_checkpoint(param_list, ckpt_file_name, model_type)
def _get_merged_param_data(net, param_name, param_data): def _get_merged_param_data(net, param_name, param_data):
......
...@@ -20,16 +20,14 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt ...@@ -20,16 +20,14 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
import os import os
import argparse import argparse
from src.dataset import create_dataset
from src.config import mnist_cfg as cfg
from src.lenet import LeNet5
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from src.dataset import create_dataset
from src.config import mnist_cfg as cfg
from src.lenet import LeNet5
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore Lenet Example') parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
...@@ -49,9 +47,6 @@ if __name__ == "__main__": ...@@ -49,9 +47,6 @@ if __name__ == "__main__":
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size repeat_size = cfg.epoch_size
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
......
...@@ -128,9 +128,9 @@ After all the following we will get the loss value of each step as following: ...@@ -128,9 +128,9 @@ After all the following we will get the loss value of each step as following:
```bash ```bash
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234] >>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
>>> ... >>> ...
>>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] >>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] >>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] >>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
``` ```
Also, you can just run this command instead. Also, you can just run this command instead.
...@@ -197,9 +197,9 @@ After all the following we will get the loss value of each step as following: ...@@ -197,9 +197,9 @@ After all the following we will get the loss value of each step as following:
```bash ```bash
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234] >>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
>>> ... >>> ...
>>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234] >>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234] >>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234] >>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
``` ```
### Evaluate quantization aware model ### Evaluate quantization aware model
...@@ -215,7 +215,7 @@ param_dict = load_checkpoint(args.ckpt_path) ...@@ -215,7 +215,7 @@ param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
# convert funsion netwrok to quantization aware network # convert funsion netwrok to quantization aware network
network = quant.convert_quant_network(network network = quant.convert_quant_network(network)
``` ```
Also, you can just run this command insread. Also, you can just run this command insread.
......
...@@ -23,7 +23,6 @@ import argparse ...@@ -23,7 +23,6 @@ import argparse
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from src.dataset import create_dataset from src.dataset import create_dataset
...@@ -47,16 +46,18 @@ if __name__ == "__main__": ...@@ -47,16 +46,18 @@ if __name__ == "__main__":
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1) ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
step_size = ds_eval.get_dataset_size() step_size = ds_eval.get_dataset_size()
# define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size # define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max) # call back and monitor
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
param_dict = load_checkpoint(args.ckpt_path) # load check point into network
param_dict = load_checkpoint(args.ckpt_path, network.type)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
......
...@@ -23,7 +23,6 @@ import argparse ...@@ -23,7 +23,6 @@ import argparse
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.quant import quant from mindspore.train.quant import quant
...@@ -48,20 +47,21 @@ if __name__ == "__main__": ...@@ -48,20 +47,21 @@ if __name__ == "__main__":
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1) ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
step_size = ds_eval.get_dataset_size() step_size = ds_eval.get_dataset_size()
# define funsion network # define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# convert funsion netwrok to quantization aware network # convert fusion netwrok to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max) # call back and monitor
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
# load quantization aware network checkpoint # load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path, model_type="quant")
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
print("============== Starting Testing ==============") print("============== Starting Testing ==============")
......
...@@ -34,8 +34,8 @@ class LeNet5(nn.Cell): ...@@ -34,8 +34,8 @@ class LeNet5(nn.Cell):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.num_class = num_class self.num_class = num_class
self.conv1 = nn.Conv2d(channel, 6, 5) self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5) self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120) self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84) self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, self.num_class) self.fc3 = nn.Dense(84, self.num_class)
......
...@@ -32,11 +32,12 @@ class LeNet5(nn.Cell): ...@@ -32,11 +32,12 @@ class LeNet5(nn.Cell):
def __init__(self, num_class=10, channel=1): def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.type = "fusion"
self.num_class = num_class self.num_class = num_class
# change `nn.Conv2d` to `nn.Conv2dBnAct` # change `nn.Conv2d` to `nn.Conv2dBnAct`
self.conv1 = nn.Conv2dBnAct(channel, 6, 5, activation='relu') self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
self.conv2 = nn.Conv2dBnAct(6, 16, 5, activation='relu') self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
# change `nn.Dense` to `nn.DenseBnAct` # change `nn.Dense` to `nn.DenseBnAct`
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
self.fc2 = nn.DenseBnAct(120, 84, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
......
...@@ -46,16 +46,24 @@ if __name__ == "__main__": ...@@ -46,16 +46,24 @@ if __name__ == "__main__":
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
step_size = ds_train.get_dataset_size() step_size = ds_train.get_dataset_size()
# define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max,
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) model_type=network.type)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============") print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode) dataset_sink_mode=args.dataset_sink_mode)
print("============== End Training ==============") print("============== End Training ==============")
...@@ -48,23 +48,30 @@ if __name__ == "__main__": ...@@ -48,23 +48,30 @@ if __name__ == "__main__":
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size) ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
step_size = ds_train.get_dataset_size() step_size = ds_train.get_dataset_size()
# define funsion network # define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# load quantization aware network checkpoint # load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path, network.type)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)
# convert funsion netwrok to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size, config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max,
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) model_type="quant")
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============") print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()], model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode) dataset_sink_mode=args.dataset_sink_mode)
print("============== End Training ==============") print("============== End Training ==============")
...@@ -85,7 +85,7 @@ if __name__ == '__main__': ...@@ -85,7 +85,7 @@ if __name__ == '__main__':
is_ckpt_exist = os.path.exists(ckpt_file_path) is_ckpt_exist = os.path.exists(ckpt_file_path)
if is_ckpt_exist: if is_ckpt_exist:
param_dict = load_checkpoint(ckpoint_file_name=ckpt_file_path) param_dict = load_checkpoint(ckpt_file_name=ckpt_file_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)
export(net, input_data, file_name=model_path_name, file_format='LITE') export(net, input_data, file_name=model_path_name, file_format='LITE')
print("test lenet predict success.") print("test lenet predict success.")
......
...@@ -111,19 +111,19 @@ def test_save_checkpoint(): ...@@ -111,19 +111,19 @@ def test_save_checkpoint():
os.chmod('./parameters.ckpt', stat.S_IWRITE) os.chmod('./parameters.ckpt', stat.S_IWRITE)
os.remove('./parameters.ckpt') os.remove('./parameters.ckpt')
ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt') ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
save_checkpoint(parameter_list, ckpoint_file_name) save_checkpoint(parameter_list, ckpt_file_name)
def test_load_checkpoint_error_filename(): def test_load_checkpoint_error_filename():
ckpoint_file_name = 1 ckpt_file_name = 1
with pytest.raises(ValueError): with pytest.raises(ValueError):
load_checkpoint(ckpoint_file_name) load_checkpoint(ckpt_file_name)
def test_load_checkpoint(): def test_load_checkpoint():
ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt') ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
par_dict = load_checkpoint(ckpoint_file_name) par_dict = load_checkpoint(ckpt_file_name)
assert len(par_dict) == 3 assert len(par_dict) == 3
assert par_dict['param_test'].name == 'param_test' assert par_dict['param_test'].name == 'param_test'
...@@ -136,17 +136,17 @@ def test_checkpoint_manager(): ...@@ -136,17 +136,17 @@ def test_checkpoint_manager():
""" test_checkpoint_manager """ """ test_checkpoint_manager """
ckp_mgr = _CheckpointManager() ckp_mgr = _CheckpointManager()
ckpoint_file_name = os.path.join(_cur_dir, './test1.ckpt') ckpt_file_name = os.path.join(_cur_dir, './test1.ckpt')
with open(ckpoint_file_name, 'w'): with open(ckpt_file_name, 'w'):
os.chmod(ckpoint_file_name, stat.S_IWUSR | stat.S_IRUSR) os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR)
ckp_mgr.update_ckpoint_filelist(_cur_dir, "test") ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
assert ckp_mgr.ckpoint_num == 1 assert ckp_mgr.ckpoint_num == 1
ckp_mgr.remove_ckpoint_file(ckpoint_file_name) ckp_mgr.remove_ckpoint_file(ckpt_file_name)
ckp_mgr.update_ckpoint_filelist(_cur_dir, "test") ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
assert ckp_mgr.ckpoint_num == 0 assert ckp_mgr.ckpoint_num == 0
assert not os.path.exists(ckpoint_file_name) assert not os.path.exists(ckpt_file_name)
another_file_name = os.path.join(_cur_dir, './test2.ckpt') another_file_name = os.path.join(_cur_dir, './test2.ckpt')
another_file_name = os.path.realpath(another_file_name) another_file_name = os.path.realpath(another_file_name)
...@@ -283,7 +283,7 @@ def test_exec_save_checkpoint(): ...@@ -283,7 +283,7 @@ def test_exec_save_checkpoint():
loss_net = WithLossCell(net, loss) loss_net = WithLossCell(net, loss)
train_network = TrainOneStepCell(loss_net, opt) train_network = TrainOneStepCell(loss_net, opt)
_exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt") _exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
load_checkpoint("new_ckpt.ckpt") load_checkpoint("new_ckpt.ckpt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册