未验证 提交 77e44ab3 编写于 作者: Z zhxfl 提交者: GitHub

Merge pull request #640 from zhxfl/fix-639

Augmentation should compute frame_dim
......@@ -121,8 +121,6 @@ class DataReader(object):
corresponding description file.
label_file_list (str): File containing paths of label data file and
corresponding description file.
frame_dim (int): The final feature dimension of one frame after all
augmentation applied.
drop_frame_len (int): Samples whose label length above the value will be
dropped.
process_num (int): Number of processes for processing data.
......@@ -137,21 +135,18 @@ class DataReader(object):
random_seed (int): Random seed.
"""
def __init__(
self,
feature_file_list,
label_file_list,
frame_dim=120 * 11, # @TODO augmentor is responsible for the value
drop_frame_len=512,
process_num=10,
sample_buffer_size=1024,
sample_info_buffer_size=1024,
batch_buffer_size=1024,
shuffle_block_num=1,
random_seed=0):
def __init__(self,
feature_file_list,
label_file_list,
drop_frame_len=512,
process_num=10,
sample_buffer_size=1024,
sample_info_buffer_size=1024,
batch_buffer_size=1024,
shuffle_block_num=1,
random_seed=0):
self._feature_file_list = feature_file_list
self._label_file_list = label_file_list
self._frame_dim = frame_dim
self._drop_frame_len = drop_frame_len
self._shuffle_block_num = shuffle_block_num
self._block_info_list = None
......@@ -300,8 +295,9 @@ class DataReader(object):
def batch_iterator(self, batch_size, minimum_batch_size):
def batch_to_ndarray(batch_samples, lod):
batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32")
assert len(batch_samples)
frame_dim = batch_samples[0][0].shape[1]
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册