提交 437aa67c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!952 modify code formats in quantization_aware.md

Merge pull request !952 from lvmingfu/master
...@@ -90,63 +90,36 @@ Define a fusion network and replace the specified operators. ...@@ -90,63 +90,36 @@ Define a fusion network and replace the specified operators.
The definition of the original network model LeNet5 is as follows: The definition of the original network model LeNet5 is as follows:
```python ```python
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell): class LeNet5(nn.Cell):
""" """
Lenet network Lenet network
Args: Args:
num_class (int): Num classes. Default: 10. num_class (int): Num classes. Default: 10.
num_channel (int): Num channel. Default: 1.
Returns: Returns:
Tensor, output tensor Tensor, output tensor
Examples: Examples:
>>> LeNet(num_class=10) >>> LeNet(num_class=10, num_channel=1)
""" """
def __init__(self, num_class=10, channel=1): def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.num_class = num_class self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv1 = conv(channel, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.conv2 = conv(6, 16, 5) self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc2 = fc_with_initialize(120, 84) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
def construct(self, x): def construct(self, x):
x = self.conv1(x) x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.relu(x) x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x) x = self.flatten(x)
x = self.fc1(x) x = self.relu(self.fc1(x))
x = self.relu(x) x = self.relu(self.fc2(x))
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x) x = self.fc3(x)
return x return x
``` ```
...@@ -168,10 +141,8 @@ class LeNet5(nn.Cell): ...@@ -168,10 +141,8 @@ class LeNet5(nn.Cell):
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
def construct(self, x): def construct(self, x):
x = self.conv1(x) x = self.max_pool2d(self.conv1(x))
x = self.max_pool2d(x) x = self.max_pool2d(self.conv2(x))
x = self.conv2(x)
x = self.max_pool2d(x)
x = self.flattern(x) x = self.flattern(x)
x = self.fc1(x) x = self.fc1(x)
x = self.fc2(x) x = self.fc2(x)
......
...@@ -90,63 +90,36 @@ MindSpore的感知量化训练是在训练基础上,使用低精度数据替 ...@@ -90,63 +90,36 @@ MindSpore的感知量化训练是在训练基础上,使用低精度数据替
原网络模型LeNet5的定义如下所示: 原网络模型LeNet5的定义如下所示:
```python ```python
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
"""weight initial for conv layer"""
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode="valid")
def fc_with_initialize(input_channels, out_channels):
"""weight initial for fc layer"""
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
"""weight initial"""
return TruncatedNormal(0.02)
class LeNet5(nn.Cell): class LeNet5(nn.Cell):
""" """
Lenet network Lenet network
Args: Args:
num_class (int): Num classes. Default: 10. num_class (int): Num classes. Default: 10.
num_channel (int): Num channel. Default: 1.
Returns: Returns:
Tensor, output tensor Tensor, output tensor
Examples: Examples:
>>> LeNet(num_class=10) >>> LeNet(num_class=10, num_channel=1)
""" """
def __init__(self, num_class=10, channel=1): def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__() super(LeNet5, self).__init__()
self.num_class = num_class self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv1 = conv(channel, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.conv2 = conv(6, 16, 5) self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc1 = fc_with_initialize(16 * 5 * 5, 120) self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc2 = fc_with_initialize(120, 84) self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() self.flatten = nn.Flatten()
def construct(self, x): def construct(self, x):
x = self.conv1(x) x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.relu(x) x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x) x = self.flatten(x)
x = self.fc1(x) x = self.relu(self.fc1(x))
x = self.relu(x) x = self.relu(self.fc2(x))
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x) x = self.fc3(x)
return x return x
``` ```
...@@ -168,10 +141,8 @@ class LeNet5(nn.Cell): ...@@ -168,10 +141,8 @@ class LeNet5(nn.Cell):
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
def construct(self, x): def construct(self, x):
x = self.conv1(x) x = self.max_pool2d(self.conv1(x))
x = self.max_pool2d(x) x = self.max_pool2d(self.conv2(x))
x = self.conv2(x)
x = self.max_pool2d(x)
x = self.flattern(x) x = self.flattern(x)
x = self.fc1(x) x = self.fc1(x)
x = self.fc2(x) x = self.fc2(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册