diff --git a/README.md b/README.md
index 38b6d653c7a35132cef2541cb20ab04af8bec078..90c2f94a71597696baaf1efc60058b48ec25f16c 100644
--- a/README.md
+++ b/README.md
@@ -1,25 +1,33 @@
# candock
-这是一个用于记录毕业设计的日志仓库,其目的是尝试多种不同的深度神经网络结构(如LSTM,RESNET,DFCNN等)对单通道EEG进行自动化睡眠阶段分期.我们相信这些代码同时可以用于其他生理信号(如ECG,EMG等)的分类.希望这将有助于您的研究.
+[这原本是一个用于记录毕业设计的日志仓库](),其目的是尝试多种不同的深度神经网络结构(如LSTM,ResNet,DFCNN等)对单通道EEG进行自动化睡眠阶段分期.
目前,毕业设计已经完成,我将继续跟进这个项目。项目重点将转变为如何将代码进行实际应用,我们将考虑运算量与准确率之间的平衡。另外,将提供一些预训练的模型便于使用。
同时我们相信这些代码也可以用于其他生理信号(如ECG,EMG等)的分类.希望这将有助于您的研究或项目.
![image](https://github.com/HypoX64/candock/blob/master/image/compare.png)
+
## 如何运行
-如果你需要运行这些代码(训练自己的模型或者使用预训练模型进行测试)请进入以下页面
+如果你需要运行这些代码(训练自己的模型或者使用预训练模型对自己的数据进行预测)请进入以下页面
[How to run codes](https://github.com/HypoX64/candock/blob/master/how_to_run.md)
+
## 数据集
-使用了三个睡眠数据集进行测试,分别是: [[CinC Challenge 2018]](https://physionet.org/physiobank/database/challenge/2018/#files) [[sleep-edf]](https://www.physionet.org/physiobank/database/sleep-edf/) [[sleep-edfx]](https://www.physionet.org/physiobank/database/sleep-edfx/)
-对于CinC Challenge 2018数据集,使用其C4-M1通道
对于sleep-edfx与sleep-edf数据集,使用Fpz-Cz通道
-值得注意的是:sleep-edfx是sleep-edf的扩展版本.
+使用了两个公开的睡眠数据集进行训练,分别是: [[CinC Challenge 2018]](https://physionet.org/physiobank/database/challenge/2018/#files) [[sleep-edfx]](https://www.physionet.org/physiobank/database/sleep-edfx/)
+对于CinC Challenge 2018数据集,我们仅使用其C4-M1通道, 对于sleep-edfx与sleep-edf数据集,使用Fpz-Cz通道
+注意:
+1.如果需要获得其他EEG通道的预训练模型,这需要下载这两个数据集并使用train.py完成训练。当然,你也可以使用自己的数据训练模型。
+2.对于sleep-edfx数据集,我们仅仅截取了入睡前30分钟到醒来后30分钟之间的睡眠区间作为读入数据(实验结果中用select sleep time 进行标注),目的是平衡各睡眠时期的比例并加快训练速度.
## 一些说明
-* 对数据集进行的处理
- 读取数据集各样本后分割为30s/Epoch作为一个输入,共有5个标签,分别是Sleep stage 3,2,1,R,W,将分割后的eeg信号与睡眠阶段标签进行一一对应
+* 数据预处理
- 注意:对于sleep-edfx数据集,我们仅仅截取了入睡前30分钟到醒来后30分钟之间的睡眠区间作为读入数据(实验结果中用only sleep time 进行标注),目的是平衡各睡眠时期的比例并加快训练速度.
+ 1.降采样:CinC Challenge 2018数据集的EEG信号将被降采样到100HZ
-* 数据预处理
- 对于不同的网络结构,对原始eeg信号采取了预处理,使其拥有不同的shape:
+ 2.归一化处理:我们推荐每个受试者的EEG信号均采用5th-95th分位数归一化,即第5%大的数据为0,第95%大的数据为1。注意:所有预训练模型均按照这个方法进行归一化后训练得到
+
+ 3.将读取的数据分割为30s/Epoch作为一个输入,每个输入包含3000个数据点。睡眠阶段标签为5个分别是N3,N2,N1,REM,W.每个Epoch的数据将对应一个标签。标签映射:N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4
+
+ 4.数据集扩充:训练时,对每一个Epoch的数据均要进行随机切,随机翻转,随机改变信号幅度等操作
+
+ 5.对于不同的网络结构,对原始eeg信号采取了预处理,使其拥有不同的shape:
LSTM:将30s的eeg信号进行FIR带通滤波,获得θ,σ,α,δ,β波,并将它们进行连接后作为输入数据
- resnet_1d:这里使用resnet的一维形式进行实验,(修改nn.Conv2d为nn.Conv1d).
- DFCNN:将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类.
+ CNN_1d类(标有1d的网络):没有什么特别的操作,其实就是把图像领域的各种模型换成Conv1d之后拿过来用而已
+ DFCNN类(就是科大讯飞的那种想法,先转化为频谱图,然后直接用图像分类的各种模型):将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类。我们不推荐使用这种方法,因为转化为频谱图需要耗费较大的运算资源。
* EEG频谱图
这里展示5个睡眠阶段对应的频谱图,它们依次是Wake, Stage 1, Stage 2, Stage 3, REM
@@ -30,63 +38,27 @@
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_REM.png)
* multi_scale_resnet_1d 网络结构
- 该网络参考[geekfeiw / Multi-Scale-1D-ResNet](https://github.com/geekfeiw/Multi-Scale-1D-ResNet)
+ 该网络参考[geekfeiw / Multi-Scale-1D-ResNet](https://github.com/geekfeiw/Multi-Scale-1D-ResNet) 这个网络将被我们命名为micro_multi_scale_resnet_1d
修改后的[网络结构](https://github.com/HypoX64/candock/blob/master/image/multi_scale_resnet_1d_network.png)
-* 关于交叉验证
- 为了便于与其他文献中的方法便于比较,使用了两种交叉验证方法
- 1.对于同一数据集,采用5倍K-fold交叉验证
- 2.在不同数据集间进行交叉验证
+* 关于交叉验证
为了更好的进行实际应用,我们将使用受试者交叉验证。即训练集和验证集的数据来自于不同的受试者。值得注意的是sleep-edfx数据集中每个受试者均有两个样本,我们视两个样本为同一个受试者,很多paper忽略了这一点,手动滑稽。
* 关于评估指标
- 对于各标签:
- accuracy = (TP+TN)/(TP+FN+TN+FP)
- recall = sensitivity = (TP)/(TP+FN)
- 对于总体:
- Top1.err.
+ 对于各睡眠阶段标签: Accuracy = (TP+TN)/(TP+FN+TN+FP) Recall = sensitivity = (TP)/(TP+FN)
+ 对于总体: Top1 err. Kappa 另外对Acc与Re做平均
+ 特别说明:这项分类任务中样本标签分布及不平衡,为了更具说服力,我们的平均并不加权。
-* 关于代码
- 目前的代码仍然在不断修改与更新中,不能确保其能工作.详细内容将会在毕业设计完成后抽空更新.
## 部分实验结果
该部分将持续更新... ...
[[Confusion matrix]](https://github.com/HypoX64/candock/blob/master/confusion_mat)
-#### 5-Fold Cross-Validation Results
-* sleep-edf
-
- | Network | Label average recall | Label average accuracy | error rate |
- | :----------------------- | :------------------- | ---------------------- | ---------- |
- | lstm | | | |
- | resnet18_1d | 0.8263 | 0.9601 | 0.0997 |
- | DFCNN+resnet18 | 0.8261 | 0.9594 | 0.1016 |
- | DFCNN+multi_scale_resnet | 0.8196 | 0.9631 | 0.0922 |
- | multi_scale_resnet_1d | 0.8400 | 0.9595 | 0.1013 |
-
-* sleep-edfx(only sleep time)
-
- | Network | Label average recall | Label average accuracy | error rate |
- | :------------- | :------------------- | ---------------------- | ---------- |
- | lstm | | | |
- | resnet18_1d | | | |
- | DFCNN+resnet18 | | | |
- | DFCNN+resnet50 | | | |
-
-* CinC Challenge 2018(sample size = 100)
-
- | Network | Label average recall | Label average accuracy | error rate |
- | :------------- | :------------------- | ---------------------- | ---------- |
- | lstm | | | |
- | resnet18_1d | | | |
- | DFCNN+resnet18 | 0.7823 | 0.909 | 0.2276 |
- | DFCNN+resnet50 | | | |
-
#### Subject Cross-Validation Results
-
-## 心路历程
-* 2019/04/01 DFCNN的运算量也忒大了,提升还不明显,还容易过拟合......真是食之无味,弃之可惜...
-* 2019/04/03 花了一天更新到pytorch 1.0, 然后尝试了一下缩小输入频谱图的尺寸从而减小运算量...
-* 2019/04/04 需要增加k-fold+受试者交叉验证才够严谨...
-* 2019/04/05 清明节…看文献,还是按照大部分人的做法来做吧,使用5倍K-fold和数据集间的交叉验证,这样方便与其他人的方法做横向比较. 不行,这里要吐槽一下,别人做k-fold完全是因为数据集太小了…这上百Gb的数据做K-fold…真的是多此一举,结果根本不会有什么差别…完全是浪费计算资源…
-* 2019/04/09 回老家了,啊!!!!我的毕业论文啊。。。。写不完了!
-* 2019/04/13 回学校撸论文了。
-* 2019/05/02 提交查重了,那这个仓库的更新将告一段落了。今天最后一次大更新。接下来我会开一个新的分支,其目的主要是考虑在实际应用方面的问题,比如运算量和准确率之间的平衡。另外,我也会提供一些不同情况的预训练模型,然很更新一下调用的接口,力求可以傻瓜式的调用我们的程序进行睡眠分期。
\ No newline at end of file
+特别说明:这项分类任务中样本标签分布及不平衡,我们对分类损失函数中的类别权重进行了魔改,这将使得Average Recall得到小幅提升,但同时整体error也将提升.若使用默认权重,Top1 err.至少下降5%,但这会导致数据占比极小的N1时期的recall猛跌20%,这绝对不是我们在实际应用中所希望看到的。下面给出的结果均是使用魔改后的权重得到的。
+* [sleep-edfx](https://www.physionet.org/physiobank/database/sleep-edfx/) ->sample size = 197, select sleep time
+
+| Network | Parameters | Top1.err. | Avg. Acc. | Avg. Re. | Need to extract feature |
+| --------------------------- | ---------- | --------- | --------- | -------- | ----------------------- |
+| lstm | 1.25M | 26.32% | 93.56% | 68.57% | Yes |
+| micro_multi_scale_resnet_1d | 2.11M | 25.33% | 93.87% | 72.61% | No |
+| resnet18_1d | 3.85M | 24.21% | 94.07% | 72.87% | No |
+| multi_scale_resnet_1d | 8.42M | 24.01% | 94.06% | 72.37% | No |
\ No newline at end of file
diff --git a/confusion_mat b/confusion_mat
index 45299813f69ef4fcc0a4734ba1ea784a1b265029..f4737a6dc31d283f777da646519ed3a41b010a79 100644
--- a/confusion_mat
+++ b/confusion_mat
@@ -1,64 +1,76 @@
Confusion matrix:
Pred
-
- S3 S2 S1 REM WAKE
+ S3 S2 S1 REM WAKE
+ S3
S2
True S1
REM
WAKE
-5-Fold Cross-Validation Results
--------------------sleep-edf-------------------
+-------------------sleep-edfx(select sleep time, sample size = 197)-------------------
-resnet18_1d
-final: test avg_recall:0.8263 avg_acc:0.9601 error:0.0997
+----------lstm----------
+train:
+ dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
+test:
+ dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
+net parameters: 1.25M
+final: recall,acc,sp,err,k: (0.6857, 0.8947, 0.9356, 0.2632, 0.6305)
confusion_mat:
-[[1133 148 7 3 0]
- [ 296 2968 120 195 3]
- [ 3 60 368 149 18]
- [ 3 99 135 1342 18]
- [ 9 7 190 37 7729]]
-statistics of dataset [S3 S2 S1 R W]:
-[1291 3582 598 1597 7972]
+[[ 1644 244 11 26 5]
+ [ 1801 11924 1933 2169 229]
+ [ 42 1267 2772 1689 1094]
+ [ 5 193 1845 3539 301]
+ [ 37 128 1513 495 22182]]
-DFCNN+resnet18
-final: test avg_recall:0.8261 avg_acc:0.9594 error:0.1016
+
+----------micro_multi_scale_resnet_1d----------
+train:
+ dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
+test:
+ dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
+net parameters: 2.11M
+final: recall,acc,sp,err,k: (0.7261, 0.8987, 0.9387, 0.2533, 0.648)
confusion_mat:
-[[ 216 44 2 0 6]
- [ 41 640 24 14 6]
- [ 1 9 75 11 10]
- [ 0 20 65 226 3]
- [ 2 2 21 4 1582]]
-statistics of dataset [S3 S2 S1 R W]:
-[ 268 725 106 314 1611]
+[[ 1711 211 2 6 1]
+ [ 1905 12958 1338 1741 99]
+ [ 26 1731 3074 1439 602]
+ [ 1 466 995 4328 97]
+ [ 37 237 2692 837 20554]]
+
-multi_scale_resnet_1d
+----------multi_scale_resnet_1d----------
+train:
+ dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
+test:
+ dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
+net parameters: 8.42M
+final: recall,acc,sp,err,k: (0.7237, 0.904, 0.9406, 0.2401, 0.6631)
confusion_mat:
-final: test avg_recall:0.84 avg_acc:0.9595 error:0.1013
-[[1159 114 10 3 1]
- [ 335 2961 153 143 5]
- [ 4 45 436 102 13]
- [ 0 100 241 1244 9]
- [ 2 2 214 27 7717]]
-statistics of dataset [S3 S2 S1 R W]:
-[1287 3597 600 1594 7962]
+[[ 1629 280 17 3 1]
+ [ 1775 13245 1880 973 176]
+ [ 18 1644 3484 1044 681]
+ [ 11 454 1062 3925 431]
+ [ 38 177 2326 717 21097]]
--------------------sleep-edfx(only sleep time)-------------------
+----------resnet18_1d----------
+train:
+ dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
+test:
+ dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
+net parameters: 3.85M
+final: recall,acc,sp,err,k: (0.7283, 0.9031, 0.9407, 0.2421, 0.6607)
+confusion_mat:
+[[ 1726 171 23 8 1]
+ [ 1906 12100 2336 1530 179]
+ [ 48 1223 3386 1315 898]
+ [ 1 298 1200 3989 398]
+ [ 19 108 1630 531 22064]]
+-------------------CinC Challenge 2018(sample size = 994)-------------------
--------------------CinC Challenge 2018(sample size = 100)-------------------
-DFCNN+resnet18
-final: test avg_recall:0.7823 avg_acc:0.909 error:0.2276
-confusion_mat:
-[[1897 189 3 4 3]
- [ 634 5810 614 362 126]
- [ 1 409 1545 301 406]
- [ 4 200 231 1809 72]
- [ 3 41 377 19 3036]]
-statistics of dataset [S3 S2 S1 R W]:
-[2096 7546 2662 2316 3476]
diff --git a/dataloader.py b/dataloader.py
index f11a54ed6af066824a2471aebb01651069677a5b..260aeee8bba8f1770833a20b22d144a807c432db 100644
--- a/dataloader.py
+++ b/dataloader.py
@@ -107,8 +107,7 @@ def loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time):
events, _ = mne.events_from_annotations(
raw_data, event_id=event_id, chunk_duration=30.)
- stages = []
- signals =[]
+ stages = [];signals =[]
for i in range(len(events)-1):
stages.append(events[i][2])
signals.append(eeg[events[i][0]:events[i][0]+3000])
@@ -129,38 +128,63 @@ def loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time):
return signals.astype(np.float16),stages.astype(np.int16)
#load all data in datasets
-def loaddataset(filedir,dataset_name,signal_name,num,BID,select_sleep_time,shuffle = True):
+def loaddataset(filedir,dataset_name,signal_name,num,BID,select_sleep_time,shuffle = False):
print('load dataset, please wait...')
- filenames = os.listdir(filedir)
- if shuffle:
- random.shuffle(filenames)
- signals=[]
- stages=[]
+
+ signals_train=[];stages_train=[];signals_test=[];stages_test=[]
if dataset_name == 'cc2018':
+ filenames = os.listdir(filedir)
+ if shuffle:
+ random.shuffle(filenames)
+ else:
+ filenames.sort()
+
if num > len(filenames):
num = len(filenames)
print('num of dataset is:',num)
for cnt,filename in enumerate(filenames[:num],0):
- try:
- signal,stage = loaddata_cc2018(filedir,filename,signal_name,BID = BID)
- signals,stages = connectdata(signal,stage,signals,stages)
- except Exception as e:
- print(filename,e)
-
- elif dataset_name in ['sleep-edfx','sleep-edf']:
+ signal,stage = loaddata_cc2018(filedir,filename,signal_name,BID = BID)
+ if cnt < round(num*0.8) :
+ signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train)
+ else:
+ signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test)
+ print('train subjects:',round(num*0.8),'test subjects:',round(num*0.2))
+
+ elif dataset_name == 'sleep-edfx':
if num > 197:
num = 197
- if dataset_name == 'sleep-edf':
- filenames = ['SC4002E0-PSG.edf','SC4012E0-PSG.edf','SC4102E0-PSG.edf','SC4112E0-PSG.edf',
- 'ST7022J0-PSG.edf','ST7052J0-PSG.edf','ST7121J0-PSG.edf','ST7132J0-PSG.edf']
- cnt = 0
- for filename in filenames:
- if 'PSG' in filename:
- signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
- signals,stages = connectdata(signal,stage,signals,stages)
- cnt += 1
- if cnt == num:
- break
- return signals,stages
\ No newline at end of file
+
+ filenames_sc_train = ['SC4001E0-PSG.edf', 'SC4002E0-PSG.edf', 'SC4011E0-PSG.edf', 'SC4012E0-PSG.edf', 'SC4021E0-PSG.edf', 'SC4022E0-PSG.edf', 'SC4031E0-PSG.edf', 'SC4032E0-PSG.edf', 'SC4041E0-PSG.edf', 'SC4042E0-PSG.edf', 'SC4051E0-PSG.edf', 'SC4052E0-PSG.edf', 'SC4061E0-PSG.edf', 'SC4062E0-PSG.edf', 'SC4071E0-PSG.edf', 'SC4072E0-PSG.edf', 'SC4081E0-PSG.edf', 'SC4082E0-PSG.edf', 'SC4091E0-PSG.edf', 'SC4092E0-PSG.edf', 'SC4101E0-PSG.edf', 'SC4102E0-PSG.edf', 'SC4111E0-PSG.edf', 'SC4112E0-PSG.edf', 'SC4121E0-PSG.edf', 'SC4122E0-PSG.edf', 'SC4131E0-PSG.edf', 'SC4141E0-PSG.edf', 'SC4142E0-PSG.edf', 'SC4151E0-PSG.edf', 'SC4152E0-PSG.edf', 'SC4161E0-PSG.edf', 'SC4162E0-PSG.edf', 'SC4171E0-PSG.edf', 'SC4172E0-PSG.edf', 'SC4181E0-PSG.edf', 'SC4182E0-PSG.edf', 'SC4191E0-PSG.edf', 'SC4192E0-PSG.edf', 'SC4201E0-PSG.edf', 'SC4202E0-PSG.edf', 'SC4211E0-PSG.edf', 'SC4212E0-PSG.edf', 'SC4221E0-PSG.edf', 'SC4222E0-PSG.edf', 'SC4231E0-PSG.edf', 'SC4232E0-PSG.edf', 'SC4241E0-PSG.edf', 'SC4242E0-PSG.edf', 'SC4251E0-PSG.edf', 'SC4252E0-PSG.edf', 'SC4261F0-PSG.edf', 'SC4262F0-PSG.edf', 'SC4271F0-PSG.edf', 'SC4272F0-PSG.edf', 'SC4281G0-PSG.edf', 'SC4282G0-PSG.edf', 'SC4291G0-PSG.edf', 'SC4292G0-PSG.edf', 'SC4301E0-PSG.edf', 'SC4302E0-PSG.edf', 'SC4311E0-PSG.edf', 'SC4312E0-PSG.edf', 'SC4321E0-PSG.edf', 'SC4322E0-PSG.edf', 'SC4331F0-PSG.edf', 'SC4332F0-PSG.edf', 'SC4341F0-PSG.edf', 'SC4342F0-PSG.edf', 'SC4351F0-PSG.edf', 'SC4352F0-PSG.edf', 'SC4362F0-PSG.edf', 'SC4371F0-PSG.edf', 'SC4372F0-PSG.edf', 'SC4381F0-PSG.edf', 'SC4382F0-PSG.edf', 'SC4401E0-PSG.edf', 'SC4402E0-PSG.edf', 'SC4411E0-PSG.edf', 'SC4412E0-PSG.edf', 'SC4421E0-PSG.edf', 'SC4422E0-PSG.edf', 'SC4431E0-PSG.edf', 'SC4432E0-PSG.edf', 'SC4441E0-PSG.edf', 'SC4442E0-PSG.edf', 'SC4451F0-PSG.edf', 'SC4452F0-PSG.edf', 'SC4461F0-PSG.edf', 'SC4462F0-PSG.edf', 'SC4471F0-PSG.edf', 'SC4472F0-PSG.edf', 'SC4481F0-PSG.edf', 'SC4482F0-PSG.edf', 'SC4491G0-PSG.edf', 'SC4492G0-PSG.edf', 'SC4501E0-PSG.edf', 'SC4502E0-PSG.edf', 'SC4511E0-PSG.edf', 'SC4512E0-PSG.edf', 'SC4522E0-PSG.edf', 'SC4531E0-PSG.edf', 'SC4532E0-PSG.edf', 'SC4541F0-PSG.edf', 'SC4542F0-PSG.edf', 'SC4551F0-PSG.edf', 'SC4552F0-PSG.edf', 'SC4561F0-PSG.edf', 'SC4562F0-PSG.edf', 'SC4571F0-PSG.edf', 'SC4572F0-PSG.edf', 'SC4581G0-PSG.edf', 'SC4582G0-PSG.edf', 'SC4591G0-PSG.edf', 'SC4592G0-PSG.edf', 'SC4601E0-PSG.edf', 'SC4602E0-PSG.edf', 'SC4611E0-PSG.edf', 'SC4612E0-PSG.edf', 'SC4621E0-PSG.edf', 'SC4622E0-PSG.edf', 'SC4631E0-PSG.edf', 'SC4632E0-PSG.edf']
+ filenames_sc_test = ['SC4641E0-PSG.edf', 'SC4642E0-PSG.edf', 'SC4651E0-PSG.edf', 'SC4652E0-PSG.edf', 'SC4661E0-PSG.edf', 'SC4662E0-PSG.edf', 'SC4671G0-PSG.edf', 'SC4672G0-PSG.edf', 'SC4701E0-PSG.edf', 'SC4702E0-PSG.edf', 'SC4711E0-PSG.edf', 'SC4712E0-PSG.edf', 'SC4721E0-PSG.edf', 'SC4722E0-PSG.edf', 'SC4731E0-PSG.edf', 'SC4732E0-PSG.edf', 'SC4741E0-PSG.edf', 'SC4742E0-PSG.edf', 'SC4751E0-PSG.edf', 'SC4752E0-PSG.edf', 'SC4761E0-PSG.edf', 'SC4762E0-PSG.edf', 'SC4771G0-PSG.edf', 'SC4772G0-PSG.edf', 'SC4801G0-PSG.edf', 'SC4802G0-PSG.edf', 'SC4811G0-PSG.edf', 'SC4812G0-PSG.edf', 'SC4821G0-PSG.edf', 'SC4822G0-PSG.edf']
+ filenames_st_train = ['ST7011J0-PSG.edf', 'ST7012J0-PSG.edf', 'ST7021J0-PSG.edf', 'ST7022J0-PSG.edf', 'ST7041J0-PSG.edf', 'ST7042J0-PSG.edf', 'ST7051J0-PSG.edf', 'ST7052J0-PSG.edf', 'ST7061J0-PSG.edf', 'ST7062J0-PSG.edf', 'ST7071J0-PSG.edf', 'ST7072J0-PSG.edf', 'ST7081J0-PSG.edf', 'ST7082J0-PSG.edf', 'ST7091J0-PSG.edf', 'ST7092J0-PSG.edf', 'ST7101J0-PSG.edf', 'ST7102J0-PSG.edf', 'ST7111J0-PSG.edf', 'ST7112J0-PSG.edf', 'ST7121J0-PSG.edf', 'ST7122J0-PSG.edf', 'ST7131J0-PSG.edf', 'ST7132J0-PSG.edf', 'ST7141J0-PSG.edf', 'ST7142J0-PSG.edf', 'ST7151J0-PSG.edf', 'ST7152J0-PSG.edf', 'ST7161J0-PSG.edf', 'ST7162J0-PSG.edf', 'ST7171J0-PSG.edf', 'ST7172J0-PSG.edf', 'ST7181J0-PSG.edf', 'ST7182J0-PSG.edf', 'ST7191J0-PSG.edf', 'ST7192J0-PSG.edf']
+ filenames_st_test = ['ST7201J0-PSG.edf', 'ST7202J0-PSG.edf', 'ST7211J0-PSG.edf', 'ST7212J0-PSG.edf', 'ST7221J0-PSG.edf', 'ST7222J0-PSG.edf', 'ST7241J0-PSG.edf', 'ST7242J0-PSG.edf']
+
+ for filename in filenames_sc_train[:round(num*153/197*0.8)]:
+ signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
+ signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train)
+
+ for filename in filenames_st_train[:round(num*44/197*0.8)]:
+ signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
+ signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train)
+
+ for filename in filenames_sc_test[:round(num*153/197*0.2)]:
+ signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
+ signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test)
+
+ for filename in filenames_st_test[:round(num*44/197*0.2)]:
+ signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
+ signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test)
+
+ print('---------Each subject has two sample---------',
+ '\nTrain samples_SC/ST:',round(num*153/197*0.8),round(num*44/197*0.8),
+ '\nTest samples_SC/ST:',round(num*153/197*0.2),round(num*44/197*0.2))
+
+ elif dataset_name == 'preload':
+ signals_train = np.load(filedir+'/signals_train.npy')
+ stages_train = np.load(filedir+'/stages_train.npy')
+ signals_test = np.load(filedir+'/signals_test.npy')
+ stages_test = np.load(filedir+'/stages_test.npy')
+
+ return signals_train,stages_train,signals_test,stages_test
\ No newline at end of file
diff --git a/dsp.py b/dsp.py
index e5ba0addf3739bc46d82452711e72fa31f6e6a1f..7a8d12181d66aa0ef6aea67b9312cf5f6d8c32fb 100644
--- a/dsp.py
+++ b/dsp.py
@@ -42,7 +42,8 @@ def BPF(signal,fs,fc1,fc2,mod = 'fir'):
b=getfir_b(fc1,fc2,fs)
result = scipy.signal.lfilter(b, 1, signal)
return result
-def getfeature(signal,mod = 'fir',ch_num = 5):
+
+def getfeature(signal,mod = 'fft',ch_num = 5):
result=[]
signal =signal - np.mean(signal)
eeg=signal
@@ -68,19 +69,6 @@ def getfeature(signal,mod = 'fir',ch_num = 5):
result=result.reshape(ch_num*len(signal),)
return result
-# def signal2spectrum(signal):
-# spectrum =np.zeros((224,224))
-# spectrum_y = np.zeros((224))
-# for i in range(224):
-# signal_window=signal[i*9:i*9+896]
-# signal_window_fft=np.abs(np.fft.fft(signal_window))[0:448]
-# spectrum_y[0:112]=signal_window_fft[0:112]
-# spectrum_y[112:168]=signal_window_fft[112:224][::2]
-# spectrum_y[168:224]=signal_window_fft[224:448][::4]
-# spectrum[:,i] = spectrum_y
-# # spectrum = np.log(spectrum+1)/11
-# return spectrum
-
# def signal2spectrum(data):
# # window : ('tukey',0.5) hann
@@ -96,7 +84,7 @@ def getfeature(signal,mod = 'fir',ch_num = 5):
def signal2spectrum(data):
# window : ('tukey',0.5) hann
- zxx = scipy.signal.stft(data, fs=100, window=('tukey',0.5), nperseg=1024, noverlap=1024-24, nfft=1024, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1)[2]
+ zxx = scipy.signal.stft(data, fs=100, window='hann', nperseg=1024, noverlap=1024-24, nfft=1024, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1)[2]
zxx =np.abs(zxx)[:512]
spectrum=np.zeros((256,126))
spectrum[0:128]=zxx[0:128]
diff --git a/options.py b/options.py
index 6c6614ee7619e7865b361c0d68878d16435a622e..f34dd46ee4c617f6520d58b3bf9da0275533529b 100644
--- a/options.py
+++ b/options.py
@@ -3,8 +3,8 @@ import os
import numpy as np
import torch
-#python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name cc2018 --signal_name 'C4-M1' --sample_num 10 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.0005 --BID 5_95_th --select_sleep_time --cross_validation subject
-# python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num 10 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.0005 --BID 5_95_th --select_sleep_time --cross_validation subject
+# python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name cc2018 --signal_name 'C4-M1' --sample_num 20 --model_name lstm --batchsize 64 --epochs 20 --lr 0.0005 --no_cudnn
+# python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num 50 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 25 --lr 0.0005 --BID 5_95_th --select_sleep_time --no_cudnn --select_sleep_time
class Options():
def __init__(self):
@@ -16,16 +16,14 @@ class Options():
self.parser.add_argument('--no_cudnn', action='store_true', help='if input, do not use cudnn')
self.parser.add_argument('--pretrained', action='store_true', help='if input, use pretrained models')
self.parser.add_argument('--lr', type=float, default=0.0005,help='learning rate')
- self.parser.add_argument('--cross_validation', type=str, default='subject',help='k-fold | subject')
self.parser.add_argument('--BID', type=str, default='5_95_th',help='Balance individualized differences 5_95_th | median |None')
- self.parser.add_argument('--fold_num', type=int, default=5,help='k-fold')
self.parser.add_argument('--batchsize', type=int, default=64,help='batchsize')
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/sleep-edfx/',
help='your dataset path')
- self.parser.add_argument('--dataset_name', type=str, default='sleep-edfx',help='Choose dataset sleep-edfx | sleep-edf | cc2018')
+ self.parser.add_argument('--dataset_name', type=str, default='sleep-edfx',help='Choose dataset sleep-edfx | cc2018')
self.parser.add_argument('--select_sleep_time', action='store_true', help='if input, for sleep-cassette only use sleep time to train')
self.parser.add_argument('--signal_name', type=str, default='EEG Fpz-Cz',help='Choose the EEG channel C4-M1 | EEG Fpz-Cz |...')
- self.parser.add_argument('--sample_num', type=int, default=10,help='the amount you want to load')
+ self.parser.add_argument('--sample_num', type=int, default=20,help='the amount you want to load')
self.parser.add_argument('--model_name', type=str, default='lstm',help='Choose model lstm | multi_scale_resnet_1d | resnet18 |...')
self.parser.add_argument('--epochs', type=int, default=20,help='end epoch')
self.parser.add_argument('--weight_mod', type=str, default='avg_best',help='Choose weight mode: avg_best|normal')
@@ -42,9 +40,5 @@ class Options():
self.opt.sample_num = 8
if self.opt.no_cuda:
self.opt.no_cudnn = True
- if self.opt.fold_num == 0:
- self.opt.fold_num = 1
- if self.opt.cross_validation == 'subject':
- self.opt.fold_num = 1
return self.opt
\ No newline at end of file
diff --git a/statistics.py b/statistics.py
index 36ccdb11e9a7fe4c75eccf89fa976bd68662b3a6..b3a74c17774515e72d42bc4feea7ebd99b909978 100644
--- a/statistics.py
+++ b/statistics.py
@@ -8,7 +8,7 @@ def stage(stages):
for i in range(len(stages)):
stage_cnt[stages[i]] += 1
stage_cnt_per = stage_cnt/len(stages)
- util.writelog('statistics of dataset [S3 S2 S1 R W]: '+str(stage_cnt),True)
+ util.writelog(' dataset statistics [S3 S2 S1 R W]: '+str(stage_cnt),True)
return stage_cnt,stage_cnt_per
def reversal_label(mat):
diff --git a/train.py b/train.py
index b757ae81f7af005b8488b99e07ed0c4d82e55b84..07d890eb7daf8bddc75a0ae470c806397d6e1053 100644
--- a/train.py
+++ b/train.py
@@ -28,26 +28,21 @@ but the data needs meet the following conditions:
3.fs = 100Hz
4.input signal data should be normalized!!
we recommend signal data normalized useing 5_95_th for each subject,
- example: signals_subject=Balance_individualized_differences(signals_subject, '5_95_th')
-5.when useing subject cross validation,we will generally believe [0:80%]and[80%:100%]data come from different subjects.
+ example: signals_normalized=transformer.Balance_individualized_differences(signals_origin, '5_95_th')
'''
-signals,stages = dataloader.loaddataset(opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,opt.BID,opt.select_sleep_time,shuffle = True)
-stage_cnt,stage_cnt_per = statistics.stage(stages)
-
-if opt.cross_validation =='k_fold':
- signals,stages = transformer.batch_generator(signals,stages,opt.batchsize,shuffle = True)
- train_sequences,test_sequences = transformer.k_fold_generator(len(stages),opt.fold_num)
-elif opt.cross_validation =='subject':
- util.writelog('train statistics:',True)
- stage_cnt,stage_cnt_per = statistics.stage(stages[:int(0.8*len(stages))])
- util.writelog('test statistics:',True)
- stage_cnt,stage_cnt_per = statistics.stage(stages[int(0.8*len(stages)):])
- signals,stages = transformer.batch_generator_subject(signals,stages,opt.batchsize,shuffle = False)
- train_sequences,test_sequences = transformer.k_fold_generator(len(stages),1)
-
-batch_length = len(stages)
+signals_train,stages_train,signals_test,stages_test = dataloader.loaddataset(opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,opt.BID,opt.select_sleep_time)
+
+util.writelog('train:',True)
+stage_cnt,stage_cnt_per = statistics.stage(stages_train)
+util.writelog('test:',True)
+_,_ = statistics.stage(stages_test)
+signals_train,stages_train = transformer.batch_generator(signals_train,stages_train,opt.batchsize)
+signals_test,stages_test = transformer.batch_generator(signals_test,stages_test,opt.batchsize)
+
+
+batch_length = len(signals_train)
print('length of batch:',batch_length)
-show_freq = int(len(train_sequences[0])/5)
+show_freq = int(len(stages_train)/5)
t2 = time.time()
print('load data cost time: %.2f'% (t2-t1),'s')
@@ -59,27 +54,30 @@ weight = np.array([1,1,1,1,1])
if opt.weight_mod == 'avg_best':
weight = np.log(1/stage_cnt_per)
weight[2] = weight[2]+1
- weight = np.clip(weight,1,5)
+ weight = np.clip(weight,1,3)
print('Loss_weight:',weight)
weight = torch.from_numpy(weight).float()
# print(net)
if not opt.no_cuda:
net.cuda()
weight = weight.cuda()
+if opt.pretrained:
+ net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.dataset_name+'/'+opt.model_name+'.pth'))
if not opt.no_cudnn:
torch.backends.cudnn.benchmark = True
+
optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
criterion = nn.CrossEntropyLoss(weight)
-def evalnet(net,signals,stages,sequences,epoch,plot_result={}):
+def evalnet(net,signals,stages,epoch,plot_result={}):
# net.eval()
confusion_mat = np.zeros((5,5), dtype=int)
- for i, sequence in enumerate(sequences, 1):
+ for i, (signal,stage) in enumerate(zip(signals,stages), 1):
- signal=transformer.ToInputShape(signals[sequence],opt.model_name,test_flag =True)
- signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda)
+ signal=transformer.ToInputShape(signal,opt.model_name,test_flag =True)
+ signal,stage = transformer.ToTensor(signal,stage,no_cuda =opt.no_cuda)
with torch.no_grad():
out = net(signal)
pred = torch.max(out, 1)[1]
@@ -97,61 +95,53 @@ def evalnet(net,signals,stages,sequences,epoch,plot_result={}):
print('begin to train ...')
final_confusion_mat = np.zeros((5,5), dtype=int)
-for fold in range(opt.fold_num):
- net.load_state_dict(torch.load('./checkpoints/'+opt.model_name+'.pth'))
- if opt.pretrained:
- net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.dataset_name+'/'+opt.model_name+'.pth'))
- if not opt.no_cuda:
- net.cuda()
- plot_result={'train':[1.],'test':[1.]}
- confusion_mats = []
-
- for epoch in range(opt.epochs):
- t1 = time.time()
- confusion_mat = np.zeros((5,5), dtype=int)
- print('fold:',fold+1,'epoch:',epoch+1)
- net.train()
- for i, sequence in enumerate(train_sequences[fold], 1):
-
- signal=transformer.ToInputShape(signals[sequence],opt.model_name,test_flag =False)
- signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda)
-
- out = net(signal)
- loss = criterion(out, stage)
- pred = torch.max(out, 1)[1]
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- pred=pred.data.cpu().numpy()
- stage=stage.data.cpu().numpy()
- for x in range(len(pred)):
- confusion_mat[stage[x]][pred[x]] += 1
- if i%show_freq==0:
- plot_result['train'].append(statistics.result(confusion_mat)[3])
- heatmap.draw(confusion_mat,name = 'train')
- statistics.show(plot_result,epoch+i/(batch_length*0.8))
- confusion_mat[:]=0
-
- plot_result,confusion_mat = evalnet(net,signals,stages,test_sequences[fold],epoch+1,plot_result)
- confusion_mats.append(confusion_mat)
- # scheduler.step()
-
- if (epoch+1)%opt.network_save_freq == 0:
- torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth')
- print('network saved.')
- if not opt.no_cuda:
- net.cuda()
-
- t2=time.time()
- if epoch+1==1:
- print('cost time: %.2f' % (t2-t1),'s')
- pos = plot_result['test'].index(min(plot_result['test']))-1
- final_confusion_mat = final_confusion_mat+confusion_mats[pos]
- util.writelog('fold:'+str(fold+1)+' recall,acc,sp,err,k: '+str(statistics.result(confusion_mats[pos])),True)
- print('------------------')
- util.writelog('confusion_mat:\n'+str(confusion_mat))
+plot_result={'train':[1.],'test':[1.]}
+confusion_mats = []
+
+for epoch in range(opt.epochs):
+ t1 = time.time()
+ confusion_mat = np.zeros((5,5), dtype=int)
+ print('epoch:',epoch+1)
+ net.train()
+ for i, (signal,stage) in enumerate(zip(signals_train,stages_train), 1):
+
+ signal=transformer.ToInputShape(signal,opt.model_name,test_flag =False)
+ signal,stage = transformer.ToTensor(signal,stage,no_cuda =opt.no_cuda)
+
+ out = net(signal)
+ loss = criterion(out, stage)
+ pred = torch.max(out, 1)[1]
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ pred=pred.data.cpu().numpy()
+ stage=stage.data.cpu().numpy()
+ for x in range(len(pred)):
+ confusion_mat[stage[x]][pred[x]] += 1
+ if i%show_freq==0:
+ plot_result['train'].append(statistics.result(confusion_mat)[3])
+ heatmap.draw(confusion_mat,name = 'train')
+ statistics.show(plot_result,epoch+i/(batch_length*0.8))
+ confusion_mat[:]=0
+
+ plot_result,confusion_mat = evalnet(net,signals_test,stages_test,epoch+1,plot_result)
+ confusion_mats.append(confusion_mat)
+ # scheduler.step()
+
+ if (epoch+1)%opt.network_save_freq == 0:
+ torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth')
+ print('network saved.')
+ if not opt.no_cuda:
+ net.cuda()
+
+ t2=time.time()
+ if epoch+1==1:
+ print('cost time: %.2f' % (t2-t1),'s')
+
+pos = plot_result['test'].index(min(plot_result['test']))-1
+final_confusion_mat = confusion_mats[pos]
util.writelog('final: '+'recall,acc,sp,err,k: '+str(statistics.result(final_confusion_mat)),True)
util.writelog('confusion_mat:\n'+str(final_confusion_mat),True)
statistics.stagefrommat(final_confusion_mat)
diff --git a/transformer.py b/transformer.py
index 3c712c8c2b79125ad5ae907c24b20144c2a2c611..64931e450a1755abc7bb5d00481eec756d55402d 100644
--- a/transformer.py
+++ b/transformer.py
@@ -14,16 +14,6 @@ def shuffledata(data,target):
np.random.shuffle(target)
# return data,target
-def batch_generator_subject(data,target,batchsize,shuffle = True):
- data_test = data[int(0.8*len(target)):]
- data_train = data[0:int(0.8*len(target))]
- target_test = target[int(0.8*len(target)):]
- target_train = target[0:int(0.8*len(target))]
- data_test,target_test = batch_generator(data_test, target_test, batchsize)
- data_train,target_train = batch_generator(data_train, target_train, batchsize)
- data = np.concatenate((data_train, data_test), axis=0)
- target = np.concatenate((target_train, target_test), axis=0)
- return data,target
def batch_generator(data,target,batchsize,shuffle = True):
if shuffle:
@@ -34,41 +24,28 @@ def batch_generator(data,target,batchsize,shuffle = True):
target = target.reshape(-1,batchsize)
return data,target
-def k_fold_generator(length,fold_num):
- sequence = np.linspace(0,length-1,length,dtype='int')
- if fold_num == 1:
- train_sequence = sequence[0:int(0.8*length)].reshape(1,-1)
- test_sequence = sequence[int(0.8*length):].reshape(1,-1)
- else:
- train_length = int(length/fold_num*(fold_num-1))
- test_length = int(length/fold_num)
- train_sequence = np.zeros((fold_num,train_length), dtype = 'int')
- test_sequence = np.zeros((fold_num,test_length), dtype = 'int')
- for i in range(fold_num):
- test_sequence[i] = (sequence[test_length*i:test_length*(i+1)])[:test_length]
- train_sequence[i] = np.concatenate((sequence[0:test_length*i],sequence[test_length*(i+1):]),axis=0)[:train_length]
-
- return train_sequence,test_sequence
-
-def Normalize(data,maxmin,avg,sigma):
+def Normalize(data,maxmin,avg,sigma,is_01=False):
data = np.clip(data, -maxmin, maxmin)
- return (data-avg)/sigma
+ if is_01:
+ return (data-avg)/sigma/2+0.5 #(0,1)
+ else:
+ return (data-avg)/sigma #(-1,1)
def Balance_individualized_differences(signals,BID):
if BID == 'median':
signals = (signals*8/(np.median(abs(signals))))
- signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30)
+ signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True)
elif BID == '5_95_th':
tmp = np.sort(signals.reshape(-1))
th_5 = -tmp[int(0.05*len(tmp))]
- signals=Normalize(signals,maxmin=10e3,avg=0,sigma=th_5)
+ signals=Normalize(signals,maxmin=10e3,avg=0,sigma=th_5,is_01=True)
else:
- #dataser 5_95_th median
+ #dataser 5_95_th(-1,1) median
#CC2018 24.75 7.438
#sleep edfx 37.4 9.71
#sleep edfx sleeptime 39.03 10.125
- signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30)
+ signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True)
return signals
def ToTensor(data,target=None,no_cuda = False):
@@ -141,6 +118,7 @@ def ToInputShape(data,net_name,test_flag = False):
elif net_name in ['squeezenet','multi_scale_resnet','dfcnn','resnet18','densenet121','densenet201','resnet101','resnet50']:
result =[]
+ data = (data-0.5)*2
for i in range(0,batchsize):
spectrum = dsp.signal2spectrum(data[i])
spectrum = random_transform_2d(spectrum,(224,122),test_flag=test_flag)
diff --git a/util.py b/util.py
index 71cbc81eded4f35b1ea74becc7377ea54c2278e5..34c5a5f91b8133877b8408f5239a2384c7a8ab33 100644
--- a/util.py
+++ b/util.py
@@ -13,4 +13,6 @@ def writelog(log,printflag = False):
print(log)
def show_paramsnumber(net):
- writelog('net parameters:'+str(sum(param.numel() for param in net.parameters())),True)
+ parameters = sum(param.numel() for param in net.parameters())
+ parameters = round(parameters/1e6,2)
+ writelog('net parameters: '+str(parameters)+'M',True)