From 085dde45214fee6b97883aa54825a7d55c81cfd5 Mon Sep 17 00:00:00 2001 From: HypoX64 Date: Sat, 25 Jul 2020 14:58:27 +0800 Subject: [PATCH] Add mlp --- .gitignore | 1 + models/core.py | 12 +++++++----- models/creatnet.py | 5 ++++- models/net_1d/mlp.py | 22 ++++++++++++++++++++++ tools/server.py | 2 +- util/dsp.py | 31 +++++++++++++++++++++++++++---- util/options.py | 4 ++-- util/util.py | 9 ++++++++- 8 files changed, 72 insertions(+), 14 deletions(-) create mode 100644 models/net_1d/mlp.py diff --git a/.gitignore b/.gitignore index 97fbcbc..ba7da76 100644 --- a/.gitignore +++ b/.gitignore @@ -138,6 +138,7 @@ checkpoints/ /train_backup.py /tools/client_data /tools/server_data +/trainscript.py *.pth *.edf *log* \ No newline at end of file diff --git a/models/core.py b/models/core.py index 2314738..0e17a98 100644 --- a/models/core.py +++ b/models/core.py @@ -42,7 +42,7 @@ class Core(object): self.test_flag = True if printflag: - util.writelog('network:\n'+str(self.net),self.opt,True) + #util.writelog('network:\n'+str(self.net),self.opt,True) show_paramsnumber(self.net,self.opt) if self.opt.pretrained != '': @@ -85,7 +85,8 @@ class Core(object): self.queue = Queue(self.opt.load_thread*2) process_batch_num = len(sequences)//self.opt.batchsize//self.opt.load_thread if process_batch_num == 0: - print('\033[1;33m'+'Warning: too much load thread'+'\033[0m') + if self.epoch == 1: + print('\033[1;33m'+'Warning: too much load thread'+'\033[0m') self.start_process(signals,labels,sequences) else: for i in range(self.opt.load_thread): @@ -130,8 +131,8 @@ class Core(object): loss.backward() self.optimizer.step() - self.plot_result['train'].append(epoch_loss/i) - plot.draw_loss(self.plot_result,self.epoch+i/(sequences.shape[0]/self.opt.batchsize),self.opt) + self.plot_result['train'].append(epoch_loss/(i+1)) + plot.draw_loss(self.plot_result,self.epoch+(i+1)/(sequences.shape[0]/self.opt.batchsize),self.opt) # if self.opt.model_name != 'autoencoder': # plot.draw_heatmap(confusion_mat,self.opt,name = 'current_train') @@ -142,6 +143,7 @@ class Core(object): epoch_loss = 0 confusion_mat = np.zeros((self.opt.label,self.opt.label), dtype=int) + np.random.shuffle(sequences) self.process_pool_init(signals, labels, sequences) for i in range(len(sequences)//self.opt.batchsize): signal,label = self.queue.get() @@ -160,7 +162,7 @@ class Core(object): print('epoch:'+str(self.epoch),' macro-prec,reca,F1,err,kappa: '+str(statistics.report(confusion_mat))) self.plot_result['F1'].append(statistics.report(confusion_mat)[2]) - self.plot_result['eval'].append(epoch_loss/i) + self.plot_result['eval'].append(epoch_loss/(i+1)) self.epoch +=1 self.confusion_mats.append(confusion_mat) diff --git a/models/creatnet.py b/models/creatnet.py index 604d3fd..26b158d 100644 --- a/models/creatnet.py +++ b/models/creatnet.py @@ -1,5 +1,5 @@ from torch import nn -from .net_1d import cnn_1d,lstm,resnet_1d,multi_scale_resnet_1d,micro_multi_scale_resnet_1d,autoencoder +from .net_1d import cnn_1d,lstm,resnet_1d,multi_scale_resnet_1d,micro_multi_scale_resnet_1d,autoencoder,mlp from .net_2d import densenet,dfcnn,mobilenet,resnet,squeezenet,multi_scale_resnet @@ -9,6 +9,9 @@ def creatnet(opt): #encoder if name =='autoencoder': net = autoencoder.Autoencoder(opt.input_nc, opt.feature, opt.label,opt.finesize) + #mlp + if name =='mlp': + net = mlp.mlp(opt.input_nc, opt.label, opt.finesize) #lstm elif name =='lstm': net = lstm.lstm(opt.lstm_inputsize,opt.lstm_timestep,input_nc=opt.input_nc,num_classes=opt.label) diff --git a/models/net_1d/mlp.py b/models/net_1d/mlp.py new file mode 100644 index 0000000..ca9c338 --- /dev/null +++ b/models/net_1d/mlp.py @@ -0,0 +1,22 @@ +import torch +from torch import nn +import torch.nn.functional as F + +class mlp(nn.Module): + def __init__(self, input_nc,num_classes,datasize): + super(mlp, self).__init__() + + self.net = nn.Sequential( + nn.Linear(datasize*input_nc, 128), + nn.Tanh(), + nn.Linear(128, 64), + nn.Tanh(), + nn.Linear(64, 64), + nn.Tanh(), + nn.Linear(64, num_classes), + ) + + def forward(self, x): + x = x.view(x.size(0),-1) + x = self.net(x) + return x \ No newline at end of file diff --git a/tools/server.py b/tools/server.py index 802e89c..bf5e5ca 100644 --- a/tools/server.py +++ b/tools/server.py @@ -132,4 +132,4 @@ def handlepost(): return {'return':'error'} -app.run("0.0.0.0", port= 4000, debug=True) +app.run("0.0.0.0", port= 4000, debug=False) diff --git a/util/dsp.py b/util/dsp.py index c525c02..dc53572 100644 --- a/util/dsp.py +++ b/util/dsp.py @@ -1,6 +1,7 @@ import scipy.signal import scipy.fftpack as fftpack import numpy as np +import pywt def sin(f,fs,time): x = np.linspace(0, 2*np.pi*f*time, fs*time) @@ -23,10 +24,32 @@ def medfilt(signal,x): def cleanoffset(signal): return signal - np.mean(signal) -def bpf_fir(signal,fs,fc1,fc2,numtaps=101): - b=scipy.signal.firwin(numtaps, [fc1, fc2], pass_zero=False,fs=fs) - result = scipy.signal.lfilter(b, 1, signal) - return result +def showfreq(signal,fs,fc=0): + """ + return f,fft + """ + if fc==0: + kc = int(len(signal)/2) + else: + kc = int(len(signal)/fs*fc) + signal_fft = np.abs(scipy.fftpack.fft(signal)) + f = np.linspace(0,fs/2,num=int(len(signal_fft)/2)) + return f[:kc],signal_fft[0:int(len(signal_fft)/2)][:kc] + +def bpf(signal, fs, fc1, fc2, numtaps=3, mode='iir'): + if mode == 'iir': + b,a = scipy.signal.iirfilter(numtaps, [fc1,fc2], fs=fs) + elif mode == 'fir': + b = scipy.signal.firwin(numtaps, [fc1, fc2], pass_zero=False,fs=fs) + a = 1 + return scipy.signal.lfilter(b, a, signal) + +def wave_filter(signal,wave,level,usedcoeffs): + coeffs = pywt.wavedec(signal, wave, level=level) + for i in range(len(usedcoeffs)): + if usedcoeffs[i] == 0: + coeffs[i] = np.zeros_like(coeffs[i]) + return pywt.waverec(coeffs, wave, mode='symmetric', axis=-1) def fft_filter(signal,fs,fc=[],type = 'bandpass'): ''' diff --git a/util/options.py b/util/options.py index ca9cc79..87c238e 100644 --- a/util/options.py +++ b/util/options.py @@ -48,7 +48,7 @@ class Options(): # ------------------------Network------------------------ """Available Network 1d: lstm, cnn_1d, resnet18_1d, resnet34_1d, multi_scale_resnet_1d, - micro_multi_scale_resnet_1d,autoencoder + micro_multi_scale_resnet_1d,autoencoder,mlp 2d: mobilenet, dfcnn, multi_scale_resnet, resnet18, resnet50, resnet101, densenet121, densenet201, squeezenet """ @@ -100,7 +100,7 @@ class Options(): if self.opt.model_type == 'auto': if self.opt.model_name in ['lstm', 'cnn_1d', 'resnet18_1d', 'resnet34_1d', - 'multi_scale_resnet_1d','micro_multi_scale_resnet_1d','autoencoder']: + 'multi_scale_resnet_1d','micro_multi_scale_resnet_1d','autoencoder','mlp']: self.opt.model_type = '1d' elif self.opt.model_name in ['dfcnn', 'multi_scale_resnet', 'resnet18', 'resnet50', 'resnet101','densenet121', 'densenet201', 'squeezenet', 'mobilenet']: diff --git a/util/util.py b/util/util.py index 12fcd0e..1708202 100644 --- a/util/util.py +++ b/util/util.py @@ -1,6 +1,7 @@ import os import string import random +import shutil def randomstr(num): return ''.join(random.sample(string.ascii_letters + string.digits, num)) @@ -38,4 +39,10 @@ def loadfile(path): def savefile(file,path): wf = open(path,'wb') wf.write(file) - wf.close() \ No newline at end of file + wf.close() + +def copyfile(src,dst): + try: + shutil.copyfile(src, dst) + except Exception as e: + print(e) \ No newline at end of file -- GitLab