提交 e3884091 编写于 作者: W wukesong

update lenet alexnet

上级 474016c6
......@@ -228,8 +228,6 @@ def fc_with_initialize(input_channels, out_channels):
神经网络的各层需要预先在`__init__()`方法中定义,然后通过定义`construct()`方法来完成神经网络的前向构造。按照LeNet的网络结构,定义网络各层如下:
```python
import mindspore.ops.operations as P
class LeNet5(nn.Cell):
"""
Lenet network structure
......@@ -245,7 +243,7 @@ class LeNet5(nn.Cell):
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.flatten = nn.Flatten()
#use the preceding operators to construct networks
def construct(self, x):
......@@ -255,7 +253,7 @@ class LeNet5(nn.Cell):
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
......
......@@ -24,7 +24,6 @@ from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
import mindspore.ops.operations as P
from mindspore.common.initializer import TruncatedNormal
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
......@@ -150,7 +149,7 @@ class LeNet5(nn.Cell):
self.fc3 = fc_with_initialize(84, 10)
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.flatten = nn.Flatten()
# use the preceding operators to construct networks
def construct(self, x):
......@@ -160,7 +159,7 @@ class LeNet5(nn.Cell):
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1))
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册