提交 3bbb3567 编写于 作者: Z zhxfl

data load for inference

上级 59bc4c1d
...@@ -102,11 +102,14 @@ class SampleInfoBucket(object): ...@@ -102,11 +102,14 @@ class SampleInfoBucket(object):
feature_bin_path = self._feature_bin_paths[block_idx] feature_bin_path = self._feature_bin_paths[block_idx]
feature_desc_path = self._feature_desc_paths[block_idx] feature_desc_path = self._feature_desc_paths[block_idx]
label_desc_lines = open(label_desc_path).readlines()
feature_desc_lines = open(feature_desc_path).readlines() feature_desc_lines = open(feature_desc_path).readlines()
label_desc_lines = None
if label_desc_path != "":
label_desc_lines = open(label_desc_path).readlines()
sample_num = int(label_desc_lines[0].split()[1]) sample_num = int(feature_desc_lines[0].split()[1])
assert sample_num == int(feature_desc_lines[0].split()[1]) if label_desc_path != "":
assert sample_num == int(label_desc_lines[0].split()[1])
for i in xrange(sample_num): for i in xrange(sample_num):
feature_desc_split = feature_desc_lines[i + 1].split() feature_desc_split = feature_desc_lines[i + 1].split()
...@@ -115,11 +118,15 @@ class SampleInfoBucket(object): ...@@ -115,11 +118,15 @@ class SampleInfoBucket(object):
feature_frame_num = int(feature_desc_split[4]) feature_frame_num = int(feature_desc_split[4])
feature_dim = int(feature_desc_split[5]) feature_dim = int(feature_desc_split[5])
label_desc_split = label_desc_lines[i + 1].split() label_start = -1
label_start = int(label_desc_split[2]) label_size = -1
label_size = int(label_desc_split[3]) label_frame_num = feature_frame_num
label_frame_num = int(label_desc_split[4]) if label_desc_path != "":
assert feature_frame_num == label_frame_num label_desc_split = label_desc_lines[i + 1].split()
label_start = int(label_desc_split[2])
label_size = int(label_desc_split[3])
label_frame_num = int(label_desc_split[4])
assert feature_frame_num == label_frame_num
if self._split_sentence_threshold == -1 or \ if self._split_sentence_threshold == -1 or \
self._split_perturb == -1 or \ self._split_perturb == -1 or \
...@@ -187,7 +194,7 @@ class AsyncDataReader(object): ...@@ -187,7 +194,7 @@ class AsyncDataReader(object):
def __init__(self, def __init__(self,
feature_file_list, feature_file_list,
label_file_list, label_file_list="",
drop_frame_len=512, drop_frame_len=512,
proc_num=10, proc_num=10,
sample_buffer_size=1024, sample_buffer_size=1024,
...@@ -221,16 +228,25 @@ class AsyncDataReader(object): ...@@ -221,16 +228,25 @@ class AsyncDataReader(object):
def generate_bucket_list(self, is_shuffle): def generate_bucket_list(self, is_shuffle):
if self._block_info_list is None: if self._block_info_list is None:
block_feature_info_lines = open(self._feature_file_list).readlines() block_feature_info_lines = open(self._feature_file_list).readlines()
block_label_info_lines = open(self._label_file_list).readlines()
assert len(block_feature_info_lines) == len(block_label_info_lines)
self._block_info_list = [] self._block_info_list = []
for i in xrange(0, len(block_feature_info_lines), 2): if self._label_file_list != "":
block_info = (block_feature_info_lines[i], block_label_info_lines = open(self._label_file_list).readlines()
block_feature_info_lines[i + 1], #block_label_info_lines = open(self._label_file_list).readlines()
block_label_info_lines[i], assert len(block_feature_info_lines) == len(
block_label_info_lines[i + 1]) block_label_info_lines)
self._block_info_list.append( for i in xrange(0, len(block_feature_info_lines), 2):
map(lambda line: line.strip(), block_info)) block_info = (block_feature_info_lines[i],
block_feature_info_lines[i + 1],
block_label_info_lines[i],
block_label_info_lines[i + 1])
self._block_info_list.append(
map(lambda line: line.strip(), block_info))
else:
for i in xrange(0, len(block_feature_info_lines), 2):
block_info = (block_feature_info_lines[i],
block_feature_info_lines[i + 1], "", "")
self._block_info_list.append(
map(lambda line: line.strip(), block_info))
if is_shuffle: if is_shuffle:
self._rng.shuffle(self._block_info_list) self._rng.shuffle(self._block_info_list)
...@@ -318,19 +334,24 @@ class AsyncDataReader(object): ...@@ -318,19 +334,24 @@ class AsyncDataReader(object):
sample_info.feature_dim, sample_info.feature_dim,
len(feature_bytes)) len(feature_bytes))
label_bytes = read_bytes(sample_info.label_bin_path, if sample_info.label_bin_path != "":
sample_info.label_start, label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_size) sample_info.label_start,
sample_info.label_size)
assert sample_info.label_frame_num * 4 == len(label_bytes), (
sample_info.label_bin_path, sample_info.label_array, assert sample_info.label_frame_num * 4 == len(
len(label_bytes)) label_bytes), (sample_info.label_bin_path,
sample_info.label_array,
label_array = struct.unpack('I' * sample_info.label_frame_num, len(label_bytes))
label_bytes)
label_data = np.array( label_array = struct.unpack(
label_array, dtype='int64').reshape( 'I' * sample_info.label_frame_num, label_bytes)
(sample_info.label_frame_num, 1)) label_data = np.array(
label_array, dtype='int64').reshape(
(sample_info.label_frame_num, 1))
else:
label_data = np.zeros(
(sample_info.label_frame_num, 1), dtype='int64')
feature_frame_num = sample_info.feature_frame_num feature_frame_num = sample_info.feature_frame_num
feature_dim = sample_info.feature_dim feature_dim = sample_info.feature_dim
...@@ -345,7 +366,6 @@ class AsyncDataReader(object): ...@@ -345,7 +366,6 @@ class AsyncDataReader(object):
for transformer in self._transformers: for transformer in self._transformers:
# @TODO(pkuyym) to make transfomer only accept feature_data # @TODO(pkuyym) to make transfomer only accept feature_data
sample_data = transformer.perform_trans(sample_data) sample_data = transformer.perform_trans(sample_data)
while order_id != out_order[0]: while order_id != out_order[0]:
time.sleep(0.001) time.sleep(0.001)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册