未验证 提交 77e44ab3 编写于 作者: Z zhxfl 提交者: GitHub

Merge pull request #640 from zhxfl/fix-639

Augmentation should compute frame_dim
...@@ -121,8 +121,6 @@ class DataReader(object): ...@@ -121,8 +121,6 @@ class DataReader(object):
corresponding description file. corresponding description file.
label_file_list (str): File containing paths of label data file and label_file_list (str): File containing paths of label data file and
corresponding description file. corresponding description file.
frame_dim (int): The final feature dimension of one frame after all
augmentation applied.
drop_frame_len (int): Samples whose label length above the value will be drop_frame_len (int): Samples whose label length above the value will be
dropped. dropped.
process_num (int): Number of processes for processing data. process_num (int): Number of processes for processing data.
...@@ -137,21 +135,18 @@ class DataReader(object): ...@@ -137,21 +135,18 @@ class DataReader(object):
random_seed (int): Random seed. random_seed (int): Random seed.
""" """
def __init__( def __init__(self,
self, feature_file_list,
feature_file_list, label_file_list,
label_file_list, drop_frame_len=512,
frame_dim=120 * 11, # @TODO augmentor is responsible for the value process_num=10,
drop_frame_len=512, sample_buffer_size=1024,
process_num=10, sample_info_buffer_size=1024,
sample_buffer_size=1024, batch_buffer_size=1024,
sample_info_buffer_size=1024, shuffle_block_num=1,
batch_buffer_size=1024, random_seed=0):
shuffle_block_num=1,
random_seed=0):
self._feature_file_list = feature_file_list self._feature_file_list = feature_file_list
self._label_file_list = label_file_list self._label_file_list = label_file_list
self._frame_dim = frame_dim
self._drop_frame_len = drop_frame_len self._drop_frame_len = drop_frame_len
self._shuffle_block_num = shuffle_block_num self._shuffle_block_num = shuffle_block_num
self._block_info_list = None self._block_info_list = None
...@@ -300,8 +295,9 @@ class DataReader(object): ...@@ -300,8 +295,9 @@ class DataReader(object):
def batch_iterator(self, batch_size, minimum_batch_size): def batch_iterator(self, batch_size, minimum_batch_size):
def batch_to_ndarray(batch_samples, lod): def batch_to_ndarray(batch_samples, lod):
batch_feature = np.zeros( assert len(batch_samples)
(lod[-1], self._frame_dim), dtype="float32") frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64") batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0 start = 0
for sample in batch_samples: for sample in batch_samples:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册