提交 91d69325 编写于 作者: L liuyang_655

modify save_checkpoint

上级 6737369c
...@@ -165,7 +165,7 @@ The parameter name is model\_parallel\_weight and the dividing strategy is to pe ...@@ -165,7 +165,7 @@ The parameter name is model\_parallel\_weight and the dividing strategy is to pe
2. Call the `save_checkpoint` API to write the parameter data to a file and generate a new checkpoint file. 2. Call the `save_checkpoint` API to write the parameter data to a file and generate a new checkpoint file.
``` ```
save_checkpoint(param_list, “./CKP-Integrated_1-4_32.ckpt”) save_checkpoint(save_obj, “./CKP-Integrated_1-4_32.ckpt”)
``` ```
In the preceding information: In the preceding information:
......
...@@ -138,7 +138,7 @@ Here are two examples to further understand the usage of custom Callback. ...@@ -138,7 +138,7 @@ Here are two examples to further understand the usage of custom Callback.
- Save the checkpoint file with the highest accuracy during training. - Save the checkpoint file with the highest accuracy during training.
```python ```python
from mindspore.train.serialization import _exec_save_checkpoint from mindspore.train.serialization import save_checkpoint
class SaveCallback(Callback): class SaveCallback(Callback):
def __init__(self, model, eval_dataset): def __init__(self, model, eval_dataset):
...@@ -155,7 +155,7 @@ Here are two examples to further understand the usage of custom Callback. ...@@ -155,7 +155,7 @@ Here are two examples to further understand the usage of custom Callback.
if result['acc'] > self.acc: if result['acc'] > self.acc:
self.acc = result['acc'] self.acc = result['acc']
file_name = str(self.acc) + ".ckpt" file_name = str(self.acc) + ".ckpt"
_exec_save_checkpoint(train_network=cb_params.train_network, ckpt_file_name=file_name) save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc) print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
......
...@@ -169,7 +169,7 @@ strategy = build_searched_strategy("./strategy_train.cpkt") ...@@ -169,7 +169,7 @@ strategy = build_searched_strategy("./strategy_train.cpkt")
2. 调用`save_checkpoint`接口,将参数数据写入文件,生成新的CheckPoint文件。 2. 调用`save_checkpoint`接口,将参数数据写入文件,生成新的CheckPoint文件。
``` ```
save_checkpoint(param_list, “./CKP-Integrated_1-4_32.ckpt”) save_checkpoint(save_obj, “./CKP-Integrated_1-4_32.ckpt”)
``` ```
其中, 其中,
- `save_checkpoint`: 通过该接口将网络模型参数信息存入文件。 - `save_checkpoint`: 通过该接口将网络模型参数信息存入文件。
......
...@@ -140,7 +140,7 @@ Callback可以把训练过程中的重要信息记录下来,通过一个字典 ...@@ -140,7 +140,7 @@ Callback可以把训练过程中的重要信息记录下来,通过一个字典
- 保存训练过程中精度最高的checkpoint文件。 - 保存训练过程中精度最高的checkpoint文件。
```python ```python
from mindspore.train.serialization import _exec_save_checkpoint from mindspore.train.serialization import save_checkpoint
class SaveCallback(Callback): class SaveCallback(Callback):
def __init__(self, model, eval_dataset): def __init__(self, model, eval_dataset):
...@@ -157,7 +157,7 @@ Callback可以把训练过程中的重要信息记录下来,通过一个字典 ...@@ -157,7 +157,7 @@ Callback可以把训练过程中的重要信息记录下来,通过一个字典
if result['acc'] > self.acc: if result['acc'] > self.acc:
self.acc = result['acc'] self.acc = result['acc']
file_name = str(self.acc) + ".ckpt" file_name = str(self.acc) + ".ckpt"
_exec_save_checkpoint(train_network=cb_params.train_network, ckpt_file_name=file_name) save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc) print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
......
...@@ -52,7 +52,7 @@ from mindspore.ops import composite as C ...@@ -52,7 +52,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train.dataset_helper import DatasetHelper from mindspore.train.dataset_helper import DatasetHelper
from mindspore.train.serialization import _exec_save_checkpoint from mindspore.train.serialization import save_checkpoint
from model_zoo.official.cv.lenet.src.dataset import create_dataset from model_zoo.official.cv.lenet.src.dataset import create_dataset
from model_zoo.official.cv.lenet.src.lenet import LeNet5 from model_zoo.official.cv.lenet.src.lenet import LeNet5
``` ```
...@@ -199,7 +199,7 @@ class GradientAccumulation: ...@@ -199,7 +199,7 @@ class GradientAccumulation:
train_dataset.reset() train_dataset.reset()
_exec_save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt", ) save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt", )
``` ```
### 训练并保存模型 ### 训练并保存模型
......
...@@ -9,7 +9,7 @@ from mindspore.ops import composite as C ...@@ -9,7 +9,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.train.dataset_helper import DatasetHelper from mindspore.train.dataset_helper import DatasetHelper
from mindspore.train.serialization import _exec_save_checkpoint from mindspore.train.serialization import save_checkpoint
from model_zoo.official.cv.lenet.src.dataset import create_dataset from model_zoo.official.cv.lenet.src.dataset import create_dataset
from model_zoo.official.cv.lenet.src.lenet import LeNet5 from model_zoo.official.cv.lenet.src.lenet import LeNet5
...@@ -124,7 +124,7 @@ class GradientAccumulation: ...@@ -124,7 +124,7 @@ class GradientAccumulation:
train_dataset.reset() train_dataset.reset()
_exec_save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt", ) save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt", )
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册