提交 91d69325 编写于 作者: L liuyang_655

modify save_checkpoint

上级 6737369c
......@@ -165,7 +165,7 @@ The parameter name is model\_parallel\_weight and the dividing strategy is to pe
2. Call the `save_checkpoint` API to write the parameter data to a file and generate a new checkpoint file.
```
save_checkpoint(param_list, “./CKP-Integrated_1-4_32.ckpt”)
save_checkpoint(save_obj, “./CKP-Integrated_1-4_32.ckpt”)
```
In the preceding information:
......
......@@ -138,7 +138,7 @@ Here are two examples to further understand the usage of custom Callback.
- Save the checkpoint file with the highest accuracy during training.
```python
from mindspore.train.serialization import _exec_save_checkpoint
from mindspore.train.serialization import save_checkpoint
class SaveCallback(Callback):
def __init__(self, model, eval_dataset):
......@@ -155,7 +155,7 @@ Here are two examples to further understand the usage of custom Callback.
if result['acc'] > self.acc:
self.acc = result['acc']
file_name = str(self.acc) + ".ckpt"
_exec_save_checkpoint(train_network=cb_params.train_network, ckpt_file_name=file_name)
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
......
......@@ -169,7 +169,7 @@ strategy = build_searched_strategy("./strategy_train.cpkt")
2. 调用`save_checkpoint`接口,将参数数据写入文件,生成新的CheckPoint文件。
```
save_checkpoint(param_list, “./CKP-Integrated_1-4_32.ckpt”)
save_checkpoint(save_obj, “./CKP-Integrated_1-4_32.ckpt”)
```
其中,
- `save_checkpoint`: 通过该接口将网络模型参数信息存入文件。
......
......@@ -140,7 +140,7 @@ Callback可以把训练过程中的重要信息记录下来,通过一个字典
- 保存训练过程中精度最高的checkpoint文件。
```python
from mindspore.train.serialization import _exec_save_checkpoint
from mindspore.train.serialization import save_checkpoint
class SaveCallback(Callback):
def __init__(self, model, eval_dataset):
......@@ -157,7 +157,7 @@ Callback可以把训练过程中的重要信息记录下来,通过一个字典
if result['acc'] > self.acc:
self.acc = result['acc']
file_name = str(self.acc) + ".ckpt"
_exec_save_checkpoint(train_network=cb_params.train_network, ckpt_file_name=file_name)
save_checkpoint(save_obj=cb_params.train_network, ckpt_file_name=file_name)
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
......
......@@ -52,7 +52,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train.dataset_helper import DatasetHelper
from mindspore.train.serialization import _exec_save_checkpoint
from mindspore.train.serialization import save_checkpoint
from model_zoo.official.cv.lenet.src.dataset import create_dataset
from model_zoo.official.cv.lenet.src.lenet import LeNet5
```
......@@ -199,7 +199,7 @@ class GradientAccumulation:
train_dataset.reset()
_exec_save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt", )
save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt", )
```
### 训练并保存模型
......
......@@ -9,7 +9,7 @@ from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.train.dataset_helper import DatasetHelper
from mindspore.train.serialization import _exec_save_checkpoint
from mindspore.train.serialization import save_checkpoint
from model_zoo.official.cv.lenet.src.dataset import create_dataset
from model_zoo.official.cv.lenet.src.lenet import LeNet5
......@@ -124,7 +124,7 @@ class GradientAccumulation:
train_dataset.reset()
_exec_save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt", )
save_checkpoint(self._train_forward_backward, "gradient_accumulation.ckpt", )
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册