diff --git a/fluid/DeepASR/data_utils/async_data_reader.py b/fluid/DeepASR/data_utils/async_data_reader.py index 731c55de71e8d4b7db156f1ae72172c36eb1be7a..0c8d010755cc4a947507aeb6a65343f0d160f2be 100644 --- a/fluid/DeepASR/data_utils/async_data_reader.py +++ b/fluid/DeepASR/data_utils/async_data_reader.py @@ -185,6 +185,9 @@ class AsyncDataReader(object): corresponding description file. drop_frame_len (int): Samples whose label length above the value will be dropped.(Using '-1' to disable the policy) + split_sentence_threshold(int): Sentence whose length larger than + the value will trigger split operation. + (Assign -1 to disable split) proc_num (int): Number of processes for processing data. sample_buffer_size (int): Buffer size to indicate the maximum samples cached. @@ -204,6 +207,7 @@ class AsyncDataReader(object): feature_file_list, label_file_list="", drop_frame_len=512, + split_sentence_threshold=512, proc_num=10, sample_buffer_size=1024, sample_info_buffer_size=1024, @@ -214,6 +218,7 @@ class AsyncDataReader(object): self._feature_file_list = feature_file_list self._label_file_list = label_file_list self._drop_frame_len = drop_frame_len + self._split_sentence_threshold = split_sentence_threshold self._shuffle_block_num = shuffle_block_num self._block_info_list = None self._rng = random.Random(random_seed) @@ -262,7 +267,8 @@ class AsyncDataReader(object): map(lambda info: info[0], bucket_block_info), map(lambda info: info[1], bucket_block_info), map(lambda info: info[2], bucket_block_info), - map(lambda info: info[3], bucket_block_info))) + map(lambda info: info[3], bucket_block_info), + split_sentence_threshold=self._split_sentence_threshold)) # @TODO make this configurable def set_transformers(self, transformers): diff --git a/fluid/DeepASR/infer_by_ckpt.py b/fluid/DeepASR/infer_by_ckpt.py index 36681e9a2bed9f7c2af8f35a230b9cb6558aa19e..554dd7223d19494efc684a3b60070c76672736d4 100644 --- a/fluid/DeepASR/infer_by_ckpt.py +++ b/fluid/DeepASR/infer_by_ckpt.py @@ -207,8 +207,11 @@ def infer_from_ckpt(args): label_t = fluid.LoDTensor() # infer data reader - infer_data_reader = reader.AsyncDataReader(args.infer_feature_lst, - args.infer_label_lst) + infer_data_reader = reader.AsyncDataReader( + args.infer_feature_lst, + args.infer_label_lst, + drop_frame_len=-1, + split_sentence_threshold=-1) infer_data_reader.set_transformers(ltrans) infer_costs, infer_accs = [], [] total_edit_dist, total_ref_len = 0.0, 0 diff --git a/fluid/DeepASR/train.py b/fluid/DeepASR/train.py index 8373c0e04f3a4a2ae87d129243b519c0e0622144..6073db0d07a436f40ac78e38ef072dd23b9dbad5 100644 --- a/fluid/DeepASR/train.py +++ b/fluid/DeepASR/train.py @@ -187,7 +187,7 @@ def train(args): return -1.0, -1.0 # test data reader test_data_reader = reader.AsyncDataReader(args.val_feature_lst, - args.val_label_lst) + args.val_label_lst, -1) test_data_reader.set_transformers(ltrans) test_costs, test_accs = [], [] for batch_id, batch_data in enumerate(