From e13f70c786ff993379217695c29bab377c8a5109 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=84=E5=A5=B3=E5=BA=A7=E7=9A=84=E6=9F=9A=E5=AD=90?= <2839719742@qq.com> Date: Thu, 29 Jul 2021 15:52:58 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=B0=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nets/lenet.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 nets/lenet.py diff --git a/nets/lenet.py b/nets/lenet.py new file mode 100644 index 0000000..c9e55c8 --- /dev/null +++ b/nets/lenet.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# @Time : 2021/7/5 16:29 +# @Author : Mat +# @File : lenet.py +# @Software: PyCharm +import math + +import torch.nn as nn +import torch.nn.functional as F + + +# ************************LeNet5********************* +class LeNet(nn.Module): + def __init__(self, num_class=10): + super(LeNet, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, num_class) + + def forward(self, x): + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = x.view(x.size()[0], -1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.01) + m.bias.data.zero_() + + +if __name__ == '__main__': + net = LeNet() + print(net) -- GitLab