creatnet.py 2.9 KB
Newer Older
H
hypox64 已提交
1
from torch import nn
H
Add mlp  
HypoX64 已提交
2
from .net_1d import cnn_1d,lstm,resnet_1d,multi_scale_resnet_1d,micro_multi_scale_resnet_1d,autoencoder,mlp
H
hypox64 已提交
3 4
from .net_2d import densenet,dfcnn,mobilenet,resnet,squeezenet,multi_scale_resnet

H
hypox64 已提交
5

H
hypox64 已提交
6
def creatnet(opt):
7
    name = opt.model_name
H
hypox64 已提交
8
    #---------------------------------1d---------------------------------
H
hypox64 已提交
9 10 11
    #encoder
    if name =='autoencoder':
        net = autoencoder.Autoencoder(opt.input_nc, opt.feature, opt.label,opt.finesize)
H
Add mlp  
HypoX64 已提交
12 13 14
    #mlp
    if name =='mlp':
        net = mlp.mlp(opt.input_nc, opt.label, opt.finesize)
H
hypox64 已提交
15
    #lstm
H
hypox64 已提交
16
    elif name =='lstm':
H
hypox64 已提交
17
        net =  lstm.lstm(opt.lstm_inputsize,opt.lstm_timestep,input_nc=opt.input_nc,num_classes=opt.label)
H
hypox64 已提交
18
    #cnn
H
hypox64 已提交
19
    elif name == 'cnn_1d':
H
hypox64 已提交
20
        net = cnn_1d.cnn(opt.input_nc,num_classes=opt.label)
H
hypox64 已提交
21 22
    elif name == 'resnet18_1d':
        net = resnet_1d.resnet18()
H
hypox64 已提交
23
        net.conv1 = nn.Conv1d(opt.input_nc, 64, 7, 2, 3, bias=False)
H
hypox64 已提交
24
        net.fc = nn.Linear(512, opt.label)
25 26 27
    elif name == 'resnet34_1d':
        net = resnet_1d.resnet34()
        net.conv1 = nn.Conv1d(opt.input_nc, 64, 7, 2, 3, bias=False)
H
hypox64 已提交
28
        net.fc = nn.Linear(512, opt.label)
H
hypox64 已提交
29
    elif name == 'multi_scale_resnet_1d':
H
hypox64 已提交
30
        net = multi_scale_resnet_1d.Multi_Scale_ResNet(inchannel=opt.input_nc, num_classes=opt.label)
H
hypox64 已提交
31
    elif name == 'micro_multi_scale_resnet_1d':
H
hypox64 已提交
32
        net = micro_multi_scale_resnet_1d.Multi_Scale_ResNet(inchannel=opt.input_nc, num_classes=opt.label)
H
hypox64 已提交
33

H
hypox64 已提交
34
    #---------------------------------2d---------------------------------
H
hypox64 已提交
35
    elif name == 'dfcnn':
H
hypox64 已提交
36
        net = dfcnn.dfcnn(num_classes = opt.label, input_nc = opt.input_nc)
H
hypox64 已提交
37
    elif name == 'multi_scale_resnet':
H
hypox64 已提交
38 39
        net = multi_scale_resnet.Multi_Scale_ResNet(input_nc = opt.input_nc, num_classes=opt.label)
    
H
hypox64 已提交
40 41
    elif name in ['resnet101','resnet50','resnet18']:
        if name =='resnet101':
H
hypox64 已提交
42
            net = resnet.resnet101(pretrained=True)
H
hypox64 已提交
43
            net.fc = nn.Linear(2048, opt.label)
H
hypox64 已提交
44
        elif name =='resnet50':
H
hypox64 已提交
45
            net = resnet.resnet50(pretrained=True)
H
hypox64 已提交
46
            net.fc = nn.Linear(2048, opt.label)
H
hypox64 已提交
47
        elif name =='resnet18':
H
hypox64 已提交
48
            net = resnet.resnet18(pretrained=True)
H
hypox64 已提交
49
            net.fc = nn.Linear(512, opt.label)
H
hypox64 已提交
50
        net.conv1 = nn.Conv2d(opt.input_nc, 64, 7, 2, 3, bias=False)        
H
hypox64 已提交
51 52 53
    
    elif 'densenet' in name:
        if name =='densenet121':
H
hypox64 已提交
54
            net = densenet.densenet121(pretrained=False,num_classes = opt.label)
H
hypox64 已提交
55
        elif name == 'densenet201':
H
hypox64 已提交
56 57 58 59 60
            net = densenet.densenet201(pretrained=False,num_classes = opt.label)
        net.features.conv0 = nn.Conv2d(opt.input_nc, 64, kernel_size=7, stride=2, padding=3, bias=False) 
    
    elif name == 'squeezenet':
        net = squeezenet.squeezenet1_1(pretrained=False,num_classes = opt.label,inchannel = opt.input_nc)
H
hypox64 已提交
61

H
hypox64 已提交
62 63
    elif name == 'mobilenet':
        net = mobilenet.mobilenet_v2(pretrained=False, num_classes = opt.label, input_nc = opt.input_nc)
H
hypox64 已提交
64
    return net