提交 9f79db30 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5 update lenet and alexnet

Merge pull request !5 from wukesong/update-lenet-alexnet
...@@ -228,8 +228,6 @@ def fc_with_initialize(input_channels, out_channels): ...@@ -228,8 +228,6 @@ def fc_with_initialize(input_channels, out_channels):
神经网络的各层需要预先在`__init__()`方法中定义,然后通过定义`construct()`方法来完成神经网络的前向构造。按照LeNet的网络结构,定义网络各层如下: 神经网络的各层需要预先在`__init__()`方法中定义,然后通过定义`construct()`方法来完成神经网络的前向构造。按照LeNet的网络结构,定义网络各层如下:
```python ```python
import mindspore.ops.operations as P
class LeNet5(nn.Cell): class LeNet5(nn.Cell):
""" """
Lenet network structure Lenet network structure
...@@ -245,7 +243,7 @@ class LeNet5(nn.Cell): ...@@ -245,7 +243,7 @@ class LeNet5(nn.Cell):
self.fc3 = fc_with_initialize(84, 10) self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape() self.flatten = nn.Flatten()
#use the preceding operators to construct networks #use the preceding operators to construct networks
def construct(self, x): def construct(self, x):
...@@ -255,7 +253,7 @@ class LeNet5(nn.Cell): ...@@ -255,7 +253,7 @@ class LeNet5(nn.Cell):
x = self.conv2(x) x = self.conv2(x)
x = self.relu(x) x = self.relu(x)
x = self.max_pool2d(x) x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1)) x = self.flatten(x)
x = self.fc1(x) x = self.fc1(x)
x = self.relu(x) x = self.relu(x)
x = self.fc2(x) x = self.fc2(x)
......
...@@ -24,7 +24,6 @@ from mindspore import context ...@@ -24,7 +24,6 @@ from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model from mindspore.train import Model
import mindspore.ops.operations as P
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
import mindspore.dataset.transforms.vision.c_transforms as CV import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.c_transforms as C
...@@ -150,7 +149,7 @@ class LeNet5(nn.Cell): ...@@ -150,7 +149,7 @@ class LeNet5(nn.Cell):
self.fc3 = fc_with_initialize(84, 10) self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape() self.flatten = nn.Flatten()
# use the preceding operators to construct networks # use the preceding operators to construct networks
def construct(self, x): def construct(self, x):
...@@ -160,7 +159,7 @@ class LeNet5(nn.Cell): ...@@ -160,7 +159,7 @@ class LeNet5(nn.Cell):
x = self.conv2(x) x = self.conv2(x)
x = self.relu(x) x = self.relu(x)
x = self.max_pool2d(x) x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1)) x = self.flatten(x)
x = self.fc1(x) x = self.fc1(x)
x = self.relu(x) x = self.relu(x)
x = self.fc2(x) x = self.fc2(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册