Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Hypo
candock
提交
7ac8cf6f
C
candock
项目概览
Hypo
/
candock
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
C
candock
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
7ac8cf6f
编写于
5月 17, 2019
作者:
H
hypox64
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Delete 5-fold cross validation
上级
cf02294c
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
228 addition
and
268 deletion
+228
-268
README.md
README.md
+34
-62
confusion_mat
confusion_mat
+55
-43
dataloader.py
dataloader.py
+51
-27
dsp.py
dsp.py
+3
-15
options.py
options.py
+4
-10
statistics.py
statistics.py
+1
-1
train.py
train.py
+67
-77
transformer.py
transformer.py
+10
-32
util.py
util.py
+3
-1
未找到文件。
README.md
浏览文件 @
7ac8cf6f
# candock
这是一个用于记录毕业设计的日志仓库,其目的是尝试多种不同的深度神经网络结构(如LSTM,RESNET,DFCNN等)对单通道EEG进行自动化睡眠阶段分期.我们相信这些代码同时可以用于其他生理信号(如ECG,EMG等)的分类.希望这将有助于您的研究
.
<br>
[
这原本是一个用于记录毕业设计的日志仓库
](
<
https://github.com/HypoX64/candock/tree/Graduation_Project
>
)
,其目的是尝试多种不同的深度神经网络结构(如LSTM,ResNet,DFCNN等)对单通道EEG进行自动化睡眠阶段分期.
<br>
目前,毕业设计已经完成,我将继续跟进这个项目。项目重点将转变为如何将代码进行实际应用,我们将考虑运算量与准确率之间的平衡。另外,将提供一些预训练的模型便于使用。
<br>
同时我们相信这些代码也可以用于其他生理信号(如ECG,EMG等)的分类.希望这将有助于您的研究或项目
.
<br>
![
image
](
https://github.com/HypoX64/candock/blob/master/image/compare.png
)
## 如何运行
如果你需要运行这些代码(训练自己的模型或者使用预训练模型
进行测试
)请进入以下页面
<br>
如果你需要运行这些代码(训练自己的模型或者使用预训练模型
对自己的数据进行预测
)请进入以下页面
<br>
[
How to run codes
](
https://github.com/HypoX64/candock/blob/master/how_to_run.md
)
<br>
## 数据集
使用了三个睡眠数据集进行测试,分别是:
[
[CinC Challenge 2018]
](
https://physionet.org/physiobank/database/challenge/2018/#files
)
[
[sleep-edf
]
](https://www.physionet.org/physiobank/database/sleep-edf/)
[
[sleep-edfx]
](
https://www.physionet.org/physiobank/database/sleep-edfx/
)
<br>
对于CinC Challenge 2018数据集,使用其C4-M1通道
<br>
对于sleep-edfx与sleep-edf数据集,使用Fpz-Cz通道
<br>
值得注意的是:sleep-edfx是sleep-edf的扩展版本.
<br>
使用了两个公开的睡眠数据集进行训练,分别是:
[
[CinC Challenge 2018]
](
https://physionet.org/physiobank/database/challenge/2018/#files
)
[
[sleep-edfx
]
](https://www.physionet.org/physiobank/database/sleep-edfx/)
<br>
对于CinC Challenge 2018数据集,我们仅使用其C4-M1通道, 对于sleep-edfx与sleep-edf数据集,使用Fpz-Cz通道
<br>
注意:
<br>
1.
如果需要获得其他EEG通道的预训练模型,这需要下载这两个数据集并使用train.py完成训练。当然,你也可以使用自己的数据训练模型。
<br>
2.
对于sleep-edfx数据集,我们仅仅截取了入睡前30分钟到醒来后30分钟之间的睡眠区间作为读入数据(实验结果中用select sleep time 进行标注),目的是平衡各睡眠时期的比例并加快训练速度.
<br>
## 一些说明
*
对数据集进行的处理
<br>
读取数据集各样本后分割为30s/Epoch作为一个输入,共有5个标签,分别是Sleep stage 3,2,1,R,W,将分割后的eeg信号与睡眠阶段标签进行一一对应
<br>
*
数据预处理
<br>
注意:对于sleep-edfx数据集,我们仅仅截取了入睡前30分钟到醒来后30分钟之间的睡眠区间作为读入数据(实验结果中用only sleep time 进行标注),目的是平衡各睡眠时期的比例并加快训练速度.
1.
降采样:CinC Challenge 2018数据集的EEG信号将被降采样到100HZ
<br>
*
数据预处理
<br>
对于不同的网络结构,对原始eeg信号采取了预处理,使其拥有不同的shape:
<br>
2.
归一化处理:我们推荐每个受试者的EEG信号均采用5th-95th分位数归一化,即第5%大的数据为0,第95%大的数据为1。注意:所有预训练模型均按照这个方法进行归一化后训练得到
<br>
3.
将读取的数据分割为30s/Epoch作为一个输入,每个输入包含3000个数据点。睡眠阶段标签为5个分别是N3,N2,N1,REM,W.每个Epoch的数据将对应一个标签。标签映射:N3(S4+S3)->0 N2->1 N1->2 REM->3 W->4
<br>
4.
数据集扩充:训练时,对每一个Epoch的数据均要进行随机切,随机翻转,随机改变信号幅度等操作
<br>
5.
对于不同的网络结构,对原始eeg信号采取了预处理,使其拥有不同的shape:
<br>
LSTM:将30s的eeg信号进行FIR带通滤波,获得θ,σ,α,δ,β波,并将它们进行连接后作为输入数据
<br>
resnet_1d:这里使用resnet的一维形式进行实验,(修改nn.Conv2d为nn.Conv1d).
<br>
DFCNN
:将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类.
<br>
CNN_1d类(标有1d的网络):没有什么特别的操作,其实就是把图像领域的各种模型换成Conv1d之后拿过来用而已
<br>
DFCNN
类(就是科大讯飞的那种想法,先转化为频谱图,然后直接用图像分类的各种模型):将30s的eeg信号进行短时傅里叶变换,并生成频谱图作为输入,并使用图像分类网络进行分类。我们不推荐使用这种方法,因为转化为频谱图需要耗费较大的运算资源。
<br>
*
EEG频谱图
<br>
这里展示5个睡眠阶段对应的频谱图,它们依次是Wake, Stage 1, Stage 2, Stage 3, REM
<br>
...
...
@@ -30,63 +38,27 @@
!
[
image
](
https://github.com/HypoX64/candock/blob/master/image/spectrum_REM.png
)
<br>
*
multi_scale_resnet_1d 网络结构
<br>
该网络参考
[
geekfeiw / Multi-Scale-1D-ResNet
](
https://github.com/geekfeiw/Multi-Scale-1D-ResNet
)
<br>
该网络参考
[
geekfeiw / Multi-Scale-1D-ResNet
](
https://github.com/geekfeiw/Multi-Scale-1D-ResNet
)
这个网络将被我们命名为micro_multi_scale_resnet_1d
<br>
修改后的
[
网络结构
](
https://github.com/HypoX64/candock/blob/master/image/multi_scale_resnet_1d_network.png
)
<br>
*
关于交叉验证
<br>
为了便于与其他文献中的方法便于比较,使用了两种交叉验证方法
<br>
1.
对于同一数据集,采用5倍K-fold交叉验证
<br>
2.
在不同数据集间进行交叉验证
<br>
*
关于交叉验证
<br>
为了更好的进行实际应用,我们将使用受试者交叉验证。即训练集和验证集的数据来自于不同的受试者。值得注意的是sleep-edfx数据集中每个受试者均有两个样本,我们视两个样本为同一个受试者,很多paper忽略了这一点,手动滑稽。
<br>
*
关于评估指标
<br>
对于各标签:
<br>
accuracy = (TP+TN)/(TP+FN+TN+FP)
<br>
recall = sensitivity = (TP)/(TP+FN)
<br>
对于总体:
<br>
Top1.err.
<br>
对于各睡眠阶段标签: Accuracy = (TP+TN)/(TP+FN+TN+FP) Recall = sensitivity = (TP)/(TP+FN)
<br>
对于总体: Top1 err. Kappa 另外对Acc与Re做平均
<br>
特别说明:这项分类任务中样本标签分布及不平衡,为了更具说服力,我们的平均并不加权。
*
关于代码
<br>
目前的代码仍然在不断修改与更新中,不能确保其能工作.详细内容将会在毕业设计完成后抽空更新.
<br>
## 部分实验结果
该部分将持续更新... ...
<br>
[
[Confusion matrix]
](
https://github.com/HypoX64/candock/blob/master/confusion_mat
)
<br>
#### 5-Fold Cross-Validation Results
*
sleep-edf
<br>
| Network | Label average recall | Label average accuracy | error rate |
| :----------------------- | :------------------- | ---------------------- | ---------- |
| lstm | | | |
| resnet18_1d | 0.8263 | 0.9601 | 0.0997 |
| DFCNN+resnet18 | 0.8261 | 0.9594 | 0.1016 |
| DFCNN+multi_scale_resnet | 0.8196 | 0.9631 | 0.0922 |
| multi_scale_resnet_1d | 0.8400 | 0.9595 | 0.1013 |
*
sleep-edfx(only sleep time)
<br>
| Network | Label average recall | Label average accuracy | error rate |
| :------------- | :------------------- | ---------------------- | ---------- |
| lstm | | | |
| resnet18_1d | | | |
| DFCNN+resnet18 | | | |
| DFCNN+resnet50 | | | |
*
CinC Challenge 2018(sample size = 100)
<br>
| Network | Label average recall | Label average accuracy | error rate |
| :------------- | :------------------- | ---------------------- | ---------- |
| lstm | | | |
| resnet18_1d | | | |
| DFCNN+resnet18 | 0.7823 | 0.909 | 0.2276 |
| DFCNN+resnet50 | | | |
#### Subject Cross-Validation Results
## 心路历程
*
2019/04/01 DFCNN的运算量也忒大了,提升还不明显,还容易过拟合......真是食之无味,弃之可惜...
*
2019/04/03 花了一天更新到pytorch 1.0, 然后尝试了一下缩小输入频谱图的尺寸从而减小运算量...
*
2019/04/04 需要增加k-fold+受试者交叉验证才够严谨...
*
2019/04/05 清明节…看文献,还是按照大部分人的做法来做吧,使用5倍K-fold和数据集间的交叉验证,这样方便与其他人的方法做横向比较. 不行,这里要吐槽一下,别人做k-fold完全是因为数据集太小了…这上百Gb的数据做K-fold…真的是多此一举,结果根本不会有什么差别…完全是浪费计算资源…
*
2019/04/09 回老家了,啊!!!!我的毕业论文啊。。。。写不完了!
*
2019/04/13 回学校撸论文了。
*
2019/05/02 提交查重了,那这个仓库的更新将告一段落了。今天最后一次大更新。接下来我会开一个新的分支,其目的主要是考虑在实际应用方面的问题,比如运算量和准确率之间的平衡。另外,我也会提供一些不同情况的预训练模型,然很更新一下调用的接口,力求可以傻瓜式的调用我们的程序进行睡眠分期。
\ No newline at end of file
特别说明:这项分类任务中样本标签分布及不平衡,我们对分类损失函数中的类别权重进行了魔改,这将使得Average Recall得到小幅提升,但同时整体error也将提升.若使用默认权重,Top1 err.至少下降5%,但这会导致数据占比极小的N1时期的recall猛跌20%,这绝对不是我们在实际应用中所希望看到的。下面给出的结果均是使用魔改后的权重得到的。
<br>
*
[
sleep-edfx
](
https://www.physionet.org/physiobank/database/sleep-edfx/
)
->sample size = 197, select sleep time
| Network | Parameters | Top1.err. | Avg. Acc. | Avg. Re. | Need to extract feature |
| --------------------------- | ---------- | --------- | --------- | -------- | ----------------------- |
| lstm | 1.25M | 26.32% | 93.56% | 68.57% | Yes |
| micro_multi_scale_resnet_1d | 2.11M | 25.33% | 93.87% | 72.61% | No |
| resnet18_1d | 3.85M | 24.21% | 94.07% | 72.87% | No |
| multi_scale_resnet_1d | 8.42M | 24.01% | 94.06% | 72.37% | No |
\ No newline at end of file
confusion_mat
浏览文件 @
7ac8cf6f
Confusion matrix:
Pred
S3
S2 S1 REM WAKE
S3 S2 S1 REM WAKE
S3
S2
True S1
REM
WAKE
5-Fold Cross-Validation Results
-------------------sleep-edf-------------------
-------------------sleep-edfx(select sleep time, sample size = 197)-------------------
resnet18_1d
final: test avg_recall:0.8263 avg_acc:0.9601 error:0.0997
----------lstm----------
train:
dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
test:
dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
net parameters: 1.25M
final: recall,acc,sp,err,k: (0.6857, 0.8947, 0.9356, 0.2632, 0.6305)
confusion_mat:
[[1133 148 7 3 0]
[ 296 2968 120 195 3]
[ 3 60 368 149 18]
[ 3 99 135 1342 18]
[ 9 7 190 37 7729]]
statistics of dataset [S3 S2 S1 R W]:
[1291 3582 598 1597 7972]
[[ 1644 244 11 26 5]
[ 1801 11924 1933 2169 229]
[ 42 1267 2772 1689 1094]
[ 5 193 1845 3539 301]
[ 37 128 1513 495 22182]]
DFCNN+resnet18
final: test avg_recall:0.8261 avg_acc:0.9594 error:0.1016
----------micro_multi_scale_resnet_1d----------
train:
dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
test:
dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
net parameters: 2.11M
final: recall,acc,sp,err,k: (0.7261, 0.8987, 0.9387, 0.2533, 0.648)
confusion_mat:
[[ 216 44 2 0 6]
[ 41 640 24 14 6]
[ 1 9 75 11 10]
[ 0 20 65 226 3]
[ 2 2 21 4 1582]]
statistics of dataset [S3 S2 S1 R W]:
[ 268 725 106 314 1611]
[[ 1711 211 2 6 1]
[ 1905 12958 1338 1741 99]
[ 26 1731 3074 1439 602]
[ 1 466 995 4328 97]
[ 37 237 2692 837 20554]]
multi_scale_resnet_1d
----------multi_scale_resnet_1d----------
train:
dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
test:
dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
net parameters: 8.42M
final: recall,acc,sp,err,k: (0.7237, 0.904, 0.9406, 0.2401, 0.6631)
confusion_mat:
final: test avg_recall:0.84 avg_acc:0.9595 error:0.1013
[[1159 114 10 3 1]
[ 335 2961 153 143 5]
[ 4 45 436 102 13]
[ 0 100 241 1244 9]
[ 2 2 214 27 7717]]
statistics of dataset [S3 S2 S1 R W]:
[1287 3597 600 1594 7962]
[[ 1629 280 17 3 1]
[ 1775 13245 1880 973 176]
[ 18 1644 3484 1044 681]
[ 11 454 1062 3925 431]
[ 38 177 2326 717 21097]]
-------------------sleep-edfx(only sleep time)-------------------
----------resnet18_1d----------
train:
dataset statistics [S3 S2 S1 R W]: [17382 69813 18120 27920 46573]
test:
dataset statistics [S3 S2 S1 R W]: [ 1932 18067 6876 5892 24382]
net parameters: 3.85M
final: recall,acc,sp,err,k: (0.7283, 0.9031, 0.9407, 0.2421, 0.6607)
confusion_mat:
[[ 1726 171 23 8 1]
[ 1906 12100 2336 1530 179]
[ 48 1223 3386 1315 898]
[ 1 298 1200 3989 398]
[ 19 108 1630 531 22064]]
-------------------CinC Challenge 2018(sample size = 994)-------------------
-------------------CinC Challenge 2018(sample size = 100)-------------------
DFCNN+resnet18
final: test avg_recall:0.7823 avg_acc:0.909 error:0.2276
confusion_mat:
[[1897 189 3 4 3]
[ 634 5810 614 362 126]
[ 1 409 1545 301 406]
[ 4 200 231 1809 72]
[ 3 41 377 19 3036]]
statistics of dataset [S3 S2 S1 R W]:
[2096 7546 2662 2316 3476]
dataloader.py
浏览文件 @
7ac8cf6f
...
...
@@ -107,8 +107,7 @@ def loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time):
events
,
_
=
mne
.
events_from_annotations
(
raw_data
,
event_id
=
event_id
,
chunk_duration
=
30.
)
stages
=
[]
signals
=
[]
stages
=
[];
signals
=
[]
for
i
in
range
(
len
(
events
)
-
1
):
stages
.
append
(
events
[
i
][
2
])
signals
.
append
(
eeg
[
events
[
i
][
0
]:
events
[
i
][
0
]
+
3000
])
...
...
@@ -129,38 +128,63 @@ def loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time):
return
signals
.
astype
(
np
.
float16
),
stages
.
astype
(
np
.
int16
)
#load all data in datasets
def
loaddataset
(
filedir
,
dataset_name
,
signal_name
,
num
,
BID
,
select_sleep_time
,
shuffle
=
Tru
e
):
def
loaddataset
(
filedir
,
dataset_name
,
signal_name
,
num
,
BID
,
select_sleep_time
,
shuffle
=
Fals
e
):
print
(
'load dataset, please wait...'
)
filenames
=
os
.
listdir
(
filedir
)
if
shuffle
:
random
.
shuffle
(
filenames
)
signals
=
[]
stages
=
[]
signals_train
=
[];
stages_train
=
[];
signals_test
=
[];
stages_test
=
[]
if
dataset_name
==
'cc2018'
:
filenames
=
os
.
listdir
(
filedir
)
if
shuffle
:
random
.
shuffle
(
filenames
)
else
:
filenames
.
sort
()
if
num
>
len
(
filenames
):
num
=
len
(
filenames
)
print
(
'num of dataset is:'
,
num
)
for
cnt
,
filename
in
enumerate
(
filenames
[:
num
],
0
):
try
:
signal
,
stage
=
loaddata_cc2018
(
filedir
,
filename
,
signal_name
,
BID
=
BID
)
signals
,
stages
=
connectdata
(
signal
,
stage
,
signals
,
stages
)
except
Exception
as
e
:
print
(
filename
,
e
)
elif
dataset_name
in
[
'sleep-edfx'
,
'sleep-edf'
]:
signal
,
stage
=
loaddata_cc2018
(
filedir
,
filename
,
signal_name
,
BID
=
BID
)
if
cnt
<
round
(
num
*
0.8
)
:
signals_train
,
stages_train
=
connectdata
(
signal
,
stage
,
signals_train
,
stages_train
)
else
:
signals_test
,
stages_test
=
connectdata
(
signal
,
stage
,
signals_test
,
stages_test
)
print
(
'train subjects:'
,
round
(
num
*
0.8
),
'test subjects:'
,
round
(
num
*
0.2
))
elif
dataset_name
==
'sleep-edfx'
:
if
num
>
197
:
num
=
197
if
dataset_name
==
'sleep-edf'
:
filenames
=
[
'SC4002E0-PSG.edf'
,
'SC4012E0-PSG.edf'
,
'SC4102E0-PSG.edf'
,
'SC4112E0-PSG.edf'
,
'ST7022J0-PSG.edf'
,
'ST7052J0-PSG.edf'
,
'ST7121J0-PSG.edf'
,
'ST7132J0-PSG.edf'
]
cnt
=
0
for
filename
in
filenames
:
if
'PSG'
in
filename
:
signal
,
stage
=
loaddata_sleep_edfx
(
filedir
,
filename
,
signal_name
,
BID
,
select_sleep_time
)
signals
,
stages
=
connectdata
(
signal
,
stage
,
signals
,
stages
)
cnt
+=
1
if
cnt
==
num
:
break
return
signals
,
stages
\ No newline at end of file
filenames_sc_train
=
[
'SC4001E0-PSG.edf'
,
'SC4002E0-PSG.edf'
,
'SC4011E0-PSG.edf'
,
'SC4012E0-PSG.edf'
,
'SC4021E0-PSG.edf'
,
'SC4022E0-PSG.edf'
,
'SC4031E0-PSG.edf'
,
'SC4032E0-PSG.edf'
,
'SC4041E0-PSG.edf'
,
'SC4042E0-PSG.edf'
,
'SC4051E0-PSG.edf'
,
'SC4052E0-PSG.edf'
,
'SC4061E0-PSG.edf'
,
'SC4062E0-PSG.edf'
,
'SC4071E0-PSG.edf'
,
'SC4072E0-PSG.edf'
,
'SC4081E0-PSG.edf'
,
'SC4082E0-PSG.edf'
,
'SC4091E0-PSG.edf'
,
'SC4092E0-PSG.edf'
,
'SC4101E0-PSG.edf'
,
'SC4102E0-PSG.edf'
,
'SC4111E0-PSG.edf'
,
'SC4112E0-PSG.edf'
,
'SC4121E0-PSG.edf'
,
'SC4122E0-PSG.edf'
,
'SC4131E0-PSG.edf'
,
'SC4141E0-PSG.edf'
,
'SC4142E0-PSG.edf'
,
'SC4151E0-PSG.edf'
,
'SC4152E0-PSG.edf'
,
'SC4161E0-PSG.edf'
,
'SC4162E0-PSG.edf'
,
'SC4171E0-PSG.edf'
,
'SC4172E0-PSG.edf'
,
'SC4181E0-PSG.edf'
,
'SC4182E0-PSG.edf'
,
'SC4191E0-PSG.edf'
,
'SC4192E0-PSG.edf'
,
'SC4201E0-PSG.edf'
,
'SC4202E0-PSG.edf'
,
'SC4211E0-PSG.edf'
,
'SC4212E0-PSG.edf'
,
'SC4221E0-PSG.edf'
,
'SC4222E0-PSG.edf'
,
'SC4231E0-PSG.edf'
,
'SC4232E0-PSG.edf'
,
'SC4241E0-PSG.edf'
,
'SC4242E0-PSG.edf'
,
'SC4251E0-PSG.edf'
,
'SC4252E0-PSG.edf'
,
'SC4261F0-PSG.edf'
,
'SC4262F0-PSG.edf'
,
'SC4271F0-PSG.edf'
,
'SC4272F0-PSG.edf'
,
'SC4281G0-PSG.edf'
,
'SC4282G0-PSG.edf'
,
'SC4291G0-PSG.edf'
,
'SC4292G0-PSG.edf'
,
'SC4301E0-PSG.edf'
,
'SC4302E0-PSG.edf'
,
'SC4311E0-PSG.edf'
,
'SC4312E0-PSG.edf'
,
'SC4321E0-PSG.edf'
,
'SC4322E0-PSG.edf'
,
'SC4331F0-PSG.edf'
,
'SC4332F0-PSG.edf'
,
'SC4341F0-PSG.edf'
,
'SC4342F0-PSG.edf'
,
'SC4351F0-PSG.edf'
,
'SC4352F0-PSG.edf'
,
'SC4362F0-PSG.edf'
,
'SC4371F0-PSG.edf'
,
'SC4372F0-PSG.edf'
,
'SC4381F0-PSG.edf'
,
'SC4382F0-PSG.edf'
,
'SC4401E0-PSG.edf'
,
'SC4402E0-PSG.edf'
,
'SC4411E0-PSG.edf'
,
'SC4412E0-PSG.edf'
,
'SC4421E0-PSG.edf'
,
'SC4422E0-PSG.edf'
,
'SC4431E0-PSG.edf'
,
'SC4432E0-PSG.edf'
,
'SC4441E0-PSG.edf'
,
'SC4442E0-PSG.edf'
,
'SC4451F0-PSG.edf'
,
'SC4452F0-PSG.edf'
,
'SC4461F0-PSG.edf'
,
'SC4462F0-PSG.edf'
,
'SC4471F0-PSG.edf'
,
'SC4472F0-PSG.edf'
,
'SC4481F0-PSG.edf'
,
'SC4482F0-PSG.edf'
,
'SC4491G0-PSG.edf'
,
'SC4492G0-PSG.edf'
,
'SC4501E0-PSG.edf'
,
'SC4502E0-PSG.edf'
,
'SC4511E0-PSG.edf'
,
'SC4512E0-PSG.edf'
,
'SC4522E0-PSG.edf'
,
'SC4531E0-PSG.edf'
,
'SC4532E0-PSG.edf'
,
'SC4541F0-PSG.edf'
,
'SC4542F0-PSG.edf'
,
'SC4551F0-PSG.edf'
,
'SC4552F0-PSG.edf'
,
'SC4561F0-PSG.edf'
,
'SC4562F0-PSG.edf'
,
'SC4571F0-PSG.edf'
,
'SC4572F0-PSG.edf'
,
'SC4581G0-PSG.edf'
,
'SC4582G0-PSG.edf'
,
'SC4591G0-PSG.edf'
,
'SC4592G0-PSG.edf'
,
'SC4601E0-PSG.edf'
,
'SC4602E0-PSG.edf'
,
'SC4611E0-PSG.edf'
,
'SC4612E0-PSG.edf'
,
'SC4621E0-PSG.edf'
,
'SC4622E0-PSG.edf'
,
'SC4631E0-PSG.edf'
,
'SC4632E0-PSG.edf'
]
filenames_sc_test
=
[
'SC4641E0-PSG.edf'
,
'SC4642E0-PSG.edf'
,
'SC4651E0-PSG.edf'
,
'SC4652E0-PSG.edf'
,
'SC4661E0-PSG.edf'
,
'SC4662E0-PSG.edf'
,
'SC4671G0-PSG.edf'
,
'SC4672G0-PSG.edf'
,
'SC4701E0-PSG.edf'
,
'SC4702E0-PSG.edf'
,
'SC4711E0-PSG.edf'
,
'SC4712E0-PSG.edf'
,
'SC4721E0-PSG.edf'
,
'SC4722E0-PSG.edf'
,
'SC4731E0-PSG.edf'
,
'SC4732E0-PSG.edf'
,
'SC4741E0-PSG.edf'
,
'SC4742E0-PSG.edf'
,
'SC4751E0-PSG.edf'
,
'SC4752E0-PSG.edf'
,
'SC4761E0-PSG.edf'
,
'SC4762E0-PSG.edf'
,
'SC4771G0-PSG.edf'
,
'SC4772G0-PSG.edf'
,
'SC4801G0-PSG.edf'
,
'SC4802G0-PSG.edf'
,
'SC4811G0-PSG.edf'
,
'SC4812G0-PSG.edf'
,
'SC4821G0-PSG.edf'
,
'SC4822G0-PSG.edf'
]
filenames_st_train
=
[
'ST7011J0-PSG.edf'
,
'ST7012J0-PSG.edf'
,
'ST7021J0-PSG.edf'
,
'ST7022J0-PSG.edf'
,
'ST7041J0-PSG.edf'
,
'ST7042J0-PSG.edf'
,
'ST7051J0-PSG.edf'
,
'ST7052J0-PSG.edf'
,
'ST7061J0-PSG.edf'
,
'ST7062J0-PSG.edf'
,
'ST7071J0-PSG.edf'
,
'ST7072J0-PSG.edf'
,
'ST7081J0-PSG.edf'
,
'ST7082J0-PSG.edf'
,
'ST7091J0-PSG.edf'
,
'ST7092J0-PSG.edf'
,
'ST7101J0-PSG.edf'
,
'ST7102J0-PSG.edf'
,
'ST7111J0-PSG.edf'
,
'ST7112J0-PSG.edf'
,
'ST7121J0-PSG.edf'
,
'ST7122J0-PSG.edf'
,
'ST7131J0-PSG.edf'
,
'ST7132J0-PSG.edf'
,
'ST7141J0-PSG.edf'
,
'ST7142J0-PSG.edf'
,
'ST7151J0-PSG.edf'
,
'ST7152J0-PSG.edf'
,
'ST7161J0-PSG.edf'
,
'ST7162J0-PSG.edf'
,
'ST7171J0-PSG.edf'
,
'ST7172J0-PSG.edf'
,
'ST7181J0-PSG.edf'
,
'ST7182J0-PSG.edf'
,
'ST7191J0-PSG.edf'
,
'ST7192J0-PSG.edf'
]
filenames_st_test
=
[
'ST7201J0-PSG.edf'
,
'ST7202J0-PSG.edf'
,
'ST7211J0-PSG.edf'
,
'ST7212J0-PSG.edf'
,
'ST7221J0-PSG.edf'
,
'ST7222J0-PSG.edf'
,
'ST7241J0-PSG.edf'
,
'ST7242J0-PSG.edf'
]
for
filename
in
filenames_sc_train
[:
round
(
num
*
153
/
197
*
0.8
)]:
signal
,
stage
=
loaddata_sleep_edfx
(
filedir
,
filename
,
signal_name
,
BID
,
select_sleep_time
)
signals_train
,
stages_train
=
connectdata
(
signal
,
stage
,
signals_train
,
stages_train
)
for
filename
in
filenames_st_train
[:
round
(
num
*
44
/
197
*
0.8
)]:
signal
,
stage
=
loaddata_sleep_edfx
(
filedir
,
filename
,
signal_name
,
BID
,
select_sleep_time
)
signals_train
,
stages_train
=
connectdata
(
signal
,
stage
,
signals_train
,
stages_train
)
for
filename
in
filenames_sc_test
[:
round
(
num
*
153
/
197
*
0.2
)]:
signal
,
stage
=
loaddata_sleep_edfx
(
filedir
,
filename
,
signal_name
,
BID
,
select_sleep_time
)
signals_test
,
stages_test
=
connectdata
(
signal
,
stage
,
signals_test
,
stages_test
)
for
filename
in
filenames_st_test
[:
round
(
num
*
44
/
197
*
0.2
)]:
signal
,
stage
=
loaddata_sleep_edfx
(
filedir
,
filename
,
signal_name
,
BID
,
select_sleep_time
)
signals_test
,
stages_test
=
connectdata
(
signal
,
stage
,
signals_test
,
stages_test
)
print
(
'---------Each subject has two sample---------'
,
'
\n
Train samples_SC/ST:'
,
round
(
num
*
153
/
197
*
0.8
),
round
(
num
*
44
/
197
*
0.8
),
'
\n
Test samples_SC/ST:'
,
round
(
num
*
153
/
197
*
0.2
),
round
(
num
*
44
/
197
*
0.2
))
elif
dataset_name
==
'preload'
:
signals_train
=
np
.
load
(
filedir
+
'/signals_train.npy'
)
stages_train
=
np
.
load
(
filedir
+
'/stages_train.npy'
)
signals_test
=
np
.
load
(
filedir
+
'/signals_test.npy'
)
stages_test
=
np
.
load
(
filedir
+
'/stages_test.npy'
)
return
signals_train
,
stages_train
,
signals_test
,
stages_test
\ No newline at end of file
dsp.py
浏览文件 @
7ac8cf6f
...
...
@@ -42,7 +42,8 @@ def BPF(signal,fs,fc1,fc2,mod = 'fir'):
b
=
getfir_b
(
fc1
,
fc2
,
fs
)
result
=
scipy
.
signal
.
lfilter
(
b
,
1
,
signal
)
return
result
def
getfeature
(
signal
,
mod
=
'fir'
,
ch_num
=
5
):
def
getfeature
(
signal
,
mod
=
'fft'
,
ch_num
=
5
):
result
=
[]
signal
=
signal
-
np
.
mean
(
signal
)
eeg
=
signal
...
...
@@ -68,19 +69,6 @@ def getfeature(signal,mod = 'fir',ch_num = 5):
result
=
result
.
reshape
(
ch_num
*
len
(
signal
),)
return
result
# def signal2spectrum(signal):
# spectrum =np.zeros((224,224))
# spectrum_y = np.zeros((224))
# for i in range(224):
# signal_window=signal[i*9:i*9+896]
# signal_window_fft=np.abs(np.fft.fft(signal_window))[0:448]
# spectrum_y[0:112]=signal_window_fft[0:112]
# spectrum_y[112:168]=signal_window_fft[112:224][::2]
# spectrum_y[168:224]=signal_window_fft[224:448][::4]
# spectrum[:,i] = spectrum_y
# # spectrum = np.log(spectrum+1)/11
# return spectrum
# def signal2spectrum(data):
# # window : ('tukey',0.5) hann
...
...
@@ -96,7 +84,7 @@ def getfeature(signal,mod = 'fir',ch_num = 5):
def
signal2spectrum
(
data
):
# window : ('tukey',0.5) hann
zxx
=
scipy
.
signal
.
stft
(
data
,
fs
=
100
,
window
=
(
'tukey'
,
0.5
)
,
nperseg
=
1024
,
noverlap
=
1024
-
24
,
nfft
=
1024
,
detrend
=
False
,
return_onesided
=
True
,
boundary
=
'zeros'
,
padded
=
True
,
axis
=-
1
)[
2
]
zxx
=
scipy
.
signal
.
stft
(
data
,
fs
=
100
,
window
=
'hann'
,
nperseg
=
1024
,
noverlap
=
1024
-
24
,
nfft
=
1024
,
detrend
=
False
,
return_onesided
=
True
,
boundary
=
'zeros'
,
padded
=
True
,
axis
=-
1
)[
2
]
zxx
=
np
.
abs
(
zxx
)[:
512
]
spectrum
=
np
.
zeros
((
256
,
126
))
spectrum
[
0
:
128
]
=
zxx
[
0
:
128
]
...
...
options.py
浏览文件 @
7ac8cf6f
...
...
@@ -3,8 +3,8 @@ import os
import
numpy
as
np
import
torch
#
python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name cc2018 --signal_name 'C4-M1' --sample_num 10 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.0005 --BID 5_95_th --select_sleep_time --cross_validation subject
# python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num
10 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 20 --lr 0.0005 --BID 5_95_th --select_sleep_time --cross_validation subject
#
python3 train.py --dataset_dir '/media/hypo/Hypo/physionet_org_train' --dataset_name cc2018 --signal_name 'C4-M1' --sample_num 20 --model_name lstm --batchsize 64 --epochs 20 --lr 0.0005 --no_cudnn
# python3 train.py --dataset_dir './datasets/sleep-edfx/' --dataset_name sleep-edfx --signal_name 'EEG Fpz-Cz' --sample_num
50 --model_name lstm --batchsize 64 --network_save_freq 5 --epochs 25 --lr 0.0005 --BID 5_95_th --select_sleep_time --no_cudnn --select_sleep_time
class
Options
():
def
__init__
(
self
):
...
...
@@ -16,16 +16,14 @@ class Options():
self
.
parser
.
add_argument
(
'--no_cudnn'
,
action
=
'store_true'
,
help
=
'if input, do not use cudnn'
)
self
.
parser
.
add_argument
(
'--pretrained'
,
action
=
'store_true'
,
help
=
'if input, use pretrained models'
)
self
.
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
0.0005
,
help
=
'learning rate'
)
self
.
parser
.
add_argument
(
'--cross_validation'
,
type
=
str
,
default
=
'subject'
,
help
=
'k-fold | subject'
)
self
.
parser
.
add_argument
(
'--BID'
,
type
=
str
,
default
=
'5_95_th'
,
help
=
'Balance individualized differences 5_95_th | median |None'
)
self
.
parser
.
add_argument
(
'--fold_num'
,
type
=
int
,
default
=
5
,
help
=
'k-fold'
)
self
.
parser
.
add_argument
(
'--batchsize'
,
type
=
int
,
default
=
64
,
help
=
'batchsize'
)
self
.
parser
.
add_argument
(
'--dataset_dir'
,
type
=
str
,
default
=
'./datasets/sleep-edfx/'
,
help
=
'your dataset path'
)
self
.
parser
.
add_argument
(
'--dataset_name'
,
type
=
str
,
default
=
'sleep-edfx'
,
help
=
'Choose dataset sleep-edfx |
sleep-edf |
cc2018'
)
self
.
parser
.
add_argument
(
'--dataset_name'
,
type
=
str
,
default
=
'sleep-edfx'
,
help
=
'Choose dataset sleep-edfx | cc2018'
)
self
.
parser
.
add_argument
(
'--select_sleep_time'
,
action
=
'store_true'
,
help
=
'if input, for sleep-cassette only use sleep time to train'
)
self
.
parser
.
add_argument
(
'--signal_name'
,
type
=
str
,
default
=
'EEG Fpz-Cz'
,
help
=
'Choose the EEG channel C4-M1 | EEG Fpz-Cz |...'
)
self
.
parser
.
add_argument
(
'--sample_num'
,
type
=
int
,
default
=
1
0
,
help
=
'the amount you want to load'
)
self
.
parser
.
add_argument
(
'--sample_num'
,
type
=
int
,
default
=
2
0
,
help
=
'the amount you want to load'
)
self
.
parser
.
add_argument
(
'--model_name'
,
type
=
str
,
default
=
'lstm'
,
help
=
'Choose model lstm | multi_scale_resnet_1d | resnet18 |...'
)
self
.
parser
.
add_argument
(
'--epochs'
,
type
=
int
,
default
=
20
,
help
=
'end epoch'
)
self
.
parser
.
add_argument
(
'--weight_mod'
,
type
=
str
,
default
=
'avg_best'
,
help
=
'Choose weight mode: avg_best|normal'
)
...
...
@@ -42,9 +40,5 @@ class Options():
self
.
opt
.
sample_num
=
8
if
self
.
opt
.
no_cuda
:
self
.
opt
.
no_cudnn
=
True
if
self
.
opt
.
fold_num
==
0
:
self
.
opt
.
fold_num
=
1
if
self
.
opt
.
cross_validation
==
'subject'
:
self
.
opt
.
fold_num
=
1
return
self
.
opt
\ No newline at end of file
statistics.py
浏览文件 @
7ac8cf6f
...
...
@@ -8,7 +8,7 @@ def stage(stages):
for
i
in
range
(
len
(
stages
)):
stage_cnt
[
stages
[
i
]]
+=
1
stage_cnt_per
=
stage_cnt
/
len
(
stages
)
util
.
writelog
(
'
statistics of dataset
[S3 S2 S1 R W]: '
+
str
(
stage_cnt
),
True
)
util
.
writelog
(
'
dataset statistics
[S3 S2 S1 R W]: '
+
str
(
stage_cnt
),
True
)
return
stage_cnt
,
stage_cnt_per
def
reversal_label
(
mat
):
...
...
train.py
浏览文件 @
7ac8cf6f
...
...
@@ -28,26 +28,21 @@ but the data needs meet the following conditions:
3.fs = 100Hz
4.input signal data should be normalized!!
we recommend signal data normalized useing 5_95_th for each subject,
example: signals_subject=Balance_individualized_differences(signals_subject, '5_95_th')
5.when useing subject cross validation,we will generally believe [0:80%]and[80%:100%]data come from different subjects.
example: signals_normalized=transformer.Balance_individualized_differences(signals_origin, '5_95_th')
'''
signals
,
stages
=
dataloader
.
loaddataset
(
opt
.
dataset_dir
,
opt
.
dataset_name
,
opt
.
signal_name
,
opt
.
sample_num
,
opt
.
BID
,
opt
.
select_sleep_time
,
shuffle
=
True
)
stage_cnt
,
stage_cnt_per
=
statistics
.
stage
(
stages
)
if
opt
.
cross_validation
==
'k_fold'
:
signals
,
stages
=
transformer
.
batch_generator
(
signals
,
stages
,
opt
.
batchsize
,
shuffle
=
True
)
train_sequences
,
test_sequences
=
transformer
.
k_fold_generator
(
len
(
stages
),
opt
.
fold_num
)
elif
opt
.
cross_validation
==
'subject'
:
util
.
writelog
(
'train statistics:'
,
True
)
stage_cnt
,
stage_cnt_per
=
statistics
.
stage
(
stages
[:
int
(
0.8
*
len
(
stages
))])
util
.
writelog
(
'test statistics:'
,
True
)
stage_cnt
,
stage_cnt_per
=
statistics
.
stage
(
stages
[
int
(
0.8
*
len
(
stages
)):])
signals
,
stages
=
transformer
.
batch_generator_subject
(
signals
,
stages
,
opt
.
batchsize
,
shuffle
=
False
)
train_sequences
,
test_sequences
=
transformer
.
k_fold_generator
(
len
(
stages
),
1
)
batch_length
=
len
(
stages
)
signals_train
,
stages_train
,
signals_test
,
stages_test
=
dataloader
.
loaddataset
(
opt
.
dataset_dir
,
opt
.
dataset_name
,
opt
.
signal_name
,
opt
.
sample_num
,
opt
.
BID
,
opt
.
select_sleep_time
)
util
.
writelog
(
'train:'
,
True
)
stage_cnt
,
stage_cnt_per
=
statistics
.
stage
(
stages_train
)
util
.
writelog
(
'test:'
,
True
)
_
,
_
=
statistics
.
stage
(
stages_test
)
signals_train
,
stages_train
=
transformer
.
batch_generator
(
signals_train
,
stages_train
,
opt
.
batchsize
)
signals_test
,
stages_test
=
transformer
.
batch_generator
(
signals_test
,
stages_test
,
opt
.
batchsize
)
batch_length
=
len
(
signals_train
)
print
(
'length of batch:'
,
batch_length
)
show_freq
=
int
(
len
(
train_sequences
[
0
]
)
/
5
)
show_freq
=
int
(
len
(
stages_train
)
/
5
)
t2
=
time
.
time
()
print
(
'load data cost time: %.2f'
%
(
t2
-
t1
),
's'
)
...
...
@@ -59,27 +54,30 @@ weight = np.array([1,1,1,1,1])
if
opt
.
weight_mod
==
'avg_best'
:
weight
=
np
.
log
(
1
/
stage_cnt_per
)
weight
[
2
]
=
weight
[
2
]
+
1
weight
=
np
.
clip
(
weight
,
1
,
5
)
weight
=
np
.
clip
(
weight
,
1
,
3
)
print
(
'Loss_weight:'
,
weight
)
weight
=
torch
.
from_numpy
(
weight
).
float
()
# print(net)
if
not
opt
.
no_cuda
:
net
.
cuda
()
weight
=
weight
.
cuda
()
if
opt
.
pretrained
:
net
.
load_state_dict
(
torch
.
load
(
'./checkpoints/pretrained/'
+
opt
.
dataset_name
+
'/'
+
opt
.
model_name
+
'.pth'
))
if
not
opt
.
no_cudnn
:
torch
.
backends
.
cudnn
.
benchmark
=
True
optimizer
=
torch
.
optim
.
Adam
(
net
.
parameters
(),
lr
=
opt
.
lr
)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)
criterion
=
nn
.
CrossEntropyLoss
(
weight
)
def
evalnet
(
net
,
signals
,
stages
,
sequences
,
epoch
,
plot_result
=
{}):
def
evalnet
(
net
,
signals
,
stages
,
epoch
,
plot_result
=
{}):
# net.eval()
confusion_mat
=
np
.
zeros
((
5
,
5
),
dtype
=
int
)
for
i
,
sequence
in
enumerate
(
sequences
,
1
):
for
i
,
(
signal
,
stage
)
in
enumerate
(
zip
(
signals
,
stages
)
,
1
):
signal
=
transformer
.
ToInputShape
(
signal
s
[
sequence
]
,
opt
.
model_name
,
test_flag
=
True
)
signal
,
stage
=
transformer
.
ToTensor
(
signal
,
stage
s
[
sequence
]
,
no_cuda
=
opt
.
no_cuda
)
signal
=
transformer
.
ToInputShape
(
signal
,
opt
.
model_name
,
test_flag
=
True
)
signal
,
stage
=
transformer
.
ToTensor
(
signal
,
stage
,
no_cuda
=
opt
.
no_cuda
)
with
torch
.
no_grad
():
out
=
net
(
signal
)
pred
=
torch
.
max
(
out
,
1
)[
1
]
...
...
@@ -97,61 +95,53 @@ def evalnet(net,signals,stages,sequences,epoch,plot_result={}):
print
(
'begin to train ...'
)
final_confusion_mat
=
np
.
zeros
((
5
,
5
),
dtype
=
int
)
for
fold
in
range
(
opt
.
fold_num
):
net
.
load_state_dict
(
torch
.
load
(
'./checkpoints/'
+
opt
.
model_name
+
'.pth'
))
if
opt
.
pretrained
:
net
.
load_state_dict
(
torch
.
load
(
'./checkpoints/pretrained/'
+
opt
.
dataset_name
+
'/'
+
opt
.
model_name
+
'.pth'
))
if
not
opt
.
no_cuda
:
net
.
cuda
()
plot_result
=
{
'train'
:[
1.
],
'test'
:[
1.
]}
confusion_mats
=
[]
for
epoch
in
range
(
opt
.
epochs
):
t1
=
time
.
time
()
confusion_mat
=
np
.
zeros
((
5
,
5
),
dtype
=
int
)
print
(
'fold:'
,
fold
+
1
,
'epoch:'
,
epoch
+
1
)
net
.
train
()
for
i
,
sequence
in
enumerate
(
train_sequences
[
fold
],
1
):
signal
=
transformer
.
ToInputShape
(
signals
[
sequence
],
opt
.
model_name
,
test_flag
=
False
)
signal
,
stage
=
transformer
.
ToTensor
(
signal
,
stages
[
sequence
],
no_cuda
=
opt
.
no_cuda
)
out
=
net
(
signal
)
loss
=
criterion
(
out
,
stage
)
pred
=
torch
.
max
(
out
,
1
)[
1
]
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
pred
=
pred
.
data
.
cpu
().
numpy
()
stage
=
stage
.
data
.
cpu
().
numpy
()
for
x
in
range
(
len
(
pred
)):
confusion_mat
[
stage
[
x
]][
pred
[
x
]]
+=
1
if
i
%
show_freq
==
0
:
plot_result
[
'train'
].
append
(
statistics
.
result
(
confusion_mat
)[
3
])
heatmap
.
draw
(
confusion_mat
,
name
=
'train'
)
statistics
.
show
(
plot_result
,
epoch
+
i
/
(
batch_length
*
0.8
))
confusion_mat
[:]
=
0
plot_result
,
confusion_mat
=
evalnet
(
net
,
signals
,
stages
,
test_sequences
[
fold
],
epoch
+
1
,
plot_result
)
confusion_mats
.
append
(
confusion_mat
)
# scheduler.step()
if
(
epoch
+
1
)
%
opt
.
network_save_freq
==
0
:
torch
.
save
(
net
.
cpu
().
state_dict
(),
'./checkpoints/'
+
opt
.
model_name
+
'_epoch'
+
str
(
epoch
+
1
)
+
'.pth'
)
print
(
'network saved.'
)
if
not
opt
.
no_cuda
:
net
.
cuda
()
t2
=
time
.
time
()
if
epoch
+
1
==
1
:
print
(
'cost time: %.2f'
%
(
t2
-
t1
),
's'
)
pos
=
plot_result
[
'test'
].
index
(
min
(
plot_result
[
'test'
]))
-
1
final_confusion_mat
=
final_confusion_mat
+
confusion_mats
[
pos
]
util
.
writelog
(
'fold:'
+
str
(
fold
+
1
)
+
' recall,acc,sp,err,k: '
+
str
(
statistics
.
result
(
confusion_mats
[
pos
])),
True
)
print
(
'------------------'
)
util
.
writelog
(
'confusion_mat:
\n
'
+
str
(
confusion_mat
))
plot_result
=
{
'train'
:[
1.
],
'test'
:[
1.
]}
confusion_mats
=
[]
for
epoch
in
range
(
opt
.
epochs
):
t1
=
time
.
time
()
confusion_mat
=
np
.
zeros
((
5
,
5
),
dtype
=
int
)
print
(
'epoch:'
,
epoch
+
1
)
net
.
train
()
for
i
,
(
signal
,
stage
)
in
enumerate
(
zip
(
signals_train
,
stages_train
),
1
):
signal
=
transformer
.
ToInputShape
(
signal
,
opt
.
model_name
,
test_flag
=
False
)
signal
,
stage
=
transformer
.
ToTensor
(
signal
,
stage
,
no_cuda
=
opt
.
no_cuda
)
out
=
net
(
signal
)
loss
=
criterion
(
out
,
stage
)
pred
=
torch
.
max
(
out
,
1
)[
1
]
optimizer
.
zero_grad
()
loss
.
backward
()
optimizer
.
step
()
pred
=
pred
.
data
.
cpu
().
numpy
()
stage
=
stage
.
data
.
cpu
().
numpy
()
for
x
in
range
(
len
(
pred
)):
confusion_mat
[
stage
[
x
]][
pred
[
x
]]
+=
1
if
i
%
show_freq
==
0
:
plot_result
[
'train'
].
append
(
statistics
.
result
(
confusion_mat
)[
3
])
heatmap
.
draw
(
confusion_mat
,
name
=
'train'
)
statistics
.
show
(
plot_result
,
epoch
+
i
/
(
batch_length
*
0.8
))
confusion_mat
[:]
=
0
plot_result
,
confusion_mat
=
evalnet
(
net
,
signals_test
,
stages_test
,
epoch
+
1
,
plot_result
)
confusion_mats
.
append
(
confusion_mat
)
# scheduler.step()
if
(
epoch
+
1
)
%
opt
.
network_save_freq
==
0
:
torch
.
save
(
net
.
cpu
().
state_dict
(),
'./checkpoints/'
+
opt
.
model_name
+
'_epoch'
+
str
(
epoch
+
1
)
+
'.pth'
)
print
(
'network saved.'
)
if
not
opt
.
no_cuda
:
net
.
cuda
()
t2
=
time
.
time
()
if
epoch
+
1
==
1
:
print
(
'cost time: %.2f'
%
(
t2
-
t1
),
's'
)
pos
=
plot_result
[
'test'
].
index
(
min
(
plot_result
[
'test'
]))
-
1
final_confusion_mat
=
confusion_mats
[
pos
]
util
.
writelog
(
'final: '
+
'recall,acc,sp,err,k: '
+
str
(
statistics
.
result
(
final_confusion_mat
)),
True
)
util
.
writelog
(
'confusion_mat:
\n
'
+
str
(
final_confusion_mat
),
True
)
statistics
.
stagefrommat
(
final_confusion_mat
)
...
...
transformer.py
浏览文件 @
7ac8cf6f
...
...
@@ -14,16 +14,6 @@ def shuffledata(data,target):
np
.
random
.
shuffle
(
target
)
# return data,target
def
batch_generator_subject
(
data
,
target
,
batchsize
,
shuffle
=
True
):
data_test
=
data
[
int
(
0.8
*
len
(
target
)):]
data_train
=
data
[
0
:
int
(
0.8
*
len
(
target
))]
target_test
=
target
[
int
(
0.8
*
len
(
target
)):]
target_train
=
target
[
0
:
int
(
0.8
*
len
(
target
))]
data_test
,
target_test
=
batch_generator
(
data_test
,
target_test
,
batchsize
)
data_train
,
target_train
=
batch_generator
(
data_train
,
target_train
,
batchsize
)
data
=
np
.
concatenate
((
data_train
,
data_test
),
axis
=
0
)
target
=
np
.
concatenate
((
target_train
,
target_test
),
axis
=
0
)
return
data
,
target
def
batch_generator
(
data
,
target
,
batchsize
,
shuffle
=
True
):
if
shuffle
:
...
...
@@ -34,41 +24,28 @@ def batch_generator(data,target,batchsize,shuffle = True):
target
=
target
.
reshape
(
-
1
,
batchsize
)
return
data
,
target
def
k_fold_generator
(
length
,
fold_num
):
sequence
=
np
.
linspace
(
0
,
length
-
1
,
length
,
dtype
=
'int'
)
if
fold_num
==
1
:
train_sequence
=
sequence
[
0
:
int
(
0.8
*
length
)].
reshape
(
1
,
-
1
)
test_sequence
=
sequence
[
int
(
0.8
*
length
):].
reshape
(
1
,
-
1
)
else
:
train_length
=
int
(
length
/
fold_num
*
(
fold_num
-
1
))
test_length
=
int
(
length
/
fold_num
)
train_sequence
=
np
.
zeros
((
fold_num
,
train_length
),
dtype
=
'int'
)
test_sequence
=
np
.
zeros
((
fold_num
,
test_length
),
dtype
=
'int'
)
for
i
in
range
(
fold_num
):
test_sequence
[
i
]
=
(
sequence
[
test_length
*
i
:
test_length
*
(
i
+
1
)])[:
test_length
]
train_sequence
[
i
]
=
np
.
concatenate
((
sequence
[
0
:
test_length
*
i
],
sequence
[
test_length
*
(
i
+
1
):]),
axis
=
0
)[:
train_length
]
return
train_sequence
,
test_sequence
def
Normalize
(
data
,
maxmin
,
avg
,
sigma
):
def
Normalize
(
data
,
maxmin
,
avg
,
sigma
,
is_01
=
False
):
data
=
np
.
clip
(
data
,
-
maxmin
,
maxmin
)
return
(
data
-
avg
)
/
sigma
if
is_01
:
return
(
data
-
avg
)
/
sigma
/
2
+
0.5
#(0,1)
else
:
return
(
data
-
avg
)
/
sigma
#(-1,1)
def
Balance_individualized_differences
(
signals
,
BID
):
if
BID
==
'median'
:
signals
=
(
signals
*
8
/
(
np
.
median
(
abs
(
signals
))))
signals
=
Normalize
(
signals
,
maxmin
=
10e3
,
avg
=
0
,
sigma
=
30
)
signals
=
Normalize
(
signals
,
maxmin
=
10e3
,
avg
=
0
,
sigma
=
30
,
is_01
=
True
)
elif
BID
==
'5_95_th'
:
tmp
=
np
.
sort
(
signals
.
reshape
(
-
1
))
th_5
=
-
tmp
[
int
(
0.05
*
len
(
tmp
))]
signals
=
Normalize
(
signals
,
maxmin
=
10e3
,
avg
=
0
,
sigma
=
th_5
)
signals
=
Normalize
(
signals
,
maxmin
=
10e3
,
avg
=
0
,
sigma
=
th_5
,
is_01
=
True
)
else
:
#dataser 5_95_th median
#dataser 5_95_th
(-1,1)
median
#CC2018 24.75 7.438
#sleep edfx 37.4 9.71
#sleep edfx sleeptime 39.03 10.125
signals
=
Normalize
(
signals
,
maxmin
=
10e3
,
avg
=
0
,
sigma
=
30
)
signals
=
Normalize
(
signals
,
maxmin
=
10e3
,
avg
=
0
,
sigma
=
30
,
is_01
=
True
)
return
signals
def
ToTensor
(
data
,
target
=
None
,
no_cuda
=
False
):
...
...
@@ -141,6 +118,7 @@ def ToInputShape(data,net_name,test_flag = False):
elif
net_name
in
[
'squeezenet'
,
'multi_scale_resnet'
,
'dfcnn'
,
'resnet18'
,
'densenet121'
,
'densenet201'
,
'resnet101'
,
'resnet50'
]:
result
=
[]
data
=
(
data
-
0.5
)
*
2
for
i
in
range
(
0
,
batchsize
):
spectrum
=
dsp
.
signal2spectrum
(
data
[
i
])
spectrum
=
random_transform_2d
(
spectrum
,(
224
,
122
),
test_flag
=
test_flag
)
...
...
util.py
浏览文件 @
7ac8cf6f
...
...
@@ -13,4 +13,6 @@ def writelog(log,printflag = False):
print
(
log
)
def
show_paramsnumber
(
net
):
writelog
(
'net parameters:'
+
str
(
sum
(
param
.
numel
()
for
param
in
net
.
parameters
())),
True
)
parameters
=
sum
(
param
.
numel
()
for
param
in
net
.
parameters
())
parameters
=
round
(
parameters
/
1e6
,
2
)
writelog
(
'net parameters: '
+
str
(
parameters
)
+
'M'
,
True
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录