提交 3c0ab755 编写于 作者: HypoX64's avatar HypoX64

regular commit

上级 ae3aa71a
......@@ -15,10 +15,10 @@
对于不同的网络结构,对原始eeg信号采取了预处理,使其拥有不同的shape:<br>
LSTM:将30s的eeg信号进行FIR带通滤波,获得θ,σ,α,δ,β波,并将它们进行连接后作为输入数据<br>
resnet_1d:这里使用resnet的一维形式进行实验,(修改nn.Conv2d为nn.Conv1d).<br>
DFCNN:将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用resnet网络进行分类.<br>
DFCNN:将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类.<br>
* EEG频谱图<br>
这里展示5个睡眠阶段对应的频谱图,它们依次是wake, stage 1, stage 2, stage 3, REM
这里展示5个睡眠阶段对应的频谱图,它们依次是wake, stage 1, stage 2, stage 3, REM<br>
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_wake.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_N1.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_N2.png)
......@@ -38,14 +38,14 @@
| resnet18_1d | 0.8434 | 0.9627 | 0.0930 |
| DFCNN+resnet18 | 0.8567 | 0.9663 | 0.0842 |
| DFCNN+resnet50 | 0.7916 | 0.9607 | 0.0983 |
<br>
* sleep-edfx(only sleep time)<br>
| Network | Label average recall | Label average accuracy | error rate |
| :------------- | :------------------- | ---------------------- | ---------- |
| lstm | 0.7864 | 0.9166 | 0.2085 |
| resnet18_1d | xxxxxx | xxxxxx | xxxxxx |
| DFCNN+resnet18 | xxxxxx | xxxxxx | xxxxxx |
| DFCNN+resnet18 | 0.7844 | 0.9124 | 0.219 |
| DFCNN+resnet50 | xxxxxx | xxxxxx | xxxxxx |
<br>
* CinC Challenge 2018<br>
\ No newline at end of file
......@@ -96,7 +96,7 @@ def ToInputShape(data,net_name,norm=True,test_flag = False):
result = Normalize(result,maxmin = 1000,avg=0,sigma=1000)
result = result.reshape(batchsize,1,2700)
elif net_name in ['resnet18','densenet121','densenet201','resnet101','resnet50']:
elif net_name in ['dfcnn','resnet18','densenet121','densenet201','resnet101','resnet50']:
result =[]
for i in range(0,batchsize):
spectrum = DSP.signal2spectrum(data[i])
......
......@@ -57,3 +57,12 @@ confusion_mat:
[ 26 699 3159 775 384]
[ 5 380 971 5377 111]
[ 13 34 1490 191 12379]]
DFCNN+resnet18
avg_recall:0.7844 avg_acc:0.9124 error:0.219
confusion_mat:
[[ 3623 243 23 2 4]
[ 1905 12703 2452 657 94]
[ 54 671 3371 399 468]
[ 29 379 1454 4759 151]
[ 40 50 1272 124 12881]]
......@@ -12,6 +12,8 @@ def CreatNet(name):
return CNN()
elif name == 'resnet18_1d':
return resnet18_1d()
elif name == 'dfcnn':
return dfcnn()
elif name in ['resnet101','resnet50','resnet18']:
if name =='resnet101':
net = torchvision.models.resnet101(pretrained=False)
......@@ -67,6 +69,72 @@ class LSTM(nn.Module):
out = self.out(x)
return out
class dfcnn(nn.Module):
def __init__(self):
super(dfcnn, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 32, (3,3), 1, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace = True),
nn.Conv2d(32, 32, (3,3), 1, 1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(inplace = True),
nn.MaxPool2d(2),
)
self.layer2 = nn.Sequential(
nn.Conv2d(32, 64, (3,3), 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace = True),
nn.Conv2d(64, 64, (3,3), 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace = True),
nn.MaxPool2d(2),
)
self.layer3 = nn.Sequential(
nn.Conv2d(64, 128, (3,3), 1, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace = True),
nn.Conv2d(128, 128, (3,3), 1, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace = True),
nn.MaxPool2d(2),
)
self.layer4 = nn.Sequential(
nn.Conv2d(128, 256, (3,3), 1, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace = True),
nn.Conv2d(256, 256, (3,3), 1, 1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace = True),
nn.MaxPool2d(2),
)
self.layer5 = nn.Sequential(
nn.Conv2d(256, 512, (3,3), 1, 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace = True),
nn.Conv2d(512, 512, (3,3), 1, 1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace = True),
nn.MaxPool2d(2),
)
self.out = nn.Linear(512*7*7, 5) # fully connected layer, output 10 classes
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
# x = F.avg_pool1d(x, 175)
x = x.view(x.size(0), -1)
output = self.out(x)
return output
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
......
......@@ -48,11 +48,11 @@ weight = torch.from_numpy(weight).float()
if not opt.no_cuda:
net.cuda()
weight = weight.cuda()
cudnn.benchmark = True
# cudnn.benchmark = True
# print(weight)
# time.sleep(2000)
optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95)
criterion = nn.CrossEntropyLoss(weight)
......@@ -127,7 +127,7 @@ for epoch in range(opt.epochs):
# torch.cuda.empty_cache()
evalnet(net,signals_eval,stages_eval,epoch+1,plot_result,mode = 'all')
scheduler.step()
# scheduler.step()
if (epoch+1)%opt.network_save_freq == 0:
torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册