diff --git a/README.md b/README.md index 33b90894d6b14752761fd96548e8c5ad526e9008..8c29442ecaf5c0f0f2743f47633851a86c8960e2 100644 --- a/README.md +++ b/README.md @@ -15,10 +15,10 @@ 对于不同的网络结构,对原始eeg信号采取了预处理,使其拥有不同的shape:
LSTM:将30s的eeg信号进行FIR带通滤波,获得θ,σ,α,δ,β波,并将它们进行连接后作为输入数据
resnet_1d:这里使用resnet的一维形式进行实验,(修改nn.Conv2d为nn.Conv1d).
- DFCNN:将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用resnet网络进行分类.
+ DFCNN:将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类.
* EEG频谱图
- 这里展示5个睡眠阶段对应的频谱图,它们依次是wake, stage 1, stage 2, stage 3, REM + 这里展示5个睡眠阶段对应的频谱图,它们依次是wake, stage 1, stage 2, stage 3, REM
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_wake.png) ![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_N1.png) ![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_N2.png) @@ -38,14 +38,14 @@ | resnet18_1d | 0.8434 | 0.9627 | 0.0930 | | DFCNN+resnet18 | 0.8567 | 0.9663 | 0.0842 | | DFCNN+resnet50 | 0.7916 | 0.9607 | 0.0983 | -
+ * sleep-edfx(only sleep time)
| Network | Label average recall | Label average accuracy | error rate | | :------------- | :------------------- | ---------------------- | ---------- | | lstm | 0.7864 | 0.9166 | 0.2085 | | resnet18_1d | xxxxxx | xxxxxx | xxxxxx | - | DFCNN+resnet18 | xxxxxx | xxxxxx | xxxxxx | + | DFCNN+resnet18 | 0.7844 | 0.9124 | 0.219 | | DFCNN+resnet50 | xxxxxx | xxxxxx | xxxxxx | -
+ * CinC Challenge 2018
\ No newline at end of file diff --git a/data.py b/data.py index a3b94a06cbdbf9cad4974d19702b9ce90b2496b7..5fd6592432d13789f8e34213e28582fd4e591446 100644 --- a/data.py +++ b/data.py @@ -96,7 +96,7 @@ def ToInputShape(data,net_name,norm=True,test_flag = False): result = Normalize(result,maxmin = 1000,avg=0,sigma=1000) result = result.reshape(batchsize,1,2700) - elif net_name in ['resnet18','densenet121','densenet201','resnet101','resnet50']: + elif net_name in ['dfcnn','resnet18','densenet121','densenet201','resnet101','resnet50']: result =[] for i in range(0,batchsize): spectrum = DSP.signal2spectrum(data[i]) diff --git a/image/DFresnet18_running_recall_sleep-edfx_sleeptime.png b/image/DFresnet18_running_recall_sleep-edfx_sleeptime.png new file mode 100644 index 0000000000000000000000000000000000000000..060784a326aed345a75fbc1ca3a92e7ffdc388cf Binary files /dev/null and b/image/DFresnet18_running_recall_sleep-edfx_sleeptime.png differ diff --git a/image/LSTM_running_recall_sleep-edfx_sleeptime.png b/image/LSTM_running_recall_sleep-edfx_sleeptime.png new file mode 100644 index 0000000000000000000000000000000000000000..50e1a97a6d0bc0ee587d23b4003093c9b1e1c4fc Binary files /dev/null and b/image/LSTM_running_recall_sleep-edfx_sleeptime.png differ diff --git a/image/confusion_mat b/image/confusion_mat index 78faa3cd96b78f3f055af699b661adce8b553c4a..7c66e830516449ef7db70e6e99fdf99ebe0e8178 100644 --- a/image/confusion_mat +++ b/image/confusion_mat @@ -57,3 +57,12 @@ confusion_mat: [ 26 699 3159 775 384] [ 5 380 971 5377 111] [ 13 34 1490 191 12379]] + +DFCNN+resnet18 +avg_recall:0.7844 avg_acc:0.9124 error:0.219 +confusion_mat: +[[ 3623 243 23 2 4] + [ 1905 12703 2452 657 94] + [ 54 671 3371 399 468] + [ 29 379 1454 4759 151] + [ 40 50 1272 124 12881]] diff --git a/models.py b/models.py index e3dac8a7e158b72ee7edfd2421f92933df14cfdd..7effe1d9fdfb8b89a8e2b3577fc9c64944ab4e97 100644 --- a/models.py +++ b/models.py @@ -12,6 +12,8 @@ def CreatNet(name): return CNN() elif name == 'resnet18_1d': return resnet18_1d() + elif name == 'dfcnn': + return dfcnn() elif name in ['resnet101','resnet50','resnet18']: if name =='resnet101': net = torchvision.models.resnet101(pretrained=False) @@ -67,6 +69,72 @@ class LSTM(nn.Module): out = self.out(x) return out +class dfcnn(nn.Module): + def __init__(self): + super(dfcnn, self).__init__() + self.layer1 = nn.Sequential( + nn.Conv2d(1, 32, (3,3), 1, 1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace = True), + nn.Conv2d(32, 32, (3,3), 1, 1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace = True), + nn.MaxPool2d(2), + ) + self.layer2 = nn.Sequential( + nn.Conv2d(32, 64, (3,3), 1, 1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace = True), + nn.Conv2d(64, 64, (3,3), 1, 1, bias=False), + nn.BatchNorm2d(64), + nn.ReLU(inplace = True), + nn.MaxPool2d(2), + ) + self.layer3 = nn.Sequential( + nn.Conv2d(64, 128, (3,3), 1, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace = True), + nn.Conv2d(128, 128, (3,3), 1, 1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace = True), + nn.MaxPool2d(2), + ) + self.layer4 = nn.Sequential( + nn.Conv2d(128, 256, (3,3), 1, 1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace = True), + nn.Conv2d(256, 256, (3,3), 1, 1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace = True), + nn.MaxPool2d(2), + ) + self.layer5 = nn.Sequential( + nn.Conv2d(256, 512, (3,3), 1, 1, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True), + nn.Conv2d(512, 512, (3,3), 1, 1, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(inplace = True), + nn.MaxPool2d(2), + ) + + self.out = nn.Linear(512*7*7, 5) # fully connected layer, output 10 classes + + def forward(self, x): + x = self.layer1(x) + + x = self.layer2(x) + + x = self.layer3(x) + + x = self.layer4(x) + + x = self.layer5(x) + # x = F.avg_pool1d(x, 175) + x = x.view(x.size(0), -1) + output = self.out(x) + return output + class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() diff --git a/train.py b/train.py index a06aec935abb7bedfccbc7678683fe81a0b84ec5..fcc2979feac5c64b205d95809729c1819005f8a4 100644 --- a/train.py +++ b/train.py @@ -48,11 +48,11 @@ weight = torch.from_numpy(weight).float() if not opt.no_cuda: net.cuda() weight = weight.cuda() - cudnn.benchmark = True + # cudnn.benchmark = True # print(weight) # time.sleep(2000) optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr) -scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95) +# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.95) criterion = nn.CrossEntropyLoss(weight) @@ -127,7 +127,7 @@ for epoch in range(opt.epochs): # torch.cuda.empty_cache() evalnet(net,signals_eval,stages_eval,epoch+1,plot_result,mode = 'all') - scheduler.step() + # scheduler.step() if (epoch+1)%opt.network_save_freq == 0: torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth')