diff --git a/README.md b/README.md
index 33b90894d6b14752761fd96548e8c5ad526e9008..8c29442ecaf5c0f0f2743f47633851a86c8960e2 100644
--- a/README.md
+++ b/README.md
@@ -15,10 +15,10 @@
对于不同的网络结构,对原始eeg信号采取了预处理,使其拥有不同的shape:
LSTM:将30s的eeg信号进行FIR带通滤波,获得θ,σ,α,δ,β波,并将它们进行连接后作为输入数据
resnet_1d:这里使用resnet的一维形式进行实验,(修改nn.Conv2d为nn.Conv1d).
- DFCNN:将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用resnet网络进行分类.
+ DFCNN:将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类.
* EEG频谱图
- 这里展示5个睡眠阶段对应的频谱图,它们依次是wake, stage 1, stage 2, stage 3, REM
+ 这里展示5个睡眠阶段对应的频谱图,它们依次是wake, stage 1, stage 2, stage 3, REM
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_wake.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_N1.png)
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_N2.png)
@@ -38,14 +38,14 @@
| resnet18_1d | 0.8434 | 0.9627 | 0.0930 |
| DFCNN+resnet18 | 0.8567 | 0.9663 | 0.0842 |
| DFCNN+resnet50 | 0.7916 | 0.9607 | 0.0983 |
-
+
* sleep-edfx(only sleep time)
| Network | Label average recall | Label average accuracy | error rate |
| :------------- | :------------------- | ---------------------- | ---------- |
| lstm | 0.7864 | 0.9166 | 0.2085 |
| resnet18_1d | xxxxxx | xxxxxx | xxxxxx |
- | DFCNN+resnet18 | xxxxxx | xxxxxx | xxxxxx |
+ | DFCNN+resnet18 | 0.7844 | 0.9124 | 0.219 |
| DFCNN+resnet50 | xxxxxx | xxxxxx | xxxxxx |
-
+
* CinC Challenge 2018
\ No newline at end of file
diff --git a/data.py b/data.py
index a3b94a06cbdbf9cad4974d19702b9ce90b2496b7..5fd6592432d13789f8e34213e28582fd4e591446 100644
--- a/data.py
+++ b/data.py
@@ -96,7 +96,7 @@ def ToInputShape(data,net_name,norm=True,test_flag = False):
result = Normalize(result,maxmin = 1000,avg=0,sigma=1000)
result = result.reshape(batchsize,1,2700)
- elif net_name in ['resnet18','densenet121','densenet201','resnet101','resnet50']:
+ elif net_name in ['dfcnn','resnet18','densenet121','densenet201','resnet101','resnet50']:
result =[]
for i in range(0,batchsize):
spectrum = DSP.signal2spectrum(data[i])
diff --git a/image/DFresnet18_running_recall_sleep-edfx_sleeptime.png b/image/DFresnet18_running_recall_sleep-edfx_sleeptime.png
new file mode 100644
index 0000000000000000000000000000000000000000..060784a326aed345a75fbc1ca3a92e7ffdc388cf
Binary files /dev/null and b/image/DFresnet18_running_recall_sleep-edfx_sleeptime.png differ
diff --git a/image/LSTM_running_recall_sleep-edfx_sleeptime.png b/image/LSTM_running_recall_sleep-edfx_sleeptime.png
new file mode 100644
index 0000000000000000000000000000000000000000..50e1a97a6d0bc0ee587d23b4003093c9b1e1c4fc
Binary files /dev/null and b/image/LSTM_running_recall_sleep-edfx_sleeptime.png differ
diff --git a/image/confusion_mat b/image/confusion_mat
index 78faa3cd96b78f3f055af699b661adce8b553c4a..7c66e830516449ef7db70e6e99fdf99ebe0e8178 100644
--- a/image/confusion_mat
+++ b/image/confusion_mat
@@ -57,3 +57,12 @@ confusion_mat:
[ 26 699 3159 775 384]
[ 5 380 971 5377 111]
[ 13 34 1490 191 12379]]
+
+DFCNN+resnet18
+avg_recall:0.7844 avg_acc:0.9124 error:0.219
+confusion_mat:
+[[ 3623 243 23 2 4]
+ [ 1905 12703 2452 657 94]
+ [ 54 671 3371 399 468]
+ [ 29 379 1454 4759 151]
+ [ 40 50 1272 124 12881]]
diff --git a/models.py b/models.py
index e3dac8a7e158b72ee7edfd2421f92933df14cfdd..7effe1d9fdfb8b89a8e2b3577fc9c64944ab4e97 100644
--- a/models.py
+++ b/models.py
@@ -12,6 +12,8 @@ def CreatNet(name):
return CNN()
elif name == 'resnet18_1d':
return resnet18_1d()
+ elif name == 'dfcnn':
+ return dfcnn()
elif name in ['resnet101','resnet50','resnet18']:
if name =='resnet101':
net = torchvision.models.resnet101(pretrained=False)
@@ -67,6 +69,72 @@ class LSTM(nn.Module):
out = self.out(x)
return out
+class dfcnn(nn.Module):
+ def __init__(self):
+ super(dfcnn, self).__init__()
+ self.layer1 = nn.Sequential(
+ nn.Conv2d(1, 32, (3,3), 1, 1, bias=False),
+ nn.BatchNorm2d(32),
+ nn.ReLU(inplace = True),
+ nn.Conv2d(32, 32, (3,3), 1, 1, bias=False),
+ nn.BatchNorm2d(32),
+ nn.ReLU(inplace = True),
+ nn.MaxPool2d(2),
+ )
+ self.layer2 = nn.Sequential(
+ nn.Conv2d(32, 64, (3,3), 1, 1, bias=False),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace = True),
+ nn.Conv2d(64, 64, (3,3), 1, 1, bias=False),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace = True),
+ nn.MaxPool2d(2),
+ )
+ self.layer3 = nn.Sequential(
+ nn.Conv2d(64, 128, (3,3), 1, 1, bias=False),
+ nn.BatchNorm2d(128),
+ nn.ReLU(inplace = True),
+ nn.Conv2d(128, 128, (3,3), 1, 1, bias=False),
+ nn.BatchNorm2d(128),
+ nn.ReLU(inplace = True),
+ nn.MaxPool2d(2),
+ )
+ self.layer4 = nn.Sequential(
+ nn.Conv2d(128, 256, (3,3), 1, 1, bias=False),
+ nn.BatchNorm2d(256),
+ nn.ReLU(inplace = True),
+ nn.Conv2d(256, 256, (3,3), 1, 1, bias=False),
+ nn.BatchNorm2d(256),
+ nn.ReLU(inplace = True),
+ nn.MaxPool2d(2),
+ )
+ self.layer5 = nn.Sequential(
+ nn.Conv2d(256, 512, (3,3), 1, 1, bias=False),
+ nn.BatchNorm2d(512),
+ nn.ReLU(inplace = True),
+ nn.Conv2d(512, 512, (3,3), 1, 1, bias=False),
+ nn.BatchNorm2d(512),
+ nn.ReLU(inplace = True),
+ nn.MaxPool2d(2),
+ )
+
+ self.out = nn.Linear(512*7*7, 5) # fully connected layer, output 10 classes
+
+ def forward(self, x):
+ x = self.layer1(x)
+
+ x = self.layer2(x)
+
+ x = self.layer3(x)
+
+ x = self.layer4(x)
+
+ x = self.layer5(x)
+ # x = F.avg_pool1d(x, 175)
+ x = x.view(x.size(0), -1)
+ output = self.out(x)
+ return output
+
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
diff --git a/train.py b/train.py
index a06aec935abb7bedfccbc7678683fe81a0b84ec5..fcc2979feac5c64b205d95809729c1819005f8a4 100644
--- a/train.py
+++ b/train.py
@@ -48,11 +48,11 @@ weight = torch.from_numpy(weight).float()
if not opt.no_cuda:
net.cuda()
weight = weight.cuda()
- cudnn.benchmark = True
+ # cudnn.benchmark = True
# print(weight)
# time.sleep(2000)
optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr)
-scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95)
+# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95)
criterion = nn.CrossEntropyLoss(weight)
@@ -127,7 +127,7 @@ for epoch in range(opt.epochs):
# torch.cuda.empty_cache()
evalnet(net,signals_eval,stages_eval,epoch+1,plot_result,mode = 'all')
- scheduler.step()
+ # scheduler.step()
if (epoch+1)%opt.network_save_freq == 0:
torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth')