未验证 提交 11841096 编写于 作者: Z zhxfl 提交者: GitHub

Merge pull request #642 from zhxfl/fix-627

Fix 627
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
......@@ -62,10 +62,22 @@ class SampleInfoBucket(object):
label_bin_paths (list|tuple): Files containing the binary label data.
label_desc_paths (list|tuple): Files containing the description of
samples' label data.
split_perturb(int): Maximum perturbation value for length of
sub-sentence when splitting long sentence.
split_sentence_threshold(int): Sentence whose length larger than
the value will trigger split operation.
split_sub_sentence_len(int): sub-sentence length is equal to
(split_sub_sentence_len + rand() % split_perturb).
"""
def __init__(self, feature_bin_paths, feature_desc_paths, label_bin_paths,
label_desc_paths):
def __init__(self,
feature_bin_paths,
feature_desc_paths,
label_bin_paths,
label_desc_paths,
split_perturb=50,
split_sentence_threshold=512,
split_sub_sentence_len=256):
block_num = len(label_bin_paths)
assert len(label_desc_paths) == block_num
assert len(feature_bin_paths) == block_num
......@@ -76,6 +88,10 @@ class SampleInfoBucket(object):
self._feature_desc_paths = feature_desc_paths
self._label_bin_paths = label_bin_paths
self._label_desc_paths = label_desc_paths
self._split_perturb = split_perturb
self._split_sentence_threshold = split_sentence_threshold
self._split_sub_sentence_len = split_sub_sentence_len
self._rng = random.Random(0)
def generate_sample_info_list(self):
sample_info_list = []
......@@ -102,12 +118,45 @@ class SampleInfoBucket(object):
label_start = int(label_desc_split[2])
label_size = int(label_desc_split[3])
label_frame_num = int(label_desc_split[4])
sample_info_list.append(
SampleInfo(feature_bin_path, feature_start, feature_size,
feature_frame_num, feature_dim, label_bin_path,
label_start, label_size, label_frame_num))
assert feature_frame_num == label_frame_num
if self._split_sentence_threshold == -1 or \
self._split_perturb == -1 or \
self._split_sub_sentence_len == -1 \
or self._split_sentence_threshold >= feature_frame_num:
sample_info_list.append(
SampleInfo(feature_bin_path, feature_start,
feature_size, feature_frame_num, feature_dim,
label_bin_path, label_start, label_size,
label_frame_num))
#split sentence
else:
cur_frame_pos = 0
cur_frame_len = 0
remain_frame_num = feature_frame_num
while True:
if remain_frame_num > self._split_sentence_threshold:
cur_frame_len = self._split_sub_sentence_len + \
self._rng.randint(0, self._split_perturb)
if cur_frame_len > remain_frame_num:
cur_frame_len = remain_frame_num
else:
cur_frame_len = remain_frame_num
sample_info_list.append(
SampleInfo(
feature_bin_path, feature_start + cur_frame_pos
* feature_dim * 4, cur_frame_len * feature_dim *
4, cur_frame_len, feature_dim, label_bin_path,
label_start + cur_frame_pos * 4, cur_frame_len *
4, cur_frame_len))
remain_frame_num -= cur_frame_len
cur_frame_pos += cur_frame_len
if remain_frame_num <= 0:
break
print("generate_sample_info_list size ", len(sample_info_list))
return sample_info_list
......@@ -125,7 +174,7 @@ class DataReader(object):
label_file_list (str): File containing paths of label data file and
corresponding description file.
drop_frame_len (int): Samples whose label length above the value will be
dropped.
dropped.(Using '-1' to disable the policy)
process_num (int): Number of processes for processing data.
sample_buffer_size (int): Buffer size to indicate the maximum samples
cached.
......@@ -149,7 +198,7 @@ class DataReader(object):
sample_buffer_size=1024,
sample_info_buffer_size=1024,
batch_buffer_size=1024,
shuffle_block_num=1,
shuffle_block_num=10,
random_seed=0,
verbose=0):
self._feature_file_list = feature_file_list
......@@ -253,11 +302,21 @@ class DataReader(object):
sample_info.feature_start,
sample_info.feature_size)
assert sample_info.feature_frame_num * sample_info.feature_dim * 4 \
== len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))
label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_start,
sample_info.label_size)
assert sample_info.label_frame_num * 4 == len(label_bytes)
assert sample_info.label_frame_num * 4 == len(label_bytes), (
sample_info.label_bin_path, sample_info.label_array,
len(label_bytes))
label_array = struct.unpack('I' * sample_info.label_frame_num,
label_bytes)
label_data = np.array(
......@@ -282,7 +341,8 @@ class DataReader(object):
time.sleep(0.001)
# drop long sentence
if self._drop_frame_len >= sample_data[0].shape[0]:
if self._drop_frame_len == -1 or \
self._drop_frame_len >= sample_data[0].shape[0]:
sample_queue.put(sample_data)
out_order[0] += 1
......
......@@ -176,7 +176,7 @@ def train(args):
# train data reader
train_data_reader = reader.DataReader(args.train_feature_lst,
args.train_label_lst)
args.train_label_lst, -1)
train_data_reader.set_transformers(ltrans)
# train
for pass_id in xrange(args.pass_num):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册