diff --git a/model_zoo/official/cv/lenet/src/lenet.py b/model_zoo/official/cv/lenet/src/lenet.py index 3864315dba35f1e4fd013317b71572ad6db82e95..003c3e0b850c8940f6d09510679113cdd3e2e7b1 100644 --- a/model_zoo/official/cv/lenet/src/lenet.py +++ b/model_zoo/official/cv/lenet/src/lenet.py @@ -14,27 +14,6 @@ # ============================================================================ """LeNet.""" import mindspore.nn as nn -from mindspore.common.initializer import TruncatedNormal - - -def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): - """weight initial for conv layer""" - weight = weight_variable() - return nn.Conv2d(in_channels, out_channels, - kernel_size=kernel_size, stride=stride, padding=padding, - weight_init=weight, has_bias=False, pad_mode="valid") - - -def fc_with_initialize(input_channels, out_channels): - """weight initial for fc layer""" - weight = weight_variable() - bias = weight_variable() - return nn.Dense(input_channels, out_channels, weight, bias) - - -def weight_variable(): - """weight initial""" - return TruncatedNormal(0.02) class LeNet5(nn.Cell): @@ -43,6 +22,7 @@ class LeNet5(nn.Cell): Args: num_class (int): Num classes. Default: 10. + channel (int): Num classes. Default: 1. Returns: Tensor, output tensor @@ -53,26 +33,20 @@ class LeNet5(nn.Cell): def __init__(self, num_class=10, channel=1): super(LeNet5, self).__init__() self.num_class = num_class - self.conv1 = conv(channel, 6, 5) - self.conv2 = conv(6, 16, 5) - self.fc1 = fc_with_initialize(16 * 5 * 5, 120) - self.fc2 = fc_with_initialize(120, 84) - self.fc3 = fc_with_initialize(84, self.num_class) + self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.fc1 = nn.Dense(16 * 5 * 5, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, self.num_class) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() def construct(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) + x = self.max_pool2d(self.relu(self.conv1(x))) + x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.relu(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) x = self.fc3(x) return x