diff --git a/tutorials/source_zh_cn/quick_start/quick_start.md b/tutorials/source_zh_cn/quick_start/quick_start.md index c00b41ef18d8c89e7c8ec5f09c74ea552b6e22a9..96bbb5b39697220d2dfec5b0623f859e4f043b50 100644 --- a/tutorials/source_zh_cn/quick_start/quick_start.md +++ b/tutorials/source_zh_cn/quick_start/quick_start.md @@ -228,8 +228,6 @@ def fc_with_initialize(input_channels, out_channels): 神经网络的各层需要预先在`__init__()`方法中定义,然后通过定义`construct()`方法来完成神经网络的前向构造。按照LeNet的网络结构,定义网络各层如下: ```python -import mindspore.ops.operations as P - class LeNet5(nn.Cell): """ Lenet network structure @@ -245,7 +243,7 @@ class LeNet5(nn.Cell): self.fc3 = fc_with_initialize(84, 10) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.reshape = P.Reshape() + self.flatten = nn.Flatten() #use the preceding operators to construct networks def construct(self, x): @@ -255,7 +253,7 @@ class LeNet5(nn.Cell): x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) - x = self.reshape(x, (self.batch_size, -1)) + x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) diff --git a/tutorials/tutorial_code/lenet.py b/tutorials/tutorial_code/lenet.py index f3899de1464efb4ce16feea3957dcd7b2bb1b5b0..a9b4571ffb5f0035dc2f0215853ddbc0107813a6 100644 --- a/tutorials/tutorial_code/lenet.py +++ b/tutorials/tutorial_code/lenet.py @@ -24,7 +24,6 @@ from mindspore import context from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train import Model -import mindspore.ops.operations as P from mindspore.common.initializer import TruncatedNormal import mindspore.dataset.transforms.vision.c_transforms as CV import mindspore.dataset.transforms.c_transforms as C @@ -150,7 +149,7 @@ class LeNet5(nn.Cell): self.fc3 = fc_with_initialize(84, 10) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) - self.reshape = P.Reshape() + self.flatten = nn.Flatten() # use the preceding operators to construct networks def construct(self, x): @@ -160,7 +159,7 @@ class LeNet5(nn.Cell): x = self.conv2(x) x = self.relu(x) x = self.max_pool2d(x) - x = self.reshape(x, (self.batch_size, -1)) + x = self.flatten(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x)