未验证 提交 11841096 编写于 作者: Z zhxfl 提交者: GitHub

Merge pull request #642 from zhxfl/fix-627

Fix 627
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm
import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
...@@ -62,10 +62,22 @@ class SampleInfoBucket(object): ...@@ -62,10 +62,22 @@ class SampleInfoBucket(object):
label_bin_paths (list|tuple): Files containing the binary label data. label_bin_paths (list|tuple): Files containing the binary label data.
label_desc_paths (list|tuple): Files containing the description of label_desc_paths (list|tuple): Files containing the description of
samples' label data. samples' label data.
split_perturb(int): Maximum perturbation value for length of
sub-sentence when splitting long sentence.
split_sentence_threshold(int): Sentence whose length larger than
the value will trigger split operation.
split_sub_sentence_len(int): sub-sentence length is equal to
(split_sub_sentence_len + rand() % split_perturb).
""" """
def __init__(self, feature_bin_paths, feature_desc_paths, label_bin_paths, def __init__(self,
label_desc_paths): feature_bin_paths,
feature_desc_paths,
label_bin_paths,
label_desc_paths,
split_perturb=50,
split_sentence_threshold=512,
split_sub_sentence_len=256):
block_num = len(label_bin_paths) block_num = len(label_bin_paths)
assert len(label_desc_paths) == block_num assert len(label_desc_paths) == block_num
assert len(feature_bin_paths) == block_num assert len(feature_bin_paths) == block_num
...@@ -76,6 +88,10 @@ class SampleInfoBucket(object): ...@@ -76,6 +88,10 @@ class SampleInfoBucket(object):
self._feature_desc_paths = feature_desc_paths self._feature_desc_paths = feature_desc_paths
self._label_bin_paths = label_bin_paths self._label_bin_paths = label_bin_paths
self._label_desc_paths = label_desc_paths self._label_desc_paths = label_desc_paths
self._split_perturb = split_perturb
self._split_sentence_threshold = split_sentence_threshold
self._split_sub_sentence_len = split_sub_sentence_len
self._rng = random.Random(0)
def generate_sample_info_list(self): def generate_sample_info_list(self):
sample_info_list = [] sample_info_list = []
...@@ -102,12 +118,45 @@ class SampleInfoBucket(object): ...@@ -102,12 +118,45 @@ class SampleInfoBucket(object):
label_start = int(label_desc_split[2]) label_start = int(label_desc_split[2])
label_size = int(label_desc_split[3]) label_size = int(label_desc_split[3])
label_frame_num = int(label_desc_split[4]) label_frame_num = int(label_desc_split[4])
assert feature_frame_num == label_frame_num
if self._split_sentence_threshold == -1 or \
self._split_perturb == -1 or \
self._split_sub_sentence_len == -1 \
or self._split_sentence_threshold >= feature_frame_num:
sample_info_list.append(
SampleInfo(feature_bin_path, feature_start,
feature_size, feature_frame_num, feature_dim,
label_bin_path, label_start, label_size,
label_frame_num))
#split sentence
else:
cur_frame_pos = 0
cur_frame_len = 0
remain_frame_num = feature_frame_num
while True:
if remain_frame_num > self._split_sentence_threshold:
cur_frame_len = self._split_sub_sentence_len + \
self._rng.randint(0, self._split_perturb)
if cur_frame_len > remain_frame_num:
cur_frame_len = remain_frame_num
else:
cur_frame_len = remain_frame_num
sample_info_list.append( sample_info_list.append(
SampleInfo(feature_bin_path, feature_start, feature_size, SampleInfo(
feature_frame_num, feature_dim, label_bin_path, feature_bin_path, feature_start + cur_frame_pos
label_start, label_size, label_frame_num)) * feature_dim * 4, cur_frame_len * feature_dim *
4, cur_frame_len, feature_dim, label_bin_path,
label_start + cur_frame_pos * 4, cur_frame_len *
4, cur_frame_len))
remain_frame_num -= cur_frame_len
cur_frame_pos += cur_frame_len
if remain_frame_num <= 0:
break
print("generate_sample_info_list size ", len(sample_info_list))
return sample_info_list return sample_info_list
...@@ -125,7 +174,7 @@ class DataReader(object): ...@@ -125,7 +174,7 @@ class DataReader(object):
label_file_list (str): File containing paths of label data file and label_file_list (str): File containing paths of label data file and
corresponding description file. corresponding description file.
drop_frame_len (int): Samples whose label length above the value will be drop_frame_len (int): Samples whose label length above the value will be
dropped. dropped.(Using '-1' to disable the policy)
process_num (int): Number of processes for processing data. process_num (int): Number of processes for processing data.
sample_buffer_size (int): Buffer size to indicate the maximum samples sample_buffer_size (int): Buffer size to indicate the maximum samples
cached. cached.
...@@ -149,7 +198,7 @@ class DataReader(object): ...@@ -149,7 +198,7 @@ class DataReader(object):
sample_buffer_size=1024, sample_buffer_size=1024,
sample_info_buffer_size=1024, sample_info_buffer_size=1024,
batch_buffer_size=1024, batch_buffer_size=1024,
shuffle_block_num=1, shuffle_block_num=10,
random_seed=0, random_seed=0,
verbose=0): verbose=0):
self._feature_file_list = feature_file_list self._feature_file_list = feature_file_list
...@@ -253,11 +302,21 @@ class DataReader(object): ...@@ -253,11 +302,21 @@ class DataReader(object):
sample_info.feature_start, sample_info.feature_start,
sample_info.feature_size) sample_info.feature_size)
assert sample_info.feature_frame_num * sample_info.feature_dim * 4 \
== len(feature_bytes), \
(sample_info.feature_bin_path,
sample_info.feature_frame_num,
sample_info.feature_dim,
len(feature_bytes))
label_bytes = read_bytes(sample_info.label_bin_path, label_bytes = read_bytes(sample_info.label_bin_path,
sample_info.label_start, sample_info.label_start,
sample_info.label_size) sample_info.label_size)
assert sample_info.label_frame_num * 4 == len(label_bytes) assert sample_info.label_frame_num * 4 == len(label_bytes), (
sample_info.label_bin_path, sample_info.label_array,
len(label_bytes))
label_array = struct.unpack('I' * sample_info.label_frame_num, label_array = struct.unpack('I' * sample_info.label_frame_num,
label_bytes) label_bytes)
label_data = np.array( label_data = np.array(
...@@ -282,7 +341,8 @@ class DataReader(object): ...@@ -282,7 +341,8 @@ class DataReader(object):
time.sleep(0.001) time.sleep(0.001)
# drop long sentence # drop long sentence
if self._drop_frame_len >= sample_data[0].shape[0]: if self._drop_frame_len == -1 or \
self._drop_frame_len >= sample_data[0].shape[0]:
sample_queue.put(sample_data) sample_queue.put(sample_data)
out_order[0] += 1 out_order[0] += 1
......
...@@ -176,7 +176,7 @@ def train(args): ...@@ -176,7 +176,7 @@ def train(args):
# train data reader # train data reader
train_data_reader = reader.DataReader(args.train_feature_lst, train_data_reader = reader.DataReader(args.train_feature_lst,
args.train_label_lst) args.train_label_lst, -1)
train_data_reader.set_transformers(ltrans) train_data_reader.set_transformers(ltrans)
# train # train
for pass_id in xrange(args.pass_num): for pass_id in xrange(args.pass_num):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册