提交 5dc83983 编写于 作者: K kechxu 提交者: Kecheng Xu

Prediction: add downsample and merge h5

上级 8ffeff56
#!/usr/bin/env python
###############################################################################
# Copyright 2018 The Apollo Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import os
import glob
import argparse
import datetime
import numpy as np
import h5py
def load_hdf5(filename):
"""
load training samples from *.hdf5 file
"""
if not(os.path.exists(filename)):
print "file:", filename, "does not exist"
os._exit(1)
if os.path.splitext(filename)[1] != '.h5':
print "file:", filename, "is not an hdf5 file"
os._exit(1)
h5_file = h5py.File(filename, 'r')
values = h5_file.values()[0]
print "load data size:", values.shape[0]
return values
if __name__ == '__main__':
parser = argparse.ArgumentParser(description = 'generate training samples\
from a specified directory')
parser.add_argument('directory', type=str,
help='directory contains feature files in .h5')
args = parser.parse_args()
path = args.directory
print "load h5 from directory:", format(path)
if os.path.isdir(path):
features = None
labels = None
h5_files = glob.glob(path + '/*.h5')
print "Length of files:", len(h5_files)
for i, h5_file in enumerate(h5_files):
print "Process File", i, ":", h5_file
feature = load_hdf5(h5_file)
if np.any(np.isinf(feature)):
print "inf data found"
features = np.concatenate((features, feature), axis=0) if features is not None \
else feature
else:
print "Fail to find", path
os._exit(-1)
date = datetime.datetime.now().strftime('%Y-%m-%d')
sample_dir = path + '/mlp_merge'
if not os.path.exists(sample_dir):
os.makedirs(sample_dir)
sample_file = sample_dir + '/mlp_' + date + '.h5'
print "Save samples file to:", sample_file
h5_file = h5py.File(sample_file, 'w')
h5_file.create_dataset('data', data=features)
h5_file.close()
......@@ -81,6 +81,36 @@ def load_data(filename):
return samples['data']
def down_sample(data):
cutin_false_drate = 0.9
go_false_drate = 0.9
go_true_drate = 0.7
cutin_true_drate = 0.0
label = data[:, -1]
size = np.shape(label)[0]
cutin_false_index = (label == -1)
go_false_index = (label == 0)
go_true_index = (label == 1)
cutin_true_index = (label == 2)
rand = np.random.random((size))
cutin_false_select = np.logical_and(cutin_false_index, rand > cutin_false_drate)
cutin_true_select = np.logical_and(cutin_true_index, rand > cutin_true_drate)
go_false_select = np.logical_and(go_false_index, rand > go_false_drate)
go_true_select = np.logical_and(go_true_index, rand > go_true_drate)
all_select = np.logical_or(cutin_false_select, cutin_true_select)
all_select = np.logical_or(all_select, go_false_select)
all_select = np.logical_or(all_select, go_true_select)
data_downsampled = data[all_select, :]
return data_downsampled
def get_param_norm(feature):
"""
Normalize the samples and save normalized parameters
......@@ -105,12 +135,12 @@ def setup_model():
model.add(Dense(dim_hidden_2,
init = 'he_normal',
activation = 'relu',
W_regularizer = l2(0.005)))
W_regularizer = l2(0.01)))
model.add(Dense(dim_output,
init='he_normal',
activation = 'sigmoid',
W_regularizer = l2(0.001)))
W_regularizer = l2(0.01)))
model.compile(loss = 'binary_crossentropy',
optimizer = 'rmsprop',
......@@ -228,11 +258,14 @@ if __name__ == "__main__":
file = args.filename
data = load_data(file)
data = down_sample(data)
print "Data load success."
print "data size =", data.shape
train_data, test_data = train_test_split(data, train_data_rate)
print "training size =", train_data.shape
X_train = train_data[:, 0 : dim_input]
Y_train = train_data[:, -1]
Y_trainc = Y_train > 0.1
......@@ -245,11 +278,13 @@ if __name__ == "__main__":
X_train = (X_train - param_norm[0]) / param_norm[1]
X_test = (X_test - param_norm[0]) / param_norm[1]
model = setup_model()
model.fit(X_train, Y_trainc,
shuffle = True,
nb_epoch = 30,
nb_epoch = 20,
batch_size = 32)
print "Model trained success."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册