提交 7a48b441 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!654 add callback example

Merge pull request !654 from changzherui/mod_callback
......@@ -97,41 +97,76 @@ The main attributes of `cb_params` are as follows:
You can inherit the callback base class to customize a callback object.
The following example describes how to use a custom callback function.
Here are two examples to further understand the usage of custom Callback.
- Terminate training within the specified time.
```python
class StopAtTime(Callback):
def __init__(self, run_time):
super(StopAtTime, self).__init__()
self.run_time = run_time*60
def begin(self, run_context):
cb_params = run_context.original_args()
cb_params.init_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
step_num = cb_params.cur_step_num
loss = cb_params.net_outputs
cur_time = time.time()
if (cur_time - cb_params.init_time) > self.run_time:
print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss)
run_context.request_stop()
stop_cb = StopAtTime(run_time=10)
model.train(100, dataset, callbacks=stop_cb)
```
```python
class StopAtTime(Callback):
def __init__(self, run_time):
super(StopAtTime, self).__init__()
self.run_time = run_time*60
The output is as follows:
def begin(self, run_context):
cb_params = run_context.original_args()
cb_params.init_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
step_num = cb_params.cur_step_num
loss = cb_params.net_outputs
cur_time = time.time()
if (cur_time - cb_params.init_time) > self.run_time:
print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss)
run_context.request_stop()
stop_cb = StopAtTime(run_time=10)
model.train(100, dataset, callbacks=stop_cb)
```
```
epoch: 20 step: 32 loss: 2.298344373703003
```
The output is as follows:
The implementation logic is: You can use the `run_context.original_args` method to obtain the `cb_params` dictionary, which contains the main attribute information described above.
In addition, you can modify and add values in the dictionary. In the preceding example, an `init_time` object is defined in `begin` and transferred to the `cb_params` dictionary.
A decision is made at each `step_end`. When the training time is greater than the configured time threshold, a training termination signal will be sent to the `run_context` to terminate the training in advance and the current values of epoch, step, and loss will be printed.
```
epoch: 20 step: 32 loss: 2.298344373703003
```
- Save the checkpoint file with the highest accuracy during training.
```python
from mindspore.train.serialization import _exec_save_checkpoint
class SaveCallback(Callback):
def __init__(self, model, eval_dataset):
super(SaveCallback, self).__init__()
self.model = model
self.eval_dataset = eval_dataset
self.acc = 0.5
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
result = self.model.eval(self.dataset)
if result['acc'] > self.acc:
self.acc = result['acc']
file_name = str(self.acc) + ".ckpt"
_exec_save_checkpoint(train_network=cb_params.train_network, ckpt_file_name=file_name)
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
network = Lenet()
loss = nn.SoftmaxCrossEntryWithLogits()
oprimizer = nn.Momentum()
model = Model(network, loss_fn=loss, optimizer=optimizer, metrics={"accuracy"})
model.train(epoch_size, train_dataset=ds_train, callback=SaveCallback(model, ds_eval))
```
This callback function is used to terminate the training within a specified period. You can use the `run_context.original_args` method to obtain the `cb_params` dictionary, which contains the main attribute information described above.
In addition, you can modify and add values in the dictionary. In the preceding example, an `init_time` object is defined in `begin` and transferred to the `cb_params` dictionary.
A decision is made at each `step_end`. When the training time is greater than the configured time threshold, a training termination signal will be sent to the `run_context` to terminate the training in advance and the current values of epoch, step, and loss will be printed.
The specific implementation logic is: define a callback object, and initialize the object to receive the model object and the ds_eval (verification dataset). Verify the accuracy of the model in the step_end phase. When the accuracy is the current highest, manually trigger the save checkpoint method to save the current parameters.
## MindSpore Metrics
......
......@@ -99,41 +99,76 @@ Callback可以把训练过程中的重要信息记录下来,通过一个字典
用户可以继承Callback基类自定义Callback对象。
下面通过介绍一个例子,更深一步地了解自定义Callback的用法。
下面通过两个例子,进一步了解自定义Callback的用法。
- 在规定时间内终止训练。
```python
class StopAtTime(Callback):
def __init__(self, run_time):
super(StopAtTime, self).__init__()
self.run_time = run_time*60
def begin(self, run_context):
cb_params = run_context.original_args()
cb_params.init_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
step_num = cb_params.cur_step_num
loss = cb_params.net_outputs
cur_time = time.time()
if (cur_time - cb_params.init_time) > self.run_time:
print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss)
run_context.request_stop()
stop_cb = StopAtTime(run_time=10)
model.train(100, dataset, callbacks=stop_cb)
```
```python
class StopAtTime(Callback):
def __init__(self, run_time):
super(StopAtTime, self).__init__()
self.run_time = run_time*60
输出:
def begin(self, run_context):
cb_params = run_context.original_args()
cb_params.init_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
step_num = cb_params.cur_step_num
loss = cb_params.net_outputs
cur_time = time.time()
if (cur_time - cb_params.init_time) > self.run_time:
print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss)
run_context.request_stop()
stop_cb = StopAtTime(run_time=10)
model.train(100, dataset, callbacks=stop_cb)
```
```
epoch: 20 step: 32 loss: 2.298344373703003
```
输出:
实现逻辑为:通过`run_context.original_args`方法可以获取到`cb_params`字典,字典里会包含前文描述的主要属性信息。
同时可以对字典内的值进行修改和添加,上述用例中,在`begin`中定义一个`init_time`对象传递给`cb_params`字典。
在每次`step_end`会做出判断,当训练时间大于设置的时间阈值时,会向`run_context`传递终止训练的信号,提前终止训练,并打印当前的`epoch`、`step`、`loss`的值。
```
epoch: 20 step: 32 loss: 2.298344373703003
```
- 保存训练过程中精度最高的checkpoint文件。
```python
from mindspore.train.serialization import _exec_save_checkpoint
class SaveCallback(Callback):
def __init__(self, model, eval_dataset):
super(SaveCallback, self).__init__()
self.model = model
self.eval_dataset = eval_dataset
self.acc = 0.5
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
result = self.model.eval(self.dataset)
if result['acc'] > self.acc:
self.acc = result['acc']
file_name = str(self.acc) + ".ckpt"
_exec_save_checkpoint(train_network=cb_params.train_network, ckpt_file_name=file_name)
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
network = Lenet()
loss = nn.SoftmaxCrossEntryWithLogits()
oprimizer = nn.Momentum()
model = Model(network, loss_fn=loss, optimizer=optimizer, metrics={"accuracy"})
model.train(epoch_size, train_dataset=ds_train, callback=SaveCallback(model, ds_eval))
```
此Callback的功能为:在规定时间内终止训练。通过`run_context.original_args`方法可以获取到`cb_params`字典,字典里会包含前文描述的主要属性信息。
同时可以对字典内的值进行修改和添加,上述用例中,在`begin`中定义一个`init_time`对象传递给`cb_params`字典。
在每次`step_end`会做出判断,当训练时间大于设置的时间阈值时,会向`run_context`传递终止训练的信号,提前终止训练,并打印当前的`epoch``step``loss`的值。
具体实现逻辑为:定义一个callback对象,初始化对象接收model对象和ds_eval(验证数据集)。在step_end阶段验证模型的精度,当精度为当前最高时,手动触发保存checkpoint方法,保存当前的参数。
## MindSpore metrics功能介绍
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册