server.py 3.1 KB
Newer Older
H
hypox64 已提交
1 2 3 4
import os
import time
import shutil
import numpy as np
H
hypox64 已提交
5
import random
H
hypox64 已提交
6 7 8 9 10 11
import torch
from torch import nn, optim
import warnings

from util import util,transformer,dataloader,statistics,plot,options
from util import array_operation as arr
H
hypox64 已提交
12
from models import creatnet,core
H
hypox64 已提交
13 14 15 16 17 18 19

opt = options.Options()
opt.parser.add_argument('--ip',type=str,default='', help='')
opt = opt.getparse()
torch.cuda.set_device(opt.gpu_id)
opt.k_fold = 0
opt.save_dir = './datasets/server/tmp'
H
hypox64 已提交
20
util.makedirs(opt.save_dir)
H
hypox64 已提交
21 22 23 24 25 26
'''load ori data'''
signals,labels = dataloader.loaddataset(opt)
label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals.shape)

'''def network'''
H
hypox64 已提交
27 28
core = core.Core(opt)
core.network_init(printflag=True)
H
hypox64 已提交
29 30 31 32 33 34

'''Receive data'''
if os.path.isdir('./datasets/server/data'):
    shutil.rmtree('./datasets/server/data')
os.system('unzip ./datasets/server/data.zip -d ./datasets/server/')
categorys = os.listdir('./datasets/server/data')
H
hypox64 已提交
35 36
categorys.sort()
print('categorys:',categorys)
H
hypox64 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49
receive_category = len(categorys)
received_signals = []
received_labels = []
for i in range(receive_category):
    samples = os.listdir(os.path.join('./datasets/server/data',categorys[i]))

    for sample in samples:
        txt = util.loadtxt(os.path.join('./datasets/server/data',categorys[i],sample))
        #print(os.path.join('./datasets/server/data',categorys[i],sample))
        txt_split = txt.split()
        signal_ori = np.zeros(len(txt_split))
        for point in range(len(txt_split)):
            signal_ori[point] = float(txt_split[point])
H
hypox64 已提交
50 51 52 53 54 55 56 57 58 59 60 61
        # #just cut
        # for j in range(1,len(signal_ori)//opt.loadsize-1):
        #     this_signal = signal_ori[j*opt.loadsize:(j+1)*opt.loadsize]
        #     this_signal = arr.normliaze(this_signal,'5_95',truncated=4)
        #     received_signals.append(this_signal)
        #     received_labels.append(i)
        #random cut
        for j in range(500//len(samples)-1):
            ran = random.randint(1000, len(signal_ori)-2000-1)
            this_signal = signal_ori[ran:ran+2000]
            this_signal = arr.normliaze(this_signal,'5_95',truncated=4)
            received_signals.append(this_signal)
H
hypox64 已提交
62 63 64 65 66
            received_labels.append(i)

received_signals = np.array(received_signals).reshape(-1,opt.input_nc,opt.loadsize)
received_labels = np.array(received_labels).reshape(-1,1)

H
hypox64 已提交
67
# print(labels)
H
hypox64 已提交
68 69 70 71 72
'''merge data'''
signals = signals[receive_category*500:]
labels = labels[receive_category*500:]
signals = np.concatenate((signals, received_signals))
labels = np.concatenate((labels, received_labels))
H
hypox64 已提交
73 74
transformer.shuffledata(signals,labels)

H
hypox64 已提交
75 76

label_cnt,label_cnt_per,label_num = statistics.label_statistics(labels)
H
hypox64 已提交
77
opt = options.get_auto_options(opt, label_cnt_per, label_num, signals.shape)
H
hypox64 已提交
78 79 80 81
train_sequences,test_sequences = transformer.k_fold_generator(len(labels),opt.k_fold)

for epoch in range(opt.epochs):
    t1 = time.time()
H
hypox64 已提交
82 83
    core.train(signals,labels,train_sequences[0])
    core.eval(signals,labels,test_sequences[0])
H
hypox64 已提交
84 85 86
    t2=time.time()
    if epoch+1==1:
        util.writelog('>>> per epoch cost time:'+str(round((t2-t1),2))+'s',opt,True)
H
hypox64 已提交
87
core.save_traced_net()