提交 3bbb3567 编写于 作者: Z zhxfl

data load for inference

上级 59bc4c1d
...@@ -102,11 +102,14 @@ class SampleInfoBucket(object): ...@@ -102,11 +102,14 @@ class SampleInfoBucket(object):
feature_bin_path = self._feature_bin_paths[block_idx] feature_bin_path = self._feature_bin_paths[block_idx]
feature_desc_path = self._feature_desc_paths[block_idx] feature_desc_path = self._feature_desc_paths[block_idx]
label_desc_lines = open(label_desc_path).readlines()
feature_desc_lines = open(feature_desc_path).readlines() feature_desc_lines = open(feature_desc_path).readlines()
label_desc_lines = None
if label_desc_path != "":
label_desc_lines = open(label_desc_path).readlines()
sample_num = int(label_desc_lines[0].split()[1]) sample_num = int(feature_desc_lines[0].split()[1])
assert sample_num == int(feature_desc_lines[0].split()[1]) if label_desc_path != "":
assert sample_num == int(label_desc_lines[0].split()[1])
for i in xrange(sample_num): for i in xrange(sample_num):
feature_desc_split = feature_desc_lines[i + 1].split() feature_desc_split = feature_desc_lines[i + 1].split()
...@@ -115,6 +118,10 @@ class SampleInfoBucket(object): ...@@ -115,6 +118,10 @@ class SampleInfoBucket(object):
feature_frame_num = int(feature_desc_split[4]) feature_frame_num = int(feature_desc_split[4])
feature_dim = int(feature_desc_split[5]) feature_dim = int(feature_desc_split[5])
label_start = -1
label_size = -1
label_frame_num = feature_frame_num
if label_desc_path != "":
label_desc_split = label_desc_lines[i + 1].split() label_desc_split = label_desc_lines[i + 1].split()
label_start = int(label_desc_split[2]) label_start = int(label_desc_split[2])
label_size = int(label_desc_split[3]) label_size = int(label_desc_split[3])
...@@ -187,7 +194,7 @@ class AsyncDataReader(object): ...@@ -187,7 +194,7 @@ class AsyncDataReader(object):
def __init__(self, def __init__(self,
feature_file_list, feature_file_list,
label_file_list, label_file_list="",
drop_frame_len=512, drop_frame_len=512,
proc_num=10, proc_num=10,
sample_buffer_size=1024, sample_buffer_size=1024,
...@@ -221,9 +228,12 @@ class AsyncDataReader(object): ...@@ -221,9 +228,12 @@ class AsyncDataReader(object):
def generate_bucket_list(self, is_shuffle): def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None: if self._block_info_list is None:
block_feature_info_lines = open(self._feature_file_list).readlines() block_feature_info_lines = open(self._feature_file_list).readlines()
block_label_info_lines = open(self._label_file_list).readlines()
assert len(block_feature_info_lines) == len(block_label_info_lines)
self._block_info_list = [] self._block_info_list = []
if self._label_file_list != "":
block_label_info_lines = open(self._label_file_list).readlines()
#block_label_info_lines = open(self._label_file_list).readlines()
assert len(block_feature_info_lines) == len(
block_label_info_lines)
for i in xrange(0, len(block_feature_info_lines), 2): for i in xrange(0, len(block_feature_info_lines), 2):
block_info = (block_feature_info_lines[i], block_info = (block_feature_info_lines[i],
block_feature_info_lines[i + 1], block_feature_info_lines[i + 1],
...@@ -231,6 +241,12 @@ class AsyncDataReader(object): ...@@ -231,6 +241,12 @@ class AsyncDataReader(object):
block_label_info_lines[i + 1]) block_label_info_lines[i + 1])
self._block_info_list.append( self._block_info_list.append(
map(lambda line: line.strip(), block_info)) map(lambda line: line.strip(), block_info))
else:
for i in xrange(0, len(block_feature_info_lines), 2):
block_info = (block_feature_info_lines[i],
block_feature_info_lines[i + 1], "", "")
self._block_info_list.append(
map(lambda line: line.strip(), block_info))
if is_shuffle: if is_shuffle:
self._rng.shuffle(self._block_info_list) self._rng.shuffle(self._block_info_list)
...@@ -318,19 +334,24 @@ class AsyncDataReader(object): ...@@ -318,19 +334,24 @@ class AsyncDataReader(object):
sample_info.feature_dim, sample_info.feature_dim,
len(feature_bytes)) len(feature_bytes))
if sample_info.label_bin_path != "":
label_bytes = read_bytes(sample_info.label_bin_path, label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_start, sample_info.label_start,
sample_info.label_size) sample_info.label_size)
assert sample_info.label_frame_num * 4 == len(label_bytes), ( assert sample_info.label_frame_num * 4 == len(
sample_info.label_bin_path, sample_info.label_array, label_bytes), (sample_info.label_bin_path,
sample_info.label_array,
len(label_bytes)) len(label_bytes))
label_array = struct.unpack('I' * sample_info.label_frame_num, label_array = struct.unpack(
label_bytes) 'I' * sample_info.label_frame_num, label_bytes)
label_data = np.array( label_data = np.array(
label_array, dtype='int64').reshape( label_array, dtype='int64').reshape(
(sample_info.label_frame_num, 1)) (sample_info.label_frame_num, 1))
else:
label_data = np.zeros(
(sample_info.label_frame_num, 1), dtype='int64')
feature_frame_num = sample_info.feature_frame_num feature_frame_num = sample_info.feature_frame_num
feature_dim = sample_info.feature_dim feature_dim = sample_info.feature_dim
...@@ -345,7 +366,6 @@ class AsyncDataReader(object): ...@@ -345,7 +366,6 @@ class AsyncDataReader(object):
for transformer in self._transformers: for transformer in self._transformers:
# @TODO(pkuyym) to make transfomer only accept feature_data # @TODO(pkuyym) to make transfomer only accept feature_data
sample_data = transformer.perform_trans(sample_data) sample_data = transformer.perform_trans(sample_data)
while order_id != out_order[0]: while order_id != out_order[0]:
time.sleep(0.001) time.sleep(0.001)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册