提交 2c59afb0 编写于 作者: Y yangyaming

Simplify some logic.

上级 a830020d
......@@ -299,6 +299,18 @@ class DataReader(object):
w.join()
def batch_iterator(self, batch_size, minimum_batch_size):
def batch_to_ndarray(batch_samples, lod):
batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
return (batch_feature, batch_label)
def batch_assembling_task(sample_generator, batch_queue):
batch_samples = []
lod = [0]
......@@ -306,29 +318,15 @@ class DataReader(object):
batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size:
batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
(batch_feature, batch_label) = batch_to_ndarray(
batch_samples, lod)
batch_queue.put((batch_feature, batch_label, lod))
batch_samples = []
lod = [0]
if len(batch_samples) >= minimum_batch_size:
batch_feature = np.zeros(
(lod[-1], self._frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0
for sample in batch_samples:
frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1]
start += frame_num
(batch_feature, batch_label) = batch_to_ndarray(batch_samples,
lod)
batch_queue.put((batch_feature, batch_label, lod))
batch_queue.put(EpochEndSignal())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册