diff --git a/fluid/DeepASR/data_utils/augmentor/tests/__init__.py b/fluid/DeepASR/data_utils/augmentor/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90856dc44374211453f7de128c08c8004ffda912 --- /dev/null +++ b/fluid/DeepASR/data_utils/augmentor/tests/__init__.py @@ -0,0 +1,7 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import data_utils.augmentor.trans_mean_variance_norm as trans_mean_variance_norm +import data_utils.augmentor.trans_add_delta as trans_add_delta +import data_utils.augmentor.trans_splice as trans_splice diff --git a/fluid/DeepASR/data_utils/data_reader.py b/fluid/DeepASR/data_utils/data_reader.py index 42d460ada10ac8cf683364f6e713c9255c0f1df8..0495b7e7331cafd913c581266aa9f066b8fbbb83 100644 --- a/fluid/DeepASR/data_utils/data_reader.py +++ b/fluid/DeepASR/data_utils/data_reader.py @@ -62,10 +62,22 @@ class SampleInfoBucket(object): label_bin_paths (list|tuple): Files containing the binary label data. label_desc_paths (list|tuple): Files containing the description of samples' label data. + split_perturb(int): Maximum perturbation value for length of + sub-sentence when splitting long sentence. + split_sentence_threshold(int): Sentence whose length larger than + the value will trigger split operation. + split_sub_sentence_len(int): sub-sentence length is equal to + (split_sub_sentence_len + rand() % split_perturb). """ - def __init__(self, feature_bin_paths, feature_desc_paths, label_bin_paths, - label_desc_paths): + def __init__(self, + feature_bin_paths, + feature_desc_paths, + label_bin_paths, + label_desc_paths, + split_perturb=50, + split_sentence_threshold=512, + split_sub_sentence_len=256): block_num = len(label_bin_paths) assert len(label_desc_paths) == block_num assert len(feature_bin_paths) == block_num @@ -76,6 +88,10 @@ class SampleInfoBucket(object): self._feature_desc_paths = feature_desc_paths self._label_bin_paths = label_bin_paths self._label_desc_paths = label_desc_paths + self._split_perturb = split_perturb + self._split_sentence_threshold = split_sentence_threshold + self._split_sub_sentence_len = split_sub_sentence_len + self._rng = random.Random(0) def generate_sample_info_list(self): sample_info_list = [] @@ -102,12 +118,45 @@ class SampleInfoBucket(object): label_start = int(label_desc_split[2]) label_size = int(label_desc_split[3]) label_frame_num = int(label_desc_split[4]) - - sample_info_list.append( - SampleInfo(feature_bin_path, feature_start, feature_size, - feature_frame_num, feature_dim, label_bin_path, - label_start, label_size, label_frame_num)) - + assert feature_frame_num == label_frame_num + + if self._split_sentence_threshold == -1 or \ + self._split_perturb == -1 or \ + self._split_sub_sentence_len == -1 \ + or self._split_sentence_threshold >= feature_frame_num: + sample_info_list.append( + SampleInfo(feature_bin_path, feature_start, + feature_size, feature_frame_num, feature_dim, + label_bin_path, label_start, label_size, + label_frame_num)) + #split sentence + else: + cur_frame_pos = 0 + cur_frame_len = 0 + remain_frame_num = feature_frame_num + while True: + if remain_frame_num > self._split_sentence_threshold: + cur_frame_len = self._split_sub_sentence_len + \ + self._rng.randint(0, self._split_perturb) + if cur_frame_len > remain_frame_num: + cur_frame_len = remain_frame_num + else: + cur_frame_len = remain_frame_num + + sample_info_list.append( + SampleInfo( + feature_bin_path, feature_start + cur_frame_pos + * feature_dim * 4, cur_frame_len * feature_dim * + 4, cur_frame_len, feature_dim, label_bin_path, + label_start + cur_frame_pos * 4, cur_frame_len * + 4, cur_frame_len)) + + remain_frame_num -= cur_frame_len + cur_frame_pos += cur_frame_len + if remain_frame_num <= 0: + break + + print("generate_sample_info_list size ", len(sample_info_list)) return sample_info_list @@ -125,7 +174,7 @@ class DataReader(object): label_file_list (str): File containing paths of label data file and corresponding description file. drop_frame_len (int): Samples whose label length above the value will be - dropped. + dropped.(Using '-1' to disable the policy) process_num (int): Number of processes for processing data. sample_buffer_size (int): Buffer size to indicate the maximum samples cached. @@ -149,7 +198,7 @@ class DataReader(object): sample_buffer_size=1024, sample_info_buffer_size=1024, batch_buffer_size=1024, - shuffle_block_num=1, + shuffle_block_num=10, random_seed=0, verbose=0): self._feature_file_list = feature_file_list @@ -253,11 +302,21 @@ class DataReader(object): sample_info.feature_start, sample_info.feature_size) + assert sample_info.feature_frame_num * sample_info.feature_dim * 4 \ + == len(feature_bytes), \ + (sample_info.feature_bin_path, + sample_info.feature_frame_num, + sample_info.feature_dim, + len(feature_bytes)) + label_bytes = read_bytes(sample_info.label_bin_path, sample_info.label_start, sample_info.label_size) - assert sample_info.label_frame_num * 4 == len(label_bytes) + assert sample_info.label_frame_num * 4 == len(label_bytes), ( + sample_info.label_bin_path, sample_info.label_array, + len(label_bytes)) + label_array = struct.unpack('I' * sample_info.label_frame_num, label_bytes) label_data = np.array( @@ -282,7 +341,8 @@ class DataReader(object): time.sleep(0.001) # drop long sentence - if self._drop_frame_len >= sample_data[0].shape[0]: + if self._drop_frame_len == -1 or \ + self._drop_frame_len >= sample_data[0].shape[0]: sample_queue.put(sample_data) out_order[0] += 1 diff --git a/fluid/DeepASR/model_utils/__init__.py b/fluid/DeepASR/model_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fluid/DeepASR/train.py b/fluid/DeepASR/train.py index 1c45f0a086332288760c359d9845aa94641cf7d7..2a4086276460e6d9a883f6565491cab1efdf5d6e 100644 --- a/fluid/DeepASR/train.py +++ b/fluid/DeepASR/train.py @@ -176,7 +176,7 @@ def train(args): # train data reader train_data_reader = reader.DataReader(args.train_feature_lst, - args.train_label_lst) + args.train_label_lst, -1) train_data_reader.set_transformers(ltrans) # train for pass_id in xrange(args.pass_num):