提交 7ac8cf6f 编写于 作者: H hypox64

Delete 5-fold cross validation

上级 cf02294c
# candock
这是一个用于记录毕业设计的日志仓库,其目的是尝试多种不同的深度神经网络结构(如LSTM,RESNET,DFCNN等)对单通道EEG进行自动化睡眠阶段分期.我们相信这些代码同时可以用于其他生理信号(如ECG,EMG等)的分类.希望这将有助于您的研究.<br>
[这原本是一个用于记录毕业设计的日志仓库](<https://github.com/HypoX64/candock/tree/Graduation_Project>),其目的是尝试多种不同的深度神经网络结构(如LSTM,ResNet,DFCNN等)对单通道EEG进行自动化睡眠阶段分期.<br>目前,毕业设计已经完成,我将继续跟进这个项目。项目重点将转变为如何将代码进行实际应用,我们将考虑运算量与准确率之间的平衡。另外,将提供一些预训练的模型便于使用。<br>同时我们相信这些代码也可以用于其他生理信号(如ECG,EMG等)的分类.希望这将有助于您的研究或项目.<br>
![image](https://github.com/HypoX64/candock/blob/master/image/compare.png)
## 如何运行
如果你需要运行这些代码(训练自己的模型或者使用预训练模型进行测试)请进入以下页面<br>
如果你需要运行这些代码(训练自己的模型或者使用预训练模型对自己的数据进行预测)请进入以下页面<br>
[How to run codes](https://github.com/HypoX64/candock/blob/master/how_to_run.md)<br>
## 数据集
使用了三个睡眠数据集进行测试,分别是: [[CinC Challenge 2018]](https://physionet.org/physiobank/database/challenge/2018/#files) [[sleep-edf]](https://www.physionet.org/physiobank/database/sleep-edf/) [[sleep-edfx]](https://www.physionet.org/physiobank/database/sleep-edfx/) <br>
对于CinC Challenge 2018数据集,使用其C4-M1通道<br>对于sleep-edfx与sleep-edf数据集,使用Fpz-Cz通道<br>
值得注意的是:sleep-edfx是sleep-edf的扩展版本.<br>
使用了两个公开的睡眠数据集进行训练,分别是: [[CinC Challenge 2018]](https://physionet.org/physiobank/database/challenge/2018/#files) [[sleep-edfx]](https://www.physionet.org/physiobank/database/sleep-edfx/) <br>
对于CinC Challenge 2018数据集,我们仅使用其C4-M1通道, 对于sleep-edfx与sleep-edf数据集,使用Fpz-Cz通道<br>
注意:<br>
1.如果需要获得其他EEG通道的预训练模型,这需要下载这两个数据集并使用train.py完成训练。当然,你也可以使用自己的数据训练模型。<br>
2.对于sleep-edfx数据集,我们仅仅截取了入睡前30分钟到醒来后30分钟之间的睡眠区间作为读入数据(实验结果中用select sleep time 进行标注),目的是平衡各睡眠时期的比例并加快训练速度.<br>
## 一些说明
* 对数据集进行的处理<br>
读取数据集各样本后分割为30s/Epoch作为一个输入,共有5个标签,分别是Sleep stage 3,2,1,R,W,将分割后的eeg信号与睡眠阶段标签进行一一对应<br>
* 数据预处理<br>
注意:对于sleep-edfx数据集,我们仅仅截取了入睡前30分钟到醒来后30分钟之间的睡眠区间作为读入数据(实验结果中用only sleep time 进行标注),目的是平衡各睡眠时期的比例并加快训练速度.
1.降采样:CinC Challenge 2018数据集的EEG信号将被降采样到100HZ<br>
* 数据预处理<br>
对于不同的网络结构,对原始eeg信号采取了预处理,使其拥有不同的shape:<br>
2.归一化处理:我们推荐每个受试者的EEG信号均采用5th-95th分位数归一化,即第5%大的数据为0,第95%大的数据为1。注意:所有预训练模型均按照这个方法进行归一化后训练得到<br>
3.将读取的数据分割为30s/Epoch作为一个输入,每个输入包含3000个数据点。睡眠阶段标签为5个分别是N3,N2,N1,REM,W.每个Epoch的数据将对应一个标签。标签映射:N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4<br>
4.数据集扩充:训练时,对每一个Epoch的数据均要进行随机切,随机翻转,随机改变信号幅度等操作<br>
5.对于不同的网络结构,对原始eeg信号采取了预处理,使其拥有不同的shape:<br>
LSTM:将30s的eeg信号进行FIR带通滤波,获得θ,σ,α,δ,β波,并将它们进行连接后作为输入数据<br>
resnet_1d:这里使用resnet的一维形式进行实验,(修改nn.Conv2d为nn.Conv1d).<br>
DFCNN:将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类.<br>
CNN_1d类(标有1d的网络):没有什么特别的操作,其实就是把图像领域的各种模型换成Conv1d之后拿过来用而已<br>
DFCNN类(就是科大讯飞的那种想法,先转化为频谱图,然后直接用图像分类的各种模型):将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类。我们不推荐使用这种方法,因为转化为频谱图需要耗费较大的运算资源。<br>
* EEG频谱图<br>
这里展示5个睡眠阶段对应的频谱图,它们依次是Wake, Stage 1, Stage 2, Stage 3, REM<br>
......@@ -30,63 +38,27 @@
![image](https://github.com/HypoX64/candock/blob/master/image/spectrum_REM.png)<br>
* multi_scale_resnet_1d 网络结构<br>
该网络参考[geekfeiw / Multi-Scale-1D-ResNet](https://github.com/geekfeiw/Multi-Scale-1D-ResNet)<br>
该网络参考[geekfeiw / Multi-Scale-1D-ResNet](https://github.com/geekfeiw/Multi-Scale-1D-ResNet) 这个网络将被我们命名为micro_multi_scale_resnet_1d<br>
修改后的[网络结构](https://github.com/HypoX64/candock/blob/master/image/multi_scale_resnet_1d_network.png)<br>
* 关于交叉验证<br>
为了便于与其他文献中的方法便于比较,使用了两种交叉验证方法<br>
1.对于同一数据集,采用5倍K-fold交叉验证<br>
2.在不同数据集间进行交叉验证<br>
* 关于交叉验证<br>为了更好的进行实际应用,我们将使用受试者交叉验证。即训练集和验证集的数据来自于不同的受试者。值得注意的是sleep-edfx数据集中每个受试者均有两个样本,我们视两个样本为同一个受试者,很多paper忽略了这一点,手动滑稽。<br>
* 关于评估指标<br>
对于各标签:<br>
accuracy = (TP+TN)/(TP+FN+TN+FP)<br>
recall = sensitivity = (TP)/(TP+FN)<br>
对于总体:<br>
Top1.err.<br>
对于各睡眠阶段标签: Accuracy = (TP+TN)/(TP+FN+TN+FP) Recall = sensitivity = (TP)/(TP+FN)<br>
对于总体: Top1 err. Kappa 另外对Acc与Re做平均<br>
特别说明:这项分类任务中样本标签分布及不平衡,为了更具说服力,我们的平均并不加权。
* 关于代码<br>
目前的代码仍然在不断修改与更新中,不能确保其能工作.详细内容将会在毕业设计完成后抽空更新.<br>
## 部分实验结果
该部分将持续更新... ...<br>
[[Confusion matrix]](https://github.com/HypoX64/candock/blob/master/confusion_mat)<br>
#### 5-Fold Cross-Validation Results
* sleep-edf<br>
| Network | Label average recall | Label average accuracy | error rate |
| :----------------------- | :------------------- | ---------------------- | ---------- |
| lstm | | | |
| resnet18_1d | 0.8263 | 0.9601 | 0.0997 |
| DFCNN+resnet18 | 0.8261 | 0.9594 | 0.1016 |
| DFCNN+multi_scale_resnet | 0.8196 | 0.9631 | 0.0922 |
| multi_scale_resnet_1d | 0.8400 | 0.9595 | 0.1013 |
* sleep-edfx(only sleep time)<br>
| Network | Label average recall | Label average accuracy | error rate |
| :------------- | :------------------- | ---------------------- | ---------- |
| lstm | | | |
| resnet18_1d | | | |
| DFCNN+resnet18 | | | |
| DFCNN+resnet50 | | | |
* CinC Challenge 2018(sample size = 100)<br>
| Network | Label average recall | Label average accuracy | error rate |
| :------------- | :------------------- | ---------------------- | ---------- |
| lstm | | | |
| resnet18_1d | | | |
| DFCNN+resnet18 | 0.7823 | 0.909 | 0.2276 |
| DFCNN+resnet50 | | | |
#### Subject Cross-Validation Results
## 心路历程
* 2019/04/01 DFCNN的运算量也忒大了,提升还不明显,还容易过拟合......真是食之无味,弃之可惜...
* 2019/04/03 花了一天更新到pytorch 1.0, 然后尝试了一下缩小输入频谱图的尺寸从而减小运算量...
* 2019/04/04 需要增加k-fold+受试者交叉验证才够严谨...
* 2019/04/05 清明节…看文献,还是按照大部分人的做法来做吧,使用5倍K-fold和数据集间的交叉验证,这样方便与其他人的方法做横向比较. 不行,这里要吐槽一下,别人做k-fold完全是因为数据集太小了…这上百Gb的数据做K-fold…真的是多此一举,结果根本不会有什么差别…完全是浪费计算资源…
* 2019/04/09 回老家了,啊!!!!我的毕业论文啊。。。。写不完了!
* 2019/04/13 回学校撸论文了。
* 2019/05/02 提交查重了,那这个仓库的更新将告一段落了。今天最后一次大更新。接下来我会开一个新的分支,其目的主要是考虑在实际应用方面的问题,比如运算量和准确率之间的平衡。另外,我也会提供一些不同情况的预训练模型,然很更新一下调用的接口,力求可以傻瓜式的调用我们的程序进行睡眠分期。
\ No newline at end of file
特别说明:这项分类任务中样本标签分布及不平衡,我们对分类损失函数中的类别权重进行了魔改,这将使得Average Recall得到小幅提升,但同时整体error也将提升.若使用默认权重,Top1 err.至少下降5%,但这会导致数据占比极小的N1时期的recall猛跌20%,这绝对不是我们在实际应用中所希望看到的。下面给出的结果均是使用魔改后的权重得到的。<br>
* [sleep-edfx](https://www.physionet.org/physiobank/database/sleep-edfx/) ->sample size = 197, select sleep time
| Network | Parameters | Top1.err. | Avg. Acc. | Avg. Re. | Need to extract feature |
| --------------------------- | ---------- | --------- | --------- | -------- | ----------------------- |
| lstm | 1.25M | 26.32% | 93.56% | 68.57% | Yes |
| micro_multi_scale_resnet_1d | 2.11M | 25.33% | 93.87% | 72.61% | No |
| resnet18_1d | 3.85M | 24.21% | 94.07% | 72.87% | No |
| multi_scale_resnet_1d | 8.42M | 24.01% | 94.06% | 72.37% | No |
\ No newline at end of file
Confusion matrix:
Pred
S3 S2 S1 REM WAKE
S3 S2 S1 REM WAKE
S3
S2
True S1
REM
WAKE
5-Fold Cross-Validation Results
-------------------sleep-edf-------------------
-------------------sleep-edfx(select sleep time, sample size = 197)-------------------
resnet18_1d
final: test avg_recall:0.8263 avg_acc:0.9601 error:0.0997
----------lstm----------
train:
dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
test:
dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
net parameters: 1.25M
final: recall,acc,sp,err,k: (0.6857, 0.8947, 0.9356, 0.2632, 0.6305)
confusion_mat:
[[1133 148 7 3 0]
[ 296 2968 120 195 3]
[ 3 60 368 149 18]
[ 3 99 135 1342 18]
[ 9 7 190 37 7729]]
statistics of dataset [S3 S2 S1 R W]:
[1291 3582 598 1597 7972]
[[ 1644 244 11 26 5]
[ 1801 11924 1933 2169 229]
[ 42 1267 2772 1689 1094]
[ 5 193 1845 3539 301]
[ 37 128 1513 495 22182]]
DFCNN+resnet18
final: test avg_recall:0.8261 avg_acc:0.9594 error:0.1016
----------micro_multi_scale_resnet_1d----------
train:
dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
test:
dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
net parameters: 2.11M
final: recall,acc,sp,err,k: (0.7261, 0.8987, 0.9387, 0.2533, 0.648)
confusion_mat:
[[ 216 44 2 0 6]
[ 41 640 24 14 6]
[ 1 9 75 11 10]
[ 0 20 65 226 3]
[ 2 2 21 4 1582]]
statistics of dataset [S3 S2 S1 R W]:
[ 268 725 106 314 1611]
[[ 1711 211 2 6 1]
[ 1905 12958 1338 1741 99]
[ 26 1731 3074 1439 602]
[ 1 466 995 4328 97]
[ 37 237 2692 837 20554]]
multi_scale_resnet_1d
----------multi_scale_resnet_1d----------
train:
dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
test:
dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
net parameters: 8.42M
final: recall,acc,sp,err,k: (0.7237, 0.904, 0.9406, 0.2401, 0.6631)
confusion_mat:
final: test avg_recall:0.84 avg_acc:0.9595 error:0.1013
[[1159 114 10 3 1]
[ 335 2961 153 143 5]
[ 4 45 436 102 13]
[ 0 100 241 1244 9]
[ 2 2 214 27 7717]]
statistics of dataset [S3 S2 S1 R W]:
[1287 3597 600 1594 7962]
[[ 1629 280 17 3 1]
[ 1775 13245 1880 973 176]
[ 18 1644 3484 1044 681]
[ 11 454 1062 3925 431]
[ 38 177 2326 717 21097]]
-------------------sleep-edfx(only sleep time)-------------------
----------resnet18_1d----------
train:
dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
test:
dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
net parameters: 3.85M
final: recall,acc,sp,err,k: (0.7283, 0.9031, 0.9407, 0.2421, 0.6607)
confusion_mat:
[[ 1726 171 23 8 1]
[ 1906 12100 2336 1530 179]
[ 48 1223 3386 1315 898]
[ 1 298 1200 3989 398]
[ 19 108 1630 531 22064]]
-------------------CinC Challenge 2018(sample size = 994)-------------------
-------------------CinC Challenge 2018(sample size = 100)-------------------
DFCNN+resnet18
final: test avg_recall:0.7823 avg_acc:0.909 error:0.2276
confusion_mat:
[[1897 189 3 4 3]
[ 634 5810 614 362 126]
[ 1 409 1545 301 406]
[ 4 200 231 1809 72]
[ 3 41 377 19 3036]]
statistics of dataset [S3 S2 S1 R W]:
[2096 7546 2662 2316 3476]
......@@ -107,8 +107,7 @@ def loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time):
events, _ = mne.events_from_annotations(
raw_data, event_id=event_id, chunk_duration=30.)
stages = []
signals =[]
stages = [];signals =[]
for i in range(len(events)-1):
stages.append(events[i][2])
signals.append(eeg[events[i][0]:events[i][0]+3000])
......@@ -129,38 +128,63 @@ def loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time):
return signals.astype(np.float16),stages.astype(np.int16)
#load all data in datasets
def loaddataset(filedir,dataset_name,signal_name,num,BID,select_sleep_time,shuffle = True):
def loaddataset(filedir,dataset_name,signal_name,num,BID,select_sleep_time,shuffle = False):
print('load dataset, please wait...')
filenames = os.listdir(filedir)
if shuffle:
random.shuffle(filenames)
signals=[]
stages=[]
signals_train=[];stages_train=[];signals_test=[];stages_test=[]
if dataset_name == 'cc2018':
filenames = os.listdir(filedir)
if shuffle:
random.shuffle(filenames)
else:
filenames.sort()
if num > len(filenames):
num = len(filenames)
print('num of dataset is:',num)
for cnt,filename in enumerate(filenames[:num],0):
try:
signal,stage = loaddata_cc2018(filedir,filename,signal_name,BID = BID)
signals,stages = connectdata(signal,stage,signals,stages)
except Exception as e:
print(filename,e)
elif dataset_name in ['sleep-edfx','sleep-edf']:
signal,stage = loaddata_cc2018(filedir,filename,signal_name,BID = BID)
if cnt < round(num*0.8) :
signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train)
else:
signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test)
print('train subjects:',round(num*0.8),'test subjects:',round(num*0.2))
elif dataset_name == 'sleep-edfx':
if num > 197:
num = 197
if dataset_name == 'sleep-edf':
filenames = ['SC4002E0-PSG.edf','SC4012E0-PSG.edf','SC4102E0-PSG.edf','SC4112E0-PSG.edf',
'ST7022J0-PSG.edf','ST7052J0-PSG.edf','ST7121J0-PSG.edf','ST7132J0-PSG.edf']
cnt = 0
for filename in filenames:
if 'PSG' in filename:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals,stages = connectdata(signal,stage,signals,stages)
cnt += 1
if cnt == num:
break
return signals,stages
\ No newline at end of file
filenames_sc_train = ['SC4001E0-PSG.edf', 'SC4002E0-PSG.edf', 'SC4011E0-PSG.edf', 'SC4012E0-PSG.edf', 'SC4021E0-PSG.edf', 'SC4022E0-PSG.edf', 'SC4031E0-PSG.edf', 'SC4032E0-PSG.edf', 'SC4041E0-PSG.edf', 'SC4042E0-PSG.edf', 'SC4051E0-PSG.edf', 'SC4052E0-PSG.edf', 'SC4061E0-PSG.edf', 'SC4062E0-PSG.edf', 'SC4071E0-PSG.edf', 'SC4072E0-PSG.edf', 'SC4081E0-PSG.edf', 'SC4082E0-PSG.edf', 'SC4091E0-PSG.edf', 'SC4092E0-PSG.edf', 'SC4101E0-PSG.edf', 'SC4102E0-PSG.edf', 'SC4111E0-PSG.edf', 'SC4112E0-PSG.edf', 'SC4121E0-PSG.edf', 'SC4122E0-PSG.edf', 'SC4131E0-PSG.edf', 'SC4141E0-PSG.edf', 'SC4142E0-PSG.edf', 'SC4151E0-PSG.edf', 'SC4152E0-PSG.edf', 'SC4161E0-PSG.edf', 'SC4162E0-PSG.edf', 'SC4171E0-PSG.edf', 'SC4172E0-PSG.edf', 'SC4181E0-PSG.edf', 'SC4182E0-PSG.edf', 'SC4191E0-PSG.edf', 'SC4192E0-PSG.edf', 'SC4201E0-PSG.edf', 'SC4202E0-PSG.edf', 'SC4211E0-PSG.edf', 'SC4212E0-PSG.edf', 'SC4221E0-PSG.edf', 'SC4222E0-PSG.edf', 'SC4231E0-PSG.edf', 'SC4232E0-PSG.edf', 'SC4241E0-PSG.edf', 'SC4242E0-PSG.edf', 'SC4251E0-PSG.edf', 'SC4252E0-PSG.edf', 'SC4261F0-PSG.edf', 'SC4262F0-PSG.edf', 'SC4271F0-PSG.edf', 'SC4272F0-PSG.edf', 'SC4281G0-PSG.edf', 'SC4282G0-PSG.edf', 'SC4291G0-PSG.edf', 'SC4292G0-PSG.edf', 'SC4301E0-PSG.edf', 'SC4302E0-PSG.edf', 'SC4311E0-PSG.edf', 'SC4312E0-PSG.edf', 'SC4321E0-PSG.edf', 'SC4322E0-PSG.edf', 'SC4331F0-PSG.edf', 'SC4332F0-PSG.edf', 'SC4341F0-PSG.edf', 'SC4342F0-PSG.edf', 'SC4351F0-PSG.edf', 'SC4352F0-PSG.edf', 'SC4362F0-PSG.edf', 'SC4371F0-PSG.edf', 'SC4372F0-PSG.edf', 'SC4381F0-PSG.edf', 'SC4382F0-PSG.edf', 'SC4401E0-PSG.edf', 'SC4402E0-PSG.edf', 'SC4411E0-PSG.edf', 'SC4412E0-PSG.edf', 'SC4421E0-PSG.edf', 'SC4422E0-PSG.edf', 'SC4431E0-PSG.edf', 'SC4432E0-PSG.edf', 'SC4441E0-PSG.edf', 'SC4442E0-PSG.edf', 'SC4451F0-PSG.edf', 'SC4452F0-PSG.edf', 'SC4461F0-PSG.edf', 'SC4462F0-PSG.edf', 'SC4471F0-PSG.edf', 'SC4472F0-PSG.edf', 'SC4481F0-PSG.edf', 'SC4482F0-PSG.edf', 'SC4491G0-PSG.edf', 'SC4492G0-PSG.edf', 'SC4501E0-PSG.edf', 'SC4502E0-PSG.edf', 'SC4511E0-PSG.edf', 'SC4512E0-PSG.edf', 'SC4522E0-PSG.edf', 'SC4531E0-PSG.edf', 'SC4532E0-PSG.edf', 'SC4541F0-PSG.edf', 'SC4542F0-PSG.edf', 'SC4551F0-PSG.edf', 'SC4552F0-PSG.edf', 'SC4561F0-PSG.edf', 'SC4562F0-PSG.edf', 'SC4571F0-PSG.edf', 'SC4572F0-PSG.edf', 'SC4581G0-PSG.edf', 'SC4582G0-PSG.edf', 'SC4591G0-PSG.edf', 'SC4592G0-PSG.edf', 'SC4601E0-PSG.edf', 'SC4602E0-PSG.edf', 'SC4611E0-PSG.edf', 'SC4612E0-PSG.edf', 'SC4621E0-PSG.edf', 'SC4622E0-PSG.edf', 'SC4631E0-PSG.edf', 'SC4632E0-PSG.edf']
filenames_sc_test = ['SC4641E0-PSG.edf', 'SC4642E0-PSG.edf', 'SC4651E0-PSG.edf', 'SC4652E0-PSG.edf', 'SC4661E0-PSG.edf', 'SC4662E0-PSG.edf', 'SC4671G0-PSG.edf', 'SC4672G0-PSG.edf', 'SC4701E0-PSG.edf', 'SC4702E0-PSG.edf', 'SC4711E0-PSG.edf', 'SC4712E0-PSG.edf', 'SC4721E0-PSG.edf', 'SC4722E0-PSG.edf', 'SC4731E0-PSG.edf', 'SC4732E0-PSG.edf', 'SC4741E0-PSG.edf', 'SC4742E0-PSG.edf', 'SC4751E0-PSG.edf', 'SC4752E0-PSG.edf', 'SC4761E0-PSG.edf', 'SC4762E0-PSG.edf', 'SC4771G0-PSG.edf', 'SC4772G0-PSG.edf', 'SC4801G0-PSG.edf', 'SC4802G0-PSG.edf', 'SC4811G0-PSG.edf', 'SC4812G0-PSG.edf', 'SC4821G0-PSG.edf', 'SC4822G0-PSG.edf']
filenames_st_train = ['ST7011J0-PSG.edf', 'ST7012J0-PSG.edf', 'ST7021J0-PSG.edf', 'ST7022J0-PSG.edf', 'ST7041J0-PSG.edf', 'ST7042J0-PSG.edf', 'ST7051J0-PSG.edf', 'ST7052J0-PSG.edf', 'ST7061J0-PSG.edf', 'ST7062J0-PSG.edf', 'ST7071J0-PSG.edf', 'ST7072J0-PSG.edf', 'ST7081J0-PSG.edf', 'ST7082J0-PSG.edf', 'ST7091J0-PSG.edf', 'ST7092J0-PSG.edf', 'ST7101J0-PSG.edf', 'ST7102J0-PSG.edf', 'ST7111J0-PSG.edf', 'ST7112J0-PSG.edf', 'ST7121J0-PSG.edf', 'ST7122J0-PSG.edf', 'ST7131J0-PSG.edf', 'ST7132J0-PSG.edf', 'ST7141J0-PSG.edf', 'ST7142J0-PSG.edf', 'ST7151J0-PSG.edf', 'ST7152J0-PSG.edf', 'ST7161J0-PSG.edf', 'ST7162J0-PSG.edf', 'ST7171J0-PSG.edf', 'ST7172J0-PSG.edf', 'ST7181J0-PSG.edf', 'ST7182J0-PSG.edf', 'ST7191J0-PSG.edf', 'ST7192J0-PSG.edf']
filenames_st_test = ['ST7201J0-PSG.edf', 'ST7202J0-PSG.edf', 'ST7211J0-PSG.edf', 'ST7212J0-PSG.edf', 'ST7221J0-PSG.edf', 'ST7222J0-PSG.edf', 'ST7241J0-PSG.edf', 'ST7242J0-PSG.edf']
for filename in filenames_sc_train[:round(num*153/197*0.8)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train)
for filename in filenames_st_train[:round(num*44/197*0.8)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train)
for filename in filenames_sc_test[:round(num*153/197*0.2)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test)
for filename in filenames_st_test[:round(num*44/197*0.2)]:
signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test)
print('---------Each subject has two sample---------',
'\nTrain samples_SC/ST:',round(num*153/197*0.8),round(num*44/197*0.8),
'\nTest samples_SC/ST:',round(num*153/197*0.2),round(num*44/197*0.2))
elif dataset_name == 'preload':
signals_train = np.load(filedir+'/signals_train.npy')
stages_train = np.load(filedir+'/stages_train.npy')
signals_test = np.load(filedir+'/signals_test.npy')
stages_test = np.load(filedir+'/stages_test.npy')
return signals_train,stages_train,signals_test,stages_test
\ No newline at end of file
......@@ -42,7 +42,8 @@ def BPF(signal,fs,fc1,fc2,mod = 'fir'):
b=getfir_b(fc1,fc2,fs)
result = scipy.signal.lfilter(b, 1, signal)
return result
def getfeature(signal,mod = 'fir',ch_num = 5):
def getfeature(signal,mod = 'fft',ch_num = 5):
result=[]
signal =signal - np.mean(signal)
eeg=signal
......@@ -68,19 +69,6 @@ def getfeature(signal,mod = 'fir',ch_num = 5):
result=result.reshape(ch_num*len(signal),)
return result
# def signal2spectrum(signal):
# spectrum =np.zeros((224,224))
# spectrum_y = np.zeros((224))
# for i in range(224):
# signal_window=signal[i*9:i*9+896]
# signal_window_fft=np.abs(np.fft.fft(signal_window))[0:448]
# spectrum_y[0:112]=signal_window_fft[0:112]
# spectrum_y[112:168]=signal_window_fft[112:224][::2]
# spectrum_y[168:224]=signal_window_fft[224:448][::4]
# spectrum[:,i] = spectrum_y
# # spectrum = np.log(spectrum+1)/11
# return spectrum
# def signal2spectrum(data):
# # window : ('tukey',0.5) hann
......@@ -96,7 +84,7 @@ def getfeature(signal,mod = 'fir',ch_num = 5):
def signal2spectrum(data):
# window : ('tukey',0.5) hann
zxx = scipy.signal.stft(data, fs=100, window=('tukey',0.5), nperseg=1024, noverlap=1024-24, nfft=1024, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1)[2]
zxx = scipy.signal.stft(data, fs=100, window='hann', nperseg=1024, noverlap=1024-24, nfft=1024, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1)[2]
zxx =np.abs(zxx)[:512]
spectrum=np.zeros((256,126))
spectrum[0:128]=zxx[0:128]
......
......@@ -3,8 +3,8 @@ import os
import numpy as np
import torch
#python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name cc2018 --signal_name 'C4-M1' --sample_num 10 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.0005 --BID 5_95_th --select_sleep_time --cross_validation subject
# python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num 10 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.0005 --BID 5_95_th --select_sleep_time --cross_validation subject
# python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name cc2018 --signal_name 'C4-M1' --sample_num 20 --model_name lstm --batchsize 64 --epochs 20 --lr 0.0005 --no_cudnn
# python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num 50 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 25 --lr 0.0005 --BID 5_95_th --select_sleep_time --no_cudnn --select_sleep_time
class Options():
def __init__(self):
......@@ -16,16 +16,14 @@ class Options():
self.parser.add_argument('--no_cudnn', action='store_true', help='if input, do not use cudnn')
self.parser.add_argument('--pretrained', action='store_true', help='if input, use pretrained models')
self.parser.add_argument('--lr', type=float, default=0.0005,help='learning rate')
self.parser.add_argument('--cross_validation', type=str, default='subject',help='k-fold | subject')
self.parser.add_argument('--BID', type=str, default='5_95_th',help='Balance individualized differences 5_95_th | median |None')
self.parser.add_argument('--fold_num', type=int, default=5,help='k-fold')
self.parser.add_argument('--batchsize', type=int, default=64,help='batchsize')
self.parser.add_argument('--dataset_dir', type=str, default='./datasets/sleep-edfx/',
help='your dataset path')
self.parser.add_argument('--dataset_name', type=str, default='sleep-edfx',help='Choose dataset sleep-edfx | sleep-edf | cc2018')
self.parser.add_argument('--dataset_name', type=str, default='sleep-edfx',help='Choose dataset sleep-edfx | cc2018')
self.parser.add_argument('--select_sleep_time', action='store_true', help='if input, for sleep-cassette only use sleep time to train')
self.parser.add_argument('--signal_name', type=str, default='EEG Fpz-Cz',help='Choose the EEG channel C4-M1 | EEG Fpz-Cz |...')
self.parser.add_argument('--sample_num', type=int, default=10,help='the amount you want to load')
self.parser.add_argument('--sample_num', type=int, default=20,help='the amount you want to load')
self.parser.add_argument('--model_name', type=str, default='lstm',help='Choose model lstm | multi_scale_resnet_1d | resnet18 |...')
self.parser.add_argument('--epochs', type=int, default=20,help='end epoch')
self.parser.add_argument('--weight_mod', type=str, default='avg_best',help='Choose weight mode: avg_best|normal')
......@@ -42,9 +40,5 @@ class Options():
self.opt.sample_num = 8
if self.opt.no_cuda:
self.opt.no_cudnn = True
if self.opt.fold_num == 0:
self.opt.fold_num = 1
if self.opt.cross_validation == 'subject':
self.opt.fold_num = 1
return self.opt
\ No newline at end of file
......@@ -8,7 +8,7 @@ def stage(stages):
for i in range(len(stages)):
stage_cnt[stages[i]] += 1
stage_cnt_per = stage_cnt/len(stages)
util.writelog('statistics of dataset [S3 S2 S1 R W]: '+str(stage_cnt),True)
util.writelog(' dataset statistics [S3 S2 S1 R W]: '+str(stage_cnt),True)
return stage_cnt,stage_cnt_per
def reversal_label(mat):
......
......@@ -28,26 +28,21 @@ but the data needs meet the following conditions:
3.fs = 100Hz
4.input signal data should be normalized!!
we recommend signal data normalized useing 5_95_th for each subject,
example: signals_subject=Balance_individualized_differences(signals_subject, '5_95_th')
5.when useing subject cross validation,we will generally believe [0:80%]and[80%:100%]data come from different subjects.
example: signals_normalized=transformer.Balance_individualized_differences(signals_origin, '5_95_th')
'''
signals,stages = dataloader.loaddataset(opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,opt.BID,opt.select_sleep_time,shuffle = True)
stage_cnt,stage_cnt_per = statistics.stage(stages)
if opt.cross_validation =='k_fold':
signals,stages = transformer.batch_generator(signals,stages,opt.batchsize,shuffle = True)
train_sequences,test_sequences = transformer.k_fold_generator(len(stages),opt.fold_num)
elif opt.cross_validation =='subject':
util.writelog('train statistics:',True)
stage_cnt,stage_cnt_per = statistics.stage(stages[:int(0.8*len(stages))])
util.writelog('test statistics:',True)
stage_cnt,stage_cnt_per = statistics.stage(stages[int(0.8*len(stages)):])
signals,stages = transformer.batch_generator_subject(signals,stages,opt.batchsize,shuffle = False)
train_sequences,test_sequences = transformer.k_fold_generator(len(stages),1)
batch_length = len(stages)
signals_train,stages_train,signals_test,stages_test = dataloader.loaddataset(opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.sample_num,opt.BID,opt.select_sleep_time)
util.writelog('train:',True)
stage_cnt,stage_cnt_per = statistics.stage(stages_train)
util.writelog('test:',True)
_,_ = statistics.stage(stages_test)
signals_train,stages_train = transformer.batch_generator(signals_train,stages_train,opt.batchsize)
signals_test,stages_test = transformer.batch_generator(signals_test,stages_test,opt.batchsize)
batch_length = len(signals_train)
print('length of batch:',batch_length)
show_freq = int(len(train_sequences[0])/5)
show_freq = int(len(stages_train)/5)
t2 = time.time()
print('load data cost time: %.2f'% (t2-t1),'s')
......@@ -59,27 +54,30 @@ weight = np.array([1,1,1,1,1])
if opt.weight_mod == 'avg_best':
weight = np.log(1/stage_cnt_per)
weight[2] = weight[2]+1
weight = np.clip(weight,1,5)
weight = np.clip(weight,1,3)
print('Loss_weight:',weight)
weight = torch.from_numpy(weight).float()
# print(net)
if not opt.no_cuda:
net.cuda()
weight = weight.cuda()
if opt.pretrained:
net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.dataset_name+'/'+opt.model_name+'.pth'))
if not opt.no_cudnn:
torch.backends.cudnn.benchmark = True
optimizer = torch.optim.Adam(net.parameters(), lr=opt.lr)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
criterion = nn.CrossEntropyLoss(weight)
def evalnet(net,signals,stages,sequences,epoch,plot_result={}):
def evalnet(net,signals,stages,epoch,plot_result={}):
# net.eval()
confusion_mat = np.zeros((5,5), dtype=int)
for i, sequence in enumerate(sequences, 1):
for i, (signal,stage) in enumerate(zip(signals,stages), 1):
signal=transformer.ToInputShape(signals[sequence],opt.model_name,test_flag =True)
signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda)
signal=transformer.ToInputShape(signal,opt.model_name,test_flag =True)
signal,stage = transformer.ToTensor(signal,stage,no_cuda =opt.no_cuda)
with torch.no_grad():
out = net(signal)
pred = torch.max(out, 1)[1]
......@@ -97,61 +95,53 @@ def evalnet(net,signals,stages,sequences,epoch,plot_result={}):
print('begin to train ...')
final_confusion_mat = np.zeros((5,5), dtype=int)
for fold in range(opt.fold_num):
net.load_state_dict(torch.load('./checkpoints/'+opt.model_name+'.pth'))
if opt.pretrained:
net.load_state_dict(torch.load('./checkpoints/pretrained/'+opt.dataset_name+'/'+opt.model_name+'.pth'))
if not opt.no_cuda:
net.cuda()
plot_result={'train':[1.],'test':[1.]}
confusion_mats = []
for epoch in range(opt.epochs):
t1 = time.time()
confusion_mat = np.zeros((5,5), dtype=int)
print('fold:',fold+1,'epoch:',epoch+1)
net.train()
for i, sequence in enumerate(train_sequences[fold], 1):
signal=transformer.ToInputShape(signals[sequence],opt.model_name,test_flag =False)
signal,stage = transformer.ToTensor(signal,stages[sequence],no_cuda =opt.no_cuda)
out = net(signal)
loss = criterion(out, stage)
pred = torch.max(out, 1)[1]
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred=pred.data.cpu().numpy()
stage=stage.data.cpu().numpy()
for x in range(len(pred)):
confusion_mat[stage[x]][pred[x]] += 1
if i%show_freq==0:
plot_result['train'].append(statistics.result(confusion_mat)[3])
heatmap.draw(confusion_mat,name = 'train')
statistics.show(plot_result,epoch+i/(batch_length*0.8))
confusion_mat[:]=0
plot_result,confusion_mat = evalnet(net,signals,stages,test_sequences[fold],epoch+1,plot_result)
confusion_mats.append(confusion_mat)
# scheduler.step()
if (epoch+1)%opt.network_save_freq == 0:
torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth')
print('network saved.')
if not opt.no_cuda:
net.cuda()
t2=time.time()
if epoch+1==1:
print('cost time: %.2f' % (t2-t1),'s')
pos = plot_result['test'].index(min(plot_result['test']))-1
final_confusion_mat = final_confusion_mat+confusion_mats[pos]
util.writelog('fold:'+str(fold+1)+' recall,acc,sp,err,k: '+str(statistics.result(confusion_mats[pos])),True)
print('------------------')
util.writelog('confusion_mat:\n'+str(confusion_mat))
plot_result={'train':[1.],'test':[1.]}
confusion_mats = []
for epoch in range(opt.epochs):
t1 = time.time()
confusion_mat = np.zeros((5,5), dtype=int)
print('epoch:',epoch+1)
net.train()
for i, (signal,stage) in enumerate(zip(signals_train,stages_train), 1):
signal=transformer.ToInputShape(signal,opt.model_name,test_flag =False)
signal,stage = transformer.ToTensor(signal,stage,no_cuda =opt.no_cuda)
out = net(signal)
loss = criterion(out, stage)
pred = torch.max(out, 1)[1]
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred=pred.data.cpu().numpy()
stage=stage.data.cpu().numpy()
for x in range(len(pred)):
confusion_mat[stage[x]][pred[x]] += 1
if i%show_freq==0:
plot_result['train'].append(statistics.result(confusion_mat)[3])
heatmap.draw(confusion_mat,name = 'train')
statistics.show(plot_result,epoch+i/(batch_length*0.8))
confusion_mat[:]=0
plot_result,confusion_mat = evalnet(net,signals_test,stages_test,epoch+1,plot_result)
confusion_mats.append(confusion_mat)
# scheduler.step()
if (epoch+1)%opt.network_save_freq == 0:
torch.save(net.cpu().state_dict(),'./checkpoints/'+opt.model_name+'_epoch'+str(epoch+1)+'.pth')
print('network saved.')
if not opt.no_cuda:
net.cuda()
t2=time.time()
if epoch+1==1:
print('cost time: %.2f' % (t2-t1),'s')
pos = plot_result['test'].index(min(plot_result['test']))-1
final_confusion_mat = confusion_mats[pos]
util.writelog('final: '+'recall,acc,sp,err,k: '+str(statistics.result(final_confusion_mat)),True)
util.writelog('confusion_mat:\n'+str(final_confusion_mat),True)
statistics.stagefrommat(final_confusion_mat)
......
......@@ -14,16 +14,6 @@ def shuffledata(data,target):
np.random.shuffle(target)
# return data,target
def batch_generator_subject(data,target,batchsize,shuffle = True):
data_test = data[int(0.8*len(target)):]
data_train = data[0:int(0.8*len(target))]
target_test = target[int(0.8*len(target)):]
target_train = target[0:int(0.8*len(target))]
data_test,target_test = batch_generator(data_test, target_test, batchsize)
data_train,target_train = batch_generator(data_train, target_train, batchsize)
data = np.concatenate((data_train, data_test), axis=0)
target = np.concatenate((target_train, target_test), axis=0)
return data,target
def batch_generator(data,target,batchsize,shuffle = True):
if shuffle:
......@@ -34,41 +24,28 @@ def batch_generator(data,target,batchsize,shuffle = True):
target = target.reshape(-1,batchsize)
return data,target
def k_fold_generator(length,fold_num):
sequence = np.linspace(0,length-1,length,dtype='int')
if fold_num == 1:
train_sequence = sequence[0:int(0.8*length)].reshape(1,-1)
test_sequence = sequence[int(0.8*length):].reshape(1,-1)
else:
train_length = int(length/fold_num*(fold_num-1))
test_length = int(length/fold_num)
train_sequence = np.zeros((fold_num,train_length), dtype = 'int')
test_sequence = np.zeros((fold_num,test_length), dtype = 'int')
for i in range(fold_num):
test_sequence[i] = (sequence[test_length*i:test_length*(i+1)])[:test_length]
train_sequence[i] = np.concatenate((sequence[0:test_length*i],sequence[test_length*(i+1):]),axis=0)[:train_length]
return train_sequence,test_sequence
def Normalize(data,maxmin,avg,sigma):
def Normalize(data,maxmin,avg,sigma,is_01=False):
data = np.clip(data, -maxmin, maxmin)
return (data-avg)/sigma
if is_01:
return (data-avg)/sigma/2+0.5 #(0,1)
else:
return (data-avg)/sigma #(-1,1)
def Balance_individualized_differences(signals,BID):
if BID == 'median':
signals = (signals*8/(np.median(abs(signals))))
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30)
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True)
elif BID == '5_95_th':
tmp = np.sort(signals.reshape(-1))
th_5 = -tmp[int(0.05*len(tmp))]
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=th_5)
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=th_5,is_01=True)
else:
#dataser 5_95_th median
#dataser 5_95_th(-1,1) median
#CC2018 24.75 7.438
#sleep edfx 37.4 9.71
#sleep edfx sleeptime 39.03 10.125
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30)
signals=Normalize(signals,maxmin=10e3,avg=0,sigma=30,is_01=True)
return signals
def ToTensor(data,target=None,no_cuda = False):
......@@ -141,6 +118,7 @@ def ToInputShape(data,net_name,test_flag = False):
elif net_name in ['squeezenet','multi_scale_resnet','dfcnn','resnet18','densenet121','densenet201','resnet101','resnet50']:
result =[]
data = (data-0.5)*2
for i in range(0,batchsize):
spectrum = dsp.signal2spectrum(data[i])
spectrum = random_transform_2d(spectrum,(224,122),test_flag=test_flag)
......
......@@ -13,4 +13,6 @@ def writelog(log,printflag = False):
print(log)
def show_paramsnumber(net):
writelog('net parameters:'+str(sum(param.numel() for param in net.parameters())),True)
parameters = sum(param.numel() for param in net.parameters())
parameters = round(parameters/1e6,2)
writelog('net parameters: '+str(parameters)+'M',True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册