dataloader.py 11.0 KB
Newer Older
HypoX64's avatar
HypoX64 已提交
1 2 3
import os
import time
import random
H
hypox64 已提交
4 5 6 7 8 9

import scipy.io as sio
import numpy as np
import h5py
import mne

H
hypox64 已提交
10
import dsp
H
hypox64 已提交
11
import transformer
HypoX64's avatar
HypoX64 已提交
12

HypoX64's avatar
HypoX64 已提交
13

H
hypox64 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
def trimdata(data,num):
    return data[:num*int(len(data)/num)]

def reducesample(data,mult):
    return data[::mult]

# delete uesless label
def del_UND(signals,stages):
    stages_copy = stages.copy()
    cnt = 0
    for i in range(len(stages_copy)):
        if stages_copy[i] == 5 :
            signals = np.delete(signals,i-cnt,axis =0)
            stages = np.delete(stages,i-cnt,axis =0)
            cnt += 1
    return signals,stages

def connectdata(signal,stage,signals=[],stages=[]):
    if signals == []:
        signals =signal.copy()
        stages =stage.copy()
    else:
        signals=np.concatenate((signals, signal), axis=0)
        stages=np.concatenate((stages, stage), axis=0)
    return signals,stages

#load one subject data form cc2018
def loaddata_cc2018(filedir,filename,signal_name,BID,filter = True):
    dirpath = os.path.join(filedir,filename)
    #load signal
HypoX64's avatar
HypoX64 已提交
44 45 46 47 48 49 50 51
    hea_path = os.path.join(dirpath,os.path.basename(dirpath)+'.hea')
    signal_path = os.path.join(dirpath,os.path.basename(dirpath)+'.mat')
    signal_names = []
    for i,line in enumerate(open(hea_path),0):
        if i!=0:
            line=line.strip()
            signal_names.append(line.split()[8])
    mat = sio.loadmat(signal_path)
H
hypox64 已提交
52
    signals = mat['val'][signal_names.index(signal_name)]
HypoX64's avatar
HypoX64 已提交
53
    if filter:
H
hypox64 已提交
54
        signals = dsp.BPF(signals,200,0.2,50,mod = 'fir')
H
hypox64 已提交
55 56 57 58 59 60 61 62 63 64 65
    #load stage
    stagepath = os.path.join(dirpath,os.path.basename(dirpath)+'-arousal.mat')
    mat=h5py.File(stagepath,'r')
    # N3(S4+S3)->0  N2->1  N1->2  REM->3  W->4  UND->5
    N3 = mat['data']['sleep_stages']['nonrem3'][0]
    N2 = mat['data']['sleep_stages']['nonrem2'][0]
    N1 = mat['data']['sleep_stages']['nonrem1'][0]
    REM = mat['data']['sleep_stages']['rem'][0]
    W = mat['data']['sleep_stages']['wake'][0]
    UND = mat['data']['sleep_stages']['undefined'][0]
    stages = N3*0 + N2*1 + N1*2 + REM*3 + W*4 + UND*5
HypoX64's avatar
HypoX64 已提交
66 67 68 69 70 71 72 73 74
    #resample
    signals = reducesample(signals,2)
    stages = reducesample(stages,2)
    #trim
    signals = trimdata(signals,3000)
    stages = trimdata(stages,3000)
    #30s per lable
    signals = signals.reshape(-1,3000)
    stages = stages[::3000]
H
hypox64 已提交
75 76
    #Balance individualized differences
    signals = transformer.Balance_individualized_differences(signals, BID)
HypoX64's avatar
HypoX64 已提交
77
    #del UND
H
hypox64 已提交
78 79
    signals,stages = del_UND(signals, stages)

H
hypox64 已提交
80
    return signals.astype(np.float16),stages.astype(np.int16)
HypoX64's avatar
HypoX64 已提交
81

H
hypox64 已提交
82 83 84
#load one subject data form sleep-edfx
def loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time):
    filenum = filename[2:6]
HypoX64's avatar
HypoX64 已提交
85 86 87 88 89 90
    filenames = os.listdir(filedir)
    for filename in filenames:
        if str(filenum) in filename and 'Hypnogram' in filename:
            f_stage_name = filename
        if str(filenum) in filename and 'PSG' in filename:
            f_signal_name = filename
HypoX64's avatar
HypoX64 已提交
91 92 93

    raw_data= mne.io.read_raw_edf(os.path.join(filedir,f_signal_name),preload=True)
    raw_annot = mne.read_annotations(os.path.join(filedir,f_stage_name))
HypoX64's avatar
HypoX64 已提交
94
    eeg = raw_data.pick_channels([signal_name]).to_data_frame().values.T
HypoX64's avatar
HypoX64 已提交
95
    eeg = eeg.reshape(-1)
HypoX64's avatar
HypoX64 已提交
96 97

    raw_data.set_annotations(raw_annot, emit_warning=False)
H
hypox64 已提交
98
    #N3(S4+S3)->0  N2->1  N1->2  REM->3  W->4  other->UND->5
HypoX64's avatar
HypoX64 已提交
99 100 101 102 103 104 105
    event_id = {'Sleep stage 4': 0,
                  'Sleep stage 3': 0,
                  'Sleep stage 2': 1,
                  'Sleep stage 1': 2,
                  'Sleep stage R': 3,
                  'Sleep stage W': 4,
                  'Sleep stage ?': 5,
HypoX64's avatar
HypoX64 已提交
106
                  'Movement time': 5}
HypoX64's avatar
HypoX64 已提交
107 108 109
    events, _ = mne.events_from_annotations(
        raw_data, event_id=event_id, chunk_duration=30.)

H
hypox64 已提交
110
    stages = [];signals =[]
HypoX64's avatar
HypoX64 已提交
111 112 113 114 115 116 117
    for i in range(len(events)-1):
        stages.append(events[i][2])
        signals.append(eeg[events[i][0]:events[i][0]+3000])
    stages=np.array(stages)
    signals=np.array(signals)

    # #select sleep time 
H
hypox64 已提交
118
    if select_sleep_time:
HypoX64's avatar
HypoX64 已提交
119 120 121
        if 'SC' in f_signal_name:
            signals = signals[np.clip(int(raw_annot[0]['duration'])//30-60,0,9999999):int(raw_annot[-2]['onset'])//30+60]
            stages = stages[np.clip(int(raw_annot[0]['duration'])//30-60,0,9999999):int(raw_annot[-2]['onset'])//30+60]
HypoX64's avatar
HypoX64 已提交
122

H
hypox64 已提交
123
    signals,stages = del_UND(signals, stages)
HypoX64's avatar
HypoX64 已提交
124
    print('shape:',signals.shape,stages.shape)
HypoX64's avatar
HypoX64 已提交
125

H
hypox64 已提交
126
    signals = transformer.Balance_individualized_differences(signals, BID)
H
hypox64 已提交
127 128

    return signals.astype(np.float16),stages.astype(np.int16)
HypoX64's avatar
HypoX64 已提交
129

H
hypox64 已提交
130
#load all data in datasets
H
hypox64 已提交
131
def loaddataset(filedir,dataset_name,signal_name,num,BID,select_sleep_time,shuffle = False):
HypoX64's avatar
HypoX64 已提交
132
    print('load dataset, please wait...')
H
hypox64 已提交
133 134

    signals_train=[];stages_train=[];signals_test=[];stages_test=[] 
HypoX64's avatar
HypoX64 已提交
135

H
hypox64 已提交
136
    if dataset_name == 'cc2018':
H
hypox64 已提交
137 138 139 140 141 142
        filenames = os.listdir(filedir)
        if shuffle:
            random.shuffle(filenames)
        else:
            filenames.sort()

HypoX64's avatar
HypoX64 已提交
143 144 145 146
        if num > len(filenames):
            num = len(filenames)
            print('num of dataset is:',num)

H
hypox64 已提交
147
        for cnt,filename in enumerate(filenames[:num],0):
H
hypox64 已提交
148 149 150 151 152 153 154 155
            signal,stage = loaddata_cc2018(filedir,filename,signal_name,BID = BID)
            if cnt < round(num*0.8) :
                signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train)
            else:
                signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test)
        print('train subjects:',round(num*0.8),'test subjects:',round(num*0.2))

    elif dataset_name == 'sleep-edfx':
HypoX64's avatar
HypoX64 已提交
156 157
        if num > 197:
            num = 197
H
hypox64 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190

        filenames_sc_train = ['SC4001E0-PSG.edf', 'SC4002E0-PSG.edf', 'SC4011E0-PSG.edf', 'SC4012E0-PSG.edf', 'SC4021E0-PSG.edf', 'SC4022E0-PSG.edf', 'SC4031E0-PSG.edf', 'SC4032E0-PSG.edf', 'SC4041E0-PSG.edf', 'SC4042E0-PSG.edf', 'SC4051E0-PSG.edf', 'SC4052E0-PSG.edf', 'SC4061E0-PSG.edf', 'SC4062E0-PSG.edf', 'SC4071E0-PSG.edf', 'SC4072E0-PSG.edf', 'SC4081E0-PSG.edf', 'SC4082E0-PSG.edf', 'SC4091E0-PSG.edf', 'SC4092E0-PSG.edf', 'SC4101E0-PSG.edf', 'SC4102E0-PSG.edf', 'SC4111E0-PSG.edf', 'SC4112E0-PSG.edf', 'SC4121E0-PSG.edf', 'SC4122E0-PSG.edf', 'SC4131E0-PSG.edf', 'SC4141E0-PSG.edf', 'SC4142E0-PSG.edf', 'SC4151E0-PSG.edf', 'SC4152E0-PSG.edf', 'SC4161E0-PSG.edf', 'SC4162E0-PSG.edf', 'SC4171E0-PSG.edf', 'SC4172E0-PSG.edf', 'SC4181E0-PSG.edf', 'SC4182E0-PSG.edf', 'SC4191E0-PSG.edf', 'SC4192E0-PSG.edf', 'SC4201E0-PSG.edf', 'SC4202E0-PSG.edf', 'SC4211E0-PSG.edf', 'SC4212E0-PSG.edf', 'SC4221E0-PSG.edf', 'SC4222E0-PSG.edf', 'SC4231E0-PSG.edf', 'SC4232E0-PSG.edf', 'SC4241E0-PSG.edf', 'SC4242E0-PSG.edf', 'SC4251E0-PSG.edf', 'SC4252E0-PSG.edf', 'SC4261F0-PSG.edf', 'SC4262F0-PSG.edf', 'SC4271F0-PSG.edf', 'SC4272F0-PSG.edf', 'SC4281G0-PSG.edf', 'SC4282G0-PSG.edf', 'SC4291G0-PSG.edf', 'SC4292G0-PSG.edf', 'SC4301E0-PSG.edf', 'SC4302E0-PSG.edf', 'SC4311E0-PSG.edf', 'SC4312E0-PSG.edf', 'SC4321E0-PSG.edf', 'SC4322E0-PSG.edf', 'SC4331F0-PSG.edf', 'SC4332F0-PSG.edf', 'SC4341F0-PSG.edf', 'SC4342F0-PSG.edf', 'SC4351F0-PSG.edf', 'SC4352F0-PSG.edf', 'SC4362F0-PSG.edf', 'SC4371F0-PSG.edf', 'SC4372F0-PSG.edf', 'SC4381F0-PSG.edf', 'SC4382F0-PSG.edf', 'SC4401E0-PSG.edf', 'SC4402E0-PSG.edf', 'SC4411E0-PSG.edf', 'SC4412E0-PSG.edf', 'SC4421E0-PSG.edf', 'SC4422E0-PSG.edf', 'SC4431E0-PSG.edf', 'SC4432E0-PSG.edf', 'SC4441E0-PSG.edf', 'SC4442E0-PSG.edf', 'SC4451F0-PSG.edf', 'SC4452F0-PSG.edf', 'SC4461F0-PSG.edf', 'SC4462F0-PSG.edf', 'SC4471F0-PSG.edf', 'SC4472F0-PSG.edf', 'SC4481F0-PSG.edf', 'SC4482F0-PSG.edf', 'SC4491G0-PSG.edf', 'SC4492G0-PSG.edf', 'SC4501E0-PSG.edf', 'SC4502E0-PSG.edf', 'SC4511E0-PSG.edf', 'SC4512E0-PSG.edf', 'SC4522E0-PSG.edf', 'SC4531E0-PSG.edf', 'SC4532E0-PSG.edf', 'SC4541F0-PSG.edf', 'SC4542F0-PSG.edf', 'SC4551F0-PSG.edf', 'SC4552F0-PSG.edf', 'SC4561F0-PSG.edf', 'SC4562F0-PSG.edf', 'SC4571F0-PSG.edf', 'SC4572F0-PSG.edf', 'SC4581G0-PSG.edf', 'SC4582G0-PSG.edf', 'SC4591G0-PSG.edf', 'SC4592G0-PSG.edf', 'SC4601E0-PSG.edf', 'SC4602E0-PSG.edf', 'SC4611E0-PSG.edf', 'SC4612E0-PSG.edf', 'SC4621E0-PSG.edf', 'SC4622E0-PSG.edf', 'SC4631E0-PSG.edf', 'SC4632E0-PSG.edf']
        filenames_sc_test = ['SC4641E0-PSG.edf', 'SC4642E0-PSG.edf', 'SC4651E0-PSG.edf', 'SC4652E0-PSG.edf', 'SC4661E0-PSG.edf', 'SC4662E0-PSG.edf', 'SC4671G0-PSG.edf', 'SC4672G0-PSG.edf', 'SC4701E0-PSG.edf', 'SC4702E0-PSG.edf', 'SC4711E0-PSG.edf', 'SC4712E0-PSG.edf', 'SC4721E0-PSG.edf', 'SC4722E0-PSG.edf', 'SC4731E0-PSG.edf', 'SC4732E0-PSG.edf', 'SC4741E0-PSG.edf', 'SC4742E0-PSG.edf', 'SC4751E0-PSG.edf', 'SC4752E0-PSG.edf', 'SC4761E0-PSG.edf', 'SC4762E0-PSG.edf', 'SC4771G0-PSG.edf', 'SC4772G0-PSG.edf', 'SC4801G0-PSG.edf', 'SC4802G0-PSG.edf', 'SC4811G0-PSG.edf', 'SC4812G0-PSG.edf', 'SC4821G0-PSG.edf', 'SC4822G0-PSG.edf']
        filenames_st_train = ['ST7011J0-PSG.edf', 'ST7012J0-PSG.edf', 'ST7021J0-PSG.edf', 'ST7022J0-PSG.edf', 'ST7041J0-PSG.edf', 'ST7042J0-PSG.edf', 'ST7051J0-PSG.edf', 'ST7052J0-PSG.edf', 'ST7061J0-PSG.edf', 'ST7062J0-PSG.edf', 'ST7071J0-PSG.edf', 'ST7072J0-PSG.edf', 'ST7081J0-PSG.edf', 'ST7082J0-PSG.edf', 'ST7091J0-PSG.edf', 'ST7092J0-PSG.edf', 'ST7101J0-PSG.edf', 'ST7102J0-PSG.edf', 'ST7111J0-PSG.edf', 'ST7112J0-PSG.edf', 'ST7121J0-PSG.edf', 'ST7122J0-PSG.edf', 'ST7131J0-PSG.edf', 'ST7132J0-PSG.edf', 'ST7141J0-PSG.edf', 'ST7142J0-PSG.edf', 'ST7151J0-PSG.edf', 'ST7152J0-PSG.edf', 'ST7161J0-PSG.edf', 'ST7162J0-PSG.edf', 'ST7171J0-PSG.edf', 'ST7172J0-PSG.edf', 'ST7181J0-PSG.edf', 'ST7182J0-PSG.edf', 'ST7191J0-PSG.edf', 'ST7192J0-PSG.edf']
        filenames_st_test = ['ST7201J0-PSG.edf', 'ST7202J0-PSG.edf', 'ST7211J0-PSG.edf', 'ST7212J0-PSG.edf', 'ST7221J0-PSG.edf', 'ST7222J0-PSG.edf', 'ST7241J0-PSG.edf', 'ST7242J0-PSG.edf']

        for filename in filenames_sc_train[:round(num*153/197*0.8)]:
            signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
            signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train)

        for filename in filenames_st_train[:round(num*44/197*0.8)]:
            signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
            signals_train,stages_train = connectdata(signal,stage,signals_train,stages_train)
        
        for filename in filenames_sc_test[:round(num*153/197*0.2)]:
            signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
            signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test)

        for filename in filenames_st_test[:round(num*44/197*0.2)]:
            signal,stage = loaddata_sleep_edfx(filedir,filename,signal_name,BID,select_sleep_time)
            signals_test,stages_test = connectdata(signal,stage,signals_test,stages_test)

        print('---------Each subject has two sample---------',
            '\nTrain samples_SC/ST:',round(num*153/197*0.8),round(num*44/197*0.8),
            '\nTest samples_SC/ST:',round(num*153/197*0.2),round(num*44/197*0.2))
    
    elif dataset_name == 'preload':
        signals_train = np.load(filedir+'/signals_train.npy')
        stages_train = np.load(filedir+'/stages_train.npy')
        signals_test = np.load(filedir+'/signals_test.npy')
        stages_test = np.load(filedir+'/stages_test.npy')

    return signals_train,stages_train,signals_test,stages_test