diff --git a/tutorials/source_en/advanced_use/customized_debugging_information.md b/tutorials/source_en/advanced_use/customized_debugging_information.md index 7f735f6dfdfe3bd00bccb7350e12a7be40902e69..d9132bd1bcceac5b27e4a0d4bbf895051e69cb7f 100644 --- a/tutorials/source_en/advanced_use/customized_debugging_information.md +++ b/tutorials/source_en/advanced_use/customized_debugging_information.md @@ -97,41 +97,76 @@ The main attributes of `cb_params` are as follows: You can inherit the callback base class to customize a callback object. -The following example describes how to use a custom callback function. +Here are two examples to further understand the usage of custom Callback. + +- Terminate training within the specified time. + + ```python + class StopAtTime(Callback): + def __init__(self, run_time): + super(StopAtTime, self).__init__() + self.run_time = run_time*60 + + def begin(self, run_context): + cb_params = run_context.original_args() + cb_params.init_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + epoch_num = cb_params.cur_epoch_num + step_num = cb_params.cur_step_num + loss = cb_params.net_outputs + cur_time = time.time() + if (cur_time - cb_params.init_time) > self.run_time: + print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss) + run_context.request_stop() + + stop_cb = StopAtTime(run_time=10) + model.train(100, dataset, callbacks=stop_cb) + ``` -```python -class StopAtTime(Callback): - def __init__(self, run_time): - super(StopAtTime, self).__init__() - self.run_time = run_time*60 + The output is as follows: - def begin(self, run_context): - cb_params = run_context.original_args() - cb_params.init_time = time.time() - - def step_end(self, run_context): - cb_params = run_context.original_args() - epoch_num = cb_params.cur_epoch_num - step_num = cb_params.cur_step_num - loss = cb_params.net_outputs - cur_time = time.time() - if (cur_time - cb_params.init_time) > self.run_time: - print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss) - run_context.request_stop() - -stop_cb = StopAtTime(run_time=10) -model.train(100, dataset, callbacks=stop_cb) -``` + ``` + epoch: 20 step: 32 loss: 2.298344373703003 + ``` -The output is as follows: + The implementation logic is: You can use the `run_context.original_args` method to obtain the `cb_params` dictionary, which contains the main attribute information described above. + In addition, you can modify and add values in the dictionary. In the preceding example, an `init_time` object is defined in `begin` and transferred to the `cb_params` dictionary. + A decision is made at each `step_end`. When the training time is greater than the configured time threshold, a training termination signal will be sent to the `run_context` to terminate the training in advance and the current values of epoch, step, and loss will be printed. -``` -epoch: 20 step: 32 loss: 2.298344373703003 -``` +- Save the checkpoint file with the highest accuracy during training. + + ```python + from mindspore.train.serialization import _exec_save_checkpoint + + class SaveCallback(Callback): + def __init__(self, model, eval_dataset): + super(SaveCallback, self).__init__() + self.model = model + self.eval_dataset = eval_dataset + self.acc = 0.5 + + def step_end(self, run_context): + cb_params = run_context.original_args() + epoch_num = cb_params.cur_epoch_num + + result = self.model.eval(self.dataset) + if result['acc'] > self.acc: + self.acc = result['acc'] + file_name = str(self.acc) + ".ckpt" + _exec_save_checkpoint(train_network=cb_params.train_network, ckpt_file_name=file_name) + print("Save the maximum accuracy checkpoint,the accuracy is", self.acc) + + + network = Lenet() + loss = nn.SoftmaxCrossEntryWithLogits() + oprimizer = nn.Momentum() + model = Model(network, loss_fn=loss, optimizer=optimizer, metrics={"accuracy"}) + model.train(epoch_size, train_dataset=ds_train, callback=SaveCallback(model, ds_eval)) + ``` -This callback function is used to terminate the training within a specified period. You can use the `run_context.original_args` method to obtain the `cb_params` dictionary, which contains the main attribute information described above. -In addition, you can modify and add values in the dictionary. In the preceding example, an `init_time` object is defined in `begin` and transferred to the `cb_params` dictionary. -A decision is made at each `step_end`. When the training time is greater than the configured time threshold, a training termination signal will be sent to the `run_context` to terminate the training in advance and the current values of epoch, step, and loss will be printed. + The specific implementation logic is: define a callback object, and initialize the object to receive the model object and the ds_eval (verification dataset). Verify the accuracy of the model in the step_end phase. When the accuracy is the current highest, manually trigger the save checkpoint method to save the current parameters. ## MindSpore Metrics diff --git a/tutorials/source_zh_cn/advanced_use/customized_debugging_information.md b/tutorials/source_zh_cn/advanced_use/customized_debugging_information.md index adfd17b572dd4e91e2a490d6cacbe79278e320ab..9e66951fbd0f6147f8c5496ccf62d1f8ac8a9876 100644 --- a/tutorials/source_zh_cn/advanced_use/customized_debugging_information.md +++ b/tutorials/source_zh_cn/advanced_use/customized_debugging_information.md @@ -99,41 +99,76 @@ Callback可以把训练过程中的重要信息记录下来,通过一个字典 用户可以继承Callback基类自定义Callback对象。 -下面通过介绍一个例子,更深一步地了解自定义Callback的用法。 +下面通过两个例子,进一步了解自定义Callback的用法。 + +- 在规定时间内终止训练。 + + ```python + class StopAtTime(Callback): + def __init__(self, run_time): + super(StopAtTime, self).__init__() + self.run_time = run_time*60 + + def begin(self, run_context): + cb_params = run_context.original_args() + cb_params.init_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + epoch_num = cb_params.cur_epoch_num + step_num = cb_params.cur_step_num + loss = cb_params.net_outputs + cur_time = time.time() + if (cur_time - cb_params.init_time) > self.run_time: + print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss) + run_context.request_stop() + + stop_cb = StopAtTime(run_time=10) + model.train(100, dataset, callbacks=stop_cb) + ``` -```python -class StopAtTime(Callback): - def __init__(self, run_time): - super(StopAtTime, self).__init__() - self.run_time = run_time*60 + 输出: - def begin(self, run_context): - cb_params = run_context.original_args() - cb_params.init_time = time.time() - - def step_end(self, run_context): - cb_params = run_context.original_args() - epoch_num = cb_params.cur_epoch_num - step_num = cb_params.cur_step_num - loss = cb_params.net_outputs - cur_time = time.time() - if (cur_time - cb_params.init_time) > self.run_time: - print("epoch: ", epoch_num, " step: ", step_num, " loss: ", loss) - run_context.request_stop() - -stop_cb = StopAtTime(run_time=10) -model.train(100, dataset, callbacks=stop_cb) -``` + ``` + epoch: 20 step: 32 loss: 2.298344373703003 + ``` -输出: + 实现逻辑为:通过`run_context.original_args`方法可以获取到`cb_params`字典,字典里会包含前文描述的主要属性信息。 + 同时可以对字典内的值进行修改和添加,上述用例中,在`begin`中定义一个`init_time`对象传递给`cb_params`字典。 + 在每次`step_end`会做出判断,当训练时间大于设置的时间阈值时,会向`run_context`传递终止训练的信号,提前终止训练,并打印当前的`epoch`、`step`、`loss`的值。 -``` -epoch: 20 step: 32 loss: 2.298344373703003 -``` +- 保存训练过程中精度最高的checkpoint文件。 + + ```python + from mindspore.train.serialization import _exec_save_checkpoint + + class SaveCallback(Callback): + def __init__(self, model, eval_dataset): + super(SaveCallback, self).__init__() + self.model = model + self.eval_dataset = eval_dataset + self.acc = 0.5 + + def step_end(self, run_context): + cb_params = run_context.original_args() + epoch_num = cb_params.cur_epoch_num + + result = self.model.eval(self.dataset) + if result['acc'] > self.acc: + self.acc = result['acc'] + file_name = str(self.acc) + ".ckpt" + _exec_save_checkpoint(train_network=cb_params.train_network, ckpt_file_name=file_name) + print("Save the maximum accuracy checkpoint,the accuracy is", self.acc) + + + network = Lenet() + loss = nn.SoftmaxCrossEntryWithLogits() + oprimizer = nn.Momentum() + model = Model(network, loss_fn=loss, optimizer=optimizer, metrics={"accuracy"}) + model.train(epoch_size, train_dataset=ds_train, callback=SaveCallback(model, ds_eval)) + ``` -此Callback的功能为:在规定时间内终止训练。通过`run_context.original_args`方法可以获取到`cb_params`字典,字典里会包含前文描述的主要属性信息。 -同时可以对字典内的值进行修改和添加,上述用例中,在`begin`中定义一个`init_time`对象传递给`cb_params`字典。 -在每次`step_end`会做出判断,当训练时间大于设置的时间阈值时,会向`run_context`传递终止训练的信号,提前终止训练,并打印当前的`epoch`、`step`、`loss`的值。 + 具体实现逻辑为:定义一个callback对象,初始化对象接收model对象和ds_eval(验证数据集)。在step_end阶段验证模型的精度,当精度为当前最高时,手动触发保存checkpoint方法,保存当前的参数。 ## MindSpore metrics功能介绍