diff --git a/fluid/DeepASR/data_utils/data_reader.py b/fluid/DeepASR/data_utils/data_reader.py index ce587b1d8756485da81075836bb2093c8ccd3755..19873a40259a1e9787d11568b79db36a5cf089a3 100644 --- a/fluid/DeepASR/data_utils/data_reader.py +++ b/fluid/DeepASR/data_utils/data_reader.py @@ -122,8 +122,6 @@ class DataReader(object): corresponding description file. label_file_list (str): File containing paths of label data file and corresponding description file. - frame_dim (int): The final feature dimension of one frame after all - augmentation applied. drop_frame_len (int): Samples whose label length above the value will be dropped. process_num (int): Number of processes for processing data. @@ -141,22 +139,19 @@ class DataReader(object): suppressed. If set to 1, all complaints will be printed. """ - def __init__( - self, - feature_file_list, - label_file_list, - frame_dim=120 * 11, # @TODO augmentor is responsible for the value - drop_frame_len=512, - process_num=10, - sample_buffer_size=1024, - sample_info_buffer_size=1024, - batch_buffer_size=1024, - shuffle_block_num=1, - random_seed=0, - verbose=0): + def __init__(self, + feature_file_list, + label_file_list, + drop_frame_len=512, + process_num=10, + sample_buffer_size=1024, + sample_info_buffer_size=1024, + batch_buffer_size=1024, + shuffle_block_num=1, + random_seed=0, + verbose=0): self._feature_file_list = feature_file_list self._label_file_list = label_file_list - self._frame_dim = frame_dim self._drop_frame_len = drop_frame_len self._shuffle_block_num = shuffle_block_num self._block_info_list = None @@ -308,8 +303,9 @@ class DataReader(object): def batch_iterator(self, batch_size, minimum_batch_size): def batch_to_ndarray(batch_samples, lod): - batch_feature = np.zeros( - (lod[-1], self._frame_dim), dtype="float32") + assert len(batch_samples) + frame_dim = batch_samples[0][0].shape[1] + batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32") batch_label = np.zeros((lod[-1], 1), dtype="int64") start = 0 for sample in batch_samples: