diff --git a/fluid/DeepASR/data_utils/data_reader.py b/fluid/DeepASR/data_utils/data_reader.py index ab1510a4e8cbdaf0fd182d703a8d2f712ed23229..a57e9d0f58d2909dcf3d1e082693461dfef37337 100644 --- a/fluid/DeepASR/data_utils/data_reader.py +++ b/fluid/DeepASR/data_utils/data_reader.py @@ -121,8 +121,6 @@ class DataReader(object): corresponding description file. label_file_list (str): File containing paths of label data file and corresponding description file. - frame_dim (int): The final feature dimension of one frame after all - augmentation applied. drop_frame_len (int): Samples whose label length above the value will be dropped. process_num (int): Number of processes for processing data. @@ -137,21 +135,18 @@ class DataReader(object): random_seed (int): Random seed. """ - def __init__( - self, - feature_file_list, - label_file_list, - frame_dim=120 * 11, # @TODO augmentor is responsible for the value - drop_frame_len=512, - process_num=10, - sample_buffer_size=1024, - sample_info_buffer_size=1024, - batch_buffer_size=1024, - shuffle_block_num=1, - random_seed=0): + def __init__(self, + feature_file_list, + label_file_list, + drop_frame_len=512, + process_num=10, + sample_buffer_size=1024, + sample_info_buffer_size=1024, + batch_buffer_size=1024, + shuffle_block_num=1, + random_seed=0): self._feature_file_list = feature_file_list self._label_file_list = label_file_list - self._frame_dim = frame_dim self._drop_frame_len = drop_frame_len self._shuffle_block_num = shuffle_block_num self._block_info_list = None @@ -300,8 +295,9 @@ class DataReader(object): def batch_iterator(self, batch_size, minimum_batch_size): def batch_to_ndarray(batch_samples, lod): - batch_feature = np.zeros( - (lod[-1], self._frame_dim), dtype="float32") + assert len(batch_samples) + frame_dim = batch_samples[0][0].shape[1] + batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32") batch_label = np.zeros((lod[-1], 1), dtype="int64") start = 0 for sample in batch_samples: