diff --git a/data.py b/data.py index 35b87e28ce080bbf68642b8631f5e368df5de072..52428ae0f25824463c09c6d9ed4b48275b18d1b9 100644 --- a/data.py +++ b/data.py @@ -60,19 +60,19 @@ def random_transform_1d(data,finesize,test_flag): def random_transform_2d(img,finesize,test_flag): h,w = img.shape[:2] if test_flag: - h_move = 1 + h_move = 2 w_move = int((w-finesize)*0.5) result = img[h_move:h_move+finesize,w_move:w_move+finesize] else: #random crop - h_move = int(3*random.random()) #do not loss low freq signal infos + h_move = int(5*random.random()) #do not loss low freq signal infos w_move = int((w-finesize)*random.random()) result = img[h_move:h_move+finesize,w_move:w_move+finesize] #random flip if random.random()<0.5: result = result[:,::-1] #random amp - result = result*random.uniform(0.98,1.02)+random.uniform(-0.01,0.01) + result = result*random.uniform(0.95,1.05)+random.uniform(-0.02,0.02) return result diff --git a/dataloader.py b/dataloader.py index 8846bd231e02d33b3cc14203a82009161aeaad7a..400d04399e11096889e981121e7157ba655b22ee 100644 --- a/dataloader.py +++ b/dataloader.py @@ -6,7 +6,9 @@ import time import torch import random import DSP -import pyedflib +# import pyedflib +import mne + # CinC_Challenge_2018 def loadstages(dirpath): filepath = os.path.join(dirpath,os.path.basename(dirpath)+'-arousal.mat') @@ -86,17 +88,51 @@ def stage_str2int(stagestr): stage = 5 return stage -def loaddata_sleep_edf(filedir,filenum,filenames,signal_name,BID = 'median',filter = True): +def loaddata_sleep_edf(filedir,filenum,signal_name,BID = 'median',filter = True): filenames = os.listdir(filedir) - # f_stage_name='a' - # f_signal_name='b' for filename in filenames: if str(filenum) in filename and 'Hypnogram' in filename: f_stage_name = filename if str(filenum) in filename and 'PSG' in filename: f_signal_name = filename - # print(f_stage_name) + print(f_stage_name) + + raw_data= mne.io.read_raw_edf(os.path.join(filedir,f_signal_name),preload=True) + raw_annot = mne.read_annotations(os.path.join(filedir,f_stage_name)) + eeg = raw_data.pick_channels(['EEG Fpz-Cz']).to_data_frame().values.T.reshape(-1) + + raw_data.set_annotations(raw_annot, emit_warning=False) + event_id = {'Sleep stage 4': 0, + 'Sleep stage 3': 0, + 'Sleep stage 2': 1, + 'Sleep stage 1': 2, + 'Sleep stage R': 3, + 'Sleep stage W': 4, + 'Sleep stage ?': 5, + 'Sleep stage Movement time': 5} + events, _ = mne.events_from_annotations( + raw_data, event_id=event_id, chunk_duration=30.) + events = np.array(events) + + signals=trimdata(eeg,3000) + signals = signals.reshape(-1,3000) + stages = events[:,2] + print(signals.shape,events.shape) + # stages = stages[0:signals.shape[0]] + + stages_copy = stages.copy() + cnt = 0 + for i in range(len(stages_copy)): + if stages_copy[i] == 5 : + signals = np.delete(signals,i-cnt,axis =0) + stages = np.delete(stages,i-cnt,axis =0) + cnt += 1 + + + + # print(f_signal_name) + ''' f_stage = pyedflib.EdfReader(os.path.join(filedir,f_stage_name)) annotations = f_stage.readAnnotations() number_of_annotations = f_stage.annotations_in_file @@ -106,18 +142,15 @@ def loaddata_sleep_edf(filedir,filenum,filenames,signal_name,BID = 'median',filt for i in range(number_of_annotations): stages[int(annotations[0][i])//30:(int(annotations[0][i])+int(annotations[1][i]))//30] = stage_str2int(annotations[2][i]) - # #select sleep time - # stages[int(annotations[0][0])//30:(int(annotations[0][0])+int(annotations[1][0]))//30-120] = 5 - # stages[int(annotations[0][number_of_annotations-2])//30+120:(int(annotations[0][number_of_annotations-2])+int(annotations[1][number_of_annotations-2]))//30] = 5 - f_signal = pyedflib.EdfReader(os.path.join(filedir,f_signal_name)) signals = f_signal.readSignal(0) + signals=trimdata(signals,3000) signals = signals.reshape(-1,3000) stages = stages[0:signals.shape[0]] - #select sleep time - signals = signals[(int(annotations[0][0])+int(annotations[1][0]))//30-60:int(annotations[0][number_of_annotations-2])//30+60] - stages = stages[(int(annotations[0][0])+int(annotations[1][0]))//30-60:int(annotations[0][number_of_annotations-2])//30+60] + # #select sleep time + # signals = signals[np.clip(int(annotations[1][0])//30-60,0,9999999):int(annotations[0][number_of_annotations-2])//30+60] + # stages = stages[np.clip(int(annotations[1][0])//30-60,0,9999999):int(annotations[0][number_of_annotations-2])//30+60] #del UND stages_copy = stages.copy() @@ -127,6 +160,7 @@ def loaddata_sleep_edf(filedir,filenum,filenames,signal_name,BID = 'median',filt signals = np.delete(signals,i-cnt,axis =0) stages = np.delete(stages,i-cnt,axis =0) cnt += 1 + ''' return signals.astype(np.int16),stages.astype(np.int16) @@ -155,17 +189,9 @@ def loaddataset(filedir,dataset_name = 'CinC_Challenge_2018',signal_name = 'C4-M print(filename,e) elif dataset_name == 'sleep-edfx': cnt = 0 - signals=[] - stages=[] for filename in filenames: if 'PSG' in filename: - # try: - - signal,stage = loaddata_sleep_edf(filedir,filename[2:6],filenames,signal_name = 'FPZ-CZ') - # print(type(signal[0][0])) - # signals.append(signal) - # stages.append(stage) - + signal,stage = loaddata_sleep_edf(filedir,filename[2:6],signal_name = 'FPZ-CZ') if cnt == 0: signals =signal.copy() stages =stage.copy() @@ -175,12 +201,4 @@ def loaddataset(filedir,dataset_name = 'CinC_Challenge_2018',signal_name = 'C4-M cnt += 1 if cnt == num: break - # except Exception as e: - # print(filename,e) - # signals = signals.reshape(-1,3000) - # print(signals) - # signals = np.array(signals) - # stages = np.array(stages) - # print(signals.shape) - return signals,stages \ No newline at end of file diff --git a/download_dataset.py b/download_dataset.py index e1d95ed619d86e2a161cf4b349047d78552196b4..5106f5e299173f282b63b65cba1dfb3b192dfc96 100644 --- a/download_dataset.py +++ b/download_dataset.py @@ -4,9 +4,8 @@ import re import threading import os import json -import configparser from bs4 import BeautifulSoup - +import hashlib def RequestWeb(url): headers = {'Accept-Language':'zh-CN,zh;q=0.9', 'User-Agent':'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/64.0.3282.186 Safari/537.36' @@ -29,26 +28,47 @@ def download(url,name,path): # if chunk: # f.write(chunk) -def downloader(url,filenames,path): +def compare_md5(filepath,md5s): + if os.path.exists(filepath): + md5file=open(filepath,'rb') + md5=hashlib.md5(md5file.read()).hexdigest() + md5file.close() + if md5 in md5s: + return True + else: + print('Warning:',name,'md5 do not match, we will try again') + return False + else: + return False + + +def downloader(url,filenames,md5s,dir): for name in filenames: + filepath = os.path.join(dir,name) print('Download:',name) - while not os.path.exists(os.path.join(path,name)): + while not compare_md5(filepath,md5s): try: - download(url+name,name,path) + download(url+name,name,dir) except Exception as e: print('Warning:',name,'download failed! we will try again') -def rundownloader(url,filenames,path,ThreadNum=5): +def rundownloader(url,filenames,md5s,dir,ThreadNum=5): perthread=int(len(filenames)/ThreadNum) for i in range(0,ThreadNum): - t = threading.Thread(target=downloader,args=(url,filenames[perthread*i:perthread*(1+i)],path,)) + t = threading.Thread(target=downloader,args=(url,filenames[perthread*i:perthread*(1+i)],md5s,dir,)) t.start() - t = threading.Thread(target=downloader,args=(url,filenames[perthread*ThreadNum:],path,)) + t = threading.Thread(target=downloader,args=(url,filenames[perthread*ThreadNum:],md5s,dir,)) t.start() -savedir = './sleep-edfx/sleep-cassette' -url = 'https://physionet.org/physiobank/database/sleep-edfx/sleep-cassette/' +savedir = './sleep-edfx/sleep-telemetry' +url = 'https://physionet.org/physiobank/database/sleep-edfx/sleep-telemetry/' + + +md5s=open(os.path.join(savedir,'MD5SUMS.txt'),'rb') +md5s = md5s.read() +md5s=md5s.decode('utf-8') +md5s = md5s.split() soup,page_info=RequestWeb(url) links = soup.find_all('a',href=re.compile(r".edf")) @@ -58,15 +78,5 @@ for link in links[1:]: stop = str(link).index('') filename = str(link)[begin+2:stop] filenames.append(filename) -rundownloader(url,filenames,savedir) -''' - print('download:',filename) - try: - download(url+filename,filename,savedir) - except Exception as e: - print(filename,'download failed! ERR:',e) -''' -# had_down_files = os.listdir(savedir) -# for filename in filenames: -# if : -# pass \ No newline at end of file +rundownloader(url,filenames,md5s,savedir) + diff --git a/options.py b/options.py new file mode 100644 index 0000000000000000000000000000000000000000..79826e53671b02be7355151eccc886810ffb7c82 --- /dev/null +++ b/options.py @@ -0,0 +1,48 @@ +import argparse +import os +import numpy as np +import torch +#filedir = '/media/hypo/Hypo/physionet_org_train' +# filedir ='E:\physionet_org_train' + +#'/media/hypo/Hypo/physionet_org_train' +class Options(): + def __init__(self): + self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + self.initialized = False + + def initialize(self): + self.parser.add_argument('--no_cuda', action='store_true', help='if true, do not use gpu') + self.parser.add_argument('--lr', type=float, default=0.001,help='learning rate') + self.parser.add_argument('--batchsize', type=int, default=16,help='batchsize') + self.parser.add_argument('--dataset_dir', type=str, default='./sleep-edfx/sleep-telemetry', + help='your dataset path') + self.parser.add_argument('--dataset_name', type=str, default='sleep-edfx',help='Choose dataset') + self.parser.add_argument('--signal_name', type=str, default='C4-M1',help='Choose the EEG channel') + self.parser.add_argument('--signal_num', type=int, default=44,help='the amount you want to load') + self.parser.add_argument('--model_name', type=str, default='LSTM',help='Choose model') + self.parser.add_argument('--epochs', type=int, default=20,help='end epoch') + self.parser.add_argument('--weight_mod', type=str, default='normal',help='Choose weight mod') + + self.initialized = True + + def getparse(self): + if not self.initialized: + self.initialize() + self.opt = self.parser.parse_args() + + if self.opt.dataset_name == 'CinC_Challenge_2018': + if self.opt.weight_mod == 'avg_best': + weight = np.log(np.array([1/0.15,1/0.3,1/0.08,1/0.13,1/0.18])) + elif self.opt.weight_mod == 'normal': + weight = np.array([1,1,1,1,1]) + + elif self.opt.dataset_name == 'sleep-edfx': + if self.opt.weight_mod == 'avg_best': + weight = np.log(1/np.array([0.08,0.30,0.05,0.15,0.35])) + elif self.opt.weight_mod == 'normal': + weight = np.array([1,1,1,1,1]) + self.opt.weight = torch.from_numpy(weight).float() + + + return self.opt \ No newline at end of file diff --git a/statistics.py b/statistics.py index 0fa0ef242b8154cd0d4a32e9da75c7a8080d2a2c..58210b5c5de2d7cfd680b2c6d51410e9c2084975 100644 --- a/statistics.py +++ b/statistics.py @@ -1,6 +1,11 @@ import numpy as np import matplotlib.pyplot as plt +def writelog(log): + f = open('./log','a+') + f.write(log+'\n') + # print(log) + def stage(stages): #N3->0 N2->1 N1->2 REM->3 W->4 stage_cnt=np.array([0,0,0,0,0]) diff --git a/train.py b/train.py index 98d7bd67f82f1cd822c174c9b1c042c1c8c98fb0..3e92da2e012a933ffa795278d6cdc10ad686d299 100644 --- a/train.py +++ b/train.py @@ -18,6 +18,8 @@ warnings.filterwarnings("ignore") #test avg_recall: 0.7932 avg_acc: 0.9583 error: 0.1043 opt = Options().getparse() +localtime = time.asctime(time.localtime(time.time())) +statistics.writelog('\n'+str(localtime)+'\n'+str(opt)) t1 = time.time() signals,stages = dataloader.loaddataset(opt.dataset_dir,opt.dataset_name,opt.signal_name,opt.signal_num,shuffle=True,BID='median') @@ -74,6 +76,8 @@ def evalnet(net,signals,stages,plot_result={},mode = 'part'): plot_result['test'].append(recall) heatmap.draw(confusion_mat,name = 'test') print('test avg_recall:','%.4f' % recall,'avg_acc:','%.4f' % acc,'error:','%.4f' % error) + + statistics.writelog(str(confusion_mat)+'\navg_recall:'+str(recall)+' avg_acc:'+str(acc)+' error:'+str(error)) # torch.cuda.empty_cache() return plot_result @@ -83,6 +87,7 @@ plot_result['train']=[0] plot_result['test']=[0] for epoch in range(opt.epochs): t1 = time.time() + statistics.writelog('epoch:'+str(epoch)+'\n') confusion_mat = np.zeros((5,5), dtype=int) # running_loss, running_recall = 0.0, 0.0 print('epoch:',epoch+1)