README_CN.md 4.4 KB
Newer Older
H
hypox64 已提交
1 2 3
<div align="center">    
<img src="./imgs/compare.png " alt="image" style="zoom:100%;" />
</div>
H
hypox64 已提交
4

H
hypox64 已提交
5
# candock
H
hypox64 已提交
6

H
hypox64 已提交
7
| [English](./README.md) | 中文版 |<br><br>
H
hypox64 已提交
8
一个通用的一维时序信号分析,分类框架.<br>
H
HypoX64 已提交
9
它将包含多种网络结构,并提供数据预处理,数据增强,训练,评估,测试等功能.<br>
H
hypox64 已提交
10
一些训练时的输出样例: [heatmap](./image/heatmap_eg.png)  [running_loss](./image/running_loss_eg.png)  [log.txt](./docs/log_eg.txt)<br>
H
HypoX64 已提交
11 12 13 14

## 支持的功能
### 数据预处理
通用的数据预处理方法
H
HypoX64 已提交
15 16
* Normliaze  :   5_95 | maxmin | None
* Filter           :   fft | fir | iir | wavelet | None
H
HypoX64 已提交
17 18 19

### 数据增强
多种多样的数据增强方法.注意:使用时应该结合数据的物理特性进行选择.<br>[[Time Series Data Augmentation for Deep Learning: A Survey]](https://arxiv.org/pdf/2002.12478.pdf)
H
HypoX64 已提交
20 21 22
* Base     :  scale, warp, app, aaft, iaaft, filp, crop
* Noise   :  spike, step, slope, white, pink, blue, brown, violet
* Gan      :  dcgan
H
HypoX64 已提交
23 24 25

### 网络
提供多种用于评估的网络.
H
hypox64 已提交
26 27
>1d
>
H
HypoX64 已提交
28
>>lstm, cnn_1d, resnet18_1d, resnet34_1d, multi_scale_resnet_1d, micro_multi_scale_resnet_1d,autoencoder,mlp
H
hypox64 已提交
29 30 31 32 33

>2d(stft spectrum)
>
>>mobilenet, resnet18, resnet50, resnet101, densenet121, densenet201, squeezenet, dfcnn, multi_scale_resnet,

H
HypoX64 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
### K-fold
使用K-fold使得结果更加可靠.
```--k_fold```&```--fold_index```<br>

* --k_fold
```python
# fold_num of k-fold. If 0 or 1, no k-fold and cut 80% to train and other to eval.
```
* --fold_index
```python
"""--fold_index
When --k_fold != 0 or 1:
Cut dataset into sub-set using index , and then run k-fold with sub-set
If input 'auto', it will shuffle dataset and then cut dataset equally
If input: [2,4,6,7]
when len(dataset) == 10
sub-set: dataset[0:2],dataset[2:4],dataset[4:6],dataset[6:7],dataset[7:]
-------
When --k_fold == 0 or 1:
If input 'auto', it will shuffle dataset and then cut 80% dataset to train and other to eval
If input: [5]
when len(dataset) == 10
train-set : dataset[0:5]  eval-set : dataset[5:]
"""
```
H
hypox64 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
## 关于EEG睡眠分期数据的实例
为了适应新的项目,代码已被大幅更改,不能正常运行如sleep-edfx等睡眠数据集,如果仍然需要运行,请参照下文按照输入格式标准自行加载数据,如果有时间我会修复这个问题。
当然,如果需要加载睡眠数据集也可以直接使用[老的版本](https://github.com/HypoX64/candock/tree/f24cc44933f494d2235b3bf965a04cde5e6a1ae9)<br>
感谢[@swalltail99](https://github.com/swalltail99)指出的错误,为了适应sleep-edfx数据集的读取,使用这个版本的代码时,请安装mne==0.18.0<br>

```bash
pip install mne==0.18.0
```

## 入门
### 前提要求
- Linux, Windows,mac
- CPU or NVIDIA GPU + CUDA CuDNN
- Python 3
- Pytroch 1.0+
### 依赖
This code depends on torchvision, numpy, scipy , matplotlib, available via pip install.<br>
For example:<br>

```bash
pip3 install matplotlib
```
### 克隆仓库:
```bash
git clone https://github.com/HypoX64/candock
cd candock
```
### 下载数据集以及预训练模型
[[Google Drive]](https://drive.google.com/open?id=1NTtLmT02jqlc81lhtzQ7GlPK8epuHfU5)   [[百度云,y4ks]](https://pan.baidu.com/s/1WKWZL91SekrSlhOoEC1bQA)

* 数据集包括 signals.npy(shape:18207, 1, 2000) 以及 labels.npy(shape:18207) 可以使用"np.load()"加载
* 样本量:18207,  通道数:1,  每个样本的长度:2000,  总类别数:50
* Top1 err: 2.09%
### 训练
```bash
python3 train.py --label 50 --input_nc 1 --dataset_dir ./datasets/simple_test --save_dir ./checkpoints/simple_test --model_name micro_multi_scale_resnet_1d --gpu_id 0 --batchsize 64 --k_fold 5
# 如果需要使用cpu进行训练, 请输入 --gpu_id -1
```
* 更多可选参数 [options](./util/options.py).
### 测试
```bash
python3 simple_test.py --label 50 --input_nc 1 --model_name micro_multi_scale_resnet_1d --gpu_id 0
# 如果需要使用cpu进行训练, 请输入 --gpu_id -1
```

## 使用自己的数据进行训练
* step1: 按照如下格式生成 signals.npy 以及 labels.npy.
```python
#1.type:numpydata   signals:np.float64   labels:np.int64
#2.shape  signals:[num,ch,length]    labels:[num]
H
hypox64 已提交
109
#num:samples_num, ch :channel_num,  length:length of each sample
H
hypox64 已提交
110 111 112 113 114 115
#for example:
signals = np.zeros((10,1,10),dtype='np.float64')
labels = np.array([0,0,0,0,0,1,1,1,1,1])      #0->class0    1->class1
```
* step2: 输入  ```--dataset_dir "your_dataset_dir"``` 当运行代码的时候.

H
hypox64 已提交
116
### [ More options](./util/options.py).