提交 7a48b441 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!654 add callback example

Merge pull request !654 from changzherui/mod_callback
...@@ -97,41 +97,76 @@ The main attributes of `cb_params` are as follows: ...@@ -97,41 +97,76 @@ The main attributes of `cb_params` are as follows:
You can inherit the callback base class to customize a callback object. You can inherit the callback base class to customize a callback object.
The following example describes how to use a custom callback function. Here are two examples to further understand the usage of custom Callback.
- Terminate training within the specified time.
```python
class StopAtTime(Callback):
def __init__(self, run_time):
super(StopAtTime, self).__init__()
self.run_time = run_time*60
def begin(self, run_context):
cb_params = run_context.original_args()
cb_params.init_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
step_num = cb_params.cur_step_num
loss = cb_params.net_outputs
cur_time = time.time()
if (cur_time - cb_params.init_time) > self.run_time:
print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss)
run_context.request_stop()
stop_cb = StopAtTime(run_time=10)
model.train(100, dataset, callbacks=stop_cb)
```
```python The output is as follows:
class StopAtTime(Callback):
def __init__(self, run_time):
super(StopAtTime, self).__init__()
self.run_time = run_time*60
def begin(self, run_context): ```
cb_params = run_context.original_args() epoch: 20 step: 32 loss: 2.298344373703003
cb_params.init_time = time.time() ```
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
step_num = cb_params.cur_step_num
loss = cb_params.net_outputs
cur_time = time.time()
if (cur_time - cb_params.init_time) > self.run_time:
print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss)
run_context.request_stop()
stop_cb = StopAtTime(run_time=10)
model.train(100, dataset, callbacks=stop_cb)
```
The output is as follows: The implementation logic is: You can use the `run_context.original_args` method to obtain the `cb_params` dictionary, which contains the main attribute information described above.
In addition, you can modify and add values in the dictionary. In the preceding example, an `init_time` object is defined in `begin` and transferred to the `cb_params` dictionary.
A decision is made at each `step_end`. When the training time is greater than the configured time threshold, a training termination signal will be sent to the `run_context` to terminate the training in advance and the current values of epoch, step, and loss will be printed.
``` - Save the checkpoint file with the highest accuracy during training.
epoch: 20 step: 32 loss: 2.298344373703003
``` ```python
from mindspore.train.serialization import _exec_save_checkpoint
class SaveCallback(Callback):
def __init__(self, model, eval_dataset):
super(SaveCallback, self).__init__()
self.model = model
self.eval_dataset = eval_dataset
self.acc = 0.5
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
result = self.model.eval(self.dataset)
if result['acc'] > self.acc:
self.acc = result['acc']
file_name = str(self.acc) + ".ckpt"
_exec_save_checkpoint(train_network=cb_params.train_network, ckpt_file_name=file_name)
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
network = Lenet()
loss = nn.SoftmaxCrossEntryWithLogits()
oprimizer = nn.Momentum()
model = Model(network, loss_fn=loss, optimizer=optimizer, metrics={"accuracy"})
model.train(epoch_size, train_dataset=ds_train, callback=SaveCallback(model, ds_eval))
```
This callback function is used to terminate the training within a specified period. You can use the `run_context.original_args` method to obtain the `cb_params` dictionary, which contains the main attribute information described above. The specific implementation logic is: define a callback object, and initialize the object to receive the model object and the ds_eval (verification dataset). Verify the accuracy of the model in the step_end phase. When the accuracy is the current highest, manually trigger the save checkpoint method to save the current parameters.
In addition, you can modify and add values in the dictionary. In the preceding example, an `init_time` object is defined in `begin` and transferred to the `cb_params` dictionary.
A decision is made at each `step_end`. When the training time is greater than the configured time threshold, a training termination signal will be sent to the `run_context` to terminate the training in advance and the current values of epoch, step, and loss will be printed.
## MindSpore Metrics ## MindSpore Metrics
......
...@@ -99,41 +99,76 @@ Callback可以把训练过程中的重要信息记录下来,通过一个字典 ...@@ -99,41 +99,76 @@ Callback可以把训练过程中的重要信息记录下来,通过一个字典
用户可以继承Callback基类自定义Callback对象。 用户可以继承Callback基类自定义Callback对象。
下面通过介绍一个例子,更深一步地了解自定义Callback的用法。 下面通过两个例子,进一步了解自定义Callback的用法。
- 在规定时间内终止训练。
```python
class StopAtTime(Callback):
def __init__(self, run_time):
super(StopAtTime, self).__init__()
self.run_time = run_time*60
def begin(self, run_context):
cb_params = run_context.original_args()
cb_params.init_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
step_num = cb_params.cur_step_num
loss = cb_params.net_outputs
cur_time = time.time()
if (cur_time - cb_params.init_time) > self.run_time:
print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss)
run_context.request_stop()
stop_cb = StopAtTime(run_time=10)
model.train(100, dataset, callbacks=stop_cb)
```
```python 输出:
class StopAtTime(Callback):
def __init__(self, run_time):
super(StopAtTime, self).__init__()
self.run_time = run_time*60
def begin(self, run_context): ```
cb_params = run_context.original_args() epoch: 20 step: 32 loss: 2.298344373703003
cb_params.init_time = time.time() ```
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
step_num = cb_params.cur_step_num
loss = cb_params.net_outputs
cur_time = time.time()
if (cur_time - cb_params.init_time) > self.run_time:
print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss)
run_context.request_stop()
stop_cb = StopAtTime(run_time=10)
model.train(100, dataset, callbacks=stop_cb)
```
输出: 实现逻辑为:通过`run_context.original_args`方法可以获取到`cb_params`字典,字典里会包含前文描述的主要属性信息。
同时可以对字典内的值进行修改和添加,上述用例中,在`begin`中定义一个`init_time`对象传递给`cb_params`字典。
在每次`step_end`会做出判断,当训练时间大于设置的时间阈值时,会向`run_context`传递终止训练的信号,提前终止训练,并打印当前的`epoch`、`step`、`loss`的值。
``` - 保存训练过程中精度最高的checkpoint文件。
epoch: 20 step: 32 loss: 2.298344373703003
``` ```python
from mindspore.train.serialization import _exec_save_checkpoint
class SaveCallback(Callback):
def __init__(self, model, eval_dataset):
super(SaveCallback, self).__init__()
self.model = model
self.eval_dataset = eval_dataset
self.acc = 0.5
def step_end(self, run_context):
cb_params = run_context.original_args()
epoch_num = cb_params.cur_epoch_num
result = self.model.eval(self.dataset)
if result['acc'] > self.acc:
self.acc = result['acc']
file_name = str(self.acc) + ".ckpt"
_exec_save_checkpoint(train_network=cb_params.train_network, ckpt_file_name=file_name)
print("Save the maximum accuracy checkpoint,the accuracy is", self.acc)
network = Lenet()
loss = nn.SoftmaxCrossEntryWithLogits()
oprimizer = nn.Momentum()
model = Model(network, loss_fn=loss, optimizer=optimizer, metrics={"accuracy"})
model.train(epoch_size, train_dataset=ds_train, callback=SaveCallback(model, ds_eval))
```
此Callback的功能为:在规定时间内终止训练。通过`run_context.original_args`方法可以获取到`cb_params`字典,字典里会包含前文描述的主要属性信息。 具体实现逻辑为:定义一个callback对象,初始化对象接收model对象和ds_eval(验证数据集)。在step_end阶段验证模型的精度,当精度为当前最高时,手动触发保存checkpoint方法,保存当前的参数。
同时可以对字典内的值进行修改和添加,上述用例中,在`begin`中定义一个`init_time`对象传递给`cb_params`字典。
在每次`step_end`会做出判断,当训练时间大于设置的时间阈值时,会向`run_context`传递终止训练的信号,提前终止训练,并打印当前的`epoch``step``loss`的值。
## MindSpore metrics功能介绍 ## MindSpore metrics功能介绍
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册