提交 3b632eac 编写于 作者: C chenzomi

checkpoint add model_type

上级 166d8865
......@@ -593,6 +593,17 @@ def check_bool(input_param):
raise TypeError("Input type must be bool!")
def check_string(input_param, valid_values):
"""String type judgment."""
if isinstance(input_param, str) and input_param in valid_values:
return input_param
if len(valid_values) == 1:
raise ValueError(f'Input should be str and must be {valid_values[0]},'
f' but got {input_param}.')
raise ValueError(f'Input should be str and must be one of {valid_values},'
f' but got {input_param}.')
def check_input_format(input_param):
"""Judge input format."""
if input_param == "NCHW":
......
......@@ -22,6 +22,7 @@ message Checkpoint {
required TensorProto tensor = 2;
}
repeated Value value = 1;
required string model_type = 2;
}
......
......@@ -21,17 +21,16 @@ import time
import mindspore.context as context
from mindspore import log as logger
from mindspore._checkparam import check_bool, check_int_non_negative
from mindspore._checkparam import check_bool, check_string, check_int_non_negative
from mindspore.train._utils import _make_directory
from mindspore.train.serialization import _exec_save_checkpoint, _save_graph
from ._callback import Callback, set_cur_net
_cur_dir = os.getcwd()
_save_dir = _cur_dir
def _check_file_name_prefix(file_name_prefix):
"""
Check file name valid or not.
......@@ -87,6 +86,7 @@ class CheckpointConfig:
Can't be used with keep_checkpoint_max at the same time.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene. Default: True.
Integrated save function is only supported in automatic parallel scene, not supported in manual parallel.
model_type (str): Model type in `normal`, `fusion` or `quant`. Default: "normal".
Raises:
ValueError: If the input_param is None or 0.
......@@ -101,7 +101,8 @@ class CheckpointConfig:
save_checkpoint_seconds=0,
keep_checkpoint_max=5,
keep_checkpoint_per_n_minutes=0,
integrated_save=True):
integrated_save=True,
model_type="normal"):
if not save_checkpoint_steps and not save_checkpoint_seconds and \
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
......@@ -115,6 +116,8 @@ class CheckpointConfig:
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
if keep_checkpoint_per_n_minutes:
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
if model_type:
model_type = check_string(model_type, ["normal", "fusion", "quant"])
self._save_checkpoint_steps = save_checkpoint_steps
self._save_checkpoint_seconds = save_checkpoint_seconds
......@@ -129,6 +132,7 @@ class CheckpointConfig:
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
self._keep_checkpoint_max = 1
self._model_type = model_type
self._integrated_save = check_bool(integrated_save)
@property
......@@ -156,12 +160,18 @@ class CheckpointConfig:
"""Get the value of _integrated_save."""
return self._integrated_save
@property
def model_type(self):
"""Get the value of model_type."""
return self._model_type
def get_checkpoint_policy(self):
"""Get the policy of checkpoint."""
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
'save_checkpoint_seconds': self._save_checkpoint_seconds,
'keep_checkpoint_max': self._keep_checkpoint_max,
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes}
'keep_checkpoint_per_n_minutes': self._keep_checkpoint_per_n_minutes,
'model_type': self._model_type}
return checkpoint_policy
......@@ -226,7 +236,7 @@ class ModelCheckpoint(Callback):
graph_file_name = os.path.join(self._directory, self._prefix + '-graph.meta')
_save_graph(cb_params.train_network, graph_file_name)
self._graph_saved = True
self._save_ckpt(cb_params)
self._save_ckpt(cb_params, self._config.model_type)
def end(self, run_context):
"""
......@@ -237,7 +247,7 @@ class ModelCheckpoint(Callback):
"""
cb_params = run_context.original_args()
_to_save_last_ckpt = True
self._save_ckpt(cb_params, _to_save_last_ckpt)
self._save_ckpt(cb_params, self._config.model_type, _to_save_last_ckpt)
from mindspore.parallel._cell_wrapper import destroy_allgather_cell
destroy_allgather_cell()
......@@ -256,7 +266,7 @@ class ModelCheckpoint(Callback):
return False
def _save_ckpt(self, cb_params, force_to_save=False):
def _save_ckpt(self, cb_params, model_type, force_to_save=False):
"""Save checkpoint files."""
if cb_params.cur_step_num == self._last_triggered_step:
return
......@@ -292,7 +302,7 @@ class ModelCheckpoint(Callback):
set_cur_net(cb_params.train_network)
cb_params.train_network.exec_checkpoint_graph()
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
_exec_save_checkpoint(cb_params.train_network, gen_file, model_type, self._config.integrated_save)
if os.path.exists(gen_file):
shutil.move(gen_file, cur_file)
......
......@@ -76,7 +76,7 @@ class LossMonitor(Callback):
step_loss = np.mean(step_loss.asnumpy())
self.losses.append(step_loss)
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num
cur_step_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num) + 1
if isinstance(step_loss, float) and (np.isnan(step_loss) or np.isinf(step_loss)):
raise ValueError("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}]. "
......@@ -87,7 +87,7 @@ class LossMonitor(Callback):
if self._per_print_times != 0 and cb_params.cur_step_num % self._per_print_times == 0:
print("Epoch: [{:3d}/{:3d}], step: [{:5d}/{:5d}], "
"loss: [{:5.4f}/{:5.4f}], time: [{:5.4f}]".format(
cb_params.cur_epoch_num - 1, cb_params.epoch_num,
cur_step_in_epoch, cb_params.batch_num,
cb_params.cur_epoch_num, cb_params.epoch_num,
cur_step_in_epoch, int(cb_params.batch_num),
step_loss, np.mean(self.losses),
step_mseconds), flush=True)
......@@ -29,6 +29,7 @@ from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
......@@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
ModelType = ["normal", "fusion", "quant"]
def _special_process_par(par, new_par):
"""
......@@ -101,20 +104,22 @@ def _update_param(param, new_param):
param.set_parameter_data(type(param.data)(new_param.data))
def save_checkpoint(parameter_list, ckpoint_file_name):
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
"""
Saves checkpoint info to a specified file.
Args:
parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpoint_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type. Default: "normal".
Raises:
RuntimeError: Failed to save the Checkpoint file.
"""
logger.info("Execute save checkpoint process.")
checkpoint_list = Checkpoint()
checkpoint_list.model_type = model_type
try:
for param in parameter_list:
......@@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
for dim in param['data'].shape:
param_tensor.dims.append(dim)
with open(ckpoint_file_name, "wb") as f:
with open(ckpt_file_name, "wb") as f:
f.write(checkpoint_list.SerializeToString())
os.chmod(ckpoint_file_name, stat.S_IRUSR)
os.chmod(ckpt_file_name, stat.S_IRUSR)
except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name)
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
raise RuntimeError(e.__str__())
logger.info("Save checkpoint process finish.")
def load_checkpoint(ckpoint_file_name, net=None):
def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
"""
Loads checkpoint info from a specified file.
Args:
ckpoint_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
net (Cell): Cell network. Default: None
Returns:
......@@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None):
Raises:
ValueError: Checkpoint file is incorrect.
"""
if not isinstance(ckpoint_file_name, str):
raise ValueError("The ckpoint_file_name must be String.")
if not isinstance(ckpt_file_name, str):
raise ValueError("The ckpt_file_name must be string.")
if not os.path.exists(ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt":
if model_type not in ModelType:
raise ValueError(f"The model_type is not in {ModelType}.")
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
raise ValueError("Please input the correct checkpoint file name.")
if os.path.getsize(ckpoint_file_name) == 0:
if os.path.getsize(ckpt_file_name) == 0:
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
logger.info("Execute load checkpoint process.")
checkpoint_list = Checkpoint()
try:
with open(ckpoint_file_name, "rb") as f:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
logger.error("Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name)
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__())
parameter_dict = {}
if model_type != checkpoint_list.model_type:
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
checkpoint_list.model_type, model_type))
try:
for element in checkpoint_list.value:
data = element.tensor.tensor_content
......@@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None):
logger.info("Load checkpoint process finish.")
except BaseException as e:
logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name)
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__())
if net:
......@@ -303,14 +314,15 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True):
"""
Saves checkpoint for 'ms' backend.
Args:
train_network (Network): The train network for training.
ckpoint_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
ckpt_file_name (str): The name of checkpoint file.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
"""
param_dict = {}
......@@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
each_param["data"] = param_data
param_list.append(each_param)
save_checkpoint(param_list, ckpoint_file_name)
save_checkpoint(param_list, ckpt_file_name, model_type)
def _get_merged_param_data(net, param_name, param_data):
......
......@@ -20,16 +20,14 @@ python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
import os
import argparse
from src.dataset import create_dataset
from src.config import mnist_cfg as cfg
from src.lenet import LeNet5
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from src.dataset import create_dataset
from src.config import mnist_cfg as cfg
from src.lenet import LeNet5
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore Lenet Example')
......@@ -49,9 +47,6 @@ if __name__ == "__main__":
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Testing ==============")
......
......@@ -128,9 +128,9 @@ After all the following we will get the loss value of each step as following:
```bash
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
>>> ...
>>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
```
To save your time, just run this command.
......@@ -197,9 +197,9 @@ After all the following we will get the loss value of each step as following:
```bash
>>> Epoch: [ 1/ 10] step: [ 1/ 900], loss: [2.3040/2.5234], time: [1.300234]
>>> ...
>>> Epoch: [ 10/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
>>> Epoch: [ 10/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [887/ 900], loss: [0.0113/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [888/ 900], loss: [0.0334/0.0223], time: [1.300234]
>>> Epoch: [ 9/ 10] step: [889/ 900], loss: [0.0233/0.0223], time: [1.300234]
```
### Evaluate quantization aware model
......@@ -214,8 +214,8 @@ network = LeNet5Fusion(cfg.num_classes)
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
# convert funsion netwrok to aware quantizaiton network
network = quant.convert_quant_network(network
# convert funsion netwrok to quantization aware network
network = quant.convert_quant_network(network)
```
To save your time, just run this command.
......
......@@ -23,7 +23,6 @@ import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from src.dataset import create_dataset
......@@ -47,16 +46,18 @@ if __name__ == "__main__":
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
step_size = ds_eval.get_dataset_size()
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
# call back and monitor
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
param_dict = load_checkpoint(args.ckpt_path)
# load check point into network
param_dict = load_checkpoint(args.ckpt_path, network.type)
load_param_into_net(network, param_dict)
print("============== Starting Testing ==============")
......
......@@ -23,7 +23,6 @@ import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.train.quant import quant
......@@ -48,20 +47,21 @@ if __name__ == "__main__":
ds_eval = create_dataset(os.path.join(args.data_path, "test"), cfg.batch_size, 1)
step_size = ds_eval.get_dataset_size()
# define funsion network
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# convert funsion netwrok to aware quantizaiton network
# convert fusion netwrok to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
# call back and monitor
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
# load aware quantizaiton network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, model_type="quant")
load_param_into_net(network, param_dict)
print("============== Starting Testing ==============")
......
......@@ -34,8 +34,8 @@ class LeNet5(nn.Cell):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = nn.Conv2d(channel, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, self.num_class)
......
......@@ -32,11 +32,12 @@ class LeNet5(nn.Cell):
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.type = "fusion"
self.num_class = num_class
# change `nn.Conv2d` to `nn.Conv2dBnAct`
self.conv1 = nn.Conv2dBnAct(channel, 6, 5, activation='relu')
self.conv2 = nn.Conv2dBnAct(6, 16, 5, activation='relu')
self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu')
self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu')
# change `nn.Dense` to `nn.DenseBnAct`
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
......
......@@ -46,16 +46,24 @@ if __name__ == "__main__":
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
step_size = ds_train.get_dataset_size()
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max,
model_type=network.type)
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)
print("============== End Training ==============")
......@@ -48,23 +48,30 @@ if __name__ == "__main__":
ds_train = create_dataset(os.path.join(args.data_path, "train"), cfg.batch_size, cfg.epoch_size)
step_size = ds_train.get_dataset_size()
# define funsion network
# define fusion network
network = LeNet5Fusion(cfg.num_classes)
# load aware quantizaiton network checkpoint
param_dict = load_checkpoint(args.ckpt_path)
# load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path, network.type)
load_param_into_net(network, param_dict)
# convert funsion netwrok to aware quantizaiton network
# convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000)
# define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# define network optimization
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
# call back and monitor
time_cb = TimeMonitor(data_size=ds_train.get_dataset_size())
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
config_ckpt = CheckpointConfig(save_checkpoint_steps=cfg.epoch_size * step_size,
keep_checkpoint_max=cfg.keep_checkpoint_max,
model_type="quant")
ckpt_callback = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ckpt)
# define model
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpoint_cb, LossMonitor()],
model.train(cfg['epoch_size'], ds_train, callbacks=[time_cb, ckpt_callback, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)
print("============== End Training ==============")
......@@ -85,7 +85,7 @@ if __name__ == '__main__':
is_ckpt_exist = os.path.exists(ckpt_file_path)
if is_ckpt_exist:
param_dict = load_checkpoint(ckpoint_file_name=ckpt_file_path)
param_dict = load_checkpoint(ckpt_file_name=ckpt_file_path)
load_param_into_net(net, param_dict)
export(net, input_data, file_name=model_path_name, file_format='LITE')
print("test lenet predict success.")
......
......@@ -111,19 +111,19 @@ def test_save_checkpoint():
os.chmod('./parameters.ckpt', stat.S_IWRITE)
os.remove('./parameters.ckpt')
ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
save_checkpoint(parameter_list, ckpoint_file_name)
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
save_checkpoint(parameter_list, ckpt_file_name)
def test_load_checkpoint_error_filename():
ckpoint_file_name = 1
ckpt_file_name = 1
with pytest.raises(ValueError):
load_checkpoint(ckpoint_file_name)
load_checkpoint(ckpt_file_name)
def test_load_checkpoint():
ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
par_dict = load_checkpoint(ckpoint_file_name)
ckpt_file_name = os.path.join(_cur_dir, './parameters.ckpt')
par_dict = load_checkpoint(ckpt_file_name)
assert len(par_dict) == 3
assert par_dict['param_test'].name == 'param_test'
......@@ -136,17 +136,17 @@ def test_checkpoint_manager():
""" test_checkpoint_manager """
ckp_mgr = _CheckpointManager()
ckpoint_file_name = os.path.join(_cur_dir, './test1.ckpt')
with open(ckpoint_file_name, 'w'):
os.chmod(ckpoint_file_name, stat.S_IWUSR | stat.S_IRUSR)
ckpt_file_name = os.path.join(_cur_dir, './test1.ckpt')
with open(ckpt_file_name, 'w'):
os.chmod(ckpt_file_name, stat.S_IWUSR | stat.S_IRUSR)
ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
assert ckp_mgr.ckpoint_num == 1
ckp_mgr.remove_ckpoint_file(ckpoint_file_name)
ckp_mgr.remove_ckpoint_file(ckpt_file_name)
ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
assert ckp_mgr.ckpoint_num == 0
assert not os.path.exists(ckpoint_file_name)
assert not os.path.exists(ckpt_file_name)
another_file_name = os.path.join(_cur_dir, './test2.ckpt')
another_file_name = os.path.realpath(another_file_name)
......@@ -283,7 +283,7 @@ def test_exec_save_checkpoint():
loss_net = WithLossCell(net, loss)
train_network = TrainOneStepCell(loss_net, opt)
_exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt")
_exec_save_checkpoint(train_network, ckpt_file_name="./new_ckpt.ckpt")
load_checkpoint("new_ckpt.ckpt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册