提交 e130247c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!687 delete lenet and create_dataset api description

Merge pull request !687 from jinyaohui/master
......@@ -28,7 +28,7 @@
最终目的是为了达到跟直接用N*Mini-batch数据训练几乎同样的效果。
> 本教程用于GPU、Ascend 910 AI处理器, 你可以在这里下载完整的样例代码:<https://gitee.com/mindspore/docs/tree/master/tutorials/tutorial_code/gradient_accumulation>
> 本教程用于GPU、Ascend 910 AI处理器, 你可以在这里下载主要的训练样例代码:<https://gitee.com/mindspore/docs/tree/master/tutorials/tutorial_code/gradient_accumulation>
## 创建梯度累积模型
......@@ -57,111 +57,11 @@ from model_zoo.official.cv.lenet.src.lenet import LeNet5
### 加载数据集
利用MindSpore的dataset提供的`MnistDataset`接口加载MNIST数据集。
```python
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
```
利用MindSpore的dataset提供的`MnistDataset`接口加载MNIST数据集,此部分代码由model_zoo中lenet目录下的[dataset.py](<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/lenet/src/dataset.py>)导入。
### 定义网络
这里以LeNet网络为例进行介绍,当然也可以使用其它的网络,如ResNet-50、BERT等。
```python
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Number classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = conv(channel, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
```
这里以LeNet网络为例进行介绍,当然也可以使用其它的网络,如ResNet-50、BERT等, 此部分代码由model_zoo中lenet目录下的[lenet.py](<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/lenet/src/lenet.py>)导入。
### 定义训练模型
将训练流程拆分为正向反向训练、参数更新和累积梯度清理三个部分:
......@@ -350,7 +250,7 @@ if __name__ == "__main__":
**验证模型**
通过model_zoo下lenet网络[eval.py](<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/lenet/train.py>),使用保存的CheckPoint文件,加载验证数据集,进行验证。
通过model_zoo中lenet目录下[eval.py](<https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/lenet/train.py>),使用保存的CheckPoint文件,加载验证数据集,进行验证。
```shell
$ python eval.py --data_path=./MNIST_Data --ckpt_path=./gradient_accumulation.ckpt
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册