提交 c81567fb 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!255 Add SummaryCollector callback tutorial to the train visualization tutorials

Merge pull request !255 from ougongchang/feature_collector
......@@ -38,12 +38,49 @@ Scalars, images, computational graphs, and model hyperparameters during training
### Collect Summary Data
Currently, MindSpore uses the `Callback` mechanism to save scalars, images, computational graphs, and model hyperparameters to summary log files and display them on the web page.
Currently, MindSpore supports to save scalars, images, computational graph, and model hyperparameters to summary log file and display them on the web page.
Scalar and image data is recorded by using the `Summary` operator. A computational graph is saved to the summary log file by using `SummaryRecord` after network compilation is complete.
Model parameters are saved to the summary log file by using `TrainLineage` or `EvalLineage`.
MindSpore currently supports three ways to record data into summary log file.
Step 1: Call the `Summary` operator in the `construct` function of the derived class that inherits `nn.Cell` to collect image or scalar data.
**Method one: Automatically collected through `SummaryCollector`**
The `Callback` mechanism in MindSpore provides a quick and easy way to collect common information, including the calculational graph, loss value, learning rate, parameter weights, etc. It is named 'SummaryCollector'.
When you write a training script, you just instantiate the `SummaryCollector` and apply it to either `model.train` or `model.eval`. You can automatically collect some common summary data. `SummaryCollector` detailed usage can reference `API` document `mindspore.train.callback.SummaryCollector`.
The sample code is as follows:
```python
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.train import Model
from mindspore.model_zoo.alexnet import AlexNet
from mindspore.train.callback import SummaryCollector
context.set_context(mode=context.GRAPH_MODE)
network = AlexNet(num_classes=10)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
lr = Tensor(0.1)
opt = nn.Momentum(network.trainable_params(), lr, momentum=0.9)
model = Model(network, loss, opt)
ds_train = create_dataset('./dataset_path')
# Init a SummaryCollector callback instance, and use it in model.train or model.eval
summmary_collector = SummaryCollector(summary_dir='./summary_dir', collect_freq=1)
# Note: dataset_sink_mode should be set to False, else you should modify collect freq in SummaryCollector
model.train(epoch=1, ds_train, callbacks=[summmary_collector], dataset_sink_mode=False)
ds_eval = create_dataset('./dataset_path')
model.eval(ds_eval, callbacks=[summary_collector])
```
**Method two: Custom collection of network data with summary operators and SummaryCollector**
In addition to providing the `SummaryCollector` that automatically collects some summary data, MindSpore provides summary operators that enable custom collection other data on the network, such as the input of each convolutional layer, or the loss value in the loss function, etc. The recording method is shown in the following steps.
Step 1: Call the summary operator in the `construct` function of the derived class that inherits `nn.Cell` to collect image or scalar data.
For example, when a network is defined, image data is recorded in `construct` of the network. When the loss function is defined, the loss value is recorded in `construct` of the loss function.
......@@ -120,62 +157,77 @@ class Net(nn.Cell):
return out
```
Step 2: Use the `Callback` mechanism to add the required callback instance to specify the data to be recorded during training.
Step 2: In the training script, instantiate the `SummaryCollector` and apply it to `model.train`.
- `SummaryStep` specifies the step interval for recording summary data.
The sample code is as follows:
- `TrainLineage` records parameters related to model training.
```python
from mindspore import Model, nn, context
from mindspore.train.callback import SummaryCollector
- `EvalLineage` records parameters related to the model test.
context.set_context(mode=context.GRAPH_MODE)
net = Net()
loss_fn = CrossEntropyLoss()
optim = MyOptimizer(learning_rate=0.01, params=network.trainable_params())
model = Model(net, loss_fn=loss_fn, optimizer=optim, metrics=None)
The `network` parameter needs to be specified when `SummaryRecord` is called to record the computational graph. By default, the computational graph is not recorded.
train_ds = create_mindrecord_dataset_for_training()
The sample code is as follows:
summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_freq=1)
model.train(epoch=2, train_ds, callbacks=[summary_collector])
```
```python
from mindinsight.lineagemgr import TrainLineage, EvalLineage
from mindspore import Model, nn, context
from mindspore.train.callback import SummaryStep
from mindspore.train.summary.summary_record import SummaryRecord
**Method three: Custom callback recording data**
MindSpore supports custom callback and support to record data into summary log file
in custom callback, and display the data by the web page.
def test_summary():
# Init context env
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
# Init hyperparameter
epoch = 2
# Init network and Model
net = Net()
loss_fn = CrossEntropyLoss()
optim = MyOptimizer(learning_rate=0.01, params=network.trainable_params())
model = Model(net, loss_fn=loss_fn, optimizer=optim, metrics=None)
The following pseudocode is shown in the CNN network, where developers can use the network output with the original tag and the prediction tag to generate the image of the confusion matrix.
It is then recorded into the summary log file through the `SummaryRecord` module.
`SummaryRecord` detailed usage can reference `API` document `mindspore.train.summary.SummaryRecord`.
# Init SummaryRecord and specify a folder for storing summary log files
# and specify the graph that needs to be recorded
with SummaryRecord(log_dir='./summary', network=net) as summary_writer:
summary_callback = SummaryStep(summary_writer, flush_step=10)
The sample code is as follows:
# Init TrainLineage to record the training information
train_callback = TrainLineage(summary_writer)
```
from mindspore.train.callback import Callback
from mindspore.train.summary import SummaryRecord
# Prepare mindrecord_dataset for training
train_ds = create_mindrecord_dataset_for_training()
model.train(epoch, train_ds, callbacks=[summary_callback, train_callback])
class ConfusionMatrixCallback(Callback):
def __init__(self, summary_dir):
self._summary_dir = summary_dir
def __enter__(self):
# init you summary record in here, when the train script run, it will be inited before training
self.summary_record = SummaryRecord(summary_dir)
def __exit__(self, *exc_args):
# Note: you must close the summary record, it will release the process pool resource
# else your training script will not exit from training.
self.summary_record.close()
return self
# Init EvalLineage to record the evaluation information
eval_callback = EvalLineage(summary_writer)
def step_end(self, run_context):
cb_params = run_context.run_context.original_args()
# Prepare mindrecord_dataset for testing
eval_ds = create_mindrecord_dataset_for_testing()
model.eval(eval_ds, callbacks=[eval_callback])
# create a confusion matric image, and record it to summary file
confusion_martrix = create_confusion_matrix(cb_params)
self.summary_record.add_value('image', 'confusion_matrix', confusion_matric)
self.summary_record.record(cb_params.cur_step)
# init you train script
...
confusion_martrix = ConfusionMartrixCallback(summary_dir='./summary_dir')
model.train(cnn_network, callbacks=[confusion_martrix])
```
Use the `save_graphs` option of `context` to record the computational graph after operator fusion.
`ms_output_after_hwopt.pb` is the computational graph after operator fusion.
The above three ways, support the record computational graph, loss value and other data. In addition, MindSpore also supports the saving of computational graph for other phases of training, through
the `save_graphs` option of `context.set_context` in the training script is set to `True` to record computational graphs of other phases, including the computational graph after operator fusion.
In the saved files, `ms_output_after_hwopt.pb` is the computational graph after operator fusion, which can be viewed on the web page.
> - Currently MindSpore supports recording computational graph after operator fusion for Ascend 910 AI processor only.
> - It's recommended that you reduce calls to `HistogramSummary` under 10 times per batch. The more you call `HistogramSummary`, the more performance overhead.
> - Please use the *with statement* to ensure that `SummaryRecord` is properly closed at the end, otherwise the process may fail to exit.
> - When using the Summary operator to collect data in training, 'HistogramSummary' operator affects performance, so please use as little as possible.
### Collect Performance Profile Data
......
......@@ -43,16 +43,54 @@
### Summary数据收集
当前MindSpore利用 `Callback` 机制将标量、图像、计算图、模型超参等信息保存到summary日志文件中,并通过可视化界面进行展示。
当前MindSpore支持将标量、图像、计算图、模型超参等信息保存到summary日志文件中,并通过可视化界面进行展示。
其中标量、图像是通过Summary算子实现记录数据,计算图是在网络编译完成后,通过 `SummaryRecord` 将其保存到summary日志文件中,
模型参数是通过 `TrainLineage``EvalLineage` 保存到summary日志文件中。
MindSpore目前支持三种方式将数据记录到summary日志文件中。
步骤一:在继承 `nn.Cell` 的衍生类的 `construct` 函数中调用Summary算子来采集图像或标量数据。
**方式一:通过 `SummaryCollector` 自动收集**
比如,在定义网络时,在网络的 `construct` 中记录图像数据;在定义损失函数时,在损失函数的 `construct`中记录损失值
在MindSpore中通过 `Callback` 机制提供支持快速简易地收集一些常见的信息,包括计算图,损失值,学习率,参数权重等信息的 `Callback`, 叫做 `SummaryCollector`
如果要记录动态学习率,可以在定义优化器时,在优化器的 `construct` 中记录学习率。
在编写训练脚本时,仅需要实例化 `SummaryCollector`,并将其应用到 `model.train` 或者 `model.eval` 中,
即可自动收集一些常见信息。`SummaryCollector` 详细的用法可以参考 `API` 文档中 `mindspore.train.callback.SummaryCollector`
样例代码如下:
```python
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
from mindspore.train import Model
from mindspore.model_zoo.alexnet import AlexNet
from mindspore.train.callback import SummaryCollector
context.set_context(mode=context.GRAPH_MODE)
network = AlexNet(num_classes=10)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
lr = Tensor(0.1)
opt = nn.Momentum(network.trainable_params(), lr, momentum=0.9)
model = Model(network, loss, opt)
ds_train = create_dataset('./dataset_path')
# Init a SummaryCollector callback instance, and use it in model.train or model.eval
summmary_collector = SummaryCollector(summary_dir='./summary_dir', collect_freq=1)
# Note: dataset_sink_mode should be set to False, else you should modify collect freq in SummaryCollector
model.train(epoch=1, ds_train, callbacks=[summmary_collector], dataset_sink_mode=False)
ds_eval = create_dataset('./dataset_path')
model.eval(ds_eval, callbacks=[summary_collector])
```
**方式二:结合Summary算子和 `SummaryCollector`,自定义收集网络中的数据**
MindSpore除了提供 `SummaryCollector` 能够自动收集一些常见数据,还提供了Summary算子,支持在网络中自定义收集其他的数据,比如每一个卷积层的输入,或在损失函数中的损失值等。记录方式如下面的步骤所示。
步骤一:在继承 `nn.Cell` 的衍生类的 `construct` 函数中调用Summary算子来采集图像或标量数据或者其他数据。
比如,定义网络时,在网络的 `construct` 中记录图像数据;定义损失函数时,在损失函数的 `construct`中记录损失值。
如果要记录动态学习率,可以定义优化器时,在优化器的 `construct` 中记录学习率。
样例代码如下:
......@@ -108,7 +146,6 @@ class MyOptimizer(Optimizer):
......
class Net(nn.Cell):
"""Net definition."""
def __init__(self):
......@@ -126,62 +163,78 @@ class Net(nn.Cell):
```
步骤二:通过 `Callback` 的机制,添加所需的Callback实例来指定训练过程中所需要记录的数据
步骤二:在训练脚本中,实例化 `SummaryCollector`,并将其应用到 `model.train`
- `SummaryStep` 用于指定记录summary数据的步骤间隔,每隔指定步骤记录一次数据。
样例代码如下:
- `TrainLineage` 用于记录模型训练相关的参数信息。
```python
from mindspore import Model, nn, context
from mindspore.train.callback import SummaryCollector
- `EvalLineage` 用于记录模型测试相关的参数信息。
context.set_context(mode=context.GRAPH_MODE)
net = Net()
loss_fn = CrossEntropyLoss()
optim = MyOptimizer(learning_rate=0.01, params=network.trainable_params())
model = Model(net, loss_fn=loss_fn, optimizer=optim, metrics=None)
其中,记录计算图需要在调用 `SummaryRecord` 时,指定 `network` 参数,默认不记录。
train_ds = create_mindrecord_dataset_for_training()
样例代码如下:
summary_collector = SummaryCollector(summary_dir='./summary_dir', collect_freq=1)
model.train(epoch=2, train_ds, callbacks=[summary_collector])
```
```python
from mindinsight.lineagemgr import TrainLineage, EvalLineage
from mindspore import Model, nn, context
from mindspore.train.callback import SummaryStep
from mindspore.train.summary.summary_record import SummaryRecord
**方式三:自定义Callback记录数据**
MindSpore支持自定义Callback, 并允许在自定义Callback中将数据记录到summary日志文件中,
并通过可视化页面进行查看。
def test_summary():
# Init context env
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
# Init hyperparameter
epoch = 2
# Init network and Model
net = Net()
loss_fn = CrossEntropyLoss()
optim = MyOptimizer(learning_rate=0.01, params=network.trainable_params())
model = Model(net, loss_fn=loss_fn, optimizer=optim, metrics=None)
下面的伪代码则展示在CNN网络中,开发者可以利用带有原始标签和预测标签的网络输出,生成混淆矩阵的图片,
然后通过 `SummaryRecord` 模块记录到summary日志文件中。
`SummaryRecord` 详细的用法可以参考 `API` 文档中 `mindspore.train.summary.SummaryRecord`
# Init SummaryRecord and specify a folder for storing summary log files
# and specify the graph that needs to be recorded
with SummaryRecord(log_dir='./summary', network=net) as summary_writer:
summary_callback = SummaryStep(summary_writer, flush_step=10)
样例代码如下:
```
from mindspore.train.callback import Callback
from mindspore.train.summary import SummaryRecord
class ConfusionMatrixCallback(Callback):
def __init__(self, summary_dir):
self._summary_dir = summary_dir
def __enter__(self):
# init you summary record in here, when the train script run, it will be inited before training
self.summary_record = SummaryRecord(summary_dir)
def __exit__(self, *exc_args):
# Note: you must close the summary record, it will release the process pool resource
# else your training script will not exit from training.
self.summary_record.close()
return self
# Init TrainLineage to record the training information
train_callback = TrainLineage(summary_writer)
def step_end(self, run_context):
cb_params = run_context.run_context.original_args()
# Prepare mindrecord_dataset for training
train_ds = create_mindrecord_dataset_for_training()
model.train(epoch, train_ds, callbacks=[summary_callback, train_callback])
# create a confusion matric image, and record it to summary file
confusion_martrix = create_confusion_matrix(cb_params)
self.summary_record.add_value('image', 'confusion_matrix', confusion_matric)
self.summary_record.record(cb_params.cur_step)
# Init EvalLineage to record the evaluation information
eval_callback = EvalLineage(summary_writer)
# init you train script
...
# Prepare mindrecord_dataset for testing
eval_ds = create_mindrecord_dataset_for_testing()
model.eval(eval_ds, callbacks=[eval_callback])
confusion_martrix = ConfusionMartrixCallback(summary_dir='./summary_dir')
model.train(cnn_network, callbacks=[confusion_martrix])
```
可以通过脚本中`context``save_graphs`选项配置记录算子融合后的计算图。
其中`ms_output_after_hwopt.pb`为算子融合后的计算图。
上面的三种方式,支持记录计算图, 损失值等多种数据。除此以外,MindSpore还支持保存训练中其他阶段的计算图,通过
将训练脚本中 `context.set_context``save_graphs` 选项设置为 `True`, 可以记录其他阶段的计算图,其中包括算子融合后的计算图。
在保存的文件中,`ms_output_after_hwopt.pb` 即为算子融合后的计算图,可以使用可视化页面对其进行查看。
> - 目前MindSpore仅支持在Ascend 910 AI处理器上导出算子融合后的计算图。
> - 一个batch中,`HistogramSummary`算子的调用次数请尽量控制在10次以下,调用次数越多,性能开销越大
> - 请使用*with语句*确保`SummaryRecord`最后正确关闭,否则可能会导致进程无法退出。
> - 在训练中使用Summary算子收集数据时,`HistogramSummary`算子会影响性能,所以请尽量少地使用
### 性能数据收集
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册