diff --git a/README.md b/README.md index 38b6d653c7a35132cef2541cb20ab04af8bec078..90c2f94a71597696baaf1efc60058b48ec25f16c 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,33 @@ # candock -这是一个用于记录毕业设计的日志仓库,其目的是尝试多种不同的深度神经网络结构(如LSTM,RESNET,DFCNN等)对单通道EEG进行自动化睡眠阶段分期.我们相信这些代码同时可以用于其他生理信号(如ECG,EMG等)的分类.希望这将有助于您的研究.
+[这原本是一个用于记录毕业设计的日志仓库](),其目的是尝试多种不同的深度神经网络结构(如LSTM,ResNet,DFCNN等)对单通道EEG进行自动化睡眠阶段分期.
目前,毕业设计已经完成,我将继续跟进这个项目。项目重点将转变为如何将代码进行实际应用,我们将考虑运算量与准确率之间的平衡。另外,将提供一些预训练的模型便于使用。
同时我们相信这些代码也可以用于其他生理信号(如ECG,EMG等)的分类.希望这将有助于您的研究或项目.
![image](https://github.com/HypoX64/candock/blob/master/image/compare.png) + ## 如何运行 -如果你需要运行这些代码(训练自己的模型或者使用预训练模型进行测试)请进入以下页面
+如果你需要运行这些代码(训练自己的模型或者使用预训练模型对自己的数据进行预测)请进入以下页面
[How to run codes](https://github.com/HypoX64/candock/blob/master/how_to_run.md)
+ ## 数据集 -使用了三个睡眠数据集进行测试,分别是: [[CinC Challenge 2018]](https://physionet.org/physiobank/database/challenge/2018/#files) [[sleep-edf]](https://www.physionet.org/physiobank/database/sleep-edf/) [[sleep-edfx]](https://www.physionet.org/physiobank/database/sleep-edfx/)
-对于CinC Challenge 2018数据集,使用其C4-M1通道
对于sleep-edfx与sleep-edf数据集,使用Fpz-Cz通道
-值得注意的是:sleep-edfx是sleep-edf的扩展版本.
+使用了两个公开的睡眠数据集进行训练,分别是: [[CinC Challenge 2018]](https://physionet.org/physiobank/database/challenge/2018/#files) [[sleep-edfx]](https://www.physionet.org/physiobank/database/sleep-edfx/)
+对于CinC Challenge 2018数据集,我们仅使用其C4-M1通道, 对于sleep-edfx与sleep-edf数据集,使用Fpz-Cz通道
+注意:
+1.如果需要获得其他EEG通道的预训练模型,这需要下载这两个数据集并使用train.py完成训练。当然,你也可以使用自己的数据训练模型。
+2.对于sleep-edfx数据集,我们仅仅截取了入睡前30分钟到醒来后30分钟之间的睡眠区间作为读入数据(实验结果中用select sleep time 进行标注),目的是平衡各睡眠时期的比例并加快训练速度.
## 一些说明 -* 对数据集进行的处理
- 读取数据集各样本后分割为30s/Epoch作为一个输入,共有5个标签,分别是Sleep stage 3,2,1,R,W,将分割后的eeg信号与睡眠阶段标签进行一一对应
+* 数据预处理
- 注意:对于sleep-edfx数据集,我们仅仅截取了入睡前30分钟到醒来后30分钟之间的睡眠区间作为读入数据(实验结果中用only sleep time 进行标注),目的是平衡各睡眠时期的比例并加快训练速度. + 1.降采样:CinC Challenge 2018数据集的EEG信号将被降采样到100HZ
-* 数据预处理
- 对于不同的网络结构,对原始eeg信号采取了预处理,使其拥有不同的shape:
+ 2.归一化处理:我们推荐每个受试者的EEG信号均采用5th-95th分位数归一化,即第5%大的数据为0,第95%大的数据为1。注意:所有预训练模型均按照这个方法进行归一化后训练得到
+ + 3.将读取的数据分割为30s/Epoch作为一个输入,每个输入包含3000个数据点。睡眠阶段标签为5个分别是N3,N2,N1,REM,W.每个Epoch的数据将对应一个标签。标签映射:N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4
+ + 4.数据集扩充:训练时,对每一个Epoch的数据均要进行随机切,随机翻转,随机改变信号幅度等操作
+ + 5.对于不同的网络结构,对原始eeg信号采取了预处理,使其拥有不同的shape:
LSTM:将30s的eeg信号进行FIR带通滤波,获得θ,σ,α,δ,β波,并将它们进行连接后作为输入数据
- resnet_1d:这里使用resnet的一维形式进行实验,(修改nn.Conv2d为nn.Conv1d).
- DFCNN:将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类.
+ CNN_1d类(标有1d的网络):没有什么特别的操作,其实就是把图像领域的各种模型换成Conv1d之后拿过来用而已
+ DFCNN类(就是科大讯飞的那种想法,先转化为频谱图,然后直接用图像分类的各种模型):将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类。我们不推荐使用这种方法,因为转化为频谱图需要耗费较大的运算资源。
* EEG频谱图
这里展示5个睡眠阶段对应的频谱图,它们依次是Wake, Stage 1, Stage 2, Stage 3, REM
@@ -30,63 +38,27 @@ ![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_REM.png)
* multi_scale_resnet_1d 网络结构
- 该网络参考[geekfeiw / Multi-Scale-1D-ResNet](https://github.com/geekfeiw/Multi-Scale-1D-ResNet)
+ 该网络参考[geekfeiw / Multi-Scale-1D-ResNet](https://github.com/geekfeiw/Multi-Scale-1D-ResNet) 这个网络将被我们命名为micro_multi_scale_resnet_1d
修改后的[网络结构](https://github.com/HypoX64/candock/blob/master/image/multi_scale_resnet_1d_network.png)
-* 关于交叉验证
- 为了便于与其他文献中的方法便于比较,使用了两种交叉验证方法
- 1.对于同一数据集,采用5倍K-fold交叉验证
- 2.在不同数据集间进行交叉验证
+* 关于交叉验证
为了更好的进行实际应用,我们将使用受试者交叉验证。即训练集和验证集的数据来自于不同的受试者。值得注意的是sleep-edfx数据集中每个受试者均有两个样本,我们视两个样本为同一个受试者,很多paper忽略了这一点,手动滑稽。
* 关于评估指标
- 对于各标签:
- accuracy = (TP+TN)/(TP+FN+TN+FP)
- recall = sensitivity = (TP)/(TP+FN)
- 对于总体:
- Top1.err.
+ 对于各睡眠阶段标签: Accuracy = (TP+TN)/(TP+FN+TN+FP) Recall = sensitivity = (TP)/(TP+FN)
+ 对于总体: Top1 err. Kappa 另外对Acc与Re做平均
+ 特别说明:这项分类任务中样本标签分布及不平衡,为了更具说服力,我们的平均并不加权。 -* 关于代码
- 目前的代码仍然在不断修改与更新中,不能确保其能工作.详细内容将会在毕业设计完成后抽空更新.
## 部分实验结果 该部分将持续更新... ...
[[Confusion matrix]](https://github.com/HypoX64/candock/blob/master/confusion_mat)
-#### 5-Fold Cross-Validation Results -* sleep-edf
- - | Network | Label average recall | Label average accuracy | error rate | - | :----------------------- | :------------------- | ---------------------- | ---------- | - | lstm | | | | - | resnet18_1d | 0.8263 | 0.9601 | 0.0997 | - | DFCNN+resnet18 | 0.8261 | 0.9594 | 0.1016 | - | DFCNN+multi_scale_resnet | 0.8196 | 0.9631 | 0.0922 | - | multi_scale_resnet_1d | 0.8400 | 0.9595 | 0.1013 | - -* sleep-edfx(only sleep time)
- - | Network | Label average recall | Label average accuracy | error rate | - | :------------- | :------------------- | ---------------------- | ---------- | - | lstm | | | | - | resnet18_1d | | | | - | DFCNN+resnet18 | | | | - | DFCNN+resnet50 | | | | - -* CinC Challenge 2018(sample size = 100)
- - | Network | Label average recall | Label average accuracy | error rate | - | :------------- | :------------------- | ---------------------- | ---------- | - | lstm | | | | - | resnet18_1d | | | | - | DFCNN+resnet18 | 0.7823 | 0.909 | 0.2276 | - | DFCNN+resnet50 | | | | - #### Subject Cross-Validation Results - -## 心路历程 -* 2019/04/01 DFCNN的运算量也忒大了,提升还不明显,还容易过拟合......真是食之无味,弃之可惜... -* 2019/04/03 花了一天更新到pytorch 1.0, 然后尝试了一下缩小输入频谱图的尺寸从而减小运算量... -* 2019/04/04 需要增加k-fold+受试者交叉验证才够严谨... -* 2019/04/05 清明节…看文献,还是按照大部分人的做法来做吧,使用5倍K-fold和数据集间的交叉验证,这样方便与其他人的方法做横向比较. 不行,这里要吐槽一下,别人做k-fold完全是因为数据集太小了…这上百Gb的数据做K-fold…真的是多此一举,结果根本不会有什么差别…完全是浪费计算资源… -* 2019/04/09 回老家了,啊!!!!我的毕业论文啊。。。。写不完了! -* 2019/04/13 回学校撸论文了。 -* 2019/05/02 提交查重了,那这个仓库的更新将告一段落了。今天最后一次大更新。接下来我会开一个新的分支,其目的主要是考虑在实际应用方面的问题,比如运算量和准确率之间的平衡。另外,我也会提供一些不同情况的预训练模型,然很更新一下调用的接口,力求可以傻瓜式的调用我们的程序进行睡眠分期。 \ No newline at end of file +特别说明:这项分类任务中样本标签分布及不平衡,我们对分类损失函数中的类别权重进行了魔改,这将使得Average Recall得到小幅提升,但同时整体error也将提升.若使用默认权重,Top1 err.至少下降5%,但这会导致数据占比极小的N1时期的recall猛跌20%,这绝对不是我们在实际应用中所希望看到的。下面给出的结果均是使用魔改后的权重得到的。
+* [sleep-edfx](https://www.physionet.org/physiobank/database/sleep-edfx/) ->sample size = 197, select sleep time + +| Network | Parameters | Top1.err. | Avg. Acc. | Avg. Re. | Need to extract feature | +| --------------------------- | ---------- | --------- | --------- | -------- | ----------------------- | +| lstm | 1.25M | 26.32% | 93.56% | 68.57% | Yes | +| micro_multi_scale_resnet_1d | 2.11M | 25.33% | 93.87% | 72.61% | No | +| resnet18_1d | 3.85M | 24.21% | 94.07% | 72.87% | No | +| multi_scale_resnet_1d | 8.42M | 24.01% | 94.06% | 72.37% | No | \ No newline at end of file diff --git a/confusion_mat b/confusion_mat index 45299813f69ef4fcc0a4734ba1ea784a1b265029..f4737a6dc31d283f777da646519ed3a41b010a79 100644 --- a/confusion_mat +++ b/confusion_mat @@ -1,64 +1,76 @@ Confusion matrix: Pred - - S3 S2 S1 REM WAKE + S3 S2 S1 REM WAKE + S3 S2 True S1 REM WAKE -5-Fold Cross-Validation Results --------------------sleep-edf------------------- +-------------------sleep-edfx(select sleep time, sample size = 197)------------------- -resnet18_1d -final: test avg_recall:0.8263 avg_acc:0.9601 error:0.0997 +----------lstm---------- +train: + dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573] +test: + dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382] +net parameters: 1.25M +final: recall,acc,sp,err,k: (0.6857, 0.8947, 0.9356, 0.2632, 0.6305) confusion_mat: -[[1133 148 7 3 0] - [ 296 2968 120 195 3] - [ 3 60 368 149 18] - [ 3 99 135 1342 18] - [ 9 7 190 37 7729]] -statistics of dataset [S3 S2 S1 R W]: -[1291 3582 598 1597 7972] +[[ 1644 244 11 26 5] + [ 1801 11924 1933 2169 229] + [ 42 1267 2772 1689 1094] + [ 5 193 1845 3539 301] + [ 37 128 1513 495 22182]] -DFCNN+resnet18 -final: test avg_recall:0.8261 avg_acc:0.9594 error:0.1016 + +----------micro_multi_scale_resnet_1d---------- +train: + dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573] +test: + dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382] +net parameters: 2.11M +final: recall,acc,sp,err,k: (0.7261, 0.8987, 0.9387, 0.2533, 0.648) confusion_mat: -[[ 216 44 2 0 6] - [ 41 640 24 14 6] - [ 1 9 75 11 10] - [ 0 20 65 226 3] - [ 2 2 21 4 1582]] -statistics of dataset [S3 S2 S1 R W]: -[ 268 725 106 314 1611] +[[ 1711 211 2 6 1] + [ 1905 12958 1338 1741 99] + [ 26 1731 3074 1439 602] + [ 1 466 995 4328 97] + [ 37 237 2692 837 20554]] + -multi_scale_resnet_1d +----------multi_scale_resnet_1d---------- +train: + dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573] +test: + dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382] +net parameters: 8.42M +final: recall,acc,sp,err,k: (0.7237, 0.904, 0.9406, 0.2401, 0.6631) confusion_mat: -final: test avg_recall:0.84 avg_acc:0.9595 error:0.1013 -[[1159 114 10 3 1] - [ 335 2961 153 143 5] - [ 4 45 436 102 13] - [ 0 100 241 1244 9] - [ 2 2 214 27 7717]] -statistics of dataset [S3 S2 S1 R W]: -[1287 3597 600 1594 7962] +[[ 1629 280 17 3 1] + [ 1775 13245 1880 973 176] + [ 18 1644 3484 1044 681] + [ 11 454 1062 3925 431] + [ 38 177 2326 717 21097]] --------------------sleep-edfx(only sleep time)------------------- +----------resnet18_1d---------- +train: + dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573] +test: + dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382] +net parameters: 3.85M +final: recall,acc,sp,err,k: (0.7283, 0.9031, 0.9407, 0.2421, 0.6607) +confusion_mat: +[[ 1726 171 23 8 1] + [ 1906 12100 2336 1530 179] + [ 48 1223 3386 1315 898] + [ 1 298 1200 3989 398] + [ 19 108 1630 531 22064]] +-------------------CinC Challenge 2018(sample size = 994)------------------- --------------------CinC Challenge 2018(sample size = 100)------------------- -DFCNN+resnet18 -final: test avg_recall:0.7823 avg_acc:0.909 error:0.2276 -confusion_mat: -[[1897 189 3 4 3] - [ 634 5810 614 362 126] - [ 1 409 1545 301 406] - [ 4 200 231 1809 72] - [ 3 41 377 19 3036]] -statistics of dataset [S3 S2 S1 R W]: -[2096 7546 2662 2316 3476] diff --git a/dataloader.py b/dataloader.py index f11a54ed6af066824a2471aebb01651069677a5b..260aeee8bba8f1770833a20b22d144a807c432db 100644 --- a/dataloader.py +++ b/dataloader.py @@ -107,8 +107,7 @@ def loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time): events, _ = mne.events_from_annotations( raw_data, event_id=event_id, chunk_duration=30.) - stages = [] - signals =[] + stages = [];signals =[] for i in range(len(events)-1): stages.append(events[i][2]) signals.append(eeg[events[i][0]:events[i][0]+3000]) @@ -129,38 +128,63 @@ def loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time): return signals.astype(np.float16),stages.astype(np.int16) #load all data in datasets -def loaddataset(filedir,dataset_name,signal_name,num,BID,select_sleep_time,shuffle = True): +def loaddataset(filedir,dataset_name,signal_name,num,BID,select_sleep_time,shuffle = False): print('load dataset, please wait...') - filenames = os.listdir(filedir) - if shuffle: - random.shuffle(filenames) - signals=[] - stages=[] + + signals_train=[];stages_train=[];signals_test=[];stages_test=[] if dataset_name == 'cc2018': + filenames = os.listdir(filedir) + if shuffle: + random.shuffle(filenames) + else: + filenames.sort() + if num > len(filenames): num = len(filenames) print('num of dataset is:',num) for cnt,filename in enumerate(filenames[:num],0): - try: - signal,stage = loaddata_cc2018(filedir,filename,signal_name,BID = BID) - signals,stages = connectdata(signal,stage,signals,stages) - except Exception as e: - print(filename,e) - - elif dataset_name in ['sleep-edfx','sleep-edf']: + signal,stage = loaddata_cc2018(filedir,filename,signal_name,BID = BID) + if cnt < round(num*0.8) : + signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train) + else: + signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test) + print('train subjects:',round(num*0.8),'test subjects:',round(num*0.2)) + + elif dataset_name == 'sleep-edfx': if num > 197: num = 197 - if dataset_name == 'sleep-edf': - filenames = ['SC4002E0-PSG.edf','SC4012E0-PSG.edf','SC4102E0-PSG.edf','SC4112E0-PSG.edf', - 'ST7022J0-PSG.edf','ST7052J0-PSG.edf','ST7121J0-PSG.edf','ST7132J0-PSG.edf'] - cnt = 0 - for filename in filenames: - if 'PSG' in filename: - signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time) - signals,stages = connectdata(signal,stage,signals,stages) - cnt += 1 - if cnt == num: - break - return signals,stages \ No newline at end of file + + filenames_sc_train = ['SC4001E0-PSG.edf', 'SC4002E0-PSG.edf', 'SC4011E0-PSG.edf', 'SC4012E0-PSG.edf', 'SC4021E0-PSG.edf', 'SC4022E0-PSG.edf', 'SC4031E0-PSG.edf', 'SC4032E0-PSG.edf', 'SC4041E0-PSG.edf', 'SC4042E0-PSG.edf', 'SC4051E0-PSG.edf', 'SC4052E0-PSG.edf', 'SC4061E0-PSG.edf', 'SC4062E0-PSG.edf', 'SC4071E0-PSG.edf', 'SC4072E0-PSG.edf', 'SC4081E0-PSG.edf', 'SC4082E0-PSG.edf', 'SC4091E0-PSG.edf', 'SC4092E0-PSG.edf', 'SC4101E0-PSG.edf', 'SC4102E0-PSG.edf', 'SC4111E0-PSG.edf', 'SC4112E0-PSG.edf', 'SC4121E0-PSG.edf', 'SC4122E0-PSG.edf', 'SC4131E0-PSG.edf', 'SC4141E0-PSG.edf', 'SC4142E0-PSG.edf', 'SC4151E0-PSG.edf', 'SC4152E0-PSG.edf', 'SC4161E0-PSG.edf', 'SC4162E0-PSG.edf', 'SC4171E0-PSG.edf', 'SC4172E0-PSG.edf', 'SC4181E0-PSG.edf', 'SC4182E0-PSG.edf', 'SC4191E0-PSG.edf', 'SC4192E0-PSG.edf', 'SC4201E0-PSG.edf', 'SC4202E0-PSG.edf', 'SC4211E0-PSG.edf', 'SC4212E0-PSG.edf', 'SC4221E0-PSG.edf', 'SC4222E0-PSG.edf', 'SC4231E0-PSG.edf', 'SC4232E0-PSG.edf', 'SC4241E0-PSG.edf', 'SC4242E0-PSG.edf', 'SC4251E0-PSG.edf', 'SC4252E0-PSG.edf', 'SC4261F0-PSG.edf', 'SC4262F0-PSG.edf', 'SC4271F0-PSG.edf', 'SC4272F0-PSG.edf', 'SC4281G0-PSG.edf', 'SC4282G0-PSG.edf', 'SC4291G0-PSG.edf', 'SC4292G0-PSG.edf', 'SC4301E0-PSG.edf', 'SC4302E0-PSG.edf', 'SC4311E0-PSG.edf', 'SC4312E0-PSG.edf', 'SC4321E0-PSG.edf', 'SC4322E0-PSG.edf', 'SC4331F0-PSG.edf', 'SC4332F0-PSG.edf', 'SC4341F0-PSG.edf', 'SC4342F0-PSG.edf', 'SC4351F0-PSG.edf', 'SC4352F0-PSG.edf', 'SC4362F0-PSG.edf', 'SC4371F0-PSG.edf', 'SC4372F0-PSG.edf', 'SC4381F0-PSG.edf', 'SC4382F0-PSG.edf', 'SC4401E0-PSG.edf', 'SC4402E0-PSG.edf', 'SC4411E0-PSG.edf', 'SC4412E0-PSG.edf', 'SC4421E0-PSG.edf', 'SC4422E0-PSG.edf', 'SC4431E0-PSG.edf', 'SC4432E0-PSG.edf', 'SC4441E0-PSG.edf', 'SC4442E0-PSG.edf', 'SC4451F0-PSG.edf', 'SC4452F0-PSG.edf', 'SC4461F0-PSG.edf', 'SC4462F0-PSG.edf', 'SC4471F0-PSG.edf', 'SC4472F0-PSG.edf', 'SC4481F0-PSG.edf', 'SC4482F0-PSG.edf', 'SC4491G0-PSG.edf', 'SC4492G0-PSG.edf', 'SC4501E0-PSG.edf', 'SC4502E0-PSG.edf', 'SC4511E0-PSG.edf', 'SC4512E0-PSG.edf', 'SC4522E0-PSG.edf', 'SC4531E0-PSG.edf', 'SC4532E0-PSG.edf', 'SC4541F0-PSG.edf', 'SC4542F0-PSG.edf', 'SC4551F0-PSG.edf', 'SC4552F0-PSG.edf', 'SC4561F0-PSG.edf', 'SC4562F0-PSG.edf', 'SC4571F0-PSG.edf', 'SC4572F0-PSG.edf', 'SC4581G0-PSG.edf', 'SC4582G0-PSG.edf', 'SC4591G0-PSG.edf', 'SC4592G0-PSG.edf', 'SC4601E0-PSG.edf', 'SC4602E0-PSG.edf', 'SC4611E0-PSG.edf', 'SC4612E0-PSG.edf', 'SC4621E0-PSG.edf', 'SC4622E0-PSG.edf', 'SC4631E0-PSG.edf', 'SC4632E0-PSG.edf'] + filenames_sc_test = ['SC4641E0-PSG.edf', 'SC4642E0-PSG.edf', 'SC4651E0-PSG.edf', 'SC4652E0-PSG.edf', 'SC4661E0-PSG.edf', 'SC4662E0-PSG.edf', 'SC4671G0-PSG.edf', 'SC4672G0-PSG.edf', 'SC4701E0-PSG.edf', 'SC4702E0-PSG.edf', 'SC4711E0-PSG.edf', 'SC4712E0-PSG.edf', 'SC4721E0-PSG.edf', 'SC4722E0-PSG.edf', 'SC4731E0-PSG.edf', 'SC4732E0-PSG.edf', 'SC4741E0-PSG.edf', 'SC4742E0-PSG.edf', 'SC4751E0-PSG.edf', 'SC4752E0-PSG.edf', 'SC4761E0-PSG.edf', 'SC4762E0-PSG.edf', 'SC4771G0-PSG.edf', 'SC4772G0-PSG.edf', 'SC4801G0-PSG.edf', 'SC4802G0-PSG.edf', 'SC4811G0-PSG.edf', 'SC4812G0-PSG.edf', 'SC4821G0-PSG.edf', 'SC4822G0-PSG.edf'] + filenames_st_train = ['ST7011J0-PSG.edf', 'ST7012J0-PSG.edf', 'ST7021J0-PSG.edf', 'ST7022J0-PSG.edf', 'ST7041J0-PSG.edf', 'ST7042J0-PSG.edf', 'ST7051J0-PSG.edf', 'ST7052J0-PSG.edf', 'ST7061J0-PSG.edf', 'ST7062J0-PSG.edf', 'ST7071J0-PSG.edf', 'ST7072J0-PSG.edf', 'ST7081J0-PSG.edf', 'ST7082J0-PSG.edf', 'ST7091J0-PSG.edf', 'ST7092J0-PSG.edf', 'ST7101J0-PSG.edf', 'ST7102J0-PSG.edf', 'ST7111J0-PSG.edf', 'ST7112J0-PSG.edf', 'ST7121J0-PSG.edf', 'ST7122J0-PSG.edf', 'ST7131J0-PSG.edf', 'ST7132J0-PSG.edf', 'ST7141J0-PSG.edf', 'ST7142J0-PSG.edf', 'ST7151J0-PSG.edf', 'ST7152J0-PSG.edf', 'ST7161J0-PSG.edf', 'ST7162J0-PSG.edf', 'ST7171J0-PSG.edf', 'ST7172J0-PSG.edf', 'ST7181J0-PSG.edf', 'ST7182J0-PSG.edf', 'ST7191J0-PSG.edf', 'ST7192J0-PSG.edf'] + filenames_st_test = ['ST7201J0-PSG.edf', 'ST7202J0-PSG.edf', 'ST7211J0-PSG.edf', 'ST7212J0-PSG.edf', 'ST7221J0-PSG.edf', 'ST7222J0-PSG.edf', 'ST7241J0-PSG.edf', 'ST7242J0-PSG.edf'] + + for filename in filenames_sc_train[:round(num*153/197*0.8)]: + signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time) + signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train) + + for filename in filenames_st_train[:round(num*44/197*0.8)]: + signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time) + signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train) + + for filename in filenames_sc_test[:round(num*153/197*0.2)]: + signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time) + signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test) + + for filename in filenames_st_test[:round(num*44/197*0.2)]: + signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time) + signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test) + + print('---------Each subject has two sample---------', + '\nTrain samples_SC/ST:',round(num*153/197*0.8),round(num*44/197*0.8), + '\nTest samples_SC/ST:',round(num*153/197*0.2),round(num*44/197*0.2)) + + elif dataset_name == 'preload': + signals_train = np.load(filedir+'/signals_train.npy') + stages_train = np.load(filedir+'/stages_train.npy') + signals_test = np.load(filedir+'/signals_test.npy') + stages_test = np.load(filedir+'/stages_test.npy') + + return signals_train,stages_train,signals_test,stages_test \ No newline at end of file diff --git a/dsp.py b/dsp.py index e5ba0addf3739bc46d82452711e72fa31f6e6a1f..7a8d12181d66aa0ef6aea67b9312cf5f6d8c32fb 100644 --- a/dsp.py +++ b/dsp.py @@ -42,7 +42,8 @@ def BPF(signal,fs,fc1,fc2,mod = 'fir'): b=getfir_b(fc1,fc2,fs) result = scipy.signal.lfilter(b, 1, signal) return result -def getfeature(signal,mod = 'fir',ch_num = 5): + +def getfeature(signal,mod = 'fft',ch_num = 5): result=[] signal =signal - np.mean(signal) eeg=signal @@ -68,19 +69,6 @@ def getfeature(signal,mod = 'fir',ch_num = 5): result=result.reshape(ch_num*len(signal),) return result -# def signal2spectrum(signal): -# spectrum =np.zeros((224,224)) -# spectrum_y = np.zeros((224)) -# for i in range(224): -# signal_window=signal[i*9:i*9+896] -# signal_window_fft=np.abs(np.fft.fft(signal_window))[0:448] -# spectrum_y[0:112]=signal_window_fft[0:112] -# spectrum_y[112:168]=signal_window_fft[112:224][::2] -# spectrum_y[168:224]=signal_window_fft[224:448][::4] -# spectrum[:,i] = spectrum_y -# # spectrum = np.log(spectrum+1)/11 -# return spectrum - # def signal2spectrum(data): # # window : ('tukey',0.5) hann @@ -96,7 +84,7 @@ def getfeature(signal,mod = 'fir',ch_num = 5): def signal2spectrum(data): # window : ('tukey',0.5) hann - zxx = scipy.signal.stft(data, fs=100, window=('tukey',0.5), nperseg=1024, noverlap=1024-24, nfft=1024, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1)[2] + zxx = scipy.signal.stft(data, fs=100, window='hann', nperseg=1024, noverlap=1024-24, nfft=1024, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1)[2] zxx =np.abs(zxx)[:512] spectrum=np.zeros((256,126)) spectrum[0:128]=zxx[0:128] diff --git a/options.py b/options.py index 6c6614ee7619e7865b361c0d68878d16435a622e..f34dd46ee4c617f6520d58b3bf9da0275533529b 100644 --- a/options.py +++ b/options.py @@ -3,8 +3,8 @@ import os import numpy as np import torch -#python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name cc2018 --signal_name 'C4-M1' --sample_num 10 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.0005 --BID 5_95_th --select_sleep_time --cross_validation subject -# python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num 10 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.0005 --BID 5_95_th --select_sleep_time --cross_validation subject +# python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name cc2018 --signal_name 'C4-M1' --sample_num 20 --model_name lstm --batchsize 64 --epochs 20 --lr 0.0005 --no_cudnn +# python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num 50 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 25 --lr 0.0005 --BID 5_95_th --select_sleep_time --no_cudnn --select_sleep_time class Options(): def __init__(self): @@ -16,16 +16,14 @@ class Options(): self.parser.add_argument('--no_cudnn', action='store_true', help='if input, do not use cudnn') self.parser.add_argument('--pretrained', action='store_true', help='if input, use pretrained models') self.parser.add_argument('--lr', type=float, default=0.0005,help='learning rate') - self.parser.add_argument('--cross_validation', type=str, default='subject',help='k-fold | subject') self.parser.add_argument('--BID', type=str, default='5_95_th',help='Balance individualized differences 5_95_th | median |None') - self.parser.add_argument('--fold_num', type=int, default=5,help='k-fold') self.parser.add_argument('--batchsize', type=int, default=64,help='batchsize') self.parser.add_argument('--dataset_dir', type=str, default='./datasets/sleep-edfx/', help='your dataset path') - self.parser.add_argument('--dataset_name', type=str, default='sleep-edfx',help='Choose dataset sleep-edfx | sleep-edf | cc2018') + self.parser.add_argument('--dataset_name', type=str, default='sleep-edfx',help='Choose dataset sleep-edfx | cc2018') self.parser.add_argument('--select_sleep_time', action='store_true', help='if input, for sleep-cassette only use sleep time to train') self.parser.add_argument('--signal_name', type=str, default='EEG Fpz-Cz',help='Choose the EEG channel C4-M1 | EEG Fpz-Cz |...') - self.parser.add_argument('--sample_num', type=int, default=10,help='the amount you want to load') + self.parser.add_argument('--sample_num', type=int, default=20,help='the amount you want to load') self.parser.add_argument('--model_name', type=str, default='lstm',help='Choose model lstm | multi_scale_resnet_1d | resnet18 |...') self.parser.add_argument('--epochs', type=int, default=20,help='end epoch') self.parser.add_argument('--weight_mod', type=str, default='avg_best',help='Choose weight mode: avg_best|normal') @@ -42,9 +40,5 @@ class Options(): self.opt.sample_num = 8 if self.opt.no_cuda: self.opt.no_cudnn = True - if self.opt.fold_num == 0: - self.opt.fold_num = 1 - if self.opt.cross_validation == 'subject': - self.opt.fold_num = 1 return self.opt \ No newline at end of file diff --git a/statistics.py b/statistics.py index 36ccdb11e9a7fe4c75eccf89fa976bd68662b3a6..b3a74c17774515e72d42bc4feea7ebd99b909978 100644 --- a/statistics.py +++ b/statistics.py @@ -8,7 +8,7 @@ def stage(stages): for i in range(len(stages)): stage_cnt[stages[i]] += 1 stage_cnt_per = stage_cnt/len(stages) - util.writelog('statistics of dataset [S3 S2 S1 R W]: '+str(stage_cnt),True) + util.writelog(' dataset statistics [S3 S2 S1 R W]: '+str(stage_cnt),True) return stage_cnt,stage_cnt_per def reversal_label(mat): diff --git a/train.py b/train.py index b757ae81f7af005b8488b99e07ed0c4d82e55b84..07d890eb7daf8bddc75a0ae470c806397d6e1053 100644 --- a/train.py +++ b/train.py @@ -28,26 +28,21 @@ but the data needs meet the following conditions: 3.fs = 100Hz 4.input signal data should be normalized!! we recommend signal data normalized useing 5_95_th for each subject, - example: signals_subject=Balance_individualized_differences(signals_subject, '5_95_th') -5.when useing subject cross validation,we will generally believe [0:80%]and[80%:100%]data come from different subjects. + example: signals_normalized=transformer.Balance_individualized_differences(signals_origin, '5_95_th') ''' -signals,stages = dataloader.loaddataset(opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,opt.BID,opt.select_sleep_time,shuffle = True) -stage_cnt,stage_cnt_per = statistics.stage(stages) - -if opt.cross_validation =='k_fold': - signals,stages = transformer.batch_generator(signals,stages,opt.batchsize,shuffle = True) - train_sequences,test_sequences = transformer.k_fold_generator(len(stages),opt.fold_num) -elif opt.cross_validation =='subject': - util.writelog('train statistics:',True) - stage_cnt,stage_cnt_per = statistics.stage(stages[:int(0.8*len(stages))]) - util.writelog('test statistics:',True) - stage_cnt,stage_cnt_per = statistics.stage(stages[int(0.8*len(stages)):]) - signals,stages = transformer.batch_generator_subject(signals,stages,opt.batchsize,shuffle = False) - train_sequences,test_sequences = transformer.k_fold_generator(len(stages),1) - -batch_length = len(stages) +signals_train,stages_train,signals_test,stages_test = dataloader.loaddataset(opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,opt.BID,opt.select_sleep_time) + +util.writelog('train:',True) +stage_cnt,stage_cnt_per = statistics.stage(stages_train) +util.writelog('test:',True) +_,_ = statistics.stage(stages_test) +signals_train,stages_train = transformer.batch_generator(signals_train,stages_train,opt.batchsize) +signals_test,stages_test = transformer.batch_generator(signals_test,stages_test,opt.batchsize) + + +batch_length = len(signals_train) print('length of batch:',batch_length) -show_freq = int(len(train_sequences[0])/5) +show_freq = int(len(stages_train)/5) t2 = time.time() print('load data cost time: %.2f'% (t2-t1),'s') @@ -59,27 +54,30 @@ weight = np.array([1,1,1,1,1]) if opt.weight_mod == 'avg_best': weight = np.log(1/stage_cnt_per) weight[2] = weight[2]+1 - weight = np.clip(weight,1,5) + weight = np.clip(weight,1,3) print('Loss_weight:',weight) weight = torch.from_numpy(weight).float() # print(net) if not opt.no_cuda: net.cuda() weight = weight.cuda() +if opt.pretrained: + net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.dataset_name+'/'+opt.model_name+'.pth')) if not opt.no_cudnn: torch.backends.cudnn.benchmark = True + optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr) # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8) criterion = nn.CrossEntropyLoss(weight) -def evalnet(net,signals,stages,sequences,epoch,plot_result={}): +def evalnet(net,signals,stages,epoch,plot_result={}): # net.eval() confusion_mat = np.zeros((5,5), dtype=int) - for i, sequence in enumerate(sequences, 1): + for i, (signal,stage) in enumerate(zip(signals,stages), 1): - signal=transformer.ToInputShape(signals[sequence],opt.model_name,test_flag =True) - signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda) + signal=transformer.ToInputShape(signal,opt.model_name,test_flag =True) + signal,stage = transformer.ToTensor(signal,stage,no_cuda =opt.no_cuda) with torch.no_grad(): out = net(signal) pred = torch.max(out, 1)[1] @@ -97,61 +95,53 @@ def evalnet(net,signals,stages,sequences,epoch,plot_result={}): print('begin to train ...') final_confusion_mat = np.zeros((5,5), dtype=int) -for fold in range(opt.fold_num): - net.load_state_dict(torch.load('./checkpoints/'+opt.model_name+'.pth')) - if opt.pretrained: - net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.dataset_name+'/'+opt.model_name+'.pth')) - if not opt.no_cuda: - net.cuda() - plot_result={'train':[1.],'test':[1.]} - confusion_mats = [] - - for epoch in range(opt.epochs): - t1 = time.time() - confusion_mat = np.zeros((5,5), dtype=int) - print('fold:',fold+1,'epoch:',epoch+1) - net.train() - for i, sequence in enumerate(train_sequences[fold], 1): - - signal=transformer.ToInputShape(signals[sequence],opt.model_name,test_flag =False) - signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda) - - out = net(signal) - loss = criterion(out, stage) - pred = torch.max(out, 1)[1] - optimizer.zero_grad() - loss.backward() - optimizer.step() - - pred=pred.data.cpu().numpy() - stage=stage.data.cpu().numpy() - for x in range(len(pred)): - confusion_mat[stage[x]][pred[x]] += 1 - if i%show_freq==0: - plot_result['train'].append(statistics.result(confusion_mat)[3]) - heatmap.draw(confusion_mat,name = 'train') - statistics.show(plot_result,epoch+i/(batch_length*0.8)) - confusion_mat[:]=0 - - plot_result,confusion_mat = evalnet(net,signals,stages,test_sequences[fold],epoch+1,plot_result) - confusion_mats.append(confusion_mat) - # scheduler.step() - - if (epoch+1)%opt.network_save_freq == 0: - torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth') - print('network saved.') - if not opt.no_cuda: - net.cuda() - - t2=time.time() - if epoch+1==1: - print('cost time: %.2f' % (t2-t1),'s') - pos = plot_result['test'].index(min(plot_result['test']))-1 - final_confusion_mat = final_confusion_mat+confusion_mats[pos] - util.writelog('fold:'+str(fold+1)+' recall,acc,sp,err,k: '+str(statistics.result(confusion_mats[pos])),True) - print('------------------') - util.writelog('confusion_mat:\n'+str(confusion_mat)) +plot_result={'train':[1.],'test':[1.]} +confusion_mats = [] + +for epoch in range(opt.epochs): + t1 = time.time() + confusion_mat = np.zeros((5,5), dtype=int) + print('epoch:',epoch+1) + net.train() + for i, (signal,stage) in enumerate(zip(signals_train,stages_train), 1): + + signal=transformer.ToInputShape(signal,opt.model_name,test_flag =False) + signal,stage = transformer.ToTensor(signal,stage,no_cuda =opt.no_cuda) + + out = net(signal) + loss = criterion(out, stage) + pred = torch.max(out, 1)[1] + optimizer.zero_grad() + loss.backward() + optimizer.step() + + pred=pred.data.cpu().numpy() + stage=stage.data.cpu().numpy() + for x in range(len(pred)): + confusion_mat[stage[x]][pred[x]] += 1 + if i%show_freq==0: + plot_result['train'].append(statistics.result(confusion_mat)[3]) + heatmap.draw(confusion_mat,name = 'train') + statistics.show(plot_result,epoch+i/(batch_length*0.8)) + confusion_mat[:]=0 + + plot_result,confusion_mat = evalnet(net,signals_test,stages_test,epoch+1,plot_result) + confusion_mats.append(confusion_mat) + # scheduler.step() + + if (epoch+1)%opt.network_save_freq == 0: + torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth') + print('network saved.') + if not opt.no_cuda: + net.cuda() + + t2=time.time() + if epoch+1==1: + print('cost time: %.2f' % (t2-t1),'s') + +pos = plot_result['test'].index(min(plot_result['test']))-1 +final_confusion_mat = confusion_mats[pos] util.writelog('final: '+'recall,acc,sp,err,k: '+str(statistics.result(final_confusion_mat)),True) util.writelog('confusion_mat:\n'+str(final_confusion_mat),True) statistics.stagefrommat(final_confusion_mat) diff --git a/transformer.py b/transformer.py index 3c712c8c2b79125ad5ae907c24b20144c2a2c611..64931e450a1755abc7bb5d00481eec756d55402d 100644 --- a/transformer.py +++ b/transformer.py @@ -14,16 +14,6 @@ def shuffledata(data,target): np.random.shuffle(target) # return data,target -def batch_generator_subject(data,target,batchsize,shuffle = True): - data_test = data[int(0.8*len(target)):] - data_train = data[0:int(0.8*len(target))] - target_test = target[int(0.8*len(target)):] - target_train = target[0:int(0.8*len(target))] - data_test,target_test = batch_generator(data_test, target_test, batchsize) - data_train,target_train = batch_generator(data_train, target_train, batchsize) - data = np.concatenate((data_train, data_test), axis=0) - target = np.concatenate((target_train, target_test), axis=0) - return data,target def batch_generator(data,target,batchsize,shuffle = True): if shuffle: @@ -34,41 +24,28 @@ def batch_generator(data,target,batchsize,shuffle = True): target = target.reshape(-1,batchsize) return data,target -def k_fold_generator(length,fold_num): - sequence = np.linspace(0,length-1,length,dtype='int') - if fold_num == 1: - train_sequence = sequence[0:int(0.8*length)].reshape(1,-1) - test_sequence = sequence[int(0.8*length):].reshape(1,-1) - else: - train_length = int(length/fold_num*(fold_num-1)) - test_length = int(length/fold_num) - train_sequence = np.zeros((fold_num,train_length), dtype = 'int') - test_sequence = np.zeros((fold_num,test_length), dtype = 'int') - for i in range(fold_num): - test_sequence[i] = (sequence[test_length*i:test_length*(i+1)])[:test_length] - train_sequence[i] = np.concatenate((sequence[0:test_length*i],sequence[test_length*(i+1):]),axis=0)[:train_length] - - return train_sequence,test_sequence - -def Normalize(data,maxmin,avg,sigma): +def Normalize(data,maxmin,avg,sigma,is_01=False): data = np.clip(data, -maxmin, maxmin) - return (data-avg)/sigma + if is_01: + return (data-avg)/sigma/2+0.5 #(0,1) + else: + return (data-avg)/sigma #(-1,1) def Balance_individualized_differences(signals,BID): if BID == 'median': signals = (signals*8/(np.median(abs(signals)))) - signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30) + signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True) elif BID == '5_95_th': tmp = np.sort(signals.reshape(-1)) th_5 = -tmp[int(0.05*len(tmp))] - signals=Normalize(signals,maxmin=10e3,avg=0,sigma=th_5) + signals=Normalize(signals,maxmin=10e3,avg=0,sigma=th_5,is_01=True) else: - #dataser 5_95_th median + #dataser 5_95_th(-1,1) median #CC2018 24.75 7.438 #sleep edfx 37.4 9.71 #sleep edfx sleeptime 39.03 10.125 - signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30) + signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True) return signals def ToTensor(data,target=None,no_cuda = False): @@ -141,6 +118,7 @@ def ToInputShape(data,net_name,test_flag = False): elif net_name in ['squeezenet','multi_scale_resnet','dfcnn','resnet18','densenet121','densenet201','resnet101','resnet50']: result =[] + data = (data-0.5)*2 for i in range(0,batchsize): spectrum = dsp.signal2spectrum(data[i]) spectrum = random_transform_2d(spectrum,(224,122),test_flag=test_flag) diff --git a/util.py b/util.py index 71cbc81eded4f35b1ea74becc7377ea54c2278e5..34c5a5f91b8133877b8408f5239a2384c7a8ab33 100644 --- a/util.py +++ b/util.py @@ -13,4 +13,6 @@ def writelog(log,printflag = False): print(log) def show_paramsnumber(net): - writelog('net parameters:'+str(sum(param.numel() for param in net.parameters())),True) + parameters = sum(param.numel() for param in net.parameters()) + parameters = round(parameters/1e6,2) + writelog('net parameters: '+str(parameters)+'M',True)