提交 e830aae3 编写于 作者: W wukesong

modify quick_start_lenet

上级 fc438758
......@@ -190,76 +190,41 @@ The LeNet network is relatively simple. In addition to the input layer, the LeNe
> For details about the LeNet network, visit <http://yann.lecun.com/exdb/lenet/>.
You need to initialize the full connection layers and convolutional layers.
You can initialize the full connection layers and convolutional layers by `Normal`.
`TruncatedNormal`: parameter initialization method. MindSpore supports multiple parameter initialization methods, such as `TruncatedNormal`, `Normal`, and `Uniform`. For details, see the description of the `mindspore.common.initializer` module of the MindSpore API.
The following is the sample code for initialization:
```python
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
def weight_variable():
"""
weight initial
"""
return TruncatedNormal(0.02)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""
conv layer weight initial
"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""
fc layer weight initial
"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
```
MindSpore supports multiple parameter initialization methods, such as `TruncatedNormal`, `Normal`, and `Uniform`, default value is `Normal`. For details, see the description of the `mindspore.common.initializer` module of the MindSpore API.
To use MindSpore for neural network definition, inherit `mindspore.nn.cell.Cell`. `Cell` is the base class of all neural networks (such as `Conv2d`).
Define each layer of a neural network in the `__init__` method in advance, and then define the `construct` method to complete the forward construction of the neural network. According to the structure of the LeNet network, define the network layers as follows:
```python
import mindspore.ops.operations as P
import mindspore.nn as nn
class LeNet5(nn.Cell):
"""
Lenet network structure
"""
#define the operator required
def __init__(self):
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.num_class = num_class
self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.flatten = nn.Flatten()
#use the preceding operators to construct networks
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
```
......
......@@ -193,74 +193,41 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
> 更多的LeNet网络的介绍不在此赘述,希望详细了解LeNet网络,可以查询<http://yann.lecun.com/exdb/lenet/>。
我们需要对全连接层以及卷积层进行初始化。
我们对全连接层以及卷积层采用`Normal`进行参数初始化。
`TruncatedNormal`:参数初始化方法,MindSpore支持`TruncatedNormal``Normal``Uniform`等多种参数初始化方法,具体可以参考MindSpore API的`mindspore.common.initializer`模块说明。
初始化示例代码如下:
```python
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
def weight_variable():
"""
weight initial
"""
return TruncatedNormal(0.02)
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""
conv layer weight initial
"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""
fc layer weight initial
"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
```
MindSpore支持`TruncatedNormal``Normal``Uniform`等多种参数初始化方法,默认采用`Normal`。具体可以参考MindSpore API的`mindspore.common.initializer`模块说明。
使用MindSpore定义神经网络需要继承`mindspore.nn.cell.Cell``Cell`是所有神经网络(`Conv2d`等)的基类。
神经网络的各层需要预先在`__init__`方法中定义,然后通过定义`construct`方法来完成神经网络的前向构造。按照LeNet的网络结构,定义网络各层如下:
```python
import mindspore.nn as nn
class LeNet5(nn.Cell):
"""
Lenet network structure
"""
#define the operator required
def __init__(self):
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.num_class = num_class
self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
#use the preceding operators to construct networks
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
```
......
......@@ -26,7 +26,6 @@ from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
from mindspore.common.initializer import TruncatedNormal
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.transforms.vision import Inter
......@@ -118,53 +117,28 @@ def create_dataset(data_path, batch_size=32, repeat_size=1,
return mnist_ds
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""Conv layer weight initial."""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""Fc layer weight initial."""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""Weight initial."""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell):
"""Lenet network structure."""
# define the operator required
def __init__(self):
def __init__(self, num_class=10, channel=1):
super(LeNet5, self).__init__()
self.conv1 = conv(1, 6, 5)
self.conv2 = conv(6, 16, 5)
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
self.fc2 = fc_with_initialize(120, 84)
self.fc3 = fc_with_initialize(84, 10)
self.num_class = num_class
self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, self.num_class)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
# use the preceding operators to construct networks
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册