From 5ad8dc3a6aa76846ced02bbed1a7827c33ba6430 Mon Sep 17 00:00:00 2001 From: wukesong Date: Thu, 3 Sep 2020 14:28:20 +0800 Subject: [PATCH] modify lenet --- chapter03/lenet/lenet.py | 48 ++++++++++------------------------------ 1 file changed, 12 insertions(+), 36 deletions(-) diff --git a/chapter03/lenet/lenet.py b/chapter03/lenet/lenet.py index 5c08ed1..068719d 100644 --- a/chapter03/lenet/lenet.py +++ b/chapter03/lenet/lenet.py @@ -14,63 +14,39 @@ # ============================================================================ """LeNet.""" import mindspore.nn as nn -from mindspore.common.initializer import TruncatedNormal -def conv(in_channels, out_channels, kernel_size, stride=1, padding=0): - """weight initial for conv layer""" - weight = weight_variable() - return nn.Conv2d(in_channels, out_channels, - kernel_size=kernel_size, stride=stride, padding=padding, - weight_init=weight, has_bias=False, pad_mode="valid") - -def fc_with_initialize(input_channels, out_channels): - """weight initial for fc layer""" - weight = weight_variable() - bias = weight_variable() - return nn.Dense(input_channels, out_channels, weight, bias) - -def weight_variable(): - """weight initial""" - return TruncatedNormal(0.02) - class LeNet5(nn.Cell): """ Lenet network Args: num_class (int): Num classes. Default: 10. + channel (int): Num classes. Default: 1. Returns: Tensor, output tensor Examples: - >>> LeNet(num_class=10) + >>> LeNet(num_class=10, channel=1) """ - def __init__(self, num_class=10): + def __init__(self, num_class=10, channel=1): super(LeNet5, self).__init__() self.num_class = num_class - self.batch_size = 32 - self.conv1 = conv(1, 6, 5) - self.conv2 = conv(6, 16, 5) - self.fc1 = fc_with_initialize(16 * 5 * 5, 120) - self.fc2 = fc_with_initialize(120, 84) - self.fc3 = fc_with_initialize(84, self.num_class) + self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid') + self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.fc1 = nn.Dense(16 * 5 * 5, 120) + self.fc2 = nn.Dense(120, 84) + self.fc3 = nn.Dense(84, self.num_class) self.relu = nn.ReLU() self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) self.flatten = nn.Flatten() def construct(self, x): - x = self.conv1(x) - x = self.relu(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.relu(x) - x = self.max_pool2d(x) + x = self.max_pool2d(self.relu(self.conv1(x))) + x = self.max_pool2d(self.relu(self.conv2(x))) x = self.flatten(x) - x = self.fc1(x) - x = self.relu(x) - x = self.fc2(x) - x = self.relu(x) + x = self.relu(self.fc1(x)) + x = self.relu(self.fc2(x)) x = self.fc3(x) return x -- GitLab