提交 5a091518 编写于 作者: Z zhxfl

load batch name list

上级 3bbb3567
...@@ -31,16 +31,18 @@ class SampleInfo(object): ...@@ -31,16 +31,18 @@ class SampleInfo(object):
label_bin_path (str): File containing the label data. label_bin_path (str): File containing the label data.
label_size (int): Byte count of the sample's label data. label_size (int): Byte count of the sample's label data.
label_frame_num (int): Label number of the sample. label_frame_num (int): Label number of the sample.
sample_name (str): key of the sample
""" """
def __init__(self, feature_bin_path, feature_start, feature_size, def __init__(self, feature_bin_path, feature_start, feature_size,
feature_frame_num, feature_dim, label_bin_path, label_start, feature_frame_num, feature_dim, label_bin_path, label_start,
label_size, label_frame_num): label_size, label_frame_num, sample_name):
self.feature_bin_path = feature_bin_path self.feature_bin_path = feature_bin_path
self.feature_start = feature_start self.feature_start = feature_start
self.feature_size = feature_size self.feature_size = feature_size
self.feature_frame_num = feature_frame_num self.feature_frame_num = feature_frame_num
self.feature_dim = feature_dim self.feature_dim = feature_dim
self.sample_name = sample_name
self.label_bin_path = label_bin_path self.label_bin_path = label_bin_path
self.label_start = label_start self.label_start = label_start
...@@ -113,6 +115,7 @@ class SampleInfoBucket(object): ...@@ -113,6 +115,7 @@ class SampleInfoBucket(object):
for i in xrange(sample_num): for i in xrange(sample_num):
feature_desc_split = feature_desc_lines[i + 1].split() feature_desc_split = feature_desc_lines[i + 1].split()
sample_name = feature_desc_split[0]
feature_start = int(feature_desc_split[2]) feature_start = int(feature_desc_split[2])
feature_size = int(feature_desc_split[3]) feature_size = int(feature_desc_split[3])
feature_frame_num = int(feature_desc_split[4]) feature_frame_num = int(feature_desc_split[4])
...@@ -136,7 +139,7 @@ class SampleInfoBucket(object): ...@@ -136,7 +139,7 @@ class SampleInfoBucket(object):
SampleInfo(feature_bin_path, feature_start, SampleInfo(feature_bin_path, feature_start,
feature_size, feature_frame_num, feature_dim, feature_size, feature_frame_num, feature_dim,
label_bin_path, label_start, label_size, label_bin_path, label_start, label_size,
label_frame_num)) label_frame_num, sample_name))
#split sentence #split sentence
else: else:
cur_frame_pos = 0 cur_frame_pos = 0
...@@ -157,7 +160,7 @@ class SampleInfoBucket(object): ...@@ -157,7 +160,7 @@ class SampleInfoBucket(object):
* feature_dim * 4, cur_frame_len * feature_dim * * feature_dim * 4, cur_frame_len * feature_dim *
4, cur_frame_len, feature_dim, label_bin_path, 4, cur_frame_len, feature_dim, label_bin_path,
label_start + cur_frame_pos * 4, cur_frame_len * label_start + cur_frame_pos * 4, cur_frame_len *
4, cur_frame_len)) 4, cur_frame_len, sample_name))
remain_frame_num -= cur_frame_len remain_frame_num -= cur_frame_len
cur_frame_pos += cur_frame_len cur_frame_pos += cur_frame_len
...@@ -361,8 +364,8 @@ class AsyncDataReader(object): ...@@ -361,8 +364,8 @@ class AsyncDataReader(object):
feature_data = np.array( feature_data = np.array(
feature_array, dtype='float32').reshape(( feature_array, dtype='float32').reshape((
sample_info.feature_frame_num, sample_info.feature_dim)) sample_info.feature_frame_num, sample_info.feature_dim))
sample_data = (feature_data, label_data,
sample_data = (feature_data, label_data) sample_info.sample_name)
for transformer in self._transformers: for transformer in self._transformers:
# @TODO(pkuyym) to make transfomer only accept feature_data # @TODO(pkuyym) to make transfomer only accept feature_data
sample_data = transformer.perform_trans(sample_data) sample_data = transformer.perform_trans(sample_data)
...@@ -415,24 +418,26 @@ class AsyncDataReader(object): ...@@ -415,24 +418,26 @@ class AsyncDataReader(object):
batch_samples.append(sample) batch_samples.append(sample)
lod.append(lod[-1] + sample[0].shape[0]) lod.append(lod[-1] + sample[0].shape[0])
if len(batch_samples) == batch_size: if len(batch_samples) == batch_size:
feature, label = batch_to_ndarray(batch_samples, lod) feature, label, name_lst = batch_to_ndarray(
batch_samples, lod)
feature = conv_to_shared(feature) feature = conv_to_shared(feature)
label = conv_to_shared(label) label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64')) lod = conv_to_shared(np.array(lod).astype('int64'))
batch_queue.put((feature, label, lod)) batch_queue.put((feature, label, lod, name_lst))
batch_samples = [] batch_samples = []
lod = [0] lod = [0]
if len(batch_samples) >= minimum_batch_size: if len(batch_samples) >= minimum_batch_size:
(feature, label) = batch_to_ndarray(batch_samples, lod) (feature, label, name_lst) = batch_to_ndarray(batch_samples,
lod)
feature = conv_to_shared(feature) feature = conv_to_shared(feature)
label = conv_to_shared(label) label = conv_to_shared(label)
lod = conv_to_shared(np.array(lod).astype('int64')) lod = conv_to_shared(np.array(lod).astype('int64'))
batch_queue.put((feature, label, lod)) batch_queue.put((feature, label, lod, name_lst))
batch_queue.put(EpochEndSignal()) batch_queue.put(EpochEndSignal())
......
...@@ -22,7 +22,7 @@ class TestTransMeanVarianceNorm(unittest.TestCase): ...@@ -22,7 +22,7 @@ class TestTransMeanVarianceNorm(unittest.TestCase):
feature = np.zeros((2, 120), dtype="float32") feature = np.zeros((2, 120), dtype="float32")
feature.fill(1) feature.fill(1)
trans = trans_mean_variance_norm.TransMeanVarianceNorm(self._file_path) trans = trans_mean_variance_norm.TransMeanVarianceNorm(self._file_path)
(feature1, label1) = trans.perform_trans((feature, None)) (feature1, label1, name) = trans.perform_trans((feature, None, None))
(mean, var) = trans.get_mean_var() (mean, var) = trans.get_mean_var()
feature_flat1 = feature1.flatten() feature_flat1 = feature1.flatten()
feature_flat = feature.flatten() feature_flat = feature.flatten()
...@@ -70,7 +70,7 @@ class TestTransAddDelta(unittest.TestCase): ...@@ -70,7 +70,7 @@ class TestTransAddDelta(unittest.TestCase):
feature[2, 0:40].fill(3) feature[2, 0:40].fill(3)
feature[3, 0:40].fill(4) feature[3, 0:40].fill(4)
trans = trans_add_delta.TransAddDelta() trans = trans_add_delta.TransAddDelta()
(feature, label) = trans.perform_trans((feature, None)) (feature, label, name) = trans.perform_trans((feature, None, None))
self.assertAlmostEqual(feature.shape[0], 4) self.assertAlmostEqual(feature.shape[0], 4)
self.assertAlmostEqual(feature.shape[1], 120) self.assertAlmostEqual(feature.shape[1], 120)
self.assertAlmostEqual(1.0, feature[0][0]) self.assertAlmostEqual(1.0, feature[0][0])
...@@ -93,7 +93,7 @@ class TestTransSplict(unittest.TestCase): ...@@ -93,7 +93,7 @@ class TestTransSplict(unittest.TestCase):
feature[i, :].fill(i) feature[i, :].fill(i)
trans = trans_splice.TransSplice() trans = trans_splice.TransSplice()
(feature, label) = trans.perform_trans((feature, None)) (feature, label, name) = trans.perform_trans((feature, None, None))
self.assertEqual(feature.shape[1], 110) self.assertEqual(feature.shape[1], 110)
for i in xrange(8): for i in xrange(8):
......
...@@ -32,9 +32,9 @@ class TransAddDelta(object): ...@@ -32,9 +32,9 @@ class TransAddDelta(object):
Args: Args:
sample(object,tuple): contain feature numpy and label numpy sample(object,tuple): contain feature numpy and label numpy
Returns: Returns:
(feature, label) (feature, label, sample_name)
""" """
(feature, label) = sample (feature, label, sample_name) = sample
frame_dim = feature.shape[1] frame_dim = feature.shape[1]
d_frame_dim = frame_dim * 3 d_frame_dim = frame_dim * 3
head_filled = 5 head_filled = 5
...@@ -64,7 +64,8 @@ class TransAddDelta(object): ...@@ -64,7 +64,8 @@ class TransAddDelta(object):
start * d_frame_dim + 2 * frame_dim, frame_dim, nframe, start * d_frame_dim + 2 * frame_dim, frame_dim, nframe,
d_frame_dim) d_frame_dim)
mat.shape = tmp_shape mat.shape = tmp_shape
return (mat[head_filled:mat.shape[0] - tail_filled, :], label) return (mat[head_filled:mat.shape[0] - tail_filled, :], label,
sample_name)
def _regress(self, data_in, start_in, data_out, start_out, size, n, step): def _regress(self, data_in, start_in, data_out, start_out, size, n, step):
""" regress """ regress
......
...@@ -53,9 +53,9 @@ class TransMeanVarianceNorm(object): ...@@ -53,9 +53,9 @@ class TransMeanVarianceNorm(object):
Args: Args:
sample(object):input sample, contain feature numpy and label numpy sample(object):input sample, contain feature numpy and label numpy
Returns: Returns:
(feature, label) (feature, label, sample_name)
""" """
(feature, label) = sample (feature, label, sample_name) = sample
shape = feature.shape shape = feature.shape
assert len(shape) == 2 assert len(shape) == 2
nfeature_len = shape[0] * shape[1] nfeature_len = shape[0] * shape[1]
...@@ -68,4 +68,4 @@ class TransMeanVarianceNorm(object): ...@@ -68,4 +68,4 @@ class TransMeanVarianceNorm(object):
feature[ncur_idx:ncur_idx + self._nLen] = block feature[ncur_idx:ncur_idx + self._nLen] = block
ncur_idx += self._nLen ncur_idx += self._nLen
feature = feature.reshape(shape) feature = feature.reshape(shape)
return (feature, label) return (feature, label, sample_name)
...@@ -30,9 +30,9 @@ class TransSplice(object): ...@@ -30,9 +30,9 @@ class TransSplice(object):
Args: Args:
sample(object): input sample(feature, label) sample(object): input sample(feature, label)
Return: Return:
(feature, label) (feature, label, sample_name)
""" """
(feature, label) = sample (feature, label, sample_name) = sample
nframe_num = feature.shape[0] nframe_num = feature.shape[0]
nframe_dim = feature.shape[1] nframe_dim = feature.shape[1]
nnew_frame_dim = nframe_dim * ( nnew_frame_dim = nframe_dim * (
...@@ -61,4 +61,4 @@ class TransSplice(object): ...@@ -61,4 +61,4 @@ class TransSplice(object):
np.copyto(ret[i * nnew_frame_dim:(i + 1) * nnew_frame_dim], np.copyto(ret[i * nnew_frame_dim:(i + 1) * nnew_frame_dim],
mat[i * nframe_dim:i * nframe_dim + nnew_frame_dim]) mat[i * nframe_dim:i * nframe_dim + nnew_frame_dim])
ret = ret.reshape((nframe_num, nnew_frame_dim)) ret = ret.reshape((nframe_num, nnew_frame_dim))
return (ret, label) return (ret, label, sample_name)
...@@ -42,12 +42,14 @@ def batch_to_ndarray(batch_samples, lod): ...@@ -42,12 +42,14 @@ def batch_to_ndarray(batch_samples, lod):
batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32") batch_feature = np.zeros((lod[-1], frame_dim), dtype="float32")
batch_label = np.zeros((lod[-1], 1), dtype="int64") batch_label = np.zeros((lod[-1], 1), dtype="int64")
start = 0 start = 0
name_lst = []
for sample in batch_samples: for sample in batch_samples:
frame_num = sample[0].shape[0] frame_num = sample[0].shape[0]
batch_feature[start:start + frame_num, :] = sample[0] batch_feature[start:start + frame_num, :] = sample[0]
batch_label[start:start + frame_num, :] = sample[1] batch_label[start:start + frame_num, :] = sample[1]
start += frame_num start += frame_num
return (batch_feature, batch_label) name_lst.append(sample[2])
return (batch_feature, batch_label, name_lst)
def split_infer_result(infer_seq, lod): def split_infer_result(infer_seq, lod):
......
...@@ -220,7 +220,7 @@ def train(args): ...@@ -220,7 +220,7 @@ def train(args):
train_data_reader.batch_iterator(args.batch_size, train_data_reader.batch_iterator(args.batch_size,
args.minimum_batch_size)): args.minimum_batch_size)):
# load_data # load_data
(features, labels, lod) = batch_data (features, labels, lod, name_lst) = batch_data
feature_t.set(features.ndarray, place) feature_t.set(features.ndarray, place)
feature_t.set_lod([lod.ndarray]) feature_t.set_lod([lod.ndarray])
label_t.set(labels.ndarray, place) label_t.set(labels.ndarray, place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册