提交 8dcaf8b1 编写于 作者: W wukesong

1.modify network; 2. remove is_grad

上级 d043b195
......@@ -14,6 +14,7 @@
# ============================================================================
"""LeNet."""
import mindspore.nn as nn
from mindspore.common.initializer import Normal
class LeNet5(nn.Cell):
......@@ -22,22 +23,21 @@ class LeNet5(nn.Cell):
Args:
num_class (int): Num classes. Default: 10.
channel (int): Num classes. Default: 1.
num_channel (int): Num channels. Default: 1.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10, channel=1)
>>> LeNet(num_class=10, num_channel=1)
"""
def __init__(self, num_class=10, channel=1):
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid')
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120)
self.fc2 = nn.Dense(120, 84)
self.fc3 = nn.Dense(84, self.num_class)
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
......
......@@ -91,7 +91,7 @@ if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
network = LeNet5(cfg.num_classes)
net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
repeat_size = 1
net_opt = nn.Momentum(network.trainable_params(), cfg.lr, cfg.momentum)
model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册