creatnet.py 2.2 KB
Newer Older
H
hypox64 已提交
1 2
from torch import nn
from models import cnn_1d,densenet,dfcnn,lstm,mobilenet,resnet,resnet_1d,squeezenet
H
hypox64 已提交
3
from models import multi_scale_resnet,multi_scale_resnet_1d,micro_multi_scale_resnet_1d
H
hypox64 已提交
4

5 6 7
def CreatNet(opt):
    name = opt.model_name
    label_num = opt.label
H
hypox64 已提交
8
    if name =='lstm':
9
        net =  lstm.lstm(100,27,num_classes=label_num)
H
hypox64 已提交
10
    elif name == 'cnn_1d':
H
hypox64 已提交
11
        net = cnn_1d.cnn(opt.input_nc,num_classes=label_num)
H
hypox64 已提交
12 13
    elif name == 'resnet18_1d':
        net = resnet_1d.resnet18()
H
hypox64 已提交
14
        net.conv1 = nn.Conv1d(opt.input_nc, 64, 7, 2, 3, bias=False)
15
        net.fc = nn.Linear(512, label_num)
16 17 18 19
    elif name == 'resnet34_1d':
        net = resnet_1d.resnet34()
        net.conv1 = nn.Conv1d(opt.input_nc, 64, 7, 2, 3, bias=False)
        net.fc = nn.Linear(512, label_num)
H
hypox64 已提交
20
    elif name == 'multi_scale_resnet_1d':
H
hypox64 已提交
21
        net = multi_scale_resnet_1d.Multi_Scale_ResNet(inchannel=opt.input_nc, num_classes=label_num)
H
hypox64 已提交
22
    elif name == 'micro_multi_scale_resnet_1d':
H
hypox64 已提交
23
        net = micro_multi_scale_resnet_1d.Multi_Scale_ResNet(inchannel=opt.input_nc, num_classes=label_num)
H
hypox64 已提交
24
    elif name == 'multi_scale_resnet':
H
hypox64 已提交
25
        net = multi_scale_resnet.Multi_Scale_ResNet(inchannel=opt.input_nc, num_classes=label_num)
H
hypox64 已提交
26
    elif name == 'dfcnn':
27
        net = dfcnn.dfcnn(num_classes = label_num)
H
hypox64 已提交
28 29 30
    elif name in ['resnet101','resnet50','resnet18']:
        if name =='resnet101':
            net = resnet.resnet101(pretrained=False)
31
            net.fc = nn.Linear(2048, label_num)
H
hypox64 已提交
32 33
        elif name =='resnet50':
            net = resnet.resnet50(pretrained=False)
34
            net.fc = nn.Linear(2048, label_num)
H
hypox64 已提交
35 36
        elif name =='resnet18':
            net = resnet.resnet18(pretrained=False)
37
            net.fc = nn.Linear(512, label_num)
H
hypox64 已提交
38
        net.conv1 = nn.Conv2d(opt.input_nc, 64, 7, 2, 3, bias=False)        
H
hypox64 已提交
39 40 41
    
    elif 'densenet' in name:
        if name =='densenet121':
42
            net = densenet.densenet121(pretrained=False,num_classes=label_num)
H
hypox64 已提交
43
        elif name == 'densenet201':
44
            net = densenet.densenet201(pretrained=False,num_classes=label_num)
H
hypox64 已提交
45
    elif name =='squeezenet':
46
        net = squeezenet.squeezenet1_1(pretrained=False,num_classes=label_num,inchannel = 1)
H
hypox64 已提交
47 48

    return net