提交 c46e267c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!21 update lenet , add alexnet in example

Merge pull request !21 from wukesong/add_lenet_alexnet
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in train.py
"""
from easydict import EasyDict as edict
alexnet_cfg = edict({
'num_classes': 10,
'learning_rate': 0.002,
'momentum': 0.9,
'epoch_size': 1,
'batch_size': 32,
'buffer_size': 1000,
'image_height': 227,
'image_width': 227,
'save_checkpoint_steps': 1562,
'keep_checkpoint_max': 10,
})
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Produce the dataset
"""
from config import alexnet_cfg as cfg
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as CV
from mindspore.common import dtype as mstype
def create_dataset(data_path, batch_size=32, repeat_size=1, status="train"):
"""
create dataset for train or test
"""
cifar_ds = ds.Cifar10Dataset(data_path)
rescale = 1.0 / 255.0
shift = 0.0
resize_op = CV.Resize((cfg.image_height, cfg.image_width))
rescale_op = CV.Rescale(rescale, shift)
normalize_op = CV.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
if status == "train":
random_crop_op = CV.RandomCrop([32, 32], [4, 4, 4, 4])
random_horizontal_op = CV.RandomHorizontalFlip()
channel_swap_op = CV.HWC2CHW()
typecast_op = C.TypeCast(mstype.int32)
cifar_ds = cifar_ds.map(input_columns="label", operations=typecast_op)
if status == "train":
cifar_ds = cifar_ds.map(input_columns="image", operations=random_crop_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=random_horizontal_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=resize_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=rescale_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=normalize_op)
cifar_ds = cifar_ds.map(input_columns="image", operations=channel_swap_op)
cifar_ds = cifar_ds.shuffle(buffer_size=cfg.buffer_size)
cifar_ds = cifar_ds.batch(batch_size, drop_remainder=True)
cifar_ds = cifar_ds.repeat(repeat_size)
return cifar_ds
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## eval alexnet example ########################
eval alexnet according to model file:
python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
"""
import argparse
from config import alexnet_cfg as cfg
from dataset import create_dataset
import mindspore.nn as nn
from mindspore import context
from mindspore.model_zoo.alexnet import AlexNet
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
repeat_size = cfg.epoch_size
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) # test
print("============== Starting Testing ==============")
param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict)
ds_eval = create_dataset(args.data_path,
cfg.batch_size,
1,
"test")
acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
print("============== Accuracy:{} ==============".format(acc))
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train alexnet example ########################
train alexnet and get network model files(.ckpt) :
python train.py --data_path /YourDataPath
"""
import argparse
from config import alexnet_cfg as cfg
from dataset import create_dataset
import mindspore.nn as nn
from mindspore import context
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
from mindspore.model_zoo.alexnet import AlexNet
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore AlexNet Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./", help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="./ckpt", help='if is test, must provide\
path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
network = AlexNet(cfg.num_classes)
loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
opt = nn.Momentum(network.trainable_params(), cfg.learning_rate, cfg.momentum)
model = Model(network, loss, opt, metrics={"Accuracy": Accuracy()}) # test
print("============== Starting Training ==============")
ds_train = create_dataset(args.data_path,
cfg.batch_size,
cfg.epoch_size,
"train")
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_alexnet", directory=args.ckpt_path, config=config_ck)
model.train(cfg.epoch_size, ds_train, callbacks=[ckpoint_cb, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)
...@@ -13,8 +13,9 @@ ...@@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
network config setting, will be used in main.py network config setting, will be used in train.py
""" """
from easydict import EasyDict as edict from easydict import EasyDict as edict
mnist_cfg = edict({ mnist_cfg = edict({
...@@ -23,7 +24,6 @@ mnist_cfg = edict({ ...@@ -23,7 +24,6 @@ mnist_cfg = edict({
'momentum': 0.9, 'momentum': 0.9,
'epoch_size': 1, 'epoch_size': 1,
'batch_size': 32, 'batch_size': 32,
'repeat_size': 1,
'buffer_size': 1000, 'buffer_size': 1000,
'image_height': 32, 'image_height': 32,
'image_width': 32, 'image_width': 32,
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Produce the dataset
"""
import mindspore.dataset as ds
import mindspore.dataset.transforms.vision.c_transforms as CV
import mindspore.dataset.transforms.c_transforms as C
from mindspore.dataset.transforms.vision import Inter
from mindspore.common import dtype as mstype
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
mnist_ds = ds.MnistDataset(data_path)
resize_height, resize_width = 32, 32
rescale = 1.0 / 255.0
shift = 0.0
rescale_nml = 1 / 0.3081
shift_nml = -1 * 0.1307 / 0.3081
# define map operations
resize_op = CV.Resize((resize_height, resize_width), interpolation=Inter.LINEAR) # Bilinear mode
rescale_nml_op = CV.Rescale(rescale_nml, shift_nml)
rescale_op = CV.Rescale(rescale, shift)
hwc2chw_op = CV.HWC2CHW()
type_cast_op = C.TypeCast(mstype.int32)
# apply map operations on images
mnist_ds = mnist_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=resize_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=rescale_nml_op, num_parallel_workers=num_parallel_workers)
mnist_ds = mnist_ds.map(input_columns="image", operations=hwc2chw_op, num_parallel_workers=num_parallel_workers)
# apply DatasetOps
buffer_size = 10000
mnist_ds = mnist_ds.shuffle(buffer_size=buffer_size) # 10000 as in LeNet train script
mnist_ds = mnist_ds.batch(batch_size, drop_remainder=True)
mnist_ds = mnist_ds.repeat(repeat_size)
return mnist_ds
...@@ -13,113 +13,52 @@ ...@@ -13,113 +13,52 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
######################## train and test lenet example ######################## ######################## eval lenet example ########################
1. train lenet and get network model files(.ckpt) : eval lenet according to model file:
python main.py --data_path /home/workspace/mindspore_dataset/Tutorial_Network/Lenet/MNIST_Data python eval.py --data_path /YourDataPath --ckpt_path Your.ckpt
2. test lenet according to model file:
python main.py --data_path /home/workspace/mindspore_dataset/Tutorial_Network/Lenet/MNIST_Data
--mode test --ckpt_path checkpoint_lenet_1-1_1875.ckpt
""" """
import os import os
import argparse import argparse
from dataset import create_dataset
from config import mnist_cfg as cfg from config import mnist_cfg as cfg
import mindspore.dataengine as de
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.model_zoo.lenet import LeNet5 from mindspore.model_zoo.lenet import LeNet5
from mindspore import context, Tensor from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model from mindspore.train import Model
import mindspore.ops.operations as P
import mindspore.transforms.c_transforms as C
from mindspore.transforms import Inter
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
class CrossEntropyLoss(nn.Cell):
"""
Define loss for network
"""
def __init__(self):
super(CrossEntropyLoss, self).__init__()
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
def construct(self, logits, label):
label = self.one_hot(label, F.shape(logits)[1], self.on_value, self.off_value)
loss = self.cross_entropy(logits, label)[0]
loss = self.mean(loss, (-1,))
return loss
def create_dataset(data_path, batch_size=32, repeat_size=1,
num_parallel_workers=1):
"""
create dataset for train or test
"""
# define dataset
ds1 = de.MnistDataset(data_path)
# apply map operations on images
ds1 = ds1.map(input_columns="label", operations=C.TypeCast(mstype.int32))
ds1 = ds1.map(input_columns="image", operations=C.Resize((cfg.image_height, cfg.image_width),
interpolation=Inter.LINEAR),
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=C.Rescale(1 / 0.3081, -1 * 0.1307 / 0.3081),
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=C.Rescale(1.0 / 255.0, 0.0),
num_parallel_workers=num_parallel_workers)
ds1 = ds1.map(input_columns="image", operations=C.HWC2CHW(), num_parallel_workers=num_parallel_workers)
# apply DatasetOps
ds1 = ds1.shuffle(buffer_size=cfg.buffer_size) # 10000 as in LeNet train script
ds1 = ds1.batch(batch_size, drop_remainder=True)
ds1 = ds1.repeat(repeat_size)
return ds1
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore MNIST Example') parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)') help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--mode', type=str, default="train", choices=['train', 'test'],
help='implement phase, set to train or test')
parser.add_argument('--data_path', type=str, default="./MNIST_Data", parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved') help='path where the dataset is saved')
parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\ parser.add_argument('--ckpt_path', type=str, default="", help='if mode is test, must provide\
path where the trained ckpt file') path where the trained ckpt file')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args() args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
network = LeNet5(cfg.num_classes) network = LeNet5(cfg.num_classes)
network.set_train() net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
# net_loss = nn.SoftmaxCrossEntropyWithLogits() # support this loss soon repeat_size = cfg.epoch_size
net_loss = CrossEntropyLoss()
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum) net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps, config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max) keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
if args.mode == 'train': # train print("============== Starting Testing ==============")
ds = create_dataset(os.path.join(args.data_path, args.mode), batch_size=cfg.batch_size, param_dict = load_checkpoint(args.ckpt_path)
repeat_size=cfg.epoch_size) load_param_into_net(network, param_dict)
print("============== Starting Training ==============") ds_eval = create_dataset(os.path.join(args.data_path, "test"),
model.train(cfg['epoch_size'], ds, callbacks=[ckpoint_cb, LossMonitor()], dataset_sink_mode=False) cfg.batch_size,
elif args.mode == 'test': # test 1)
print("============== Starting Testing ==============") acc = model.eval(ds_eval, dataset_sink_mode=args.dataset_sink_mode)
param_dict = load_checkpoint(args.ckpt_path) print("============== Accuracy:{} ==============".format(acc))
load_param_into_net(network, param_dict)
ds_eval = create_dataset(os.path.join(args.data_path, "test"), 32, 1)
acc = model.eval(ds_eval, dataset_sink_mode=False)
print("============== Accuracy:{} ==============".format(acc))
else:
raise RuntimeError('mode should be train or test, rather than {}'.format(args.mode))
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
######################## train lenet example ########################
train lenet and get network model files(.ckpt) :
python train.py --data_path /YourDataPath
"""
import os
import argparse
from config import mnist_cfg as cfg
from dataset import create_dataset
import mindspore.nn as nn
from mindspore.model_zoo.lenet import LeNet5
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train import Model
from mindspore.nn.metrics import Accuracy
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='MindSpore MNIST Example')
parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--data_path', type=str, default="./MNIST_Data",
help='path where the dataset is saved')
parser.add_argument('--dataset_sink_mode', type=bool, default=False, help='dataset_sink_mode is False or True')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, enable_mem_reuse=False)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
config_ck = CheckpointConfig(save_checkpoint_steps=cfg.save_checkpoint_steps,
keep_checkpoint_max=cfg.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
ds_train = create_dataset(os.path.join(args.data_path, "train"),
cfg.batch_size,
cfg.epoch_size)
print("============== Starting Training ==============")
model.train(cfg['epoch_size'], ds_train, callbacks=[ckpoint_cb, LossMonitor()],
dataset_sink_mode=args.dataset_sink_mode)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Alexnet."""
import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid"):
weight = weight_variable()
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
weight_init=weight, has_bias=False, pad_mode=pad_mode)
def fc_with_initialize(input_channels, out_channels):
weight = weight_variable()
bias = weight_variable()
return nn.Dense(input_channels, out_channels, weight, bias)
def weight_variable():
return TruncatedNormal(0.02) # 0.02
class AlexNet(nn.Cell):
"""
Alexnet
"""
def __init__(self, num_classes=10):
super(AlexNet, self).__init__()
self.batch_size = 32
self.conv1 = conv(3, 96, 11, stride=4)
self.conv2 = conv(96, 256, 5, pad_mode="same")
self.conv3 = conv(256, 384, 3, pad_mode="same")
self.conv4 = conv(384, 384, 3, pad_mode="same")
self.conv5 = conv(384, 256, 3, pad_mode="same")
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2)
self.flatten = nn.Flatten()
self.fc1 = fc_with_initialize(6*6*256, 4096)
self.fc2 = fc_with_initialize(4096, 4096)
self.fc3 = fc_with_initialize(4096, num_classes)
def construct(self, x):
x = self.conv1(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv3(x)
x = self.relu(x)
x = self.conv4(x)
x = self.relu(x)
x = self.conv5(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""LeNet.""" """LeNet."""
import mindspore.ops.operations as P
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
...@@ -62,7 +61,7 @@ class LeNet5(nn.Cell): ...@@ -62,7 +61,7 @@ class LeNet5(nn.Cell):
self.fc3 = fc_with_initialize(84, self.num_class) self.fc3 = fc_with_initialize(84, self.num_class)
self.relu = nn.ReLU() self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape() self.flatten = nn.Flatten()
def construct(self, x): def construct(self, x):
x = self.conv1(x) x = self.conv1(x)
...@@ -71,7 +70,7 @@ class LeNet5(nn.Cell): ...@@ -71,7 +70,7 @@ class LeNet5(nn.Cell):
x = self.conv2(x) x = self.conv2(x)
x = self.relu(x) x = self.relu(x)
x = self.max_pool2d(x) x = self.max_pool2d(x)
x = self.reshape(x, (self.batch_size, -1)) x = self.flatten(x)
x = self.fc1(x) x = self.fc1(x)
x = self.relu(x) x = self.relu(x)
x = self.fc2(x) x = self.fc2(x)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册